Files
ailine/backend/app/main_graph/nodes/agent.py
root 6dfa9f572e
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 5m24s
重构:清理废弃代码 + 优化 Agent 架构
主要变更:
- 删除 deprecated 文件夹(intent/hybrid_router/rag_nodes 等)
- 删除 intent_classifier.py(未使用)
- 删除 subgraph_wrapper.py(死代码)
- 重构 agent.py:简化工厂函数,支持动态模型切换
- 重构 prompts.py:添加信息获取优先级、思维链要求、工具调用约束
- 优化 tools:统一位置,rag_search 返回置信度评估
- 新增 RAG 置信度评估:embedding(25%) + rerank(25%) + LLM(50%)
- 添加循环检测:防止工具无限重复调用

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-08 00:29:12 +08:00

292 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Agent 节点 - 简化版本
直接定义 agent_node 函数,支持动态模型切换和循环检测
"""
import hashlib
from typing import Dict, Any, Optional, List
from langchain_core.runnables.config import RunnableConfig
from langchain_core.messages import AIMessage, AIMessageChunk, SystemMessage, ToolMessage
from backend.app.main_graph.state import AgentState
from backend.app.logger import info, error
from backend.app.tools import ALL_TOOLS
from backend.app.agent.stream_context import get_stream_queue
from backend.app.agent.prompts import SYSTEM_PROMPT
def _normalize_args(args: dict) -> str:
"""标准化工具参数用于比较"""
return str(sorted(args.items()))
def _is_similar_result(results: List[str], threshold: float = 0.8) -> bool:
"""检测结果是否相似(简单实现:长度相似+部分内容重复)"""
if len(results) < 2:
return False
latest = results[-1]
prev = results[-2]
# 长度差异太大,不算相似
if len(latest) == 0 or len(prev) == 0:
return len(latest) == len(prev)
len_ratio = min(len(latest), len(prev)) / max(len(latest), len(prev))
if len_ratio < 0.5:
return False
# 检查内容重复度简单前100字符
common_len = 0
for a, b in zip(latest[:100], prev[:100]):
if a == b:
common_len += 1
else:
break
return (common_len / 100) > threshold
def _should_stop_for_loop(tool_calls: List[dict], tool_results: List[str]) -> bool:
"""
检测是否应该停止(循环检测)
条件连续2次调用相同工具 + 参数相似 + 结果相似
"""
if len(tool_calls) < 2:
return False
# 检查最近的工具调用是否相同
last_tc = tool_calls[-1]
prev_tc = tool_calls[-2]
if last_tc["name"] != prev_tc["name"]:
return False
# 参数是否相似
last_args = _normalize_args(last_tc["args"])
prev_args = _normalize_args(prev_tc["args"])
if last_args != prev_args:
return False
# 结果是否相似
if len(tool_results) >= 2:
return _is_similar_result(tool_results[-2:])
return False
def create_agent_node(chat_services: dict):
"""
创建 Agent 节点 - 支持动态模型切换
简化设计:
- 直接返回 async 函数,无需工厂包装
- 从 config 中获取模型名称,运行时动态切换
"""
async def agent_node(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
"""Agent 节点:完整的 ReAct 循环"""
queue = get_stream_queue()
is_streaming = queue is not None
# 获取步数
current_step = getattr(state, "current_step", 0)
max_steps = getattr(state, "max_steps", 10)
info(f"[Agent] 从第 {current_step} 步开始,最大步数: {max_steps},流式: {is_streaming}")
# 动态获取模型
model_name = "primary"
if config:
configurable = config.get("configurable", {})
model_name = configurable.get("model", "primary")
llm = chat_services.get(model_name)
if llm is None:
llm = next(iter(chat_services.values()))
info(f"[Agent] 模型 '{model_name}' 不可用,使用 '{type(llm).__name__}'")
llm_with_tools = llm.bind_tools(ALL_TOOLS)
# 获取记忆上下文
memory_context = getattr(state, "memory_context", "暂无用户背景信息")
# 组装消息(注入记忆上下文到提示词)
prompt_with_memory = SYSTEM_PROMPT.format(memory_context=memory_context)
messages = [SystemMessage(content=prompt_with_memory)] + list(state.messages)
turn = current_step
try:
while turn < max_steps:
turn += 1
info(f"[Agent] 第 {turn} 轮思考")
if is_streaming:
await queue.put({"type": "node_start", "node": "agent"})
# 选择 LLM最后一轮不带工具
if turn >= max_steps:
current_llm = llm.bind_tools([])
info(f"[Agent] 达到步数上限,使用无工具模型")
else:
current_llm = llm_with_tools
# 初始化
full_content = ""
full_reasoning_content = ""
pending_tool_calls = {}
final_tool_calls = []
# 循环检测:记录历史调用
tool_call_history: List[dict] = []
tool_result_history: List[str] = []
# 调用 LLM
if is_streaming:
async for chunk in current_llm.astream(messages):
if isinstance(chunk, AIMessageChunk):
if chunk.content:
full_content += chunk.content
await queue.put({
"type": "llm_token",
"node": "agent",
"token": chunk.content,
"reasoning_token": ""
})
if hasattr(chunk, 'additional_kwargs') and chunk.additional_kwargs:
reasoning = chunk.additional_kwargs.get("reasoning_content", "")
if reasoning:
full_reasoning_content += reasoning
await queue.put({
"type": "llm_token",
"node": "agent",
"token": "",
"reasoning_token": reasoning
})
if hasattr(chunk, 'tool_call_chunks') and chunk.tool_call_chunks:
for tc_chunk in chunk.tool_call_chunks:
idx = tc_chunk.get("index", 0)
if idx not in pending_tool_calls:
pending_tool_calls[idx] = {"id": "", "name": "", "args": ""}
if tc_chunk.get("id"):
pending_tool_calls[idx]["id"] += tc_chunk["id"]
if tc_chunk.get("name"):
pending_tool_calls[idx]["name"] += tc_chunk["name"]
if tc_chunk.get("args"):
args_val = tc_chunk["args"]
if isinstance(args_val, str):
pending_tool_calls[idx]["args"] += args_val
else:
import json
pending_tool_calls[idx]["args"] += json.dumps(args_val)
else:
result = await current_llm.ainvoke(messages)
full_content = result.content if result.content else ""
if hasattr(result, 'tool_calls') and result.tool_calls:
final_tool_calls = result.tool_calls
if hasattr(result, 'additional_kwargs'):
full_reasoning_content = result.additional_kwargs.get("reasoning_content", "")
# 整理工具调用
if is_streaming:
for idx in sorted(pending_tool_calls.keys()):
tc_data = pending_tool_calls[idx]
if tc_data["name"]:
args = {}
if tc_data["args"]:
try:
import json
args = json.loads(tc_data["args"])
except Exception as e:
info(f"[Agent] 解析参数失败: {e}")
final_tool_calls.append({
"id": tc_data["id"],
"name": tc_data["name"],
"args": args
})
# 执行工具
if final_tool_calls:
info(f"[Agent] 第 {turn} 轮:调用 {len(final_tool_calls)} 个工具")
new_messages = []
for tc in final_tool_calls:
tool_name = tc["name"]
tool_args = tc["args"]
tool_id = tc["id"]
if is_streaming:
await queue.put({
"type": "custom",
"data": {"type": "tool_start", "tool": tool_name, "args": tool_args, "id": tool_id}
})
# 查找并执行工具
tool_result = ""
tool_found = False
for tool in ALL_TOOLS:
if tool.name == tool_name:
tool_found = True
try:
tool_result = await tool.ainvoke(tool_args)
except Exception as e:
tool_result = f"工具调用出错: {str(e)}"
error(f"[Agent] 工具 {tool_name} 调用出错: {e}")
break
if not tool_found:
tool_result = f"未找到工具: {tool_name}"
if is_streaming:
await queue.put({
"type": "custom",
"data": {"type": "tool_end", "tool": tool_name, "id": tool_id, "result": str(tool_result)}
})
# 记录历史(用于循环检测)
tool_call_history.append({"name": tool_name, "args": tool_args})
tool_result_history.append(str(tool_result))
new_messages.append(ToolMessage(content=str(tool_result), tool_call_id=tool_id, name=tool_name))
# 循环检测:相同工具 + 相似参数 + 相似结果 → 终止
if _should_stop_for_loop(tool_call_history, tool_result_history):
info(f"[Agent] ⚠️ 检测到循环,强制终止")
# 添加一条终止消息
messages.append(AIMessage(content="[系统] 检测到工具调用循环,已终止。"))
break
messages.extend(new_messages)
continue
else:
info(f"[Agent] 第 {turn} 轮:完成,无工具调用")
break
# 构建响应
response_kwargs = {"content": full_content}
if final_tool_calls:
response_kwargs["tool_calls"] = final_tool_calls
response = AIMessage(**response_kwargs)
if full_reasoning_content:
response.additional_kwargs["reasoning_content"] = full_reasoning_content
return {
"messages": [response],
"current_step": turn,
"llm_calls": getattr(state, "llm_calls", 0) + 1
}
except Exception as e:
error(f"[Agent] ❌ 第 {turn} 轮出错: {e}")
import traceback
error(f"[Agent] 堆栈: {traceback.format_exc()}")
if is_streaming:
await queue.put({"type": "error", "message": str(e)})
raise
return agent_node