diff --git a/backend/app/agent/agent_service.py b/backend/app/agent/agent_service.py index d509e77..a443fff 100644 --- a/backend/app/agent/agent_service.py +++ b/backend/app/agent/agent_service.py @@ -1,10 +1,9 @@ """ -AI Agent 服务类 - 极简 LangGraph Agent 架构 +AI Agent 服务类 - 用 LangGraph 原生 astream_events 接收外部传入的 checkpointer,不负责管理连接生命周期 """ import json -import asyncio from typing import AsyncGenerator, Dict, Any, Optional, Tuple # LangGraph 序列化器(修复 checkpoint 反序列化警告) @@ -15,7 +14,6 @@ from ..model_services import get_cached_chat_services from ..main_graph.main_graph_builder import build_agent_graph from backend.app.logger import debug, info, warning, error from ..main_graph.state import AgentState -from .stream_context import token_queue_var class AIAgentService: @@ -120,125 +118,10 @@ class AIAgentService: "model_used": resolved_model } - def _serialize_value(self, value): - """递归将 LangChain 对象转换为可 JSON 序列化的格式""" - if hasattr(value, 'content'): - msg_type = getattr(value, 'type', 'message') - return { - "role": msg_type, - "content": getattr(value, 'content', ''), - "additional_kwargs": getattr(value, 'additional_kwargs', {}), - "tool_calls": getattr(value, 'tool_calls', []) - } - elif isinstance(value, dict): - return {k: self._serialize_value(v) for k, v in value.items()} - elif isinstance(value, (list, tuple)): - return [self._serialize_value(item) for item in value] - else: - try: - json.dumps(value) - return value - except (TypeError, ValueError): - return str(value) - - async def _handle_message_chunk( - self, chunk: Dict[str, Any], current_node: Optional[str], tool_calls_in_progress: Dict[str, Any] - ) -> AsyncGenerator[Dict[str, Any], None]: - """处理 messages 类型的 chunk""" - message_chunk, metadata = chunk["data"] - node_name = metadata.get("langgraph_node", "unknown") - new_current_node = current_node - - # 检测节点变化,发送节点开始事件 - if node_name != current_node: - if current_node: - yield {"type": "node_end", "node": current_node} - yield {"type": "node_start", "node": node_name} - new_current_node = node_name - - # 处理消息内容 - token_content = getattr(message_chunk, 'content', str(message_chunk)) - reasoning_token = "" - if hasattr(message_chunk, 'additional_kwargs'): - reasoning_token = message_chunk.additional_kwargs.get("reasoning_content", "") - - # 处理思考过程 - if reasoning_token: - yield { - "type": "llm_token", - "node": node_name, - "reasoning_token": reasoning_token - } - # 处理工具调用 - elif hasattr(message_chunk, 'tool_calls') and message_chunk.tool_calls: - for tool_call in message_chunk.tool_calls: - tool_call_id = tool_call.get("id", "") - tool_name = tool_call.get("name", "") - tool_args = tool_call.get("args", {}) - - # 记录工具调用开始,避免重复 - if tool_call_id and tool_call_id not in tool_calls_in_progress: - tool_calls_in_progress[tool_call_id] = { - "name": tool_name, - "args": tool_args - } - yield { - "type": "tool_call_start", - "tool": tool_name, - "args": tool_args, - "id": tool_call_id - } - # 处理普通 token - elif token_content: - yield { - "type": "llm_token", - "node": node_name, - "token": token_content, - "reasoning_token": reasoning_token - } - - # 返回更新后的 current_node - yield {"type": "_update_state", "current_node": new_current_node} - - async def _handle_updates_chunk( - self, chunk: Dict[str, Any], tool_calls_in_progress: Dict[str, Any], actual_model_used: str - ) -> AsyncGenerator[Dict[str, Any], None]: - """处理 updates 类型的 chunk""" - updates_data = chunk["data"] - new_actual_model = actual_model_used - - serialized_data = self._serialize_value(updates_data) - - # 检查是否有工具结果 - if "messages" in serialized_data: - for msg in serialized_data["messages"]: - # 检测工具结果消息 - if msg.get("role") == "tool": - tool_call_id = msg.get("tool_call_id", "") - tool_name = msg.get("name", "") - tool_result = msg.get("content", "") - - if tool_call_id and tool_call_id in tool_calls_in_progress: - yield { - "type": "tool_call_end", - "tool": tool_name, - "id": tool_call_id, - "result": tool_result - } - del tool_calls_in_progress[tool_call_id] - - yield { - "type": "state_update", - "data": serialized_data - } - - # 返回更新后的模型 - yield {"type": "_update_state", "actual_model_used": new_actual_model} - async def process_message_stream( self, message: str, thread_id: str, model: str = "", user_id: str = "default_user" ) -> AsyncGenerator[Dict[str, Any], None]: - """流式处理消息,返回异步生成器""" + """流式处理消息,用 astream_events 原生支持""" # 解析模型名称 resolved_model = self._resolve_model(model) @@ -246,123 +129,86 @@ class AIAgentService: config, input_state = self._build_invocation(message, thread_id, resolved_model, user_id) info(f"🚀 开始执行 Agent 图,指定模型: {resolved_model}") - current_node = None - tool_calls_in_progress: Dict[str, Any] = {} actual_model_used = resolved_model - chunk_count = 0 full_message_content = "" - # 创建 token 队列 - token_queue = asyncio.Queue() - - # 设置上下文变量 - token_queue_var.set(token_queue) - - # 事件:graph 执行完成 - graph_done = asyncio.Event() - graph_error = None - - async def run_graph(): - """在后台运行 graph,并把 chunk 放进队列,同时也处理 events""" - nonlocal chunk_count, full_message_content, graph_error - try: - info(f"📡 开始调用 graph.astream()...") - - event_count = 0 - - async for chunk in self.graph.astream( - input_state, - config=config, - stream_mode=["messages", "updates"], - version="v2", - subgraphs=True - ): - chunk_count += 1 - chunk_type = chunk["type"] - - # 记录原始 chunk 信息(前 10 个和后 10 个) - if chunk_count <= 10 or chunk_count % 50 == 0: - info(f" [{chunk_count}] chunk_type={chunk_type}, data={type(chunk.get('data'))}") - - if chunk_type == "messages": - async for event in self._handle_message_chunk( - chunk, current_node, tool_calls_in_progress - ): - if event.get("type") == "_update_state": - nonlocal current_node - current_node = event.get("current_node", current_node) - else: - event_count += 1 - # 记录前 10 个事件 - if event_count <= 10: - info(f" → yield event #{event_count}: {event.get('type')}") - - # 如果是 agent 节点的 token,收集完整消息 - if ( - event.get("type") == "llm_token" - and event.get("node") == "agent" - and "token" in event - ): - full_message_content += event["token"] - await token_queue.put(event) - - elif chunk_type == "updates": - async for event in self._handle_updates_chunk( - chunk, tool_calls_in_progress, actual_model_used - ): - if event.get("type") == "_update_state": - nonlocal actual_model_used - actual_model_used = event.get("actual_model_used", actual_model_used) - else: - event_count += 1 - if event_count <= 10: - info(f" → yield event #{event_count}: {event.get('type')}") - await token_queue.put(event) - - # 完整消息集合完成后,一次性打印 - info(f"✅ graph.astream() 完成,共 {chunk_count} 个 chunks, {event_count} 个 events") - if full_message_content: - info(f"📄 完整消息内容: {repr(full_message_content)}") - - except Exception as e: - error(f"❌ 执行图时出错: {e}") - import traceback - error(f"📋 堆栈: {traceback.format_exc()}") - graph_error = e - await token_queue.put({ - "type": "error", - "message": str(e) - }) - finally: - graph_done.set() - - # 启动后台任务运行 graph - graph_task = asyncio.create_task(run_graph()) - try: - # 从队列里取事件并 yield - while True: - # 尝试从队列取事件,超时检查 graph 是否完成 - try: - event = await asyncio.wait_for(token_queue.get(), timeout=0.1) - yield event - except asyncio.TimeoutError: - # 检查 graph 是否完成 - if graph_done.is_set(): - break - - # 如果 graph 有错误,已经在 run_graph 里 yield error 了 - + info(f"📡 开始调用 graph.astream_events()...") + + async for event in self.graph.astream_events(input_state, config=config, version="v2"): + kind = event["event"] + # info(f"[Stream Event] {kind}") # 调试用 + + if kind == "on_chat_model_stream": + # 流式 token + chunk = event["data"]["chunk"] + content = chunk.content if chunk.content else "" + reasoning_content = "" + if hasattr(chunk, 'additional_kwargs') and chunk.additional_kwargs: + reasoning_content = chunk.additional_kwargs.get("reasoning_content", "") + + if content: + full_message_content += content + + yield { + "type": "llm_token", + "node": "agent", + "token": content, + "reasoning_token": reasoning_content + } + + elif kind == "on_tool_start": + # 工具调用开始 + tool_name = event["name"] + tool_args = event["data"].get("input", {}) + yield { + "type": "tool_call_start", + "tool": tool_name, + "args": tool_args, + "id": event.get("run_id", "") + } + + elif kind == "on_tool_end": + # 工具调用结束 + tool_name = event["name"] + tool_output = event["data"].get("output", "") + yield { + "type": "tool_call_end", + "tool": tool_name, + "id": event.get("run_id", ""), + "result": str(tool_output) + } + + elif kind == "on_chain_start": + # 节点开始 + node_name = event.get("name", "unknown") + yield { + "type": "node_start", + "node": node_name + } + + elif kind == "on_chain_end": + # 节点结束 + node_name = event.get("name", "unknown") + yield { + "type": "node_end", + "node": node_name + } + + info(f"✅ graph.astream_events() 完成") + if full_message_content: + info(f"📄 完整消息内容: {repr(full_message_content)}") + + except Exception as e: + error(f"❌ 执行图时出错: {e}") + import traceback + error(f"📋 堆栈: {traceback.format_exc()}") + yield { + "type": "error", + "message": str(e) + } finally: - # 无论成功或失败,都发送结束事件,保证前端平稳关闭 - if current_node: - yield { - "type": "node_end", - "node": current_node - } yield { "type": "done", "model_used": actual_model_used } - # 取消任务 - graph_task.cancel() diff --git a/backend/app/agent/stream_context.py b/backend/app/agent/stream_context.py deleted file mode 100644 index e8efa2f..0000000 --- a/backend/app/agent/stream_context.py +++ /dev/null @@ -1,9 +0,0 @@ -"""流式上下文,用于在 LangGraph 节点和 agent_service 之间传递 token 回调""" -import contextvars -import asyncio -from typing import Optional, Any - -# 上下文变量:存储当前的 token 队列 -token_queue_var: contextvars.ContextVar[Optional[asyncio.Queue]] = contextvars.ContextVar( - "token_queue", default=None -) diff --git a/backend/app/main_graph/main_graph_builder.py b/backend/app/main_graph/main_graph_builder.py index 81ecd72..7f4cae9 100644 --- a/backend/app/main_graph/main_graph_builder.py +++ b/backend/app/main_graph/main_graph_builder.py @@ -1,31 +1,22 @@ """ -极简 Agent 主图 - 回归 LangGraph 标准模式 - -架构: -START → [init_state] → [记忆] → [Agent] ⇄ [Tools] → [Finalize] → END - ↑________↓ +极简 Agent 主图 - 用 LangGraph 原生 create_react_agent + 记忆节点 """ +from langgraph.prebuilt import create_react_agent from langgraph.graph import StateGraph, START, END -from langgraph.prebuilt import ToolNode -from langchain_core.runnables.config import RunnableConfig -from typing import Dict, Any, Optional - -from .state import AgentState -from .nodes.memory_trigger import memory_trigger_node, set_mem0_client -from .nodes.summarize import create_summarize_node -from .nodes.agent import create_agent_node -from backend.app.tools import ALL_TOOLS +from ..state import AgentState +from ..nodes.memory_trigger import memory_trigger_node, set_mem0_client from backend.app.logger import info, warning +from backend.app.tools import ALL_TOOLS def build_agent_graph( chat_services: dict, mem0_client=None, max_steps: int = 10 -) -> StateGraph: +): """ - 构建极简 Agent 图 + 构建包含记忆节点的 react agent 图 Args: chat_services: 模型服务字典 @@ -33,153 +24,56 @@ def build_agent_graph( max_steps: 最大步数限制 Returns: - StateGraph: 构建好的图 + 编译好的 graph """ - - graph = StateGraph(AgentState) + # 获取主模型 + primary_model = chat_services.get("primary", next(iter(chat_services.values()))) # ========== 设置全局客户端 ========== if mem0_client: set_mem0_client(mem0_client) - # ========== 创建核心节点 ========== - - # 1. Agent 节点(绑定工具的 LLM) - llm = chat_services.get("primary", list(chat_services.values())[0] if chat_services else None) - if llm is None: - raise ValueError("No LLM service provided") - - llm_with_tools = llm.bind_tools(ALL_TOOLS) - agent_node = create_agent_node(llm_with_tools, llm) - - # 2. Tool 节点(LangGraph 内置) - tool_node = ToolNode(ALL_TOOLS) - - # 3. 记忆/总结节点(保留现有) - retrieve_memory_node = None - summarize_node = None - if mem0_client: - try: - from .nodes.retrieve_memory import create_retrieve_memory_node - retrieve_memory_node = create_retrieve_memory_node(mem0_client) - summarize_node = create_summarize_node(mem0_client) - except Exception as e: - info(f"[Graph Builder] 记忆节点初始化失败: {e}") - - # ========== 添加节点 ========== - - # 1. 初始化节点(重置步数) - async def init_state_node(state: AgentState) -> Dict[str, Any]: - """初始化状态:重置步数计数器""" + # ========== 1. 初始化节点:重置步数 ========== + async def init_state_node(state: AgentState): info("[Init State] 初始化状态,重置步数") return { "current_step": 0 } - graph.add_node("init_state", init_state_node) + # ========== 2. 记忆节点(可选) ========== + retrieve_memory_node = None + if mem0_client: + try: + from ..nodes.retrieve_memory import create_retrieve_memory_node + retrieve_memory_node = create_retrieve_memory_node(mem0_client) + except Exception as e: + info(f"[Graph Builder] 记忆节点初始化失败: {e}") - # 2. 记忆阶段 + # ========== 3. 创建 react agent 子图 ========== + agent_runnable = create_react_agent(primary_model, ALL_TOOLS) + + # ========== 4. 构建主图 ========== + graph = StateGraph(AgentState) + + graph.add_node("init_state", init_state_node) if retrieve_memory_node: graph.add_node("retrieve_memory", retrieve_memory_node) graph.add_node("memory_trigger", memory_trigger_node) - # 3. 核心 Agent 循环 - graph.add_node("agent", agent_node) - graph.add_node("tools", tool_node) + # 直接把 create_react_agent 的可运行对象作为节点 + graph.add_node("agent", agent_runnable) - # 4. 完成阶段 - if summarize_node: - graph.add_node("summarize", summarize_node) - - # 简单的完成节点 - async def finalize_node_simple(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]: - """简单的完成节点,只发送完成事件""" - info("[Finalize] 进入完成节点") - - try: - from backend.app.main_graph.config import get_stream_writer - writer = get_stream_writer() - - # 提取最后的回复 - final_reply = "" - if state.messages: - last_msg = state.messages[-1] - final_reply = last_msg.content if hasattr(last_msg, "content") else str(last_msg) - - if writer and hasattr(writer, "__call__"): - try: - writer({ - "type": "custom", - "data": { - "type": "done", - "token_usage": state.last_token_usage, - "elapsed_time": state.last_elapsed_time, - "final_result": final_reply - } - }) - info("🏁 [完成事件] 已发送完成事件") - except Exception as e: - warning(f"⚠️ [完成事件] 发送失败 (非致命): {e}") - except Exception as e: - warning(f"⚠️ [完成事件] 处理失败 (非致命): {e}") - - return {} - - graph.add_node("finalize", finalize_node_simple) - - # ========== 添加边 ========== - - # 1. 初始化 + # ========== 边的连接 ========== graph.add_edge(START, "init_state") - # 2. 记忆阶段 if retrieve_memory_node: graph.add_edge("init_state", "retrieve_memory") graph.add_edge("retrieve_memory", "memory_trigger") else: graph.add_edge("init_state", "memory_trigger") - # 3. 进入 Agent graph.add_edge("memory_trigger", "agent") + graph.add_edge("agent", END) - # 4. 核心循环:Agent ⇄ Tools - def should_continue(state: AgentState) -> str: - """判断是继续调用工具还是结束""" - messages = state.messages - last_message = messages[-1] if messages else None - - # 检查是否有 tool_calls - if last_message and hasattr(last_message, "tool_calls") and last_message.tool_calls: - return "tools" - - # 否则结束 - return "finalize" - - graph.add_conditional_edges( - "agent", - should_continue, - { - "tools": "tools", - "finalize": "finalize" - } - ) - - # Tools 执行完回到 Agent - graph.add_edge("tools", "agent") - - # 5. 完成阶段 - if summarize_node: - def should_summarize(state: AgentState) -> str: - if state.turns_since_last_summary >= 5: - return "summarize" - return "finalize" - - # 总结逻辑暂简化:先 finalize - graph.add_edge("agent", "finalize") - else: - graph.add_edge("agent", "finalize") - - graph.add_edge("finalize", END) - - info("✅ [图构建] 极简 Agent 图构建完成") + info("✅ [Graph Builder] 极简 Agent 图构建完成(用 create_react_agent)") return graph