127 lines
5.0 KiB
Python
127 lines
5.0 KiB
Python
|
|
"""
|
|||
|
|
LangGraph 状态图构建模块 - 完全面向对象风格,无嵌套函数
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import operator
|
|||
|
|
import asyncio
|
|||
|
|
from typing import Literal, Annotated, Any
|
|||
|
|
from langchain_core.language_models import BaseLLM
|
|||
|
|
from langchain_core.messages import AnyMessage, AIMessage, ToolMessage, SystemMessage
|
|||
|
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
|||
|
|
from langgraph.graph import StateGraph, START, END
|
|||
|
|
from typing_extensions import TypedDict
|
|||
|
|
|
|||
|
|
|
|||
|
|
class MessageState(TypedDict):
|
|||
|
|
"""对话状态类型定义"""
|
|||
|
|
messages: Annotated[list[AnyMessage], operator.add]
|
|||
|
|
llm_calls: int
|
|||
|
|
|
|||
|
|
|
|||
|
|
class GraphBuilder:
|
|||
|
|
"""LangGraph 状态图构建器 - 所有节点均为类方法"""
|
|||
|
|
|
|||
|
|
def __init__(self, llm: BaseLLM, tools: list, tools_by_name: dict[str, Any]):
|
|||
|
|
"""
|
|||
|
|
初始化构建器
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
llm: 大语言模型实例
|
|||
|
|
tools: 工具列表
|
|||
|
|
tools_by_name: 名称到工具函数的映射
|
|||
|
|
"""
|
|||
|
|
self.llm = llm
|
|||
|
|
self.tools = tools
|
|||
|
|
self.tools_by_name = tools_by_name
|
|||
|
|
self._llm_with_tools = llm.bind_tools(tools)
|
|||
|
|
self._prompt = self._create_prompt()
|
|||
|
|
self._chain = self._prompt | self._llm_with_tools
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def _create_prompt() -> ChatPromptTemplate:
|
|||
|
|
"""创建系统提示模板(静态方法,无需访问实例)"""
|
|||
|
|
return ChatPromptTemplate.from_messages([
|
|||
|
|
SystemMessage(content=(
|
|||
|
|
"你是一个个人生活助手和数据分析助手。请说中文。"
|
|||
|
|
"当用户询问天气或温度时,使用get_current_temperature工具获取信息。"
|
|||
|
|
"当用户要求读文本文件时,请使用 read_local_file 工具,只能读取 './user_docs' 目录下的文件。"
|
|||
|
|
"当用户要求读PDF文件时,请使用 read_pdf_summary 工具,只能读取 './user_docs' 目录下的文件。"
|
|||
|
|
"当用户要求读Excel文件时,请使用 read_excel_as_markdown 工具,只能读取 './user_docs' 目录下的文件。"
|
|||
|
|
"当用户要求抓取网页时,请使用 fetch_webpage_content 工具。"
|
|||
|
|
"重要:你的回答必须简洁、直接,不要包含任何关于思考过程的描述。"
|
|||
|
|
)),
|
|||
|
|
MessagesPlaceholder(variable_name="message")
|
|||
|
|
])
|
|||
|
|
|
|||
|
|
async def call_llm(self, state: MessageState) -> dict:
|
|||
|
|
"""
|
|||
|
|
LLM 调用节点(异步方法)
|
|||
|
|
注意:因为 self._chain.invoke 是同步方法,使用 run_in_executor 避免阻塞事件循环
|
|||
|
|
"""
|
|||
|
|
loop = asyncio.get_event_loop()
|
|||
|
|
response = await loop.run_in_executor(
|
|||
|
|
None,
|
|||
|
|
lambda: self._chain.invoke({"message": state["messages"]})
|
|||
|
|
)
|
|||
|
|
return {
|
|||
|
|
"messages": [response],
|
|||
|
|
"llm_calls": state.get('llm_calls', 0) + 1
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
async def call_tools(self, state: MessageState) -> dict:
|
|||
|
|
"""
|
|||
|
|
工具执行节点(异步方法)
|
|||
|
|
对于每个工具调用,在线程池中执行同步工具函数
|
|||
|
|
"""
|
|||
|
|
last_message = state['messages'][-1]
|
|||
|
|
if not isinstance(last_message, AIMessage) or not last_message.tool_calls:
|
|||
|
|
return {"messages": []}
|
|||
|
|
|
|||
|
|
results = []
|
|||
|
|
loop = asyncio.get_event_loop()
|
|||
|
|
|
|||
|
|
for tool_call in last_message.tool_calls:
|
|||
|
|
tool_name = tool_call["name"]
|
|||
|
|
tool_args = tool_call["args"]
|
|||
|
|
tool_id = tool_call["id"]
|
|||
|
|
tool_func = self.tools_by_name.get(tool_name)
|
|||
|
|
|
|||
|
|
if tool_func is None:
|
|||
|
|
results.append(ToolMessage(content=f"Tool {tool_name} not found", tool_call_id=tool_id))
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 同步工具函数在线程池中执行
|
|||
|
|
observation = await loop.run_in_executor(
|
|||
|
|
None,
|
|||
|
|
lambda: tool_func.invoke(tool_args)
|
|||
|
|
)
|
|||
|
|
results.append(ToolMessage(content=str(observation), tool_call_id=tool_id))
|
|||
|
|
except Exception as e:
|
|||
|
|
results.append(ToolMessage(content=f"Error: {e}", tool_call_id=tool_id))
|
|||
|
|
|
|||
|
|
return {"messages": results}
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def should_continue(state: MessageState) -> Literal['tool_node', END]:
|
|||
|
|
"""
|
|||
|
|
条件边判断(静态方法)
|
|||
|
|
决定下一步是进入工具节点还是结束
|
|||
|
|
"""
|
|||
|
|
last_message = state["messages"][-1]
|
|||
|
|
if isinstance(last_message, AIMessage) and bool(last_message.tool_calls):
|
|||
|
|
return 'tool_node'
|
|||
|
|
return END
|
|||
|
|
|
|||
|
|
def build(self) -> StateGraph:
|
|||
|
|
"""
|
|||
|
|
构建未编译的状态图(返回 StateGraph 实例)
|
|||
|
|
图中节点直接使用实例方法 call_llm, call_tools
|
|||
|
|
"""
|
|||
|
|
builder = StateGraph(MessageState)
|
|||
|
|
builder.add_node("llm_call", self.call_llm)
|
|||
|
|
builder.add_node("tool_node", self.call_tools)
|
|||
|
|
builder.add_edge(START, "llm_call")
|
|||
|
|
builder.add_conditional_edges("llm_call", self.should_continue, ["tool_node", END])
|
|||
|
|
builder.add_edge("tool_node", "llm_call")
|
|||
|
|
return builder
|