Files
ailine/backend/app/main_graph/nodes/agent.py

243 lines
9.1 KiB
Python
Raw Normal View History

"""
Agent 节点 - 简化版本单步推理
只负责一次 LLM 调用不执行工具
"""
import json
from typing import Dict, Any, Optional, List
from langchain_core.runnables.config import RunnableConfig
from langchain_core.messages import AIMessage, AIMessageChunk, SystemMessage, ToolMessage
from backend.app.main_graph.state import AgentState
from backend.app.logger import info, error, debug
from backend.app.tools import ALL_TOOLS
from backend.app.agent.stream_context import get_stream_queue
from backend.app.agent.prompts import SYSTEM_PROMPT
def _normalize_args(args: dict) -> str:
"""标准化工具参数用于比较"""
return str(sorted(args.items()))
def _is_similar_result(results: List[str], threshold: float = 0.8) -> bool:
"""检测结果是否相似(简单实现:长度相似+部分内容重复)"""
if len(results) < 2:
return False
latest = results[-1]
prev = results[-2]
if len(latest) == 0 or len(prev) == 0:
return len(latest) == len(prev)
len_ratio = min(len(latest), len(prev)) / max(len(latest), len(prev))
if len_ratio < 0.5:
return False
common_len = 0
for a, b in zip(latest[:100], prev[:100]):
if a == b:
common_len += 1
else:
break
return (common_len / 100) > threshold
def _should_stop_for_loop(tool_calls: List[dict], tool_results: List[str]) -> bool:
"""
检测是否应该停止循环检测
条件连续2次调用相同工具 + 参数相似 + 结果相似
"""
if len(tool_calls) < 2:
return False
last_tc = tool_calls[-1]
prev_tc = tool_calls[-2]
if last_tc["name"] != prev_tc["name"]:
return False
last_args = _normalize_args(last_tc["args"])
prev_args = _normalize_args(prev_tc["args"])
if last_args != prev_args:
return False
if len(tool_results) >= 2:
return _is_similar_result(tool_results[-2:])
return False
def create_agent_node(chat_services: dict):
"""
创建 Agent 节点 - 单步推理版本
设计
- 只做一次 LLM 调用
- 不执行工具工具执行由 tools 节点负责
- 返回 AIMessage可能包含 tool_calls
"""
async def agent_node(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
"""Agent 节点:单步 LLM 调用"""
queue = get_stream_queue()
is_streaming = queue is not None
# 获取步数
current_step = getattr(state, "current_step", 0)
max_steps = getattr(state, "max_steps", 10)
info(f"[Agent] 第 {current_step + 1} 步开始")
# 步数已达上限
if current_step >= max_steps:
info("[Agent] 达到步数上限,强制结束")
return {
"messages": [AIMessage(content="[系统] 已达到最大步数限制。")],
"stop": True,
"stop_reason": "max_steps",
}
# 循环检测
tool_history = getattr(state, "tool_call_history", [])
result_history = getattr(state, "tool_result_history", [])
if _should_stop_for_loop(tool_history, result_history):
info("[Agent] 检测到循环,终止推理")
return {
"messages": [AIMessage(content="[系统] 检测到工具调用循环,已终止。")],
"stop": True,
"stop_reason": "loop_detected",
}
# 动态获取模型
model_name = "primary"
if config:
configurable = config.get("configurable", {})
model_name = configurable.get("model", "primary")
llm = chat_services.get(model_name)
if llm is None:
llm = next(iter(chat_services.values()))
info(f"[Agent] 模型 '{model_name}' 不可用,使用 '{type(llm).__name__}'")
llm_with_tools = llm.bind_tools(ALL_TOOLS)
# 获取记忆上下文
memory_context = getattr(state, "memory_context", "暂无用户背景信息")
prompt_with_memory = SYSTEM_PROMPT.format(memory_context=memory_context)
messages = [SystemMessage(content=prompt_with_memory)] + list(state.messages)
# 发送节点开始事件
if is_streaming:
await queue.put({"type": "node_start", "node": "agent"})
# 选择 LLM最后一轮不带工具
if current_step + 1 >= max_steps:
current_llm = llm.bind_tools([])
info(f"[Agent] 达到步数上限,使用无工具模型")
else:
current_llm = llm_with_tools
# 初始化
full_content = ""
full_reasoning_content = ""
pending_tool_calls = {}
final_tool_calls = []
try:
# 调用 LLM
if is_streaming:
async for chunk in current_llm.astream(messages):
if isinstance(chunk, AIMessageChunk):
if chunk.content:
full_content += chunk.content
await queue.put({
"type": "llm_token",
"node": "agent",
"token": chunk.content,
"reasoning_token": ""
})
if hasattr(chunk, 'additional_kwargs') and chunk.additional_kwargs:
reasoning = chunk.additional_kwargs.get("reasoning_content", "")
if reasoning:
full_reasoning_content += reasoning
await queue.put({
"type": "llm_token",
"node": "agent",
"token": "",
"reasoning_token": reasoning
})
if hasattr(chunk, 'tool_call_chunks') and chunk.tool_call_chunks:
for tc_chunk in chunk.tool_call_chunks:
idx = tc_chunk.get("index", 0)
if idx not in pending_tool_calls:
pending_tool_calls[idx] = {"id": "", "name": "", "args": ""}
if tc_chunk.get("id"):
pending_tool_calls[idx]["id"] += tc_chunk["id"]
if tc_chunk.get("name"):
pending_tool_calls[idx]["name"] += tc_chunk["name"]
if tc_chunk.get("args"):
args_val = tc_chunk["args"]
if isinstance(args_val, str):
pending_tool_calls[idx]["args"] += args_val
else:
pending_tool_calls[idx]["args"] += json.dumps(args_val)
else:
result = await current_llm.ainvoke(messages)
full_content = result.content if result.content else ""
if hasattr(result, 'tool_calls') and result.tool_calls:
final_tool_calls = result.tool_calls
if hasattr(result, 'additional_kwargs'):
full_reasoning_content = result.additional_kwargs.get("reasoning_content", "")
# 整理工具调用
if is_streaming:
for idx in sorted(pending_tool_calls.keys()):
tc_data = pending_tool_calls[idx]
if tc_data["name"]:
args = {}
if tc_data["args"]:
try:
args = json.loads(tc_data["args"])
except Exception as e:
info(f"[Agent] 解析参数失败: {e}")
final_tool_calls.append({
"id": tc_data["id"],
"name": tc_data["name"],
"args": args
})
# 发送节点结束事件
if is_streaming:
await queue.put({"type": "node_end", "node": "agent"})
# 构建响应
response_kwargs = {"content": full_content}
if final_tool_calls:
response_kwargs["tool_calls"] = final_tool_calls
response = AIMessage(**response_kwargs)
if full_reasoning_content:
response.additional_kwargs["reasoning_content"] = full_reasoning_content
info(f"[Agent] 完成 - content长度: {len(full_content)}, tool_calls: {len(final_tool_calls)}")
return {
"messages": [response],
"llm_calls": getattr(state, "llm_calls", 0) + 1
}
except Exception as e:
error(f"[Agent] 执行出错: {e}")
import traceback
error(f"[Agent] 堆栈: {traceback.format_exc()}")
if is_streaming:
await queue.put({"type": "error", "message": str(e)})
raise
return agent_node