114 lines
3.6 KiB
Python
114 lines
3.6 KiB
Python
"""
|
||
Tools 节点 - 负责执行 tool_calls
|
||
"""
|
||
|
||
from typing import Dict, Any, Optional, List
|
||
from langchain_core.runnables.config import RunnableConfig
|
||
from langchain_core.messages import AIMessage, ToolMessage
|
||
|
||
from backend.app.main_graph.state import AgentState
|
||
from backend.app.logger import info, error, debug
|
||
from backend.app.tools import ALL_TOOLS
|
||
from backend.app.agent.stream_context import get_stream_queue
|
||
|
||
|
||
async def tools_node(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
|
||
"""
|
||
Tools 节点:执行 AIMessage.tool_calls,返回 ToolMessage 列表
|
||
|
||
职责:
|
||
1. 获取最后一条 AIMessage 的 tool_calls
|
||
2. 遍历执行每个工具
|
||
3. 记录历史(tool_call_history, tool_result_history)
|
||
4. 更新步数 current_step += 1
|
||
5. 发送工具开始/结束事件
|
||
"""
|
||
queue = get_stream_queue()
|
||
is_streaming = queue is not None
|
||
|
||
# 获取最后一条 AIMessage
|
||
last_message = state.messages[-1]
|
||
if not isinstance(last_message, AIMessage) or not last_message.tool_calls:
|
||
info("[Tools] 没有工具调用,跳过")
|
||
return {}
|
||
|
||
tool_calls = last_message.tool_calls
|
||
info(f"[Tools] 执行 {len(tool_calls)} 个工具调用")
|
||
|
||
# 获取历史记录
|
||
tool_call_history: List[dict] = list(getattr(state, "tool_call_history", []))
|
||
tool_result_history: List[str] = list(getattr(state, "tool_result_history", []))
|
||
tools_used: List[str] = list(getattr(state, "tools_used", []))
|
||
|
||
tool_messages = []
|
||
|
||
for tc in tool_calls:
|
||
tool_name = tc["name"]
|
||
tool_args = tc["args"]
|
||
tool_id = tc["id"]
|
||
|
||
tools_used.append(tool_name)
|
||
|
||
# 发送工具开始事件
|
||
if is_streaming:
|
||
await queue.put({
|
||
"type": "custom",
|
||
"data": {
|
||
"type": "tool_start",
|
||
"tool": tool_name,
|
||
"args": tool_args,
|
||
"id": tool_id
|
||
}
|
||
})
|
||
|
||
# 查找并执行工具
|
||
tool_result = ""
|
||
tool_found = False
|
||
|
||
for tool in ALL_TOOLS:
|
||
if tool.name == tool_name:
|
||
tool_found = True
|
||
try:
|
||
tool_result = await tool.ainvoke(tool_args)
|
||
debug(f"[Tools] 工具 {tool_name} 执行成功")
|
||
except Exception as e:
|
||
tool_result = f"工具调用出错: {str(e)}"
|
||
error(f"[Tools] 工具 {tool_name} 调用出错: {e}")
|
||
break
|
||
|
||
if not tool_found:
|
||
tool_result = f"未找到工具: {tool_name}"
|
||
error(f"[Tools] 未找到工具: {tool_name}")
|
||
|
||
# 发送工具结束事件
|
||
if is_streaming:
|
||
await queue.put({
|
||
"type": "custom",
|
||
"data": {
|
||
"type": "tool_end",
|
||
"tool": tool_name,
|
||
"id": tool_id,
|
||
"result": str(tool_result)
|
||
}
|
||
})
|
||
|
||
# 记录历史
|
||
tool_call_history.append({"name": tool_name, "args": tool_args})
|
||
tool_result_history.append(str(tool_result))
|
||
|
||
# 创建 ToolMessage
|
||
tool_messages.append(
|
||
ToolMessage(content=str(tool_result), tool_call_id=tool_id, name=tool_name)
|
||
)
|
||
|
||
# 更新步数
|
||
current_step = getattr(state, "current_step", 0) + 1
|
||
|
||
return {
|
||
"messages": tool_messages,
|
||
"current_step": current_step,
|
||
"tool_call_history": tool_call_history,
|
||
"tool_result_history": tool_result_history,
|
||
"tools_used": tools_used,
|
||
}
|