Files
ailine/backend/app/agent/agent_service.py
root eb33203b5c
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m26s
feat: 优化后的流式方案:双协程 + 结束哨兵 + turn/phase 元数据
2026-05-07 02:21:09 +08:00

391 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
AI Agent 服务类 - 优化版本:双协程 + 结束哨兵 + 完整的取消和异常处理
接收外部传入的 checkpointer不负责管理连接生命周期
"""
import json
import asyncio
from typing import AsyncGenerator, Dict, Any, Optional, Tuple
# LangGraph 序列化器(修复 checkpoint 反序列化警告)
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
# 本地模块
from ..model_services import get_cached_chat_services
from ..main_graph.main_graph_builder import build_agent_graph
from backend.app.logger import debug, info, warning, error
from ..main_graph.state import AgentState
from .stream_context import token_queue_var
class AIAgentService:
def __init__(self, checkpointer):
self.checkpointer = checkpointer
self.graph = None
self.chat_services = None
# Mem0 客户端
self.mem0_client = None
async def initialize(self):
# 0. 初始化 Mem0 客户端
from ..memory.mem0_client import Mem0Client
self.mem0_client = Mem0Client()
# 1. 获取缓存的模型字典
self.chat_services = get_cached_chat_services()
info(f"✅ 加载了 {len(self.chat_services)} 个可用模型: {list(self.chat_services.keys())}")
# 2. 构建图
info(f"🔄 构建 Agent 图...")
graph_builder = build_agent_graph(
chat_services=self.chat_services,
mem0_client=self.mem0_client
)
# 编译图
self.graph = graph_builder.compile(checkpointer=self.checkpointer)
info(f"✅ Agent 图初始化完成")
return self
def _resolve_model(self, model: str) -> str:
"""
解析并验证模型名称,不可用时回退到第一个可用模型
Args:
model: 目标模型名称
Returns:
实际使用的模型名称
"""
if not model or model not in self.chat_services:
fallback = next(iter(self.chat_services.keys()))
warning(f"模型 '{model}' 不可用,回退到 '{fallback}'")
return fallback
return model
def _build_invocation(
self, message: str, thread_id: str, model: str, user_id: str
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""
构建图调用所需的 config 和 input_state
Args:
message: 用户消息
thread_id: 会话 ID
model: 模型名称
user_id: 用户 ID
Returns:
(config, input_state) 元组
"""
from langchain_core.messages import HumanMessage
config = {
"configurable": {
"thread_id": thread_id,
},
"metadata": {"user_id": user_id}
}
input_state = {
"messages": [HumanMessage(content=message)],
"user_id": user_id,
}
return config, input_state
async def process_message(
self, message: str, thread_id: str, model: str = "", user_id: str = "default_user"
) -> dict:
"""处理用户消息返回包含回复、token统计和耗时的字典"""
# 解析模型名称
resolved_model = self._resolve_model(model)
# 构建调用参数
config, input_state = self._build_invocation(message, thread_id, resolved_model, user_id)
result = await self.graph.ainvoke(input_state, config=config)
reply = ""
if result.get("messages"):
reply = result["messages"][-1].content
token_usage = result.get("last_token_usage", {})
elapsed_time = result.get("last_elapsed_time", 0.0)
return {
"reply": reply,
"token_usage": token_usage,
"elapsed_time": elapsed_time,
"model_used": resolved_model
}
def _serialize_value(self, value):
"""递归将 LangChain 对象转换为可 JSON 序列化的格式"""
if hasattr(value, 'content'):
msg_type = getattr(value, 'type', 'message')
return {
"role": msg_type,
"content": getattr(value, 'content', ''),
"additional_kwargs": getattr(value, 'additional_kwargs', {}),
"tool_calls": getattr(value, 'tool_calls', [])
}
elif isinstance(value, dict):
return {k: self._serialize_value(v) for k, v in value.items()}
elif isinstance(value, (list, tuple)):
return [self._serialize_value(item) for item in value]
else:
try:
json.dumps(value)
return value
except (TypeError, ValueError):
return str(value)
async def _handle_message_chunk(
self, chunk: Dict[str, Any], current_node: Optional[str], tool_calls_in_progress: Dict[str, Any]
) -> AsyncGenerator[Dict[str, Any], None]:
"""处理 messages 类型的 chunk"""
message_chunk, metadata = chunk["data"]
node_name = metadata.get("langgraph_node", "unknown")
new_current_node = current_node
# 检测节点变化,发送节点开始事件
if node_name != current_node:
if current_node:
yield {"type": "node_end", "node": current_node}
yield {"type": "node_start", "node": node_name}
new_current_node = node_name
# 处理消息内容
token_content = getattr(message_chunk, 'content', str(message_chunk))
reasoning_token = ""
if hasattr(message_chunk, 'additional_kwargs'):
reasoning_token = message_chunk.additional_kwargs.get("reasoning_content", "")
# 处理思考过程
if reasoning_token:
yield {
"type": "llm_token",
"node": node_name,
"reasoning_token": reasoning_token
}
# 处理工具调用
elif hasattr(message_chunk, 'tool_calls') and message_chunk.tool_calls:
for tool_call in message_chunk.tool_calls:
tool_call_id = tool_call.get("id", "")
tool_name = tool_call.get("name", "")
tool_args = tool_call.get("args", {})
# 记录工具调用开始,避免重复
if tool_call_id and tool_call_id not in tool_calls_in_progress:
tool_calls_in_progress[tool_call_id] = {
"name": tool_name,
"args": tool_args
}
yield {
"type": "tool_call_start",
"tool": tool_name,
"args": tool_args,
"id": tool_call_id
}
# 处理普通 token
elif token_content:
yield {
"type": "llm_token",
"node": node_name,
"token": token_content,
"reasoning_token": reasoning_token
}
# 返回更新后的 current_node
yield {"type": "_update_state", "current_node": new_current_node}
async def _handle_updates_chunk(
self, chunk: Dict[str, Any], tool_calls_in_progress: Dict[str, Any], actual_model_used: str
) -> AsyncGenerator[Dict[str, Any], None]:
"""处理 updates 类型的 chunk"""
updates_data = chunk["data"]
new_actual_model = actual_model_used
serialized_data = self._serialize_value(updates_data)
# 检查是否有工具结果
if "messages" in serialized_data:
for msg in serialized_data["messages"]:
# 检测工具结果消息
if msg.get("role") == "tool":
tool_call_id = msg.get("tool_call_id", "")
tool_name = msg.get("name", "")
tool_result = msg.get("content", "")
if tool_call_id and tool_call_id in tool_calls_in_progress:
yield {
"type": "tool_call_end",
"tool": tool_name,
"id": tool_call_id,
"result": tool_result
}
del tool_calls_in_progress[tool_call_id]
yield {
"type": "state_update",
"data": serialized_data
}
# 返回更新后的模型
yield {"type": "_update_state", "actual_model_used": new_actual_model}
async def process_message_stream(
self, message: str, thread_id: str, model: str = "", user_id: str = "default_user"
) -> AsyncGenerator[Dict[str, Any], None]:
"""流式处理消息 - 双协程 + 结束哨兵 + 完整取消和异常处理"""
# 解析模型名称
resolved_model = self._resolve_model(model)
# 构建调用参数
config, input_state = self._build_invocation(message, thread_id, resolved_model, user_id)
info(f"🚀 开始执行 Agent 图,指定模型: {resolved_model}")
current_node = None
tool_calls_in_progress: Dict[str, Any] = {}
actual_model_used = resolved_model
full_message_content = ""
# 创建 token 队列
token_queue = asyncio.Queue()
# 结束哨兵
SENTINEL = object()
# 设置上下文变量
token_queue_var.set(token_queue)
# 事件和错误跟踪
graph_error = None
graph_done = asyncio.Event()
async def run_graph_task():
"""后台任务:运行 graph.astream()"""
nonlocal current_node, actual_model_used, full_message_content, graph_error
try:
info(f"📡 开始调用 graph.astream()...")
event_count = 0
async for chunk in self.graph.astream(
input_state,
config=config,
stream_mode=["messages", "updates"],
version="v2",
subgraphs=True
):
chunk_count = 0
chunk_count += 1
chunk_type = chunk["type"]
# 记录原始 chunk 信息(前 10 个和后 10 个)
if chunk_count <= 10 or chunk_count % 50 == 0:
info(f" [{chunk_count}] chunk_type={chunk_type}, data={type(chunk.get('data'))}")
if chunk_type == "messages":
async for event in self._handle_message_chunk(
chunk, current_node, tool_calls_in_progress
):
if event.get("type") == "_update_state":
current_node = event.get("current_node", current_node)
else:
event_count += 1
# 记录前 10 个事件
if event_count <= 10:
info(f" → yield event #{event_count}: {event.get('type')}")
# 如果是 agent 节点的 token收集完整消息
if (
event.get("type") == "llm_token"
and event.get("node") == "agent"
and "token" in event
):
full_message_content += event["token"]
await token_queue.put(event)
elif chunk_type == "updates":
async for event in self._handle_updates_chunk(
chunk, tool_calls_in_progress, actual_model_used
):
if event.get("type") == "_update_state":
actual_model_used = event.get("actual_model_used", actual_model_used)
else:
event_count += 1
if event_count <= 10:
info(f" → yield event #{event_count}: {event.get('type')}")
await token_queue.put(event)
# 完整消息集合完成后,一次性打印
info(f"✅ graph.astream() 完成,共 {event_count} 个 events")
if full_message_content:
info(f"📄 完整消息内容: {repr(full_message_content)}")
except Exception as e:
error(f"❌ 执行图时出错: {e}")
import traceback
error(f"📋 堆栈: {traceback.format_exc()}")
graph_error = e
await token_queue.put({
"type": "error",
"message": str(e)
})
finally:
# 发送结束哨兵
await token_queue.put(SENTINEL)
graph_done.set()
# 启动后台任务
graph_task = asyncio.create_task(run_graph_task())
try:
# 主协程:从队列里取事件并 yield
while True:
try:
# 等待队列中的事件,带超时检查任务是否完成
event = await asyncio.wait_for(token_queue.get(), timeout=0.5)
# 检查是否是结束哨兵
if event is SENTINEL:
break
yield event
except asyncio.TimeoutError:
# 超时检查任务是否完成
if graph_task.done():
# 检查任务是否抛出异常
if graph_task.exception():
exc = graph_task.exception()
error(f"❌ 后台任务异常: {exc}")
break
except asyncio.CancelledError:
info("⚠️ 流式生成被取消")
raise
finally:
# 无论成功或失败,都清理资源
# 取消后台任务
if not graph_task.done():
info("⏹️ 取消后台任务")
graph_task.cancel()
try:
await graph_task
except asyncio.CancelledError:
info("✅ 后台任务已取消")
# 发送结束事件,保证前端平稳关闭
if current_node:
yield {
"type": "node_end",
"node": current_node
}
yield {
"type": "done",
"model_used": actual_model_used
}