Files
ailine/backend/app/main_graph/nodes/tools.py
2026-05-08 01:44:13 +08:00

114 lines
3.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Tools 节点 - 负责执行 tool_calls
"""
from typing import Dict, Any, Optional, List
from langchain_core.runnables.config import RunnableConfig
from langchain_core.messages import AIMessage, ToolMessage
from backend.app.main_graph.state import AgentState
from backend.app.logger import info, error, debug
from backend.app.tools import ALL_TOOLS
from backend.app.agent.stream_context import get_stream_queue
async def tools_node(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
"""
Tools 节点:执行 AIMessage.tool_calls返回 ToolMessage 列表
职责:
1. 获取最后一条 AIMessage 的 tool_calls
2. 遍历执行每个工具
3. 记录历史tool_call_history, tool_result_history
4. 更新步数 current_step += 1
5. 发送工具开始/结束事件
"""
queue = get_stream_queue()
is_streaming = queue is not None
# 获取最后一条 AIMessage
last_message = state.messages[-1]
if not isinstance(last_message, AIMessage) or not last_message.tool_calls:
info("[Tools] 没有工具调用,跳过")
return {}
tool_calls = last_message.tool_calls
info(f"[Tools] 执行 {len(tool_calls)} 个工具调用")
# 获取历史记录
tool_call_history: List[dict] = list(getattr(state, "tool_call_history", []))
tool_result_history: List[str] = list(getattr(state, "tool_result_history", []))
tools_used: List[str] = list(getattr(state, "tools_used", []))
tool_messages = []
for tc in tool_calls:
tool_name = tc["name"]
tool_args = tc["args"]
tool_id = tc["id"]
tools_used.append(tool_name)
# 发送工具开始事件
if is_streaming:
await queue.put({
"type": "custom",
"data": {
"type": "tool_start",
"tool": tool_name,
"args": tool_args,
"id": tool_id
}
})
# 查找并执行工具
tool_result = ""
tool_found = False
for tool in ALL_TOOLS:
if tool.name == tool_name:
tool_found = True
try:
tool_result = await tool.ainvoke(tool_args)
debug(f"[Tools] 工具 {tool_name} 执行成功")
except Exception as e:
tool_result = f"工具调用出错: {str(e)}"
error(f"[Tools] 工具 {tool_name} 调用出错: {e}")
break
if not tool_found:
tool_result = f"未找到工具: {tool_name}"
error(f"[Tools] 未找到工具: {tool_name}")
# 发送工具结束事件
if is_streaming:
await queue.put({
"type": "custom",
"data": {
"type": "tool_end",
"tool": tool_name,
"id": tool_id,
"result": str(tool_result)
}
})
# 记录历史
tool_call_history.append({"name": tool_name, "args": tool_args})
tool_result_history.append(str(tool_result))
# 创建 ToolMessage
tool_messages.append(
ToolMessage(content=str(tool_result), tool_call_id=tool_id, name=tool_name)
)
# 更新步数
current_step = getattr(state, "current_step", 0) + 1
return {
"messages": tool_messages,
"current_step": current_step,
"tool_call_history": tool_call_history,
"tool_result_history": tool_result_history,
"tools_used": tools_used,
}