77 lines
2.3 KiB
Python
77 lines
2.3 KiB
Python
"""
|
|
LangGraph 状态图构建模块 - 精简版,仅负责组装图
|
|
所有节点逻辑已拆分到独立模块
|
|
"""
|
|
|
|
from langchain_core.language_models import BaseLLM
|
|
from langgraph.graph import StateGraph, START, END
|
|
|
|
# 本地模块
|
|
from app.state import MessagesState, GraphContext
|
|
from app.nodes import (
|
|
create_llm_call_node,
|
|
create_tool_call_node,
|
|
create_retrieve_memory_node,
|
|
create_summarize_node,
|
|
should_continue
|
|
)
|
|
from app.memory import Mem0Client
|
|
|
|
|
|
class GraphBuilder:
|
|
"""LangGraph 状态图构建器 - 仅负责组装图"""
|
|
|
|
def __init__(self, llm: BaseLLM, tools: list, tools_by_name: dict):
|
|
"""
|
|
初始化构建器
|
|
|
|
Args:
|
|
llm: 大语言模型实例
|
|
tools: 工具列表
|
|
tools_by_name: 名称到工具函数的映射
|
|
"""
|
|
self.llm = llm
|
|
self.tools = tools
|
|
self.tools_by_name = tools_by_name
|
|
|
|
# ⭐ 创建 Mem0 客户端(懒加载,首次使用时初始化)
|
|
self.mem0_client = Mem0Client(llm)
|
|
|
|
def build(self) -> StateGraph:
|
|
"""
|
|
构建未编译的状态图
|
|
|
|
Returns:
|
|
StateGraph 实例
|
|
"""
|
|
builder = StateGraph(MessagesState, context_schema=GraphContext)
|
|
|
|
# ⭐ 通过工厂函数创建节点(依赖注入)
|
|
retrieve_memory_node = create_retrieve_memory_node(self.mem0_client)
|
|
llm_call_node = create_llm_call_node(self.llm, self.tools)
|
|
tool_call_node = create_tool_call_node(self.tools_by_name)
|
|
summarize_node = create_summarize_node(self.mem0_client)
|
|
|
|
# 添加节点
|
|
builder.add_node("retrieve_memory", retrieve_memory_node)
|
|
builder.add_node("llm_call", llm_call_node)
|
|
builder.add_node("tool_node", tool_call_node)
|
|
builder.add_node("summarize", summarize_node)
|
|
|
|
# 添加边
|
|
builder.add_edge(START, "retrieve_memory")
|
|
builder.add_edge("retrieve_memory", "llm_call")
|
|
builder.add_conditional_edges(
|
|
"llm_call",
|
|
should_continue,
|
|
{
|
|
"tool_node": "tool_node",
|
|
"summarize": "summarize",
|
|
'END': END
|
|
}
|
|
)
|
|
builder.add_edge("tool_node", "llm_call")
|
|
builder.add_edge("summarize", END)
|
|
|
|
return builder
|