""" LangGraph 状态图构建模块 - 完全面向对象风格,无嵌套函数 """ import operator import asyncio from typing import Literal, Annotated, Any from langchain_core.language_models import BaseLLM from langchain_core.messages import AnyMessage, AIMessage, ToolMessage, SystemMessage from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langgraph.graph import StateGraph, START, END from typing_extensions import TypedDict class MessageState(TypedDict): """对话状态类型定义""" messages: Annotated[list[AnyMessage], operator.add] llm_calls: int class GraphBuilder: """LangGraph 状态图构建器 - 所有节点均为类方法""" def __init__(self, llm: BaseLLM, tools: list, tools_by_name: dict[str, Any]): """ 初始化构建器 Args: llm: 大语言模型实例 tools: 工具列表 tools_by_name: 名称到工具函数的映射 """ self.llm = llm self.tools = tools self.tools_by_name = tools_by_name self._llm_with_tools = llm.bind_tools(tools) self._prompt = self._create_prompt() self._chain = self._prompt | self._llm_with_tools @staticmethod def _create_prompt() -> ChatPromptTemplate: """创建系统提示模板(静态方法,无需访问实例)""" return ChatPromptTemplate.from_messages([ SystemMessage(content=( "你是一个个人生活助手和数据分析助手。请说中文。" "当用户询问天气或温度时,使用get_current_temperature工具获取信息。" "当用户要求读文本文件时,请使用 read_local_file 工具,只能读取 './user_docs' 目录下的文件。" "当用户要求读PDF文件时,请使用 read_pdf_summary 工具,只能读取 './user_docs' 目录下的文件。" "当用户要求读Excel文件时,请使用 read_excel_as_markdown 工具,只能读取 './user_docs' 目录下的文件。" "当用户要求抓取网页时,请使用 fetch_webpage_content 工具。" "重要:你的回答必须简洁、直接,不要包含任何关于思考过程的描述。" )), MessagesPlaceholder(variable_name="message") ]) async def call_llm(self, state: MessageState) -> dict: """ LLM 调用节点(异步方法) 注意:因为 self._chain.invoke 是同步方法,使用 run_in_executor 避免阻塞事件循环 """ loop = asyncio.get_event_loop() response = await loop.run_in_executor( None, lambda: self._chain.invoke({"message": state["messages"]}) ) return { "messages": [response], "llm_calls": state.get('llm_calls', 0) + 1 } async def call_tools(self, state: MessageState) -> dict: """ 工具执行节点(异步方法) 对于每个工具调用,在线程池中执行同步工具函数 """ last_message = state['messages'][-1] if not isinstance(last_message, AIMessage) or not last_message.tool_calls: return {"messages": []} results = [] loop = asyncio.get_event_loop() for tool_call in last_message.tool_calls: tool_name = tool_call["name"] tool_args = tool_call["args"] tool_id = tool_call["id"] tool_func = self.tools_by_name.get(tool_name) if tool_func is None: results.append(ToolMessage(content=f"Tool {tool_name} not found", tool_call_id=tool_id)) continue try: # 同步工具函数在线程池中执行 observation = await loop.run_in_executor( None, lambda: tool_func.invoke(tool_args) ) results.append(ToolMessage(content=str(observation), tool_call_id=tool_id)) except Exception as e: results.append(ToolMessage(content=f"Error: {e}", tool_call_id=tool_id)) return {"messages": results} @staticmethod def should_continue(state: MessageState) -> Literal['tool_node', END]: """ 条件边判断(静态方法) 决定下一步是进入工具节点还是结束 """ last_message = state["messages"][-1] if isinstance(last_message, AIMessage) and bool(last_message.tool_calls): return 'tool_node' return END def build(self) -> StateGraph: """ 构建未编译的状态图(返回 StateGraph 实例) 图中节点直接使用实例方法 call_llm, call_tools """ builder = StateGraph(MessageState) builder.add_node("llm_call", self.call_llm) builder.add_node("tool_node", self.call_tools) builder.add_edge(START, "llm_call") builder.add_conditional_edges("llm_call", self.should_continue, ["tool_node", END]) builder.add_edge("tool_node", "llm_call") return builder