Files
ailine/backend/app/agent/agent_service.py
root 58a2c8c081
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m43s
refactor: 改用 LangGraph 原生 create_react_agent + astream_events
2026-05-07 02:11:20 +08:00

215 lines
7.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
AI Agent 服务类 - 用 LangGraph 原生 astream_events
接收外部传入的 checkpointer不负责管理连接生命周期
"""
import json
from typing import AsyncGenerator, Dict, Any, Optional, Tuple
# LangGraph 序列化器(修复 checkpoint 反序列化警告)
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
# 本地模块
from ..model_services import get_cached_chat_services
from ..main_graph.main_graph_builder import build_agent_graph
from backend.app.logger import debug, info, warning, error
from ..main_graph.state import AgentState
class AIAgentService:
def __init__(self, checkpointer):
self.checkpointer = checkpointer
self.graph = None
self.chat_services = None
# Mem0 客户端
self.mem0_client = None
async def initialize(self):
# 0. 初始化 Mem0 客户端
from ..memory.mem0_client import Mem0Client
self.mem0_client = Mem0Client()
# 1. 获取缓存的模型字典
self.chat_services = get_cached_chat_services()
info(f"✅ 加载了 {len(self.chat_services)} 个可用模型: {list(self.chat_services.keys())}")
# 2. 构建图
info(f"🔄 构建 Agent 图...")
graph_builder = build_agent_graph(
chat_services=self.chat_services,
mem0_client=self.mem0_client
)
# 编译图
self.graph = graph_builder.compile(checkpointer=self.checkpointer)
info(f"✅ Agent 图初始化完成")
return self
def _resolve_model(self, model: str) -> str:
"""
解析并验证模型名称,不可用时回退到第一个可用模型
Args:
model: 目标模型名称
Returns:
实际使用的模型名称
"""
if not model or model not in self.chat_services:
fallback = next(iter(self.chat_services.keys()))
warning(f"模型 '{model}' 不可用,回退到 '{fallback}'")
return fallback
return model
def _build_invocation(
self, message: str, thread_id: str, model: str, user_id: str
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""
构建图调用所需的 config 和 input_state
Args:
message: 用户消息
thread_id: 会话 ID
model: 模型名称
user_id: 用户 ID
Returns:
(config, input_state) 元组
"""
from langchain_core.messages import HumanMessage
config = {
"configurable": {
"thread_id": thread_id,
},
"metadata": {"user_id": user_id}
}
input_state = {
"messages": [HumanMessage(content=message)],
"user_id": user_id,
}
return config, input_state
async def process_message(
self, message: str, thread_id: str, model: str = "", user_id: str = "default_user"
) -> dict:
"""处理用户消息返回包含回复、token统计和耗时的字典"""
# 解析模型名称
resolved_model = self._resolve_model(model)
# 构建调用参数
config, input_state = self._build_invocation(message, thread_id, resolved_model, user_id)
result = await self.graph.ainvoke(input_state, config=config)
reply = ""
if result.get("messages"):
reply = result["messages"][-1].content
token_usage = result.get("last_token_usage", {})
elapsed_time = result.get("last_elapsed_time", 0.0)
return {
"reply": reply,
"token_usage": token_usage,
"elapsed_time": elapsed_time,
"model_used": resolved_model
}
async def process_message_stream(
self, message: str, thread_id: str, model: str = "", user_id: str = "default_user"
) -> AsyncGenerator[Dict[str, Any], None]:
"""流式处理消息,用 astream_events 原生支持"""
# 解析模型名称
resolved_model = self._resolve_model(model)
# 构建调用参数
config, input_state = self._build_invocation(message, thread_id, resolved_model, user_id)
info(f"🚀 开始执行 Agent 图,指定模型: {resolved_model}")
actual_model_used = resolved_model
full_message_content = ""
try:
info(f"📡 开始调用 graph.astream_events()...")
async for event in self.graph.astream_events(input_state, config=config, version="v2"):
kind = event["event"]
# info(f"[Stream Event] {kind}") # 调试用
if kind == "on_chat_model_stream":
# 流式 token
chunk = event["data"]["chunk"]
content = chunk.content if chunk.content else ""
reasoning_content = ""
if hasattr(chunk, 'additional_kwargs') and chunk.additional_kwargs:
reasoning_content = chunk.additional_kwargs.get("reasoning_content", "")
if content:
full_message_content += content
yield {
"type": "llm_token",
"node": "agent",
"token": content,
"reasoning_token": reasoning_content
}
elif kind == "on_tool_start":
# 工具调用开始
tool_name = event["name"]
tool_args = event["data"].get("input", {})
yield {
"type": "tool_call_start",
"tool": tool_name,
"args": tool_args,
"id": event.get("run_id", "")
}
elif kind == "on_tool_end":
# 工具调用结束
tool_name = event["name"]
tool_output = event["data"].get("output", "")
yield {
"type": "tool_call_end",
"tool": tool_name,
"id": event.get("run_id", ""),
"result": str(tool_output)
}
elif kind == "on_chain_start":
# 节点开始
node_name = event.get("name", "unknown")
yield {
"type": "node_start",
"node": node_name
}
elif kind == "on_chain_end":
# 节点结束
node_name = event.get("name", "unknown")
yield {
"type": "node_end",
"node": node_name
}
info(f"✅ graph.astream_events() 完成")
if full_message_content:
info(f"📄 完整消息内容: {repr(full_message_content)}")
except Exception as e:
error(f"❌ 执行图时出错: {e}")
import traceback
error(f"📋 堆栈: {traceback.format_exc()}")
yield {
"type": "error",
"message": str(e)
}
finally:
yield {
"type": "done",
"model_used": actual_model_used
}