243 lines
9.1 KiB
Python
243 lines
9.1 KiB
Python
"""
|
||
Agent 节点 - 简化版本(单步推理)
|
||
只负责一次 LLM 调用,不执行工具
|
||
"""
|
||
|
||
import json
|
||
from typing import Dict, Any, Optional, List
|
||
from langchain_core.runnables.config import RunnableConfig
|
||
from langchain_core.messages import AIMessage, AIMessageChunk, SystemMessage, ToolMessage
|
||
|
||
from backend.app.main_graph.state import AgentState
|
||
from backend.app.logger import info, error, debug
|
||
from backend.app.tools import ALL_TOOLS
|
||
from backend.app.agent.stream_context import get_stream_queue
|
||
from backend.app.agent.prompts import SYSTEM_PROMPT
|
||
|
||
|
||
def _normalize_args(args: dict) -> str:
|
||
"""标准化工具参数用于比较"""
|
||
return str(sorted(args.items()))
|
||
|
||
|
||
def _is_similar_result(results: List[str], threshold: float = 0.8) -> bool:
|
||
"""检测结果是否相似(简单实现:长度相似+部分内容重复)"""
|
||
if len(results) < 2:
|
||
return False
|
||
|
||
latest = results[-1]
|
||
prev = results[-2]
|
||
|
||
if len(latest) == 0 or len(prev) == 0:
|
||
return len(latest) == len(prev)
|
||
|
||
len_ratio = min(len(latest), len(prev)) / max(len(latest), len(prev))
|
||
if len_ratio < 0.5:
|
||
return False
|
||
|
||
common_len = 0
|
||
for a, b in zip(latest[:100], prev[:100]):
|
||
if a == b:
|
||
common_len += 1
|
||
else:
|
||
break
|
||
|
||
return (common_len / 100) > threshold
|
||
|
||
|
||
def _should_stop_for_loop(tool_calls: List[dict], tool_results: List[str]) -> bool:
|
||
"""
|
||
检测是否应该停止(循环检测)
|
||
条件:连续2次调用相同工具 + 参数相似 + 结果相似
|
||
"""
|
||
if len(tool_calls) < 2:
|
||
return False
|
||
|
||
last_tc = tool_calls[-1]
|
||
prev_tc = tool_calls[-2]
|
||
|
||
if last_tc["name"] != prev_tc["name"]:
|
||
return False
|
||
|
||
last_args = _normalize_args(last_tc["args"])
|
||
prev_args = _normalize_args(prev_tc["args"])
|
||
|
||
if last_args != prev_args:
|
||
return False
|
||
|
||
if len(tool_results) >= 2:
|
||
return _is_similar_result(tool_results[-2:])
|
||
|
||
return False
|
||
|
||
|
||
def create_agent_node(chat_services: dict):
|
||
"""
|
||
创建 Agent 节点 - 单步推理版本
|
||
|
||
设计:
|
||
- 只做一次 LLM 调用
|
||
- 不执行工具(工具执行由 tools 节点负责)
|
||
- 返回 AIMessage(可能包含 tool_calls)
|
||
"""
|
||
|
||
async def agent_node(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
|
||
"""Agent 节点:单步 LLM 调用"""
|
||
queue = get_stream_queue()
|
||
is_streaming = queue is not None
|
||
|
||
# 获取步数
|
||
current_step = getattr(state, "current_step", 0)
|
||
max_steps = getattr(state, "max_steps", 10)
|
||
info(f"[Agent] 第 {current_step + 1} 步开始")
|
||
|
||
# 步数已达上限
|
||
if current_step >= max_steps:
|
||
info("[Agent] 达到步数上限,强制结束")
|
||
return {
|
||
"messages": [AIMessage(content="[系统] 已达到最大步数限制。")],
|
||
"stop": True,
|
||
"stop_reason": "max_steps",
|
||
}
|
||
|
||
# 循环检测
|
||
tool_history = getattr(state, "tool_call_history", [])
|
||
result_history = getattr(state, "tool_result_history", [])
|
||
if _should_stop_for_loop(tool_history, result_history):
|
||
info("[Agent] 检测到循环,终止推理")
|
||
return {
|
||
"messages": [AIMessage(content="[系统] 检测到工具调用循环,已终止。")],
|
||
"stop": True,
|
||
"stop_reason": "loop_detected",
|
||
}
|
||
|
||
# 动态获取模型
|
||
model_name = "primary"
|
||
if config:
|
||
configurable = config.get("configurable", {})
|
||
model_name = configurable.get("model", "primary")
|
||
|
||
llm = chat_services.get(model_name)
|
||
if llm is None:
|
||
llm = next(iter(chat_services.values()))
|
||
info(f"[Agent] 模型 '{model_name}' 不可用,使用 '{type(llm).__name__}'")
|
||
|
||
llm_with_tools = llm.bind_tools(ALL_TOOLS)
|
||
|
||
# 获取记忆上下文
|
||
memory_context = getattr(state, "memory_context", "暂无用户背景信息")
|
||
prompt_with_memory = SYSTEM_PROMPT.format(memory_context=memory_context)
|
||
messages = [SystemMessage(content=prompt_with_memory)] + list(state.messages)
|
||
|
||
# 发送节点开始事件
|
||
if is_streaming:
|
||
await queue.put({"type": "node_start", "node": "agent"})
|
||
|
||
# 选择 LLM(最后一轮不带工具)
|
||
if current_step + 1 >= max_steps:
|
||
current_llm = llm.bind_tools([])
|
||
info(f"[Agent] 达到步数上限,使用无工具模型")
|
||
else:
|
||
current_llm = llm_with_tools
|
||
|
||
# 初始化
|
||
full_content = ""
|
||
full_reasoning_content = ""
|
||
pending_tool_calls = {}
|
||
final_tool_calls = []
|
||
|
||
try:
|
||
# 调用 LLM
|
||
if is_streaming:
|
||
async for chunk in current_llm.astream(messages):
|
||
if isinstance(chunk, AIMessageChunk):
|
||
if chunk.content:
|
||
full_content += chunk.content
|
||
await queue.put({
|
||
"type": "llm_token",
|
||
"node": "agent",
|
||
"token": chunk.content,
|
||
"reasoning_token": ""
|
||
})
|
||
|
||
if hasattr(chunk, 'additional_kwargs') and chunk.additional_kwargs:
|
||
reasoning = chunk.additional_kwargs.get("reasoning_content", "")
|
||
if reasoning:
|
||
full_reasoning_content += reasoning
|
||
await queue.put({
|
||
"type": "llm_token",
|
||
"node": "agent",
|
||
"token": "",
|
||
"reasoning_token": reasoning
|
||
})
|
||
|
||
if hasattr(chunk, 'tool_call_chunks') and chunk.tool_call_chunks:
|
||
for tc_chunk in chunk.tool_call_chunks:
|
||
idx = tc_chunk.get("index", 0)
|
||
if idx not in pending_tool_calls:
|
||
pending_tool_calls[idx] = {"id": "", "name": "", "args": ""}
|
||
|
||
if tc_chunk.get("id"):
|
||
pending_tool_calls[idx]["id"] += tc_chunk["id"]
|
||
if tc_chunk.get("name"):
|
||
pending_tool_calls[idx]["name"] += tc_chunk["name"]
|
||
if tc_chunk.get("args"):
|
||
args_val = tc_chunk["args"]
|
||
if isinstance(args_val, str):
|
||
pending_tool_calls[idx]["args"] += args_val
|
||
else:
|
||
pending_tool_calls[idx]["args"] += json.dumps(args_val)
|
||
else:
|
||
result = await current_llm.ainvoke(messages)
|
||
full_content = result.content if result.content else ""
|
||
if hasattr(result, 'tool_calls') and result.tool_calls:
|
||
final_tool_calls = result.tool_calls
|
||
if hasattr(result, 'additional_kwargs'):
|
||
full_reasoning_content = result.additional_kwargs.get("reasoning_content", "")
|
||
|
||
# 整理工具调用
|
||
if is_streaming:
|
||
for idx in sorted(pending_tool_calls.keys()):
|
||
tc_data = pending_tool_calls[idx]
|
||
if tc_data["name"]:
|
||
args = {}
|
||
if tc_data["args"]:
|
||
try:
|
||
args = json.loads(tc_data["args"])
|
||
except Exception as e:
|
||
info(f"[Agent] 解析参数失败: {e}")
|
||
final_tool_calls.append({
|
||
"id": tc_data["id"],
|
||
"name": tc_data["name"],
|
||
"args": args
|
||
})
|
||
|
||
# 发送节点结束事件
|
||
if is_streaming:
|
||
await queue.put({"type": "node_end", "node": "agent"})
|
||
|
||
# 构建响应
|
||
response_kwargs = {"content": full_content}
|
||
if final_tool_calls:
|
||
response_kwargs["tool_calls"] = final_tool_calls
|
||
response = AIMessage(**response_kwargs)
|
||
if full_reasoning_content:
|
||
response.additional_kwargs["reasoning_content"] = full_reasoning_content
|
||
|
||
info(f"[Agent] 完成 - content长度: {len(full_content)}, tool_calls: {len(final_tool_calls)}")
|
||
|
||
return {
|
||
"messages": [response],
|
||
"llm_calls": getattr(state, "llm_calls", 0) + 1
|
||
}
|
||
|
||
except Exception as e:
|
||
error(f"[Agent] 执行出错: {e}")
|
||
import traceback
|
||
error(f"[Agent] 堆栈: {traceback.format_exc()}")
|
||
if is_streaming:
|
||
await queue.put({"type": "error", "message": str(e)})
|
||
raise
|
||
|
||
return agent_node
|