refactor: 重构 agent 节点为单步推理(移除 while 循环)

This commit is contained in:
2026-05-08 01:48:46 +08:00
parent e851e40763
commit ef07b05c22

View File

@@ -1,15 +1,15 @@
"""
Agent 节点 - 简化版本
直接定义 agent_node 函数,支持动态模型切换和循环检测
Agent 节点 - 简化版本(单步推理)
只负责一次 LLM 调用,不执行工具
"""
import hashlib
import json
from typing import Dict, Any, Optional, List
from langchain_core.runnables.config import RunnableConfig
from langchain_core.messages import AIMessage, AIMessageChunk, SystemMessage, ToolMessage
from backend.app.main_graph.state import AgentState
from backend.app.logger import info, error
from backend.app.logger import info, error, debug
from backend.app.tools import ALL_TOOLS
from backend.app.agent.stream_context import get_stream_queue
from backend.app.agent.prompts import SYSTEM_PROMPT
@@ -28,7 +28,6 @@ def _is_similar_result(results: List[str], threshold: float = 0.8) -> bool:
latest = results[-1]
prev = results[-2]
# 长度差异太大,不算相似
if len(latest) == 0 or len(prev) == 0:
return len(latest) == len(prev)
@@ -36,7 +35,6 @@ def _is_similar_result(results: List[str], threshold: float = 0.8) -> bool:
if len_ratio < 0.5:
return False
# 检查内容重复度简单前100字符
common_len = 0
for a, b in zip(latest[:100], prev[:100]):
if a == b:
@@ -50,27 +48,23 @@ def _is_similar_result(results: List[str], threshold: float = 0.8) -> bool:
def _should_stop_for_loop(tool_calls: List[dict], tool_results: List[str]) -> bool:
"""
检测是否应该停止(循环检测)
条件连续2次调用相同工具 + 参数相似 + 结果相似
"""
if len(tool_calls) < 2:
return False
# 检查最近的工具调用是否相同
last_tc = tool_calls[-1]
prev_tc = tool_calls[-2]
if last_tc["name"] != prev_tc["name"]:
return False
# 参数是否相似
last_args = _normalize_args(last_tc["args"])
prev_args = _normalize_args(prev_tc["args"])
if last_args != prev_args:
return False
# 结果是否相似
if len(tool_results) >= 2:
return _is_similar_result(tool_results[-2:])
@@ -79,22 +73,43 @@ def _should_stop_for_loop(tool_calls: List[dict], tool_results: List[str]) -> bo
def create_agent_node(chat_services: dict):
"""
创建 Agent 节点 - 支持动态模型切换
创建 Agent 节点 - 单步推理版本
简化设计:
- 直接返回 async 函数,无需工厂包装
- 从 config 中获取模型名称,运行时动态切换
设计:
- 只做一次 LLM 调用
- 不执行工具(工具执行由 tools 节点负责)
- 返回 AIMessage可能包含 tool_calls
"""
async def agent_node(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
"""Agent 节点:完整的 ReAct 循环"""
"""Agent 节点:单步 LLM 调用"""
queue = get_stream_queue()
is_streaming = queue is not None
# 获取步数
current_step = getattr(state, "current_step", 0)
max_steps = getattr(state, "max_steps", 10)
info(f"[Agent] {current_step} 步开始,最大步数: {max_steps},流式: {is_streaming}")
info(f"[Agent] 第 {current_step + 1} 步开始")
# 步数已达上限
if current_step >= max_steps:
info("[Agent] 达到步数上限,强制结束")
return {
"messages": [AIMessage(content="[系统] 已达到最大步数限制。")],
"stop": True,
"stop_reason": "max_steps",
}
# 循环检测
tool_history = getattr(state, "tool_call_history", [])
result_history = getattr(state, "tool_result_history", [])
if _should_stop_for_loop(tool_history, result_history):
info("[Agent] 检测到循环,终止推理")
return {
"messages": [AIMessage(content="[系统] 检测到工具调用循环,已终止。")],
"stop": True,
"stop_reason": "loop_detected",
}
# 动态获取模型
model_name = "primary"
@@ -111,160 +126,95 @@ def create_agent_node(chat_services: dict):
# 获取记忆上下文
memory_context = getattr(state, "memory_context", "暂无用户背景信息")
# 组装消息(注入记忆上下文到提示词)
prompt_with_memory = SYSTEM_PROMPT.format(memory_context=memory_context)
messages = [SystemMessage(content=prompt_with_memory)] + list(state.messages)
turn = current_step
# 发送节点开始事件
if is_streaming:
await queue.put({"type": "node_start", "node": "agent"})
# 选择 LLM最后一轮不带工具
if current_step + 1 >= max_steps:
current_llm = llm.bind_tools([])
info(f"[Agent] 达到步数上限,使用无工具模型")
else:
current_llm = llm_with_tools
# 初始化
full_content = ""
full_reasoning_content = ""
pending_tool_calls = {}
final_tool_calls = []
try:
while turn < max_steps:
turn += 1
info(f"[Agent] 第 {turn} 轮思考")
# 调用 LLM
if is_streaming:
async for chunk in current_llm.astream(messages):
if isinstance(chunk, AIMessageChunk):
if chunk.content:
full_content += chunk.content
await queue.put({
"type": "llm_token",
"node": "agent",
"token": chunk.content,
"reasoning_token": ""
})
if is_streaming:
await queue.put({"type": "node_start", "node": "agent"})
# 选择 LLM最后一轮不带工具
if turn >= max_steps:
current_llm = llm.bind_tools([])
info(f"[Agent] 达到步数上限,使用无工具模型")
else:
current_llm = llm_with_tools
# 初始化
full_content = ""
full_reasoning_content = ""
pending_tool_calls = {}
final_tool_calls = []
# 循环检测:记录历史调用
tool_call_history: List[dict] = []
tool_result_history: List[str] = []
# 调用 LLM
if is_streaming:
async for chunk in current_llm.astream(messages):
if isinstance(chunk, AIMessageChunk):
if chunk.content:
full_content += chunk.content
if hasattr(chunk, 'additional_kwargs') and chunk.additional_kwargs:
reasoning = chunk.additional_kwargs.get("reasoning_content", "")
if reasoning:
full_reasoning_content += reasoning
await queue.put({
"type": "llm_token",
"node": "agent",
"token": chunk.content,
"reasoning_token": ""
"token": "",
"reasoning_token": reasoning
})
if hasattr(chunk, 'additional_kwargs') and chunk.additional_kwargs:
reasoning = chunk.additional_kwargs.get("reasoning_content", "")
if reasoning:
full_reasoning_content += reasoning
await queue.put({
"type": "llm_token",
"node": "agent",
"token": "",
"reasoning_token": reasoning
})
if hasattr(chunk, 'tool_call_chunks') and chunk.tool_call_chunks:
for tc_chunk in chunk.tool_call_chunks:
idx = tc_chunk.get("index", 0)
if idx not in pending_tool_calls:
pending_tool_calls[idx] = {"id": "", "name": "", "args": ""}
if hasattr(chunk, 'tool_call_chunks') and chunk.tool_call_chunks:
for tc_chunk in chunk.tool_call_chunks:
idx = tc_chunk.get("index", 0)
if idx not in pending_tool_calls:
pending_tool_calls[idx] = {"id": "", "name": "", "args": ""}
if tc_chunk.get("id"):
pending_tool_calls[idx]["id"] += tc_chunk["id"]
if tc_chunk.get("name"):
pending_tool_calls[idx]["name"] += tc_chunk["name"]
if tc_chunk.get("args"):
args_val = tc_chunk["args"]
if isinstance(args_val, str):
pending_tool_calls[idx]["args"] += args_val
else:
pending_tool_calls[idx]["args"] += json.dumps(args_val)
else:
result = await current_llm.ainvoke(messages)
full_content = result.content if result.content else ""
if hasattr(result, 'tool_calls') and result.tool_calls:
final_tool_calls = result.tool_calls
if hasattr(result, 'additional_kwargs'):
full_reasoning_content = result.additional_kwargs.get("reasoning_content", "")
if tc_chunk.get("id"):
pending_tool_calls[idx]["id"] += tc_chunk["id"]
if tc_chunk.get("name"):
pending_tool_calls[idx]["name"] += tc_chunk["name"]
if tc_chunk.get("args"):
args_val = tc_chunk["args"]
if isinstance(args_val, str):
pending_tool_calls[idx]["args"] += args_val
else:
import json
pending_tool_calls[idx]["args"] += json.dumps(args_val)
else:
result = await current_llm.ainvoke(messages)
full_content = result.content if result.content else ""
if hasattr(result, 'tool_calls') and result.tool_calls:
final_tool_calls = result.tool_calls
if hasattr(result, 'additional_kwargs'):
full_reasoning_content = result.additional_kwargs.get("reasoning_content", "")
# 整理工具调用
if is_streaming:
for idx in sorted(pending_tool_calls.keys()):
tc_data = pending_tool_calls[idx]
if tc_data["name"]:
args = {}
if tc_data["args"]:
try:
args = json.loads(tc_data["args"])
except Exception as e:
info(f"[Agent] 解析参数失败: {e}")
final_tool_calls.append({
"id": tc_data["id"],
"name": tc_data["name"],
"args": args
})
# 整理工具调用
if is_streaming:
for idx in sorted(pending_tool_calls.keys()):
tc_data = pending_tool_calls[idx]
if tc_data["name"]:
args = {}
if tc_data["args"]:
try:
import json
args = json.loads(tc_data["args"])
except Exception as e:
info(f"[Agent] 解析参数失败: {e}")
final_tool_calls.append({
"id": tc_data["id"],
"name": tc_data["name"],
"args": args
})
# 执行工具
if final_tool_calls:
info(f"[Agent] 第 {turn} 轮:调用 {len(final_tool_calls)} 个工具")
new_messages = []
for tc in final_tool_calls:
tool_name = tc["name"]
tool_args = tc["args"]
tool_id = tc["id"]
if is_streaming:
await queue.put({
"type": "custom",
"data": {"type": "tool_start", "tool": tool_name, "args": tool_args, "id": tool_id}
})
# 查找并执行工具
tool_result = ""
tool_found = False
for tool in ALL_TOOLS:
if tool.name == tool_name:
tool_found = True
try:
tool_result = await tool.ainvoke(tool_args)
except Exception as e:
tool_result = f"工具调用出错: {str(e)}"
error(f"[Agent] 工具 {tool_name} 调用出错: {e}")
break
if not tool_found:
tool_result = f"未找到工具: {tool_name}"
if is_streaming:
await queue.put({
"type": "custom",
"data": {"type": "tool_end", "tool": tool_name, "id": tool_id, "result": str(tool_result)}
})
# 记录历史(用于循环检测)
tool_call_history.append({"name": tool_name, "args": tool_args})
tool_result_history.append(str(tool_result))
new_messages.append(ToolMessage(content=str(tool_result), tool_call_id=tool_id, name=tool_name))
# 循环检测:相同工具 + 相似参数 + 相似结果 → 终止
if _should_stop_for_loop(tool_call_history, tool_result_history):
info(f"[Agent] ⚠️ 检测到循环,强制终止")
# 添加一条终止消息
messages.append(AIMessage(content="[系统] 检测到工具调用循环,已终止。"))
break
messages.extend(new_messages)
continue
else:
info(f"[Agent] 第 {turn} 轮:完成,无工具调用")
break
# 发送节点结束事件
if is_streaming:
await queue.put({"type": "node_end", "node": "agent"})
# 构建响应
response_kwargs = {"content": full_content}
@@ -274,14 +224,15 @@ def create_agent_node(chat_services: dict):
if full_reasoning_content:
response.additional_kwargs["reasoning_content"] = full_reasoning_content
info(f"[Agent] 完成 - content长度: {len(full_content)}, tool_calls: {len(final_tool_calls)}")
return {
"messages": [response],
"current_step": turn,
"llm_calls": getattr(state, "llm_calls", 0) + 1
}
except Exception as e:
error(f"[Agent] ❌ 第 {turn}出错: {e}")
error(f"[Agent] 执行出错: {e}")
import traceback
error(f"[Agent] 堆栈: {traceback.format_exc()}")
if is_streaming: