""" AI Agent 服务类 - 完全简化版本! 按照指南实现,不用 stream_mode="messages" 避免重复 token! """ import json import asyncio from typing import AsyncGenerator, Dict, Any, Optional, Tuple # LangGraph 序列化器(修复 checkpoint 反序列化警告) from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer # 本地模块 from backend.app.model_services import get_cached_chat_services from backend.app.main_graph.main_graph_builder import build_agent_graph from backend.app.logger import debug, info, warning, error from backend.app.main_graph.state import AgentState from .stream_context import set_stream_queue class AIAgentService: def __init__(self, checkpointer): self.checkpointer = checkpointer self.graph = None self.chat_services = None # Mem0 客户端 self.mem0_client = None async def initialize(self): # 0. 初始化 Mem0 客户端 from ..memory.mem0_client import Mem0Client self.mem0_client = Mem0Client() # 1. 获取缓存的模型字典 self.chat_services = get_cached_chat_services() info(f"✅ 加载了 {len(self.chat_services)} 个可用模型: {list(self.chat_services.keys())}") # 2. 构建图 info(f"🔄 构建 Agent 图...") graph_builder = build_agent_graph( chat_services=self.chat_services, mem0_client=self.mem0_client ) # 编译图 self.graph = graph_builder.compile(checkpointer=self.checkpointer) info(f"✅ Agent 图初始化完成") return self def _resolve_model(self, model: str) -> str: """ 解析并验证模型名称,不可用时回退到第一个可用模型 Args: model: 目标模型名称 Returns: 实际使用的模型名称 """ if not model or model not in self.chat_services: fallback = next(iter(self.chat_services.keys())) warning(f"模型 '{model}' 不可用,回退到 '{fallback}'") return fallback return model def _build_invocation( self, message: str, thread_id: str, model: str, user_id: str ) -> Tuple[Dict[str, Any], Dict[str, Any]]: """ 构建图调用所需的 config 和 input_state Args: message: 用户消息 thread_id: 会话 ID model: 模型名称 user_id: 用户 ID Returns: (config, input_state) 元组 """ from langchain_core.messages import HumanMessage config = { "configurable": { "thread_id": thread_id, }, "metadata": {"user_id": user_id} } input_state = { "messages": [HumanMessage(content=message)], "user_id": user_id, } return config, input_state async def process_message( self, message: str, thread_id: str, model: str = "", user_id: str = "default_user" ) -> dict: """处理用户消息,返回包含回复、token统计和耗时的字典""" # 解析模型名称 resolved_model = self._resolve_model(model) # 构建调用参数 config, input_state = self._build_invocation(message, thread_id, resolved_model, user_id) result = await self.graph.ainvoke(input_state, config=config) reply = "" if result.get("messages"): reply = result["messages"][-1].content token_usage = result.get("last_token_usage", {}) elapsed_time = result.get("last_elapsed_time", 0.0) return { "reply": reply, "token_usage": token_usage, "elapsed_time": elapsed_time, "model_used": resolved_model } async def process_message_stream( self, message: str, thread_id: str, model: str = "", user_id: str = "default_user" ) -> AsyncGenerator[Dict[str, Any], None]: """流式处理消息 - 完全简化!""" # 解析模型名称 resolved_model = self._resolve_model(model) # 构建调用参数 config, input_state = self._build_invocation(message, thread_id, resolved_model, user_id) info(f"🚀 开始执行 Agent 图,指定模型: {resolved_model}") actual_model_used = resolved_model # 创建 token 队列 queue = asyncio.Queue() set_stream_queue(queue) # 设置上下文变量 async def run_graph(): """后台任务:运行 graph,只获取 updates,不要用 stream_mode="messages" 避免重复 token!""" try: info(f"📡 开始调用 graph.astream()...") # 注意:只用 stream_mode=["updates"],不要 "messages"!避免重复 token! async for chunk in self.graph.astream( input_state, config=config, stream_mode=["updates"], version="v2", subgraphs=True ): # 可以处理一些状态更新事件,如 final_result 等 await queue.put({ "type": "graph_update", "data": chunk, }) except Exception as e: error(f"❌ 执行图时出错: {e}") import traceback error(f"📋 堆栈: {traceback.format_exc()}") await queue.put({"type": "error", "message": str(e)}) finally: await queue.put(None) # 结束哨兵 # 启动后台任务 bg_task = asyncio.create_task(run_graph()) try: while True: event = await queue.get() if event is None: break yield event except GeneratorExit: # 客户端断开连接,取消后台任务 info("⚠️ GeneratorExit,取消后台任务") bg_task.cancel() raise finally: # 保证任务被清理 if not bg_task.done(): info("⏹️ 清理后台任务") bg_task.cancel() try: await bg_task except asyncio.CancelledError: info("✅ 后台任务已取消") # 发送结束事件,保证前端平稳关闭 yield { "type": "done", "model_used": actual_model_used }