diff --git a/backend/app/main_graph/nodes/tools.py b/backend/app/main_graph/nodes/tools.py new file mode 100644 index 0000000..1c5d22f --- /dev/null +++ b/backend/app/main_graph/nodes/tools.py @@ -0,0 +1,113 @@ +""" +Tools 节点 - 负责执行 tool_calls +""" + +from typing import Dict, Any, Optional, List +from langchain_core.runnables.config import RunnableConfig +from langchain_core.messages import AIMessage, ToolMessage + +from backend.app.main_graph.state import AgentState +from backend.app.logger import info, error, debug +from backend.app.tools import ALL_TOOLS +from backend.app.agent.stream_context import get_stream_queue + + +async def tools_node(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]: + """ + Tools 节点:执行 AIMessage.tool_calls,返回 ToolMessage 列表 + + 职责: + 1. 获取最后一条 AIMessage 的 tool_calls + 2. 遍历执行每个工具 + 3. 记录历史(tool_call_history, tool_result_history) + 4. 更新步数 current_step += 1 + 5. 发送工具开始/结束事件 + """ + queue = get_stream_queue() + is_streaming = queue is not None + + # 获取最后一条 AIMessage + last_message = state.messages[-1] + if not isinstance(last_message, AIMessage) or not last_message.tool_calls: + info("[Tools] 没有工具调用,跳过") + return {} + + tool_calls = last_message.tool_calls + info(f"[Tools] 执行 {len(tool_calls)} 个工具调用") + + # 获取历史记录 + tool_call_history: List[dict] = list(getattr(state, "tool_call_history", [])) + tool_result_history: List[str] = list(getattr(state, "tool_result_history", [])) + tools_used: List[str] = list(getattr(state, "tools_used", [])) + + tool_messages = [] + + for tc in tool_calls: + tool_name = tc["name"] + tool_args = tc["args"] + tool_id = tc["id"] + + tools_used.append(tool_name) + + # 发送工具开始事件 + if is_streaming: + await queue.put({ + "type": "custom", + "data": { + "type": "tool_start", + "tool": tool_name, + "args": tool_args, + "id": tool_id + } + }) + + # 查找并执行工具 + tool_result = "" + tool_found = False + + for tool in ALL_TOOLS: + if tool.name == tool_name: + tool_found = True + try: + tool_result = await tool.ainvoke(tool_args) + debug(f"[Tools] 工具 {tool_name} 执行成功") + except Exception as e: + tool_result = f"工具调用出错: {str(e)}" + error(f"[Tools] 工具 {tool_name} 调用出错: {e}") + break + + if not tool_found: + tool_result = f"未找到工具: {tool_name}" + error(f"[Tools] 未找到工具: {tool_name}") + + # 发送工具结束事件 + if is_streaming: + await queue.put({ + "type": "custom", + "data": { + "type": "tool_end", + "tool": tool_name, + "id": tool_id, + "result": str(tool_result) + } + }) + + # 记录历史 + tool_call_history.append({"name": tool_name, "args": tool_args}) + tool_result_history.append(str(tool_result)) + + # 创建 ToolMessage + tool_messages.append( + ToolMessage(content=str(tool_result), tool_call_id=tool_id, name=tool_name) + ) + + # 更新步数 + current_step = getattr(state, "current_step", 0) + 1 + + return { + "messages": tool_messages, + "current_step": current_step, + "tool_call_history": tool_call_history, + "tool_result_history": tool_result_history, + "tools_used": tools_used, + }