From ef07b05c22c4daeeb8f2032de2198d0cfa5e1ec0 Mon Sep 17 00:00:00 2001 From: root <953994191@qq.com> Date: Fri, 8 May 2026 01:48:46 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E9=87=8D=E6=9E=84=20agent=20?= =?UTF-8?q?=E8=8A=82=E7=82=B9=E4=B8=BA=E5=8D=95=E6=AD=A5=E6=8E=A8=E7=90=86?= =?UTF-8?q?=EF=BC=88=E7=A7=BB=E9=99=A4=20while=20=E5=BE=AA=E7=8E=AF?= =?UTF-8?q?=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/app/main_graph/nodes/agent.py | 269 +++++++++++--------------- 1 file changed, 110 insertions(+), 159 deletions(-) diff --git a/backend/app/main_graph/nodes/agent.py b/backend/app/main_graph/nodes/agent.py index b60971c..fcb64d1 100644 --- a/backend/app/main_graph/nodes/agent.py +++ b/backend/app/main_graph/nodes/agent.py @@ -1,15 +1,15 @@ """ -Agent 节点 - 简化版本 -直接定义 agent_node 函数,支持动态模型切换和循环检测 +Agent 节点 - 简化版本(单步推理) +只负责一次 LLM 调用,不执行工具 """ -import hashlib +import json from typing import Dict, Any, Optional, List from langchain_core.runnables.config import RunnableConfig from langchain_core.messages import AIMessage, AIMessageChunk, SystemMessage, ToolMessage from backend.app.main_graph.state import AgentState -from backend.app.logger import info, error +from backend.app.logger import info, error, debug from backend.app.tools import ALL_TOOLS from backend.app.agent.stream_context import get_stream_queue from backend.app.agent.prompts import SYSTEM_PROMPT @@ -28,7 +28,6 @@ def _is_similar_result(results: List[str], threshold: float = 0.8) -> bool: latest = results[-1] prev = results[-2] - # 长度差异太大,不算相似 if len(latest) == 0 or len(prev) == 0: return len(latest) == len(prev) @@ -36,7 +35,6 @@ def _is_similar_result(results: List[str], threshold: float = 0.8) -> bool: if len_ratio < 0.5: return False - # 检查内容重复度(简单:前100字符) common_len = 0 for a, b in zip(latest[:100], prev[:100]): if a == b: @@ -50,27 +48,23 @@ def _is_similar_result(results: List[str], threshold: float = 0.8) -> bool: def _should_stop_for_loop(tool_calls: List[dict], tool_results: List[str]) -> bool: """ 检测是否应该停止(循环检测) - 条件:连续2次调用相同工具 + 参数相似 + 结果相似 """ if len(tool_calls) < 2: return False - # 检查最近的工具调用是否相同 last_tc = tool_calls[-1] prev_tc = tool_calls[-2] if last_tc["name"] != prev_tc["name"]: return False - # 参数是否相似 last_args = _normalize_args(last_tc["args"]) prev_args = _normalize_args(prev_tc["args"]) if last_args != prev_args: return False - # 结果是否相似 if len(tool_results) >= 2: return _is_similar_result(tool_results[-2:]) @@ -79,22 +73,43 @@ def _should_stop_for_loop(tool_calls: List[dict], tool_results: List[str]) -> bo def create_agent_node(chat_services: dict): """ - 创建 Agent 节点 - 支持动态模型切换 + 创建 Agent 节点 - 单步推理版本 - 简化设计: - - 直接返回 async 函数,无需工厂包装 - - 从 config 中获取模型名称,运行时动态切换 + 设计: + - 只做一次 LLM 调用 + - 不执行工具(工具执行由 tools 节点负责) + - 返回 AIMessage(可能包含 tool_calls) """ async def agent_node(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]: - """Agent 节点:完整的 ReAct 循环""" + """Agent 节点:单步 LLM 调用""" queue = get_stream_queue() is_streaming = queue is not None # 获取步数 current_step = getattr(state, "current_step", 0) max_steps = getattr(state, "max_steps", 10) - info(f"[Agent] 从第 {current_step} 步开始,最大步数: {max_steps},流式: {is_streaming}") + info(f"[Agent] 第 {current_step + 1} 步开始") + + # 步数已达上限 + if current_step >= max_steps: + info("[Agent] 达到步数上限,强制结束") + return { + "messages": [AIMessage(content="[系统] 已达到最大步数限制。")], + "stop": True, + "stop_reason": "max_steps", + } + + # 循环检测 + tool_history = getattr(state, "tool_call_history", []) + result_history = getattr(state, "tool_result_history", []) + if _should_stop_for_loop(tool_history, result_history): + info("[Agent] 检测到循环,终止推理") + return { + "messages": [AIMessage(content="[系统] 检测到工具调用循环,已终止。")], + "stop": True, + "stop_reason": "loop_detected", + } # 动态获取模型 model_name = "primary" @@ -111,160 +126,95 @@ def create_agent_node(chat_services: dict): # 获取记忆上下文 memory_context = getattr(state, "memory_context", "暂无用户背景信息") - - # 组装消息(注入记忆上下文到提示词) prompt_with_memory = SYSTEM_PROMPT.format(memory_context=memory_context) messages = [SystemMessage(content=prompt_with_memory)] + list(state.messages) - turn = current_step + + # 发送节点开始事件 + if is_streaming: + await queue.put({"type": "node_start", "node": "agent"}) + + # 选择 LLM(最后一轮不带工具) + if current_step + 1 >= max_steps: + current_llm = llm.bind_tools([]) + info(f"[Agent] 达到步数上限,使用无工具模型") + else: + current_llm = llm_with_tools + + # 初始化 + full_content = "" + full_reasoning_content = "" + pending_tool_calls = {} + final_tool_calls = [] try: - while turn < max_steps: - turn += 1 - info(f"[Agent] 第 {turn} 轮思考") + # 调用 LLM + if is_streaming: + async for chunk in current_llm.astream(messages): + if isinstance(chunk, AIMessageChunk): + if chunk.content: + full_content += chunk.content + await queue.put({ + "type": "llm_token", + "node": "agent", + "token": chunk.content, + "reasoning_token": "" + }) - if is_streaming: - await queue.put({"type": "node_start", "node": "agent"}) - - # 选择 LLM(最后一轮不带工具) - if turn >= max_steps: - current_llm = llm.bind_tools([]) - info(f"[Agent] 达到步数上限,使用无工具模型") - else: - current_llm = llm_with_tools - - # 初始化 - full_content = "" - full_reasoning_content = "" - pending_tool_calls = {} - final_tool_calls = [] - - # 循环检测:记录历史调用 - tool_call_history: List[dict] = [] - tool_result_history: List[str] = [] - - # 调用 LLM - if is_streaming: - async for chunk in current_llm.astream(messages): - if isinstance(chunk, AIMessageChunk): - if chunk.content: - full_content += chunk.content + if hasattr(chunk, 'additional_kwargs') and chunk.additional_kwargs: + reasoning = chunk.additional_kwargs.get("reasoning_content", "") + if reasoning: + full_reasoning_content += reasoning await queue.put({ "type": "llm_token", "node": "agent", - "token": chunk.content, - "reasoning_token": "" + "token": "", + "reasoning_token": reasoning }) - if hasattr(chunk, 'additional_kwargs') and chunk.additional_kwargs: - reasoning = chunk.additional_kwargs.get("reasoning_content", "") - if reasoning: - full_reasoning_content += reasoning - await queue.put({ - "type": "llm_token", - "node": "agent", - "token": "", - "reasoning_token": reasoning - }) + if hasattr(chunk, 'tool_call_chunks') and chunk.tool_call_chunks: + for tc_chunk in chunk.tool_call_chunks: + idx = tc_chunk.get("index", 0) + if idx not in pending_tool_calls: + pending_tool_calls[idx] = {"id": "", "name": "", "args": ""} - if hasattr(chunk, 'tool_call_chunks') and chunk.tool_call_chunks: - for tc_chunk in chunk.tool_call_chunks: - idx = tc_chunk.get("index", 0) - if idx not in pending_tool_calls: - pending_tool_calls[idx] = {"id": "", "name": "", "args": ""} + if tc_chunk.get("id"): + pending_tool_calls[idx]["id"] += tc_chunk["id"] + if tc_chunk.get("name"): + pending_tool_calls[idx]["name"] += tc_chunk["name"] + if tc_chunk.get("args"): + args_val = tc_chunk["args"] + if isinstance(args_val, str): + pending_tool_calls[idx]["args"] += args_val + else: + pending_tool_calls[idx]["args"] += json.dumps(args_val) + else: + result = await current_llm.ainvoke(messages) + full_content = result.content if result.content else "" + if hasattr(result, 'tool_calls') and result.tool_calls: + final_tool_calls = result.tool_calls + if hasattr(result, 'additional_kwargs'): + full_reasoning_content = result.additional_kwargs.get("reasoning_content", "") - if tc_chunk.get("id"): - pending_tool_calls[idx]["id"] += tc_chunk["id"] - if tc_chunk.get("name"): - pending_tool_calls[idx]["name"] += tc_chunk["name"] - if tc_chunk.get("args"): - args_val = tc_chunk["args"] - if isinstance(args_val, str): - pending_tool_calls[idx]["args"] += args_val - else: - import json - pending_tool_calls[idx]["args"] += json.dumps(args_val) - else: - result = await current_llm.ainvoke(messages) - full_content = result.content if result.content else "" - if hasattr(result, 'tool_calls') and result.tool_calls: - final_tool_calls = result.tool_calls - if hasattr(result, 'additional_kwargs'): - full_reasoning_content = result.additional_kwargs.get("reasoning_content", "") + # 整理工具调用 + if is_streaming: + for idx in sorted(pending_tool_calls.keys()): + tc_data = pending_tool_calls[idx] + if tc_data["name"]: + args = {} + if tc_data["args"]: + try: + args = json.loads(tc_data["args"]) + except Exception as e: + info(f"[Agent] 解析参数失败: {e}") + final_tool_calls.append({ + "id": tc_data["id"], + "name": tc_data["name"], + "args": args + }) - # 整理工具调用 - if is_streaming: - for idx in sorted(pending_tool_calls.keys()): - tc_data = pending_tool_calls[idx] - if tc_data["name"]: - args = {} - if tc_data["args"]: - try: - import json - args = json.loads(tc_data["args"]) - except Exception as e: - info(f"[Agent] 解析参数失败: {e}") - final_tool_calls.append({ - "id": tc_data["id"], - "name": tc_data["name"], - "args": args - }) - - # 执行工具 - if final_tool_calls: - info(f"[Agent] 第 {turn} 轮:调用 {len(final_tool_calls)} 个工具") - new_messages = [] - - for tc in final_tool_calls: - tool_name = tc["name"] - tool_args = tc["args"] - tool_id = tc["id"] - - if is_streaming: - await queue.put({ - "type": "custom", - "data": {"type": "tool_start", "tool": tool_name, "args": tool_args, "id": tool_id} - }) - - # 查找并执行工具 - tool_result = "" - tool_found = False - for tool in ALL_TOOLS: - if tool.name == tool_name: - tool_found = True - try: - tool_result = await tool.ainvoke(tool_args) - except Exception as e: - tool_result = f"工具调用出错: {str(e)}" - error(f"[Agent] 工具 {tool_name} 调用出错: {e}") - break - - if not tool_found: - tool_result = f"未找到工具: {tool_name}" - - if is_streaming: - await queue.put({ - "type": "custom", - "data": {"type": "tool_end", "tool": tool_name, "id": tool_id, "result": str(tool_result)} - }) - - # 记录历史(用于循环检测) - tool_call_history.append({"name": tool_name, "args": tool_args}) - tool_result_history.append(str(tool_result)) - - new_messages.append(ToolMessage(content=str(tool_result), tool_call_id=tool_id, name=tool_name)) - - # 循环检测:相同工具 + 相似参数 + 相似结果 → 终止 - if _should_stop_for_loop(tool_call_history, tool_result_history): - info(f"[Agent] ⚠️ 检测到循环,强制终止") - # 添加一条终止消息 - messages.append(AIMessage(content="[系统] 检测到工具调用循环,已终止。")) - break - - messages.extend(new_messages) - continue - else: - info(f"[Agent] 第 {turn} 轮:完成,无工具调用") - break + # 发送节点结束事件 + if is_streaming: + await queue.put({"type": "node_end", "node": "agent"}) # 构建响应 response_kwargs = {"content": full_content} @@ -274,14 +224,15 @@ def create_agent_node(chat_services: dict): if full_reasoning_content: response.additional_kwargs["reasoning_content"] = full_reasoning_content + info(f"[Agent] 完成 - content长度: {len(full_content)}, tool_calls: {len(final_tool_calls)}") + return { "messages": [response], - "current_step": turn, "llm_calls": getattr(state, "llm_calls", 0) + 1 } except Exception as e: - error(f"[Agent] ❌ 第 {turn} 轮出错: {e}") + error(f"[Agent] 执行出错: {e}") import traceback error(f"[Agent] 堆栈: {traceback.format_exc()}") if is_streaming: