91 lines
3.2 KiB
Python
91 lines
3.2 KiB
Python
|
|
"""
|
|||
|
|
工具执行节点模块
|
|||
|
|
负责执行 AI 调用的工具函数
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import asyncio
|
|||
|
|
from typing import Any, Dict
|
|||
|
|
from langchain_core.messages import AIMessage, ToolMessage
|
|||
|
|
from langgraph.runtime import Runtime
|
|||
|
|
|
|||
|
|
# 本地模块
|
|||
|
|
from app.state import MessagesState, GraphContext
|
|||
|
|
from app.utils.logging import log_state_change
|
|||
|
|
from app.logger import debug, info
|
|||
|
|
|
|||
|
|
|
|||
|
|
def create_tool_call_node(tools_by_name: Dict[str, Any]):
|
|||
|
|
"""
|
|||
|
|
工厂函数:创建工具执行节点
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
tools_by_name: 名称到工具函数的映射字典
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
异步节点函数
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
async def call_tools(state: MessagesState, runtime: Runtime[GraphContext]) -> Dict[str, Any]:
|
|||
|
|
"""
|
|||
|
|
工具执行节点(异步方法)
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
state: 当前对话状态
|
|||
|
|
runtime: LangGraph 运行时上下文
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
包含 ToolMessage 的状态更新
|
|||
|
|
"""
|
|||
|
|
log_state_change("tool_node", state, "进入")
|
|||
|
|
|
|||
|
|
last_message = state['messages'][-1]
|
|||
|
|
if not isinstance(last_message, AIMessage) or not last_message.tool_calls:
|
|||
|
|
log_state_change("tool_node", state, "离开(无工具调用)")
|
|||
|
|
return {"messages": []}
|
|||
|
|
|
|||
|
|
results = []
|
|||
|
|
loop = asyncio.get_event_loop()
|
|||
|
|
|
|||
|
|
info(f"🛠️ [工具调用] 准备执行 {len(last_message.tool_calls)} 个工具")
|
|||
|
|
|
|||
|
|
for tool_call in last_message.tool_calls:
|
|||
|
|
tool_name = tool_call["name"]
|
|||
|
|
tool_args = tool_call["args"]
|
|||
|
|
tool_id = tool_call["id"]
|
|||
|
|
tool_func = tools_by_name.get(tool_name)
|
|||
|
|
|
|||
|
|
debug(f" ├─ 调用工具: {tool_name} 参数: {tool_args}")
|
|||
|
|
|
|||
|
|
if tool_func is None:
|
|||
|
|
err_msg = f"Tool {tool_name} not found"
|
|||
|
|
debug(f" └─ ❌ {err_msg}")
|
|||
|
|
results.append(ToolMessage(content=err_msg, tool_call_id=tool_id))
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 修复闭包问题:将变量作为默认参数传入 lambda
|
|||
|
|
# 如果工具支持异步 (ainvoke),优先使用异步调用
|
|||
|
|
if hasattr(tool_func, 'ainvoke'):
|
|||
|
|
observation = await tool_func.ainvoke(tool_args)
|
|||
|
|
else:
|
|||
|
|
observation = await loop.run_in_executor(
|
|||
|
|
None,
|
|||
|
|
lambda args=tool_args: tool_func.invoke(args) # 默认参数捕获当前值
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 字符打印
|
|||
|
|
result_preview = str(observation).replace("\n", " ")
|
|||
|
|
debug(f" └─ ✅ 结果: {result_preview}")
|
|||
|
|
results.append(ToolMessage(content=str(observation), tool_call_id=tool_id))
|
|||
|
|
except Exception as e:
|
|||
|
|
debug(f" └─ ❌ 异常: {e}")
|
|||
|
|
results.append(ToolMessage(content=f"Error: {e}", tool_call_id=tool_id))
|
|||
|
|
|
|||
|
|
info(f"🛠️ [工具调用] 执行完成,返回 {len(results)} 条 ToolMessage")
|
|||
|
|
|
|||
|
|
result = {"messages": results}
|
|||
|
|
log_state_change("tool_node", {**state, **result}, "离开")
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
return call_tools
|