This commit is contained in:
142
app/agent.py
142
app/agent.py
@@ -4,6 +4,7 @@ AI Agent 服务类 - 支持多模型动态切换
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
from dotenv import load_dotenv
|
||||
from langchain_community.chat_models import ChatZhipuAI
|
||||
from langchain_openai import ChatOpenAI
|
||||
@@ -41,8 +42,9 @@ class AIAgentService:
|
||||
api_key=api_key,
|
||||
temperature=0.1,
|
||||
max_tokens=4096,
|
||||
timeout=60.0, # 请求超时时间(秒)
|
||||
max_retries=2, # 失败后自动重试次数
|
||||
timeout=120.0, # 增加请求超时时间(秒),原为60秒
|
||||
max_retries=3, # 增加重试次数,原为2次
|
||||
streaming=True, # 确保开启流式输出
|
||||
)
|
||||
|
||||
def _create_deepseek_llm(self):
|
||||
@@ -58,6 +60,7 @@ class AIAgentService:
|
||||
max_tokens=4096,
|
||||
timeout=60.0, # 请求超时时间(秒)
|
||||
max_retries=2, # 失败后自动重试次数
|
||||
streaming=True, # 确保开启流式输出
|
||||
)
|
||||
|
||||
def _create_local_llm(self):
|
||||
@@ -65,7 +68,7 @@ class AIAgentService:
|
||||
# vLLM 服务地址:优先从环境变量读取,适配 Docker、FRP 穿透和本地开发
|
||||
vllm_base_url = os.getenv(
|
||||
"VLLM_BASE_URL",
|
||||
"http://localhost:8081/v1"
|
||||
"http://127.0.0.1:8081/v1"
|
||||
)
|
||||
|
||||
return ChatOpenAI(
|
||||
@@ -74,14 +77,15 @@ class AIAgentService:
|
||||
model="gemma-4-E2B-it",
|
||||
timeout=60.0, # 请求超时时间(秒)
|
||||
max_retries=2, # 失败后自动重试次数
|
||||
streaming=True, # 确保开启流式输出
|
||||
)
|
||||
|
||||
async def initialize(self):
|
||||
"""预编译所有模型的 graph(使用传入的 checkpointer)"""
|
||||
model_configs = {
|
||||
"zhipu": self._create_zhipu_llm,
|
||||
"deepseek": self._create_deepseek_llm,
|
||||
"local": self._create_local_llm,
|
||||
"local": self._create_local_llm, # 本地模型作为第一个
|
||||
"deepseek": self._create_deepseek_llm, # DeepSeek 作为中间
|
||||
"zhipu": self._create_zhipu_llm, # GLM-4.7 作为最后一个
|
||||
}
|
||||
|
||||
for model_name, llm_creator in model_configs.items():
|
||||
@@ -107,7 +111,7 @@ class AIAgentService:
|
||||
|
||||
return self
|
||||
|
||||
async def process_message(self, message: str, thread_id: str, model: str = "zhipu", user_id: str = "default_user") -> dict:
|
||||
async def process_message(self, message: str, thread_id: str, model: str = "local", user_id: str = "default_user") -> dict:
|
||||
"""
|
||||
处理用户消息,返回包含回复、token统计和耗时的字典
|
||||
|
||||
@@ -156,6 +160,28 @@ class AIAgentService:
|
||||
"elapsed_time": elapsed_time
|
||||
}
|
||||
|
||||
def _serialize_value(self, value):
|
||||
"""递归将 LangChain 对象转换为可 JSON 序列化的格式"""
|
||||
if hasattr(value, 'content'):
|
||||
# LangChain 消息对象
|
||||
msg_type = getattr(value, 'type', 'message')
|
||||
return {
|
||||
"role": msg_type,
|
||||
"content": getattr(value, 'content', ''),
|
||||
"additional_kwargs": getattr(value, 'additional_kwargs', {}),
|
||||
"tool_calls": getattr(value, 'tool_calls', [])
|
||||
}
|
||||
elif isinstance(value, dict):
|
||||
return {k: self._serialize_value(v) for k, v in value.items()}
|
||||
elif isinstance(value, (list, tuple)):
|
||||
return [self._serialize_value(item) for item in value]
|
||||
else:
|
||||
try:
|
||||
json.dumps(value)
|
||||
return value
|
||||
except (TypeError, ValueError):
|
||||
return str(value)
|
||||
|
||||
async def process_message_stream(self, message: str, thread_id: str, model_name: str, user_id: str = "default_user"):
|
||||
"""
|
||||
流式处理消息,返回异步生成器
|
||||
@@ -170,10 +196,9 @@ class AIAgentService:
|
||||
字典,包含事件类型和数据
|
||||
"""
|
||||
graph = self.graphs.get(model_name)
|
||||
|
||||
if not graph:
|
||||
warning(f"警告: 模型 '{model_name}' 不可用,使用默认模型")
|
||||
model_name = next(iter(self.graphs.keys()))
|
||||
graph = self.graphs[model_name]
|
||||
raise ValueError(f"模型 '{model_name}' 未找到或未初始化")
|
||||
|
||||
config = {
|
||||
"configurable": {"thread_id": thread_id},
|
||||
@@ -182,36 +207,71 @@ class AIAgentService:
|
||||
input_state = {"messages": [{"role": "user", "content": message}]}
|
||||
context = GraphContext(user_id=user_id)
|
||||
|
||||
# 使用 astream_events 获取流式事件
|
||||
async for event in graph.astream_events(input_state, config=config, context=context, version="v2"):
|
||||
kind = event["event"]
|
||||
|
||||
# 聊天模型流式输出
|
||||
if kind == "on_chat_model_stream":
|
||||
content = event["data"]["chunk"].content
|
||||
if content:
|
||||
yield {"type": "token", "content": content}
|
||||
|
||||
# 工具调用开始
|
||||
elif kind == "on_tool_start":
|
||||
tool_name = event["name"]
|
||||
yield {"type": "tool_start", "tool": tool_name}
|
||||
|
||||
# 工具调用结束
|
||||
elif kind == "on_tool_end":
|
||||
tool_name = event["name"]
|
||||
yield {"type": "tool_end", "tool": tool_name}
|
||||
|
||||
# 链结束,获取最终结果
|
||||
elif kind == "on_chain_end" and event["name"] == "LangGraph":
|
||||
output = event["data"]["output"]
|
||||
reply = output["messages"][-1].content if output.get("messages") else ""
|
||||
token_usage = output.get("last_token_usage", {})
|
||||
elapsed_time = output.get("last_elapsed_time", 0.0)
|
||||
async for chunk in graph.astream(
|
||||
input_state,
|
||||
config=config,
|
||||
context=context,
|
||||
stream_mode=["messages", "updates", "custom"], # 组合多种模式,添加 custom
|
||||
version="v2", # 使用统一的v2格式
|
||||
subgraphs=True # 如果你使用了子图,请开启此项
|
||||
):
|
||||
chunk_type = chunk["type"]
|
||||
processed_event = {}
|
||||
|
||||
# 1. 处理 LLM Token 流 (实现打字机效果)
|
||||
if chunk_type == "messages":
|
||||
message_chunk, metadata = chunk["data"]
|
||||
|
||||
yield {
|
||||
"type": "done",
|
||||
"reply": reply,
|
||||
"token_usage": token_usage,
|
||||
"elapsed_time": elapsed_time
|
||||
# 提取元数据
|
||||
node_name = metadata.get("langgraph_node", "unknown")
|
||||
# 使用 getattr 安全地获取内容,因为 message_chunk 可能不是字符串
|
||||
token_content = getattr(message_chunk, 'content', str(message_chunk))
|
||||
|
||||
# 提取 DeepSeek reasoner 的思考过程 token
|
||||
reasoning_token = ""
|
||||
if hasattr(message_chunk, 'additional_kwargs'):
|
||||
reasoning_token = message_chunk.additional_kwargs.get("reasoning_content", "")
|
||||
|
||||
# [DEBUG] 临时添加:只在 reasoning_token 不为空时打印,方便你直观地看到它
|
||||
if reasoning_token:
|
||||
import logging
|
||||
logging.debug(f"💡 [Reasoning Token 捕获]: {repr(reasoning_token)}")
|
||||
|
||||
processed_event = {
|
||||
"type": "llm_token",
|
||||
"node": node_name,
|
||||
"token": token_content,
|
||||
"reasoning_token": reasoning_token,
|
||||
"metadata": metadata # 可选的元数据
|
||||
}
|
||||
|
||||
# 2. 处理状态更新 (节点执行完成)
|
||||
elif chunk_type == "updates":
|
||||
updates_data = chunk["data"]
|
||||
# 序列化 updates 中的所有数据
|
||||
serialized_data = self._serialize_value(updates_data)
|
||||
processed_event = {
|
||||
"type": "state_update",
|
||||
"data": serialized_data
|
||||
}
|
||||
# 为了兼容前端旧字段,也保留 messages 字段(可选)
|
||||
if "messages" in serialized_data:
|
||||
processed_event["messages"] = serialized_data["messages"]
|
||||
|
||||
# 3. 处理自定义数据 (如果需要)
|
||||
elif chunk_type == "custom":
|
||||
# 自定义事件同样需要序列化
|
||||
serialized_data = self._serialize_value(chunk["data"])
|
||||
processed_event = {
|
||||
"type": "custom",
|
||||
"data": serialized_data
|
||||
}
|
||||
|
||||
# 4. 其他类型(debug, tasks等)按需处理
|
||||
else:
|
||||
# 对于不需要的类型,直接跳过
|
||||
continue
|
||||
|
||||
# 确保事件有数据再发送
|
||||
if processed_event:
|
||||
yield processed_event
|
||||
Reference in New Issue
Block a user