""" AI Agent 服务类 - 用 LangGraph 原生 astream_events 接收外部传入的 checkpointer,不负责管理连接生命周期 """ import json from typing import AsyncGenerator, Dict, Any, Optional, Tuple # LangGraph 序列化器(修复 checkpoint 反序列化警告) from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer # 本地模块 from ..model_services import get_cached_chat_services from ..main_graph.main_graph_builder import build_agent_graph from backend.app.logger import debug, info, warning, error from ..main_graph.state import AgentState class AIAgentService: def __init__(self, checkpointer): self.checkpointer = checkpointer self.graph = None self.chat_services = None # Mem0 客户端 self.mem0_client = None async def initialize(self): # 0. 初始化 Mem0 客户端 from ..memory.mem0_client import Mem0Client self.mem0_client = Mem0Client() # 1. 获取缓存的模型字典 self.chat_services = get_cached_chat_services() info(f"✅ 加载了 {len(self.chat_services)} 个可用模型: {list(self.chat_services.keys())}") # 2. 构建图 info(f"🔄 构建 Agent 图...") graph_builder = build_agent_graph( chat_services=self.chat_services, mem0_client=self.mem0_client ) # 编译图 self.graph = graph_builder.compile(checkpointer=self.checkpointer) info(f"✅ Agent 图初始化完成") return self def _resolve_model(self, model: str) -> str: """ 解析并验证模型名称,不可用时回退到第一个可用模型 Args: model: 目标模型名称 Returns: 实际使用的模型名称 """ if not model or model not in self.chat_services: fallback = next(iter(self.chat_services.keys())) warning(f"模型 '{model}' 不可用,回退到 '{fallback}'") return fallback return model def _build_invocation( self, message: str, thread_id: str, model: str, user_id: str ) -> Tuple[Dict[str, Any], Dict[str, Any]]: """ 构建图调用所需的 config 和 input_state Args: message: 用户消息 thread_id: 会话 ID model: 模型名称 user_id: 用户 ID Returns: (config, input_state) 元组 """ from langchain_core.messages import HumanMessage config = { "configurable": { "thread_id": thread_id, }, "metadata": {"user_id": user_id} } input_state = { "messages": [HumanMessage(content=message)], "user_id": user_id, } return config, input_state async def process_message( self, message: str, thread_id: str, model: str = "", user_id: str = "default_user" ) -> dict: """处理用户消息,返回包含回复、token统计和耗时的字典""" # 解析模型名称 resolved_model = self._resolve_model(model) # 构建调用参数 config, input_state = self._build_invocation(message, thread_id, resolved_model, user_id) result = await self.graph.ainvoke(input_state, config=config) reply = "" if result.get("messages"): reply = result["messages"][-1].content token_usage = result.get("last_token_usage", {}) elapsed_time = result.get("last_elapsed_time", 0.0) return { "reply": reply, "token_usage": token_usage, "elapsed_time": elapsed_time, "model_used": resolved_model } async def process_message_stream( self, message: str, thread_id: str, model: str = "", user_id: str = "default_user" ) -> AsyncGenerator[Dict[str, Any], None]: """流式处理消息,用 astream_events 原生支持""" # 解析模型名称 resolved_model = self._resolve_model(model) # 构建调用参数 config, input_state = self._build_invocation(message, thread_id, resolved_model, user_id) info(f"🚀 开始执行 Agent 图,指定模型: {resolved_model}") actual_model_used = resolved_model full_message_content = "" try: info(f"📡 开始调用 graph.astream_events()...") async for event in self.graph.astream_events(input_state, config=config, version="v2"): kind = event["event"] # info(f"[Stream Event] {kind}") # 调试用 if kind == "on_chat_model_stream": # 流式 token chunk = event["data"]["chunk"] content = chunk.content if chunk.content else "" reasoning_content = "" if hasattr(chunk, 'additional_kwargs') and chunk.additional_kwargs: reasoning_content = chunk.additional_kwargs.get("reasoning_content", "") if content: full_message_content += content yield { "type": "llm_token", "node": "agent", "token": content, "reasoning_token": reasoning_content } elif kind == "on_tool_start": # 工具调用开始 tool_name = event["name"] tool_args = event["data"].get("input", {}) yield { "type": "tool_call_start", "tool": tool_name, "args": tool_args, "id": event.get("run_id", "") } elif kind == "on_tool_end": # 工具调用结束 tool_name = event["name"] tool_output = event["data"].get("output", "") yield { "type": "tool_call_end", "tool": tool_name, "id": event.get("run_id", ""), "result": str(tool_output) } elif kind == "on_chain_start": # 节点开始 node_name = event.get("name", "unknown") yield { "type": "node_start", "node": node_name } elif kind == "on_chain_end": # 节点结束 node_name = event.get("name", "unknown") yield { "type": "node_end", "node": node_name } info(f"✅ graph.astream_events() 完成") if full_message_content: info(f"📄 完整消息内容: {repr(full_message_content)}") except Exception as e: error(f"❌ 执行图时出错: {e}") import traceback error(f"📋 堆栈: {traceback.format_exc()}") yield { "type": "error", "message": str(e) } finally: yield { "type": "done", "model_used": actual_model_used }