refactor: 重构 agent 节点为单步推理(移除 while 循环)
This commit is contained in:
@@ -1,15 +1,15 @@
|
|||||||
"""
|
"""
|
||||||
Agent 节点 - 简化版本
|
Agent 节点 - 简化版本(单步推理)
|
||||||
直接定义 agent_node 函数,支持动态模型切换和循环检测
|
只负责一次 LLM 调用,不执行工具
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import hashlib
|
import json
|
||||||
from typing import Dict, Any, Optional, List
|
from typing import Dict, Any, Optional, List
|
||||||
from langchain_core.runnables.config import RunnableConfig
|
from langchain_core.runnables.config import RunnableConfig
|
||||||
from langchain_core.messages import AIMessage, AIMessageChunk, SystemMessage, ToolMessage
|
from langchain_core.messages import AIMessage, AIMessageChunk, SystemMessage, ToolMessage
|
||||||
|
|
||||||
from backend.app.main_graph.state import AgentState
|
from backend.app.main_graph.state import AgentState
|
||||||
from backend.app.logger import info, error
|
from backend.app.logger import info, error, debug
|
||||||
from backend.app.tools import ALL_TOOLS
|
from backend.app.tools import ALL_TOOLS
|
||||||
from backend.app.agent.stream_context import get_stream_queue
|
from backend.app.agent.stream_context import get_stream_queue
|
||||||
from backend.app.agent.prompts import SYSTEM_PROMPT
|
from backend.app.agent.prompts import SYSTEM_PROMPT
|
||||||
@@ -28,7 +28,6 @@ def _is_similar_result(results: List[str], threshold: float = 0.8) -> bool:
|
|||||||
latest = results[-1]
|
latest = results[-1]
|
||||||
prev = results[-2]
|
prev = results[-2]
|
||||||
|
|
||||||
# 长度差异太大,不算相似
|
|
||||||
if len(latest) == 0 or len(prev) == 0:
|
if len(latest) == 0 or len(prev) == 0:
|
||||||
return len(latest) == len(prev)
|
return len(latest) == len(prev)
|
||||||
|
|
||||||
@@ -36,7 +35,6 @@ def _is_similar_result(results: List[str], threshold: float = 0.8) -> bool:
|
|||||||
if len_ratio < 0.5:
|
if len_ratio < 0.5:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 检查内容重复度(简单:前100字符)
|
|
||||||
common_len = 0
|
common_len = 0
|
||||||
for a, b in zip(latest[:100], prev[:100]):
|
for a, b in zip(latest[:100], prev[:100]):
|
||||||
if a == b:
|
if a == b:
|
||||||
@@ -50,27 +48,23 @@ def _is_similar_result(results: List[str], threshold: float = 0.8) -> bool:
|
|||||||
def _should_stop_for_loop(tool_calls: List[dict], tool_results: List[str]) -> bool:
|
def _should_stop_for_loop(tool_calls: List[dict], tool_results: List[str]) -> bool:
|
||||||
"""
|
"""
|
||||||
检测是否应该停止(循环检测)
|
检测是否应该停止(循环检测)
|
||||||
|
|
||||||
条件:连续2次调用相同工具 + 参数相似 + 结果相似
|
条件:连续2次调用相同工具 + 参数相似 + 结果相似
|
||||||
"""
|
"""
|
||||||
if len(tool_calls) < 2:
|
if len(tool_calls) < 2:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 检查最近的工具调用是否相同
|
|
||||||
last_tc = tool_calls[-1]
|
last_tc = tool_calls[-1]
|
||||||
prev_tc = tool_calls[-2]
|
prev_tc = tool_calls[-2]
|
||||||
|
|
||||||
if last_tc["name"] != prev_tc["name"]:
|
if last_tc["name"] != prev_tc["name"]:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 参数是否相似
|
|
||||||
last_args = _normalize_args(last_tc["args"])
|
last_args = _normalize_args(last_tc["args"])
|
||||||
prev_args = _normalize_args(prev_tc["args"])
|
prev_args = _normalize_args(prev_tc["args"])
|
||||||
|
|
||||||
if last_args != prev_args:
|
if last_args != prev_args:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 结果是否相似
|
|
||||||
if len(tool_results) >= 2:
|
if len(tool_results) >= 2:
|
||||||
return _is_similar_result(tool_results[-2:])
|
return _is_similar_result(tool_results[-2:])
|
||||||
|
|
||||||
@@ -79,22 +73,43 @@ def _should_stop_for_loop(tool_calls: List[dict], tool_results: List[str]) -> bo
|
|||||||
|
|
||||||
def create_agent_node(chat_services: dict):
|
def create_agent_node(chat_services: dict):
|
||||||
"""
|
"""
|
||||||
创建 Agent 节点 - 支持动态模型切换
|
创建 Agent 节点 - 单步推理版本
|
||||||
|
|
||||||
简化设计:
|
设计:
|
||||||
- 直接返回 async 函数,无需工厂包装
|
- 只做一次 LLM 调用
|
||||||
- 从 config 中获取模型名称,运行时动态切换
|
- 不执行工具(工具执行由 tools 节点负责)
|
||||||
|
- 返回 AIMessage(可能包含 tool_calls)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def agent_node(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
|
async def agent_node(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
|
||||||
"""Agent 节点:完整的 ReAct 循环"""
|
"""Agent 节点:单步 LLM 调用"""
|
||||||
queue = get_stream_queue()
|
queue = get_stream_queue()
|
||||||
is_streaming = queue is not None
|
is_streaming = queue is not None
|
||||||
|
|
||||||
# 获取步数
|
# 获取步数
|
||||||
current_step = getattr(state, "current_step", 0)
|
current_step = getattr(state, "current_step", 0)
|
||||||
max_steps = getattr(state, "max_steps", 10)
|
max_steps = getattr(state, "max_steps", 10)
|
||||||
info(f"[Agent] 从第 {current_step} 步开始,最大步数: {max_steps},流式: {is_streaming}")
|
info(f"[Agent] 第 {current_step + 1} 步开始")
|
||||||
|
|
||||||
|
# 步数已达上限
|
||||||
|
if current_step >= max_steps:
|
||||||
|
info("[Agent] 达到步数上限,强制结束")
|
||||||
|
return {
|
||||||
|
"messages": [AIMessage(content="[系统] 已达到最大步数限制。")],
|
||||||
|
"stop": True,
|
||||||
|
"stop_reason": "max_steps",
|
||||||
|
}
|
||||||
|
|
||||||
|
# 循环检测
|
||||||
|
tool_history = getattr(state, "tool_call_history", [])
|
||||||
|
result_history = getattr(state, "tool_result_history", [])
|
||||||
|
if _should_stop_for_loop(tool_history, result_history):
|
||||||
|
info("[Agent] 检测到循环,终止推理")
|
||||||
|
return {
|
||||||
|
"messages": [AIMessage(content="[系统] 检测到工具调用循环,已终止。")],
|
||||||
|
"stop": True,
|
||||||
|
"stop_reason": "loop_detected",
|
||||||
|
}
|
||||||
|
|
||||||
# 动态获取模型
|
# 动态获取模型
|
||||||
model_name = "primary"
|
model_name = "primary"
|
||||||
@@ -111,160 +126,95 @@ def create_agent_node(chat_services: dict):
|
|||||||
|
|
||||||
# 获取记忆上下文
|
# 获取记忆上下文
|
||||||
memory_context = getattr(state, "memory_context", "暂无用户背景信息")
|
memory_context = getattr(state, "memory_context", "暂无用户背景信息")
|
||||||
|
|
||||||
# 组装消息(注入记忆上下文到提示词)
|
|
||||||
prompt_with_memory = SYSTEM_PROMPT.format(memory_context=memory_context)
|
prompt_with_memory = SYSTEM_PROMPT.format(memory_context=memory_context)
|
||||||
messages = [SystemMessage(content=prompt_with_memory)] + list(state.messages)
|
messages = [SystemMessage(content=prompt_with_memory)] + list(state.messages)
|
||||||
turn = current_step
|
|
||||||
|
# 发送节点开始事件
|
||||||
|
if is_streaming:
|
||||||
|
await queue.put({"type": "node_start", "node": "agent"})
|
||||||
|
|
||||||
|
# 选择 LLM(最后一轮不带工具)
|
||||||
|
if current_step + 1 >= max_steps:
|
||||||
|
current_llm = llm.bind_tools([])
|
||||||
|
info(f"[Agent] 达到步数上限,使用无工具模型")
|
||||||
|
else:
|
||||||
|
current_llm = llm_with_tools
|
||||||
|
|
||||||
|
# 初始化
|
||||||
|
full_content = ""
|
||||||
|
full_reasoning_content = ""
|
||||||
|
pending_tool_calls = {}
|
||||||
|
final_tool_calls = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while turn < max_steps:
|
# 调用 LLM
|
||||||
turn += 1
|
if is_streaming:
|
||||||
info(f"[Agent] 第 {turn} 轮思考")
|
async for chunk in current_llm.astream(messages):
|
||||||
|
if isinstance(chunk, AIMessageChunk):
|
||||||
|
if chunk.content:
|
||||||
|
full_content += chunk.content
|
||||||
|
await queue.put({
|
||||||
|
"type": "llm_token",
|
||||||
|
"node": "agent",
|
||||||
|
"token": chunk.content,
|
||||||
|
"reasoning_token": ""
|
||||||
|
})
|
||||||
|
|
||||||
if is_streaming:
|
if hasattr(chunk, 'additional_kwargs') and chunk.additional_kwargs:
|
||||||
await queue.put({"type": "node_start", "node": "agent"})
|
reasoning = chunk.additional_kwargs.get("reasoning_content", "")
|
||||||
|
if reasoning:
|
||||||
# 选择 LLM(最后一轮不带工具)
|
full_reasoning_content += reasoning
|
||||||
if turn >= max_steps:
|
|
||||||
current_llm = llm.bind_tools([])
|
|
||||||
info(f"[Agent] 达到步数上限,使用无工具模型")
|
|
||||||
else:
|
|
||||||
current_llm = llm_with_tools
|
|
||||||
|
|
||||||
# 初始化
|
|
||||||
full_content = ""
|
|
||||||
full_reasoning_content = ""
|
|
||||||
pending_tool_calls = {}
|
|
||||||
final_tool_calls = []
|
|
||||||
|
|
||||||
# 循环检测:记录历史调用
|
|
||||||
tool_call_history: List[dict] = []
|
|
||||||
tool_result_history: List[str] = []
|
|
||||||
|
|
||||||
# 调用 LLM
|
|
||||||
if is_streaming:
|
|
||||||
async for chunk in current_llm.astream(messages):
|
|
||||||
if isinstance(chunk, AIMessageChunk):
|
|
||||||
if chunk.content:
|
|
||||||
full_content += chunk.content
|
|
||||||
await queue.put({
|
await queue.put({
|
||||||
"type": "llm_token",
|
"type": "llm_token",
|
||||||
"node": "agent",
|
"node": "agent",
|
||||||
"token": chunk.content,
|
"token": "",
|
||||||
"reasoning_token": ""
|
"reasoning_token": reasoning
|
||||||
})
|
})
|
||||||
|
|
||||||
if hasattr(chunk, 'additional_kwargs') and chunk.additional_kwargs:
|
if hasattr(chunk, 'tool_call_chunks') and chunk.tool_call_chunks:
|
||||||
reasoning = chunk.additional_kwargs.get("reasoning_content", "")
|
for tc_chunk in chunk.tool_call_chunks:
|
||||||
if reasoning:
|
idx = tc_chunk.get("index", 0)
|
||||||
full_reasoning_content += reasoning
|
if idx not in pending_tool_calls:
|
||||||
await queue.put({
|
pending_tool_calls[idx] = {"id": "", "name": "", "args": ""}
|
||||||
"type": "llm_token",
|
|
||||||
"node": "agent",
|
|
||||||
"token": "",
|
|
||||||
"reasoning_token": reasoning
|
|
||||||
})
|
|
||||||
|
|
||||||
if hasattr(chunk, 'tool_call_chunks') and chunk.tool_call_chunks:
|
if tc_chunk.get("id"):
|
||||||
for tc_chunk in chunk.tool_call_chunks:
|
pending_tool_calls[idx]["id"] += tc_chunk["id"]
|
||||||
idx = tc_chunk.get("index", 0)
|
if tc_chunk.get("name"):
|
||||||
if idx not in pending_tool_calls:
|
pending_tool_calls[idx]["name"] += tc_chunk["name"]
|
||||||
pending_tool_calls[idx] = {"id": "", "name": "", "args": ""}
|
if tc_chunk.get("args"):
|
||||||
|
args_val = tc_chunk["args"]
|
||||||
|
if isinstance(args_val, str):
|
||||||
|
pending_tool_calls[idx]["args"] += args_val
|
||||||
|
else:
|
||||||
|
pending_tool_calls[idx]["args"] += json.dumps(args_val)
|
||||||
|
else:
|
||||||
|
result = await current_llm.ainvoke(messages)
|
||||||
|
full_content = result.content if result.content else ""
|
||||||
|
if hasattr(result, 'tool_calls') and result.tool_calls:
|
||||||
|
final_tool_calls = result.tool_calls
|
||||||
|
if hasattr(result, 'additional_kwargs'):
|
||||||
|
full_reasoning_content = result.additional_kwargs.get("reasoning_content", "")
|
||||||
|
|
||||||
if tc_chunk.get("id"):
|
# 整理工具调用
|
||||||
pending_tool_calls[idx]["id"] += tc_chunk["id"]
|
if is_streaming:
|
||||||
if tc_chunk.get("name"):
|
for idx in sorted(pending_tool_calls.keys()):
|
||||||
pending_tool_calls[idx]["name"] += tc_chunk["name"]
|
tc_data = pending_tool_calls[idx]
|
||||||
if tc_chunk.get("args"):
|
if tc_data["name"]:
|
||||||
args_val = tc_chunk["args"]
|
args = {}
|
||||||
if isinstance(args_val, str):
|
if tc_data["args"]:
|
||||||
pending_tool_calls[idx]["args"] += args_val
|
try:
|
||||||
else:
|
args = json.loads(tc_data["args"])
|
||||||
import json
|
except Exception as e:
|
||||||
pending_tool_calls[idx]["args"] += json.dumps(args_val)
|
info(f"[Agent] 解析参数失败: {e}")
|
||||||
else:
|
final_tool_calls.append({
|
||||||
result = await current_llm.ainvoke(messages)
|
"id": tc_data["id"],
|
||||||
full_content = result.content if result.content else ""
|
"name": tc_data["name"],
|
||||||
if hasattr(result, 'tool_calls') and result.tool_calls:
|
"args": args
|
||||||
final_tool_calls = result.tool_calls
|
})
|
||||||
if hasattr(result, 'additional_kwargs'):
|
|
||||||
full_reasoning_content = result.additional_kwargs.get("reasoning_content", "")
|
|
||||||
|
|
||||||
# 整理工具调用
|
# 发送节点结束事件
|
||||||
if is_streaming:
|
if is_streaming:
|
||||||
for idx in sorted(pending_tool_calls.keys()):
|
await queue.put({"type": "node_end", "node": "agent"})
|
||||||
tc_data = pending_tool_calls[idx]
|
|
||||||
if tc_data["name"]:
|
|
||||||
args = {}
|
|
||||||
if tc_data["args"]:
|
|
||||||
try:
|
|
||||||
import json
|
|
||||||
args = json.loads(tc_data["args"])
|
|
||||||
except Exception as e:
|
|
||||||
info(f"[Agent] 解析参数失败: {e}")
|
|
||||||
final_tool_calls.append({
|
|
||||||
"id": tc_data["id"],
|
|
||||||
"name": tc_data["name"],
|
|
||||||
"args": args
|
|
||||||
})
|
|
||||||
|
|
||||||
# 执行工具
|
|
||||||
if final_tool_calls:
|
|
||||||
info(f"[Agent] 第 {turn} 轮:调用 {len(final_tool_calls)} 个工具")
|
|
||||||
new_messages = []
|
|
||||||
|
|
||||||
for tc in final_tool_calls:
|
|
||||||
tool_name = tc["name"]
|
|
||||||
tool_args = tc["args"]
|
|
||||||
tool_id = tc["id"]
|
|
||||||
|
|
||||||
if is_streaming:
|
|
||||||
await queue.put({
|
|
||||||
"type": "custom",
|
|
||||||
"data": {"type": "tool_start", "tool": tool_name, "args": tool_args, "id": tool_id}
|
|
||||||
})
|
|
||||||
|
|
||||||
# 查找并执行工具
|
|
||||||
tool_result = ""
|
|
||||||
tool_found = False
|
|
||||||
for tool in ALL_TOOLS:
|
|
||||||
if tool.name == tool_name:
|
|
||||||
tool_found = True
|
|
||||||
try:
|
|
||||||
tool_result = await tool.ainvoke(tool_args)
|
|
||||||
except Exception as e:
|
|
||||||
tool_result = f"工具调用出错: {str(e)}"
|
|
||||||
error(f"[Agent] 工具 {tool_name} 调用出错: {e}")
|
|
||||||
break
|
|
||||||
|
|
||||||
if not tool_found:
|
|
||||||
tool_result = f"未找到工具: {tool_name}"
|
|
||||||
|
|
||||||
if is_streaming:
|
|
||||||
await queue.put({
|
|
||||||
"type": "custom",
|
|
||||||
"data": {"type": "tool_end", "tool": tool_name, "id": tool_id, "result": str(tool_result)}
|
|
||||||
})
|
|
||||||
|
|
||||||
# 记录历史(用于循环检测)
|
|
||||||
tool_call_history.append({"name": tool_name, "args": tool_args})
|
|
||||||
tool_result_history.append(str(tool_result))
|
|
||||||
|
|
||||||
new_messages.append(ToolMessage(content=str(tool_result), tool_call_id=tool_id, name=tool_name))
|
|
||||||
|
|
||||||
# 循环检测:相同工具 + 相似参数 + 相似结果 → 终止
|
|
||||||
if _should_stop_for_loop(tool_call_history, tool_result_history):
|
|
||||||
info(f"[Agent] ⚠️ 检测到循环,强制终止")
|
|
||||||
# 添加一条终止消息
|
|
||||||
messages.append(AIMessage(content="[系统] 检测到工具调用循环,已终止。"))
|
|
||||||
break
|
|
||||||
|
|
||||||
messages.extend(new_messages)
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
info(f"[Agent] 第 {turn} 轮:完成,无工具调用")
|
|
||||||
break
|
|
||||||
|
|
||||||
# 构建响应
|
# 构建响应
|
||||||
response_kwargs = {"content": full_content}
|
response_kwargs = {"content": full_content}
|
||||||
@@ -274,14 +224,15 @@ def create_agent_node(chat_services: dict):
|
|||||||
if full_reasoning_content:
|
if full_reasoning_content:
|
||||||
response.additional_kwargs["reasoning_content"] = full_reasoning_content
|
response.additional_kwargs["reasoning_content"] = full_reasoning_content
|
||||||
|
|
||||||
|
info(f"[Agent] 完成 - content长度: {len(full_content)}, tool_calls: {len(final_tool_calls)}")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"messages": [response],
|
"messages": [response],
|
||||||
"current_step": turn,
|
|
||||||
"llm_calls": getattr(state, "llm_calls", 0) + 1
|
"llm_calls": getattr(state, "llm_calls", 0) + 1
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error(f"[Agent] ❌ 第 {turn} 轮出错: {e}")
|
error(f"[Agent] 执行出错: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
error(f"[Agent] 堆栈: {traceback.format_exc()}")
|
error(f"[Agent] 堆栈: {traceback.format_exc()}")
|
||||||
if is_streaming:
|
if is_streaming:
|
||||||
|
|||||||
Reference in New Issue
Block a user