Files
ailine/app/agent.py
root 404efde282
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Has been cancelled
添加长期存储,流式检查
2026-04-17 01:26:05 +08:00

277 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
AI Agent 服务类 - 支持多模型动态切换
接收外部传入的 checkpointer不负责管理连接生命周期
"""
import os
import json
from dotenv import load_dotenv
from langchain_community.chat_models import ChatZhipuAI
from langchain_openai import ChatOpenAI
from pydantic import SecretStr
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
# 本地模块
from app.graph_builder import GraphBuilder, GraphContext
from app.tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
from app.logger import debug, info, warning, error
load_dotenv()
class AIAgentService:
"""异步 AI Agent 服务,支持多模型动态切换,使用外部传入的 checkpointer"""
def __init__(self, checkpointer: AsyncPostgresSaver):
"""
初始化服务
Args:
checkpointer: 已经初始化的 AsyncPostgresSaver 实例
"""
self.checkpointer = checkpointer
self.graphs = {} # 存储不同模型对应的 graph 实例
def _create_zhipu_llm(self):
"""创建智谱在线 LLM"""
api_key = os.getenv("ZHIPUAI_API_KEY")
if not api_key:
raise ValueError("ZHIPUAI_API_KEY not set in environment")
return ChatZhipuAI(
model="glm-4.7-flash",
api_key=api_key,
temperature=0.1,
max_tokens=4096,
timeout=120.0, # 增加请求超时时间原为60秒
max_retries=3, # 增加重试次数原为2次
streaming=True, # 确保开启流式输出
)
def _create_deepseek_llm(self):
"""创建 DeepSeek LLM使用 OpenAI 兼容 API"""
api_key = os.getenv("DEEPSEEK_API_KEY")
if not api_key:
raise ValueError("DEEPSEEK_API_KEY not set in environment")
return ChatOpenAI(
base_url="https://api.deepseek.com",
api_key=SecretStr(api_key),
model="deepseek-reasoner", # deepseek-chat: 非思考模式, deepseek-reasoner: 思考模式
temperature=0.1,
max_tokens=4096,
timeout=60.0, # 请求超时时间(秒)
max_retries=2, # 失败后自动重试次数
streaming=True, # 确保开启流式输出
)
def _create_local_llm(self):
"""创建本地 vLLM 服务 LLM"""
# vLLM 服务地址:优先从环境变量读取,适配 Docker、FRP 穿透和本地开发
vllm_base_url = os.getenv(
"VLLM_BASE_URL",
"http://127.0.0.1:8081/v1"
)
return ChatOpenAI(
base_url=vllm_base_url,
api_key=SecretStr(os.getenv("LLAMACPP_API_KEY", "token-abc123")),
model="gemma-4-E2B-it",
timeout=60.0, # 请求超时时间(秒)
max_retries=2, # 失败后自动重试次数
streaming=True, # 确保开启流式输出
)
async def initialize(self):
"""预编译所有模型的 graph使用传入的 checkpointer"""
model_configs = {
"local": self._create_local_llm, # 本地模型作为第一个
"deepseek": self._create_deepseek_llm, # DeepSeek 作为中间
"zhipu": self._create_zhipu_llm, # GLM-4.7 作为最后一个
}
for model_name, llm_creator in model_configs.items():
try:
info(f"🔄 正在初始化模型 '{model_name}'...")
llm = llm_creator()
builder = GraphBuilder(llm, AVAILABLE_TOOLS, TOOLS_BY_NAME).build()
graph = builder.compile(checkpointer=self.checkpointer)
self.graphs[model_name] = graph
info(f"✅ 模型 '{model_name}' 初始化成功")
except Exception as e:
import traceback
error_detail = traceback.format_exc()
warning(f"⚠️ 模型 '{model_name}' 初始化失败: {e}")
debug(f" 详细错误:\n{error_detail}")
if not self.graphs:
raise RuntimeError("没有可用的模型,请检查配置。可能的原因:\n"
"1. ZHIPUAI_API_KEY 未配置或无效\n"
"2. DEEPSEEK_API_KEY 未配置或无效\n"
"3. vLLM 服务未启动或地址错误 (VLLM_BASE_URL)\n"
"4. 网络连接问题")
return self
async def process_message(self, message: str, thread_id: str, model: str = "local", user_id: str = "default_user") -> dict:
"""
处理用户消息返回包含回复、token统计和耗时的字典
Returns:
dict: {
"reply": str, # AI 回复内容
"token_usage": dict, # Token 使用详情
"elapsed_time": float # 调用耗时(秒)
}
"""
# 尝试使用指定模型,如果不可用则循环尝试其他模型
if model not in self.graphs:
warning(f"警告: 模型 '{model}' 不可用,尝试切换到其他可用模型")
found = False
for available_model in self.graphs.keys():
try:
# 这里可以添加额外的模型可用性检查逻辑
model = available_model
found = True
info(f"已切换到可用模型: '{model}'")
break
except Exception as e:
warning(f"模型 '{available_model}' 也不可用: {str(e)}")
continue
if not found:
raise RuntimeError(f"错误: 没有任何可用的模型。当前注册的模型: {list(self.graphs.keys())}")
graph = self.graphs[model]
config = {
"configurable": {"thread_id": thread_id},
"metadata": {"user_id": user_id} # 写入 metadata 供历史查询使用
}
input_state = {"messages": [{"role": "user", "content": message}]}
context = GraphContext(user_id=user_id)
result = await graph.ainvoke(input_state, config=config, context=context)
reply = result["messages"][-1].content
token_usage = result.get("last_token_usage", {})
elapsed_time = result.get("last_elapsed_time", 0.0)
return {
"reply": reply,
"token_usage": token_usage,
"elapsed_time": elapsed_time
}
def _serialize_value(self, value):
"""递归将 LangChain 对象转换为可 JSON 序列化的格式"""
if hasattr(value, 'content'):
# LangChain 消息对象
msg_type = getattr(value, 'type', 'message')
return {
"role": msg_type,
"content": getattr(value, 'content', ''),
"additional_kwargs": getattr(value, 'additional_kwargs', {}),
"tool_calls": getattr(value, 'tool_calls', [])
}
elif isinstance(value, dict):
return {k: self._serialize_value(v) for k, v in value.items()}
elif isinstance(value, (list, tuple)):
return [self._serialize_value(item) for item in value]
else:
try:
json.dumps(value)
return value
except (TypeError, ValueError):
return str(value)
async def process_message_stream(self, message: str, thread_id: str, model_name: str, user_id: str = "default_user"):
"""
流式处理消息,返回异步生成器
Args:
message: 用户消息
thread_id: 线程 ID
model_name: 模型名称
user_id: 用户 ID
Yields:
字典,包含事件类型和数据
"""
graph = self.graphs.get(model_name)
if not graph:
raise ValueError(f"模型 '{model_name}' 未找到或未初始化")
config = {
"configurable": {"thread_id": thread_id},
"metadata": {"user_id": user_id}
}
input_state = {"messages": [{"role": "user", "content": message}]}
context = GraphContext(user_id=user_id)
async for chunk in graph.astream(
input_state,
config=config,
context=context,
stream_mode=["messages", "updates", "custom"], # 组合多种模式,添加 custom
version="v2", # 使用统一的v2格式
subgraphs=True # 如果你使用了子图,请开启此项
):
chunk_type = chunk["type"]
processed_event = {}
# 1. 处理 LLM Token 流 (实现打字机效果)
if chunk_type == "messages":
message_chunk, metadata = chunk["data"]
# 提取元数据
node_name = metadata.get("langgraph_node", "unknown")
# 使用 getattr 安全地获取内容,因为 message_chunk 可能不是字符串
token_content = getattr(message_chunk, 'content', str(message_chunk))
# 提取 DeepSeek reasoner 的思考过程 token
reasoning_token = ""
if hasattr(message_chunk, 'additional_kwargs'):
reasoning_token = message_chunk.additional_kwargs.get("reasoning_content", "")
# [DEBUG] 临时添加:只在 reasoning_token 不为空时打印,方便你直观地看到它
if reasoning_token:
import logging
logging.debug(f"💡 [Reasoning Token 捕获]: {repr(reasoning_token)}")
processed_event = {
"type": "llm_token",
"node": node_name,
"token": token_content,
"reasoning_token": reasoning_token,
"metadata": metadata # 可选的元数据
}
# 2. 处理状态更新 (节点执行完成)
elif chunk_type == "updates":
updates_data = chunk["data"]
# 序列化 updates 中的所有数据
serialized_data = self._serialize_value(updates_data)
processed_event = {
"type": "state_update",
"data": serialized_data
}
# 为了兼容前端旧字段,也保留 messages 字段(可选)
if "messages" in serialized_data:
processed_event["messages"] = serialized_data["messages"]
# 3. 处理自定义数据 (如果需要)
elif chunk_type == "custom":
# 自定义事件同样需要序列化
serialized_data = self._serialize_value(chunk["data"])
processed_event = {
"type": "custom",
"data": serialized_data
}
# 4. 其他类型debug, tasks等按需处理
else:
# 对于不需要的类型,直接跳过
continue
# 确保事件有数据再发送
if processed_event:
yield processed_event