Files
ailine/backend/app/main_graph/nodes/agent.py

243 lines
9.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Agent 节点 - 简化版本(单步推理)
只负责一次 LLM 调用,不执行工具
"""
import json
from typing import Dict, Any, Optional, List
from langchain_core.runnables.config import RunnableConfig
from langchain_core.messages import AIMessage, AIMessageChunk, SystemMessage, ToolMessage
from backend.app.main_graph.state import AgentState
from backend.app.logger import info, error, debug
from backend.app.tools import ALL_TOOLS
from backend.app.agent.stream_context import get_stream_queue
from backend.app.agent.prompts import SYSTEM_PROMPT
def _normalize_args(args: dict) -> str:
"""标准化工具参数用于比较"""
return str(sorted(args.items()))
def _is_similar_result(results: List[str], threshold: float = 0.8) -> bool:
"""检测结果是否相似(简单实现:长度相似+部分内容重复)"""
if len(results) < 2:
return False
latest = results[-1]
prev = results[-2]
if len(latest) == 0 or len(prev) == 0:
return len(latest) == len(prev)
len_ratio = min(len(latest), len(prev)) / max(len(latest), len(prev))
if len_ratio < 0.5:
return False
common_len = 0
for a, b in zip(latest[:100], prev[:100]):
if a == b:
common_len += 1
else:
break
return (common_len / 100) > threshold
def _should_stop_for_loop(tool_calls: List[dict], tool_results: List[str]) -> bool:
"""
检测是否应该停止(循环检测)
条件连续2次调用相同工具 + 参数相似 + 结果相似
"""
if len(tool_calls) < 2:
return False
last_tc = tool_calls[-1]
prev_tc = tool_calls[-2]
if last_tc["name"] != prev_tc["name"]:
return False
last_args = _normalize_args(last_tc["args"])
prev_args = _normalize_args(prev_tc["args"])
if last_args != prev_args:
return False
if len(tool_results) >= 2:
return _is_similar_result(tool_results[-2:])
return False
def create_agent_node(chat_services: dict):
"""
创建 Agent 节点 - 单步推理版本
设计:
- 只做一次 LLM 调用
- 不执行工具(工具执行由 tools 节点负责)
- 返回 AIMessage可能包含 tool_calls
"""
async def agent_node(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
"""Agent 节点:单步 LLM 调用"""
queue = get_stream_queue()
is_streaming = queue is not None
# 获取步数
current_step = getattr(state, "current_step", 0)
max_steps = getattr(state, "max_steps", 10)
info(f"[Agent] 第 {current_step + 1} 步开始")
# 步数已达上限
if current_step >= max_steps:
info("[Agent] 达到步数上限,强制结束")
return {
"messages": [AIMessage(content="[系统] 已达到最大步数限制。")],
"stop": True,
"stop_reason": "max_steps",
}
# 循环检测
tool_history = getattr(state, "tool_call_history", [])
result_history = getattr(state, "tool_result_history", [])
if _should_stop_for_loop(tool_history, result_history):
info("[Agent] 检测到循环,终止推理")
return {
"messages": [AIMessage(content="[系统] 检测到工具调用循环,已终止。")],
"stop": True,
"stop_reason": "loop_detected",
}
# 动态获取模型
model_name = "primary"
if config:
configurable = config.get("configurable", {})
model_name = configurable.get("model", "primary")
llm = chat_services.get(model_name)
if llm is None:
llm = next(iter(chat_services.values()))
info(f"[Agent] 模型 '{model_name}' 不可用,使用 '{type(llm).__name__}'")
llm_with_tools = llm.bind_tools(ALL_TOOLS)
# 获取记忆上下文
memory_context = getattr(state, "memory_context", "暂无用户背景信息")
prompt_with_memory = SYSTEM_PROMPT.format(memory_context=memory_context)
messages = [SystemMessage(content=prompt_with_memory)] + list(state.messages)
# 发送节点开始事件
if is_streaming:
await queue.put({"type": "node_start", "node": "agent"})
# 选择 LLM最后一轮不带工具
if current_step + 1 >= max_steps:
current_llm = llm.bind_tools([])
info(f"[Agent] 达到步数上限,使用无工具模型")
else:
current_llm = llm_with_tools
# 初始化
full_content = ""
full_reasoning_content = ""
pending_tool_calls = {}
final_tool_calls = []
try:
# 调用 LLM
if is_streaming:
async for chunk in current_llm.astream(messages):
if isinstance(chunk, AIMessageChunk):
if chunk.content:
full_content += chunk.content
await queue.put({
"type": "llm_token",
"node": "agent",
"token": chunk.content,
"reasoning_token": ""
})
if hasattr(chunk, 'additional_kwargs') and chunk.additional_kwargs:
reasoning = chunk.additional_kwargs.get("reasoning_content", "")
if reasoning:
full_reasoning_content += reasoning
await queue.put({
"type": "llm_token",
"node": "agent",
"token": "",
"reasoning_token": reasoning
})
if hasattr(chunk, 'tool_call_chunks') and chunk.tool_call_chunks:
for tc_chunk in chunk.tool_call_chunks:
idx = tc_chunk.get("index", 0)
if idx not in pending_tool_calls:
pending_tool_calls[idx] = {"id": "", "name": "", "args": ""}
if tc_chunk.get("id"):
pending_tool_calls[idx]["id"] += tc_chunk["id"]
if tc_chunk.get("name"):
pending_tool_calls[idx]["name"] += tc_chunk["name"]
if tc_chunk.get("args"):
args_val = tc_chunk["args"]
if isinstance(args_val, str):
pending_tool_calls[idx]["args"] += args_val
else:
pending_tool_calls[idx]["args"] += json.dumps(args_val)
else:
result = await current_llm.ainvoke(messages)
full_content = result.content if result.content else ""
if hasattr(result, 'tool_calls') and result.tool_calls:
final_tool_calls = result.tool_calls
if hasattr(result, 'additional_kwargs'):
full_reasoning_content = result.additional_kwargs.get("reasoning_content", "")
# 整理工具调用
if is_streaming:
for idx in sorted(pending_tool_calls.keys()):
tc_data = pending_tool_calls[idx]
if tc_data["name"]:
args = {}
if tc_data["args"]:
try:
args = json.loads(tc_data["args"])
except Exception as e:
info(f"[Agent] 解析参数失败: {e}")
final_tool_calls.append({
"id": tc_data["id"],
"name": tc_data["name"],
"args": args
})
# 发送节点结束事件
if is_streaming:
await queue.put({"type": "node_end", "node": "agent"})
# 构建响应
response_kwargs = {"content": full_content}
if final_tool_calls:
response_kwargs["tool_calls"] = final_tool_calls
response = AIMessage(**response_kwargs)
if full_reasoning_content:
response.additional_kwargs["reasoning_content"] = full_reasoning_content
info(f"[Agent] 完成 - content长度: {len(full_content)}, tool_calls: {len(final_tool_calls)}")
return {
"messages": [response],
"llm_calls": getattr(state, "llm_calls", 0) + 1
}
except Exception as e:
error(f"[Agent] 执行出错: {e}")
import traceback
error(f"[Agent] 堆栈: {traceback.format_exc()}")
if is_streaming:
await queue.put({"type": "error", "message": str(e)})
raise
return agent_node