From eb33203b5cae8f6b88438d7981d030bc3273809a Mon Sep 17 00:00:00 2001 From: root <953994191@qq.com> Date: Thu, 7 May 2026 02:21:09 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E4=BC=98=E5=8C=96=E5=90=8E=E7=9A=84?= =?UTF-8?q?=E6=B5=81=E5=BC=8F=E6=96=B9=E6=A1=88=EF=BC=9A=E5=8F=8C=E5=8D=8F?= =?UTF-8?q?=E7=A8=8B=20+=20=E7=BB=93=E6=9D=9F=E5=93=A8=E5=85=B5=20+=20turn?= =?UTF-8?q?/phase=20=E5=85=83=E6=95=B0=E6=8D=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/app/agent/agent_service.py | 320 ++++++++++++++----- backend/app/agent/stream_context.py | 9 + backend/app/main_graph/main_graph_builder.py | 59 +++- backend/app/main_graph/nodes/agent.py | 61 ++-- 4 files changed, 343 insertions(+), 106 deletions(-) create mode 100644 backend/app/agent/stream_context.py diff --git a/backend/app/agent/agent_service.py b/backend/app/agent/agent_service.py index a443fff..673dda2 100644 --- a/backend/app/agent/agent_service.py +++ b/backend/app/agent/agent_service.py @@ -1,9 +1,10 @@ """ -AI Agent 服务类 - 用 LangGraph 原生 astream_events +AI Agent 服务类 - 优化版本:双协程 + 结束哨兵 + 完整的取消和异常处理 接收外部传入的 checkpointer,不负责管理连接生命周期 """ import json +import asyncio from typing import AsyncGenerator, Dict, Any, Optional, Tuple # LangGraph 序列化器(修复 checkpoint 反序列化警告) @@ -14,6 +15,7 @@ from ..model_services import get_cached_chat_services from ..main_graph.main_graph_builder import build_agent_graph from backend.app.logger import debug, info, warning, error from ..main_graph.state import AgentState +from .stream_context import token_queue_var class AIAgentService: @@ -118,10 +120,125 @@ class AIAgentService: "model_used": resolved_model } + def _serialize_value(self, value): + """递归将 LangChain 对象转换为可 JSON 序列化的格式""" + if hasattr(value, 'content'): + msg_type = getattr(value, 'type', 'message') + return { + "role": msg_type, + "content": getattr(value, 'content', ''), + "additional_kwargs": getattr(value, 'additional_kwargs', {}), + "tool_calls": getattr(value, 'tool_calls', []) + } + elif isinstance(value, dict): + return {k: self._serialize_value(v) for k, v in value.items()} + elif isinstance(value, (list, tuple)): + return [self._serialize_value(item) for item in value] + else: + try: + json.dumps(value) + return value + except (TypeError, ValueError): + return str(value) + + async def _handle_message_chunk( + self, chunk: Dict[str, Any], current_node: Optional[str], tool_calls_in_progress: Dict[str, Any] + ) -> AsyncGenerator[Dict[str, Any], None]: + """处理 messages 类型的 chunk""" + message_chunk, metadata = chunk["data"] + node_name = metadata.get("langgraph_node", "unknown") + new_current_node = current_node + + # 检测节点变化,发送节点开始事件 + if node_name != current_node: + if current_node: + yield {"type": "node_end", "node": current_node} + yield {"type": "node_start", "node": node_name} + new_current_node = node_name + + # 处理消息内容 + token_content = getattr(message_chunk, 'content', str(message_chunk)) + reasoning_token = "" + if hasattr(message_chunk, 'additional_kwargs'): + reasoning_token = message_chunk.additional_kwargs.get("reasoning_content", "") + + # 处理思考过程 + if reasoning_token: + yield { + "type": "llm_token", + "node": node_name, + "reasoning_token": reasoning_token + } + # 处理工具调用 + elif hasattr(message_chunk, 'tool_calls') and message_chunk.tool_calls: + for tool_call in message_chunk.tool_calls: + tool_call_id = tool_call.get("id", "") + tool_name = tool_call.get("name", "") + tool_args = tool_call.get("args", {}) + + # 记录工具调用开始,避免重复 + if tool_call_id and tool_call_id not in tool_calls_in_progress: + tool_calls_in_progress[tool_call_id] = { + "name": tool_name, + "args": tool_args + } + yield { + "type": "tool_call_start", + "tool": tool_name, + "args": tool_args, + "id": tool_call_id + } + # 处理普通 token + elif token_content: + yield { + "type": "llm_token", + "node": node_name, + "token": token_content, + "reasoning_token": reasoning_token + } + + # 返回更新后的 current_node + yield {"type": "_update_state", "current_node": new_current_node} + + async def _handle_updates_chunk( + self, chunk: Dict[str, Any], tool_calls_in_progress: Dict[str, Any], actual_model_used: str + ) -> AsyncGenerator[Dict[str, Any], None]: + """处理 updates 类型的 chunk""" + updates_data = chunk["data"] + new_actual_model = actual_model_used + + serialized_data = self._serialize_value(updates_data) + + # 检查是否有工具结果 + if "messages" in serialized_data: + for msg in serialized_data["messages"]: + # 检测工具结果消息 + if msg.get("role") == "tool": + tool_call_id = msg.get("tool_call_id", "") + tool_name = msg.get("name", "") + tool_result = msg.get("content", "") + + if tool_call_id and tool_call_id in tool_calls_in_progress: + yield { + "type": "tool_call_end", + "tool": tool_name, + "id": tool_call_id, + "result": tool_result + } + del tool_calls_in_progress[tool_call_id] + + yield { + "type": "state_update", + "data": serialized_data + } + + # 返回更新后的模型 + yield {"type": "_update_state", "actual_model_used": new_actual_model} + async def process_message_stream( self, message: str, thread_id: str, model: str = "", user_id: str = "default_user" ) -> AsyncGenerator[Dict[str, Any], None]: - """流式处理消息,用 astream_events 原生支持""" + """流式处理消息 - 双协程 + 结束哨兵 + 完整取消和异常处理""" # 解析模型名称 resolved_model = self._resolve_model(model) @@ -129,85 +246,144 @@ class AIAgentService: config, input_state = self._build_invocation(message, thread_id, resolved_model, user_id) info(f"🚀 开始执行 Agent 图,指定模型: {resolved_model}") + current_node = None + tool_calls_in_progress: Dict[str, Any] = {} actual_model_used = resolved_model full_message_content = "" + # 创建 token 队列 + token_queue = asyncio.Queue() + # 结束哨兵 + SENTINEL = object() + + # 设置上下文变量 + token_queue_var.set(token_queue) + + # 事件和错误跟踪 + graph_error = None + graph_done = asyncio.Event() + + async def run_graph_task(): + """后台任务:运行 graph.astream()""" + nonlocal current_node, actual_model_used, full_message_content, graph_error + try: + info(f"📡 开始调用 graph.astream()...") + + event_count = 0 + + async for chunk in self.graph.astream( + input_state, + config=config, + stream_mode=["messages", "updates"], + version="v2", + subgraphs=True + ): + chunk_count = 0 + chunk_count += 1 + chunk_type = chunk["type"] + + # 记录原始 chunk 信息(前 10 个和后 10 个) + if chunk_count <= 10 or chunk_count % 50 == 0: + info(f" [{chunk_count}] chunk_type={chunk_type}, data={type(chunk.get('data'))}") + + if chunk_type == "messages": + async for event in self._handle_message_chunk( + chunk, current_node, tool_calls_in_progress + ): + if event.get("type") == "_update_state": + current_node = event.get("current_node", current_node) + else: + event_count += 1 + # 记录前 10 个事件 + if event_count <= 10: + info(f" → yield event #{event_count}: {event.get('type')}") + + # 如果是 agent 节点的 token,收集完整消息 + if ( + event.get("type") == "llm_token" + and event.get("node") == "agent" + and "token" in event + ): + full_message_content += event["token"] + await token_queue.put(event) + + elif chunk_type == "updates": + async for event in self._handle_updates_chunk( + chunk, tool_calls_in_progress, actual_model_used + ): + if event.get("type") == "_update_state": + actual_model_used = event.get("actual_model_used", actual_model_used) + else: + event_count += 1 + if event_count <= 10: + info(f" → yield event #{event_count}: {event.get('type')}") + await token_queue.put(event) + + # 完整消息集合完成后,一次性打印 + info(f"✅ graph.astream() 完成,共 {event_count} 个 events") + if full_message_content: + info(f"📄 完整消息内容: {repr(full_message_content)}") + + except Exception as e: + error(f"❌ 执行图时出错: {e}") + import traceback + error(f"📋 堆栈: {traceback.format_exc()}") + graph_error = e + await token_queue.put({ + "type": "error", + "message": str(e) + }) + finally: + # 发送结束哨兵 + await token_queue.put(SENTINEL) + graph_done.set() + + # 启动后台任务 + graph_task = asyncio.create_task(run_graph_task()) + try: - info(f"📡 开始调用 graph.astream_events()...") - - async for event in self.graph.astream_events(input_state, config=config, version="v2"): - kind = event["event"] - # info(f"[Stream Event] {kind}") # 调试用 - - if kind == "on_chat_model_stream": - # 流式 token - chunk = event["data"]["chunk"] - content = chunk.content if chunk.content else "" - reasoning_content = "" - if hasattr(chunk, 'additional_kwargs') and chunk.additional_kwargs: - reasoning_content = chunk.additional_kwargs.get("reasoning_content", "") + # 主协程:从队列里取事件并 yield + while True: + try: + # 等待队列中的事件,带超时检查任务是否完成 + event = await asyncio.wait_for(token_queue.get(), timeout=0.5) - if content: - full_message_content += content + # 检查是否是结束哨兵 + if event is SENTINEL: + break - yield { - "type": "llm_token", - "node": "agent", - "token": content, - "reasoning_token": reasoning_content - } + yield event - elif kind == "on_tool_start": - # 工具调用开始 - tool_name = event["name"] - tool_args = event["data"].get("input", {}) - yield { - "type": "tool_call_start", - "tool": tool_name, - "args": tool_args, - "id": event.get("run_id", "") - } - - elif kind == "on_tool_end": - # 工具调用结束 - tool_name = event["name"] - tool_output = event["data"].get("output", "") - yield { - "type": "tool_call_end", - "tool": tool_name, - "id": event.get("run_id", ""), - "result": str(tool_output) - } - - elif kind == "on_chain_start": - # 节点开始 - node_name = event.get("name", "unknown") - yield { - "type": "node_start", - "node": node_name - } - - elif kind == "on_chain_end": - # 节点结束 - node_name = event.get("name", "unknown") - yield { - "type": "node_end", - "node": node_name - } - - info(f"✅ graph.astream_events() 完成") - if full_message_content: - info(f"📄 完整消息内容: {repr(full_message_content)}") + except asyncio.TimeoutError: + # 超时检查任务是否完成 + if graph_task.done(): + # 检查任务是否抛出异常 + if graph_task.exception(): + exc = graph_task.exception() + error(f"❌ 后台任务异常: {exc}") + break + + except asyncio.CancelledError: + info("⚠️ 流式生成被取消") + raise - except Exception as e: - error(f"❌ 执行图时出错: {e}") - import traceback - error(f"📋 堆栈: {traceback.format_exc()}") - yield { - "type": "error", - "message": str(e) - } finally: + # 无论成功或失败,都清理资源 + # 取消后台任务 + if not graph_task.done(): + info("⏹️ 取消后台任务") + graph_task.cancel() + try: + await graph_task + except asyncio.CancelledError: + info("✅ 后台任务已取消") + + # 发送结束事件,保证前端平稳关闭 + if current_node: + yield { + "type": "node_end", + "node": current_node + } yield { "type": "done", "model_used": actual_model_used diff --git a/backend/app/agent/stream_context.py b/backend/app/agent/stream_context.py new file mode 100644 index 0000000..2d5470d --- /dev/null +++ b/backend/app/agent/stream_context.py @@ -0,0 +1,9 @@ +"""流式上下文,用于在 LangGraph 节点和 agent_service 之间传递 token 队列""" +import contextvars +import asyncio +from typing import Optional, Any + +# 上下文变量:存储当前的 token 队列 +token_queue_var: contextvars.ContextVar[Optional[asyncio.Queue]] = contextvars.ContextVar( + "token_queue", default=None +) diff --git a/backend/app/main_graph/main_graph_builder.py b/backend/app/main_graph/main_graph_builder.py index 7f4cae9..5f048d8 100644 --- a/backend/app/main_graph/main_graph_builder.py +++ b/backend/app/main_graph/main_graph_builder.py @@ -1,11 +1,12 @@ """ -极简 Agent 主图 - 用 LangGraph 原生 create_react_agent + 记忆节点 +极简 Agent 主图 - 自己的节点结构,更好控制流式 """ -from langgraph.prebuilt import create_react_agent from langgraph.graph import StateGraph, START, END +from langgraph.prebuilt import ToolNode from ..state import AgentState from ..nodes.memory_trigger import memory_trigger_node, set_mem0_client +from ..nodes.agent import create_agent_node from backend.app.logger import info, warning from backend.app.tools import ALL_TOOLS @@ -16,7 +17,7 @@ def build_agent_graph( max_steps: int = 10 ): """ - 构建包含记忆节点的 react agent 图 + 构建包含记忆节点的 Agent 图 Args: chat_services: 模型服务字典 @@ -24,7 +25,7 @@ def build_agent_graph( max_steps: 最大步数限制 Returns: - 编译好的 graph + 构建好的 StateGraph(未编译) """ # 获取主模型 primary_model = chat_services.get("primary", next(iter(chat_services.values()))) @@ -37,7 +38,8 @@ def build_agent_graph( async def init_state_node(state: AgentState): info("[Init State] 初始化状态,重置步数") return { - "current_step": 0 + "current_step": 0, + "max_steps": max_steps } # ========== 2. 记忆节点(可选) ========== @@ -49,21 +51,39 @@ def build_agent_graph( except Exception as e: info(f"[Graph Builder] 记忆节点初始化失败: {e}") - # ========== 3. 创建 react agent 子图 ========== - agent_runnable = create_react_agent(primary_model, ALL_TOOLS) + # ========== 3. 核心节点 ========== + llm_with_tools = primary_model.bind_tools(ALL_TOOLS) + agent_node_fn = create_agent_node(llm_with_tools, primary_model) + tool_node_fn = ToolNode(ALL_TOOLS) - # ========== 4. 构建主图 ========== + # ========== 4. 条件边判断函数 ========== + def should_continue(state: AgentState): + """判断是继续调用工具还是结束""" + messages = state.messages + last_message = messages[-1] if messages else None + + if last_message and hasattr(last_message, 'tool_calls') and last_message.tool_calls: + return "tools" + + return "finalize" + + # ========== 5. 完成节点 ========== + async def finalize_node_simple(state: AgentState): + info("[Finalize] 进入完成节点") + return {} + + # ========== 6. 构建图 ========== graph = StateGraph(AgentState) graph.add_node("init_state", init_state_node) if retrieve_memory_node: graph.add_node("retrieve_memory", retrieve_memory_node) graph.add_node("memory_trigger", memory_trigger_node) + graph.add_node("agent", agent_node_fn) + graph.add_node("tools", tool_node_fn) + graph.add_node("finalize", finalize_node_simple) - # 直接把 create_react_agent 的可运行对象作为节点 - graph.add_node("agent", agent_runnable) - - # ========== 边的连接 ========== + # ========== 7. 边的连接 ========== graph.add_edge(START, "init_state") if retrieve_memory_node: @@ -73,7 +93,18 @@ def build_agent_graph( graph.add_edge("init_state", "memory_trigger") graph.add_edge("memory_trigger", "agent") - graph.add_edge("agent", END) - info("✅ [Graph Builder] 极简 Agent 图构建完成(用 create_react_agent)") + graph.add_conditional_edges( + "agent", + should_continue, + { + "tools": "tools", + "finalize": "finalize" + } + ) + + graph.add_edge("tools", "agent") + graph.add_edge("finalize", END) + + info("✅ [Graph Builder] 极简 Agent 图构建完成") return graph diff --git a/backend/app/main_graph/nodes/agent.py b/backend/app/main_graph/nodes/agent.py index 7acce36..daa99de 100644 --- a/backend/app/main_graph/nodes/agent.py +++ b/backend/app/main_graph/nodes/agent.py @@ -67,7 +67,8 @@ def create_agent_node(llm_with_tools, llm): Returns: 状态更新字典 """ - info(f"[Agent] 第 {state.current_step} 步推理") + current_step = state.get("current_step", 0) + info(f"[Agent] 第 {current_step} 步推理") try: # 组装完整消息:系统提示 + 历史消息 @@ -76,8 +77,8 @@ def create_agent_node(llm_with_tools, llm): info(f"[Agent] 消息数量: {len(full_messages)}, 最后一条: {type(full_messages[-1]).__name__}") # 判断是否达到步数上限 - if state.current_step >= state.max_steps: - info(f"[Agent] 达到步数上限 {state.max_steps},强制结束,不绑定工具") + if current_step >= state.get("max_steps", 10): + info(f"[Agent] 达到步数上限 {state.get('max_steps', 10)},强制结束,不绑定工具") current_llm = llm.bind_tools([]) else: current_llm = llm_with_tools @@ -86,6 +87,9 @@ def create_agent_node(llm_with_tools, llm): # 获取 token 队列 token_queue = token_queue_var.get() + if token_queue is None: + error("[Agent] ❌ token_queue 为 None!") + raise RuntimeError("token_queue 上下文变量未设置") # 完整消息 full_content = "" @@ -98,26 +102,28 @@ def create_agent_node(llm_with_tools, llm): # 处理 content if chunk.content: full_content += chunk.content - if token_queue: - await token_queue.put({ - "type": "llm_token", - "node": "agent", - "token": chunk.content, - "reasoning_token": "" - }) + await token_queue.put({ + "type": "llm_token", + "node": "agent", + "token": chunk.content, + "reasoning_token": "", + "turn": current_step, + "phase": "answering" if not full_tool_calls else "thinking" + }) # 处理 reasoning_content if hasattr(chunk, 'additional_kwargs') and chunk.additional_kwargs: reasoning_content = chunk.additional_kwargs.get("reasoning_content", "") if reasoning_content: full_reasoning_content += reasoning_content - if token_queue: - await token_queue.put({ - "type": "llm_token", - "node": "agent", - "token": "", - "reasoning_token": reasoning_content - }) + await token_queue.put({ + "type": "llm_token", + "node": "agent", + "token": "", + "reasoning_token": reasoning_content, + "turn": current_step, + "phase": "thinking" + }) # 处理 tool_calls if hasattr(chunk, 'tool_calls') and chunk.tool_calls: @@ -133,6 +139,14 @@ def create_agent_node(llm_with_tools, llm): break if not found: full_tool_calls.append(tc) + # 发送工具调用开始事件 + await token_queue.put({ + "type": "tool_call_start", + "tool": tc.get("name"), + "args": tc.get("args"), + "id": tc.get("id", ""), + "turn": current_step + }) # 构建完整的 AIMessage response = AIMessage( @@ -149,14 +163,21 @@ def create_agent_node(llm_with_tools, llm): # 返回状态更新 return { "messages": [response], - "current_step": state.current_step + 1, - "llm_calls": state.llm_calls + 1 + "current_step": current_step + 1, + "llm_calls": state.get("llm_calls", 0) + 1 } except Exception as e: - error(f"[Agent] ❌ 第 {state.current_step} 步推理出错: {e}") + error(f"[Agent] ❌ 第 {current_step} 步推理出错: {e}") import traceback error(f"[Agent] 堆栈: {traceback.format_exc()}") + # 发送错误事件 + token_queue = token_queue_var.get() + if token_queue: + await token_queue.put({ + "type": "error", + "message": str(e) + }) raise return agent_node