refactor: 重构 agent 节点为单步推理(移除 while 循环)
This commit is contained in:
@@ -1,15 +1,15 @@
|
||||
"""
|
||||
Agent 节点 - 简化版本
|
||||
直接定义 agent_node 函数,支持动态模型切换和循环检测
|
||||
Agent 节点 - 简化版本(单步推理)
|
||||
只负责一次 LLM 调用,不执行工具
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
from typing import Dict, Any, Optional, List
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk, SystemMessage, ToolMessage
|
||||
|
||||
from backend.app.main_graph.state import AgentState
|
||||
from backend.app.logger import info, error
|
||||
from backend.app.logger import info, error, debug
|
||||
from backend.app.tools import ALL_TOOLS
|
||||
from backend.app.agent.stream_context import get_stream_queue
|
||||
from backend.app.agent.prompts import SYSTEM_PROMPT
|
||||
@@ -28,7 +28,6 @@ def _is_similar_result(results: List[str], threshold: float = 0.8) -> bool:
|
||||
latest = results[-1]
|
||||
prev = results[-2]
|
||||
|
||||
# 长度差异太大,不算相似
|
||||
if len(latest) == 0 or len(prev) == 0:
|
||||
return len(latest) == len(prev)
|
||||
|
||||
@@ -36,7 +35,6 @@ def _is_similar_result(results: List[str], threshold: float = 0.8) -> bool:
|
||||
if len_ratio < 0.5:
|
||||
return False
|
||||
|
||||
# 检查内容重复度(简单:前100字符)
|
||||
common_len = 0
|
||||
for a, b in zip(latest[:100], prev[:100]):
|
||||
if a == b:
|
||||
@@ -50,27 +48,23 @@ def _is_similar_result(results: List[str], threshold: float = 0.8) -> bool:
|
||||
def _should_stop_for_loop(tool_calls: List[dict], tool_results: List[str]) -> bool:
|
||||
"""
|
||||
检测是否应该停止(循环检测)
|
||||
|
||||
条件:连续2次调用相同工具 + 参数相似 + 结果相似
|
||||
"""
|
||||
if len(tool_calls) < 2:
|
||||
return False
|
||||
|
||||
# 检查最近的工具调用是否相同
|
||||
last_tc = tool_calls[-1]
|
||||
prev_tc = tool_calls[-2]
|
||||
|
||||
if last_tc["name"] != prev_tc["name"]:
|
||||
return False
|
||||
|
||||
# 参数是否相似
|
||||
last_args = _normalize_args(last_tc["args"])
|
||||
prev_args = _normalize_args(prev_tc["args"])
|
||||
|
||||
if last_args != prev_args:
|
||||
return False
|
||||
|
||||
# 结果是否相似
|
||||
if len(tool_results) >= 2:
|
||||
return _is_similar_result(tool_results[-2:])
|
||||
|
||||
@@ -79,22 +73,43 @@ def _should_stop_for_loop(tool_calls: List[dict], tool_results: List[str]) -> bo
|
||||
|
||||
def create_agent_node(chat_services: dict):
|
||||
"""
|
||||
创建 Agent 节点 - 支持动态模型切换
|
||||
创建 Agent 节点 - 单步推理版本
|
||||
|
||||
简化设计:
|
||||
- 直接返回 async 函数,无需工厂包装
|
||||
- 从 config 中获取模型名称,运行时动态切换
|
||||
设计:
|
||||
- 只做一次 LLM 调用
|
||||
- 不执行工具(工具执行由 tools 节点负责)
|
||||
- 返回 AIMessage(可能包含 tool_calls)
|
||||
"""
|
||||
|
||||
async def agent_node(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
|
||||
"""Agent 节点:完整的 ReAct 循环"""
|
||||
"""Agent 节点:单步 LLM 调用"""
|
||||
queue = get_stream_queue()
|
||||
is_streaming = queue is not None
|
||||
|
||||
# 获取步数
|
||||
current_step = getattr(state, "current_step", 0)
|
||||
max_steps = getattr(state, "max_steps", 10)
|
||||
info(f"[Agent] 从第 {current_step} 步开始,最大步数: {max_steps},流式: {is_streaming}")
|
||||
info(f"[Agent] 第 {current_step + 1} 步开始")
|
||||
|
||||
# 步数已达上限
|
||||
if current_step >= max_steps:
|
||||
info("[Agent] 达到步数上限,强制结束")
|
||||
return {
|
||||
"messages": [AIMessage(content="[系统] 已达到最大步数限制。")],
|
||||
"stop": True,
|
||||
"stop_reason": "max_steps",
|
||||
}
|
||||
|
||||
# 循环检测
|
||||
tool_history = getattr(state, "tool_call_history", [])
|
||||
result_history = getattr(state, "tool_result_history", [])
|
||||
if _should_stop_for_loop(tool_history, result_history):
|
||||
info("[Agent] 检测到循环,终止推理")
|
||||
return {
|
||||
"messages": [AIMessage(content="[系统] 检测到工具调用循环,已终止。")],
|
||||
"stop": True,
|
||||
"stop_reason": "loop_detected",
|
||||
}
|
||||
|
||||
# 动态获取模型
|
||||
model_name = "primary"
|
||||
@@ -111,22 +126,15 @@ def create_agent_node(chat_services: dict):
|
||||
|
||||
# 获取记忆上下文
|
||||
memory_context = getattr(state, "memory_context", "暂无用户背景信息")
|
||||
|
||||
# 组装消息(注入记忆上下文到提示词)
|
||||
prompt_with_memory = SYSTEM_PROMPT.format(memory_context=memory_context)
|
||||
messages = [SystemMessage(content=prompt_with_memory)] + list(state.messages)
|
||||
turn = current_step
|
||||
|
||||
try:
|
||||
while turn < max_steps:
|
||||
turn += 1
|
||||
info(f"[Agent] 第 {turn} 轮思考")
|
||||
|
||||
# 发送节点开始事件
|
||||
if is_streaming:
|
||||
await queue.put({"type": "node_start", "node": "agent"})
|
||||
|
||||
# 选择 LLM(最后一轮不带工具)
|
||||
if turn >= max_steps:
|
||||
if current_step + 1 >= max_steps:
|
||||
current_llm = llm.bind_tools([])
|
||||
info(f"[Agent] 达到步数上限,使用无工具模型")
|
||||
else:
|
||||
@@ -138,10 +146,7 @@ def create_agent_node(chat_services: dict):
|
||||
pending_tool_calls = {}
|
||||
final_tool_calls = []
|
||||
|
||||
# 循环检测:记录历史调用
|
||||
tool_call_history: List[dict] = []
|
||||
tool_result_history: List[str] = []
|
||||
|
||||
try:
|
||||
# 调用 LLM
|
||||
if is_streaming:
|
||||
async for chunk in current_llm.astream(messages):
|
||||
@@ -181,7 +186,6 @@ def create_agent_node(chat_services: dict):
|
||||
if isinstance(args_val, str):
|
||||
pending_tool_calls[idx]["args"] += args_val
|
||||
else:
|
||||
import json
|
||||
pending_tool_calls[idx]["args"] += json.dumps(args_val)
|
||||
else:
|
||||
result = await current_llm.ainvoke(messages)
|
||||
@@ -199,7 +203,6 @@ def create_agent_node(chat_services: dict):
|
||||
args = {}
|
||||
if tc_data["args"]:
|
||||
try:
|
||||
import json
|
||||
args = json.loads(tc_data["args"])
|
||||
except Exception as e:
|
||||
info(f"[Agent] 解析参数失败: {e}")
|
||||
@@ -209,62 +212,9 @@ def create_agent_node(chat_services: dict):
|
||||
"args": args
|
||||
})
|
||||
|
||||
# 执行工具
|
||||
if final_tool_calls:
|
||||
info(f"[Agent] 第 {turn} 轮:调用 {len(final_tool_calls)} 个工具")
|
||||
new_messages = []
|
||||
|
||||
for tc in final_tool_calls:
|
||||
tool_name = tc["name"]
|
||||
tool_args = tc["args"]
|
||||
tool_id = tc["id"]
|
||||
|
||||
# 发送节点结束事件
|
||||
if is_streaming:
|
||||
await queue.put({
|
||||
"type": "custom",
|
||||
"data": {"type": "tool_start", "tool": tool_name, "args": tool_args, "id": tool_id}
|
||||
})
|
||||
|
||||
# 查找并执行工具
|
||||
tool_result = ""
|
||||
tool_found = False
|
||||
for tool in ALL_TOOLS:
|
||||
if tool.name == tool_name:
|
||||
tool_found = True
|
||||
try:
|
||||
tool_result = await tool.ainvoke(tool_args)
|
||||
except Exception as e:
|
||||
tool_result = f"工具调用出错: {str(e)}"
|
||||
error(f"[Agent] 工具 {tool_name} 调用出错: {e}")
|
||||
break
|
||||
|
||||
if not tool_found:
|
||||
tool_result = f"未找到工具: {tool_name}"
|
||||
|
||||
if is_streaming:
|
||||
await queue.put({
|
||||
"type": "custom",
|
||||
"data": {"type": "tool_end", "tool": tool_name, "id": tool_id, "result": str(tool_result)}
|
||||
})
|
||||
|
||||
# 记录历史(用于循环检测)
|
||||
tool_call_history.append({"name": tool_name, "args": tool_args})
|
||||
tool_result_history.append(str(tool_result))
|
||||
|
||||
new_messages.append(ToolMessage(content=str(tool_result), tool_call_id=tool_id, name=tool_name))
|
||||
|
||||
# 循环检测:相同工具 + 相似参数 + 相似结果 → 终止
|
||||
if _should_stop_for_loop(tool_call_history, tool_result_history):
|
||||
info(f"[Agent] ⚠️ 检测到循环,强制终止")
|
||||
# 添加一条终止消息
|
||||
messages.append(AIMessage(content="[系统] 检测到工具调用循环,已终止。"))
|
||||
break
|
||||
|
||||
messages.extend(new_messages)
|
||||
continue
|
||||
else:
|
||||
info(f"[Agent] 第 {turn} 轮:完成,无工具调用")
|
||||
break
|
||||
await queue.put({"type": "node_end", "node": "agent"})
|
||||
|
||||
# 构建响应
|
||||
response_kwargs = {"content": full_content}
|
||||
@@ -274,14 +224,15 @@ def create_agent_node(chat_services: dict):
|
||||
if full_reasoning_content:
|
||||
response.additional_kwargs["reasoning_content"] = full_reasoning_content
|
||||
|
||||
info(f"[Agent] 完成 - content长度: {len(full_content)}, tool_calls: {len(final_tool_calls)}")
|
||||
|
||||
return {
|
||||
"messages": [response],
|
||||
"current_step": turn,
|
||||
"llm_calls": getattr(state, "llm_calls", 0) + 1
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error(f"[Agent] ❌ 第 {turn} 轮出错: {e}")
|
||||
error(f"[Agent] 执行出错: {e}")
|
||||
import traceback
|
||||
error(f"[Agent] 堆栈: {traceback.format_exc()}")
|
||||
if is_streaming:
|
||||
|
||||
Reference in New Issue
Block a user