233 lines
7.2 KiB
Python
233 lines
7.2 KiB
Python
|
|
"""
|
|||
|
|
主图构建器 - 构建整合后的完整主图
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from langgraph.graph import StateGraph, START, END
|
|||
|
|
from typing import Dict, Any
|
|||
|
|
|
|||
|
|
from .state import MainGraphState
|
|||
|
|
from .nodes.reasoning import react_reason_node
|
|||
|
|
from .nodes.web_search import web_search_node
|
|||
|
|
from .nodes.error_handling import error_handling_node
|
|||
|
|
from .nodes.routing import init_state_node, route_by_reasoning, should_summarize
|
|||
|
|
from .nodes.hybrid_router import (
|
|||
|
|
hybrid_router_node,
|
|||
|
|
route_from_hybrid_decision,
|
|||
|
|
check_fast_path_success,
|
|||
|
|
)
|
|||
|
|
from .nodes.fast_paths import (
|
|||
|
|
fast_chitchat_node,
|
|||
|
|
fast_rag_node,
|
|||
|
|
fast_tool_node,
|
|||
|
|
)
|
|||
|
|
from .nodes.llm_call import create_dynamic_llm_call_node
|
|||
|
|
from .nodes.rag_nodes import rag_retrieve_node
|
|||
|
|
from .nodes.retrieve_memory import create_retrieve_memory_node
|
|||
|
|
from .nodes.memory_trigger import memory_trigger_node, set_mem0_client
|
|||
|
|
from .nodes.summarize import create_summarize_node
|
|||
|
|
from .nodes.finalize import finalize_node
|
|||
|
|
from backend.app.subgraphs.contact import build_contact_subgraph
|
|||
|
|
from backend.app.subgraphs.dictionary import build_dictionary_subgraph
|
|||
|
|
from backend.app.subgraphs.news_analysis import build_news_analysis_subgraph
|
|||
|
|
from backend.app.logger import info
|
|||
|
|
|
|||
|
|
from .subgraph_wrapper import create_subgraph_nodes
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ========== 主图构建 ==========
|
|||
|
|
|
|||
|
|
def build_react_main_graph(
|
|||
|
|
chat_services: dict,
|
|||
|
|
tools=None,
|
|||
|
|
mem0_client=None,
|
|||
|
|
use_hybrid_router: bool = True
|
|||
|
|
) -> StateGraph:
|
|||
|
|
"""
|
|||
|
|
构建整合后的完整主图(支持混合路由 + 动态模型选择)
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
chat_services: 模型名称 -> ChatModel 实例 的字典
|
|||
|
|
tools: 工具列表
|
|||
|
|
mem0_client: Mem0 客户端实例
|
|||
|
|
use_hybrid_router: 是否使用混合路由(快速路径 + React 循环)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
StateGraph: 构建好的图
|
|||
|
|
"""
|
|||
|
|
# 创建图
|
|||
|
|
graph = StateGraph(MainGraphState)
|
|||
|
|
|
|||
|
|
# 设置全局 mem0_client
|
|||
|
|
if mem0_client:
|
|||
|
|
set_mem0_client(mem0_client)
|
|||
|
|
|
|||
|
|
# ========== 创建节点 ==========
|
|||
|
|
|
|||
|
|
# LLM 调用节点
|
|||
|
|
llm_node = create_dynamic_llm_call_node(chat_services, tools or [])
|
|||
|
|
|
|||
|
|
# 记忆节点
|
|||
|
|
retrieve_memory_node = None
|
|||
|
|
summarize_node = None
|
|||
|
|
if mem0_client:
|
|||
|
|
retrieve_memory_node = create_retrieve_memory_node(mem0_client)
|
|||
|
|
summarize_node = create_summarize_node(mem0_client)
|
|||
|
|
|
|||
|
|
# 子图节点
|
|||
|
|
contact_graph = build_contact_subgraph()
|
|||
|
|
dictionary_graph = build_dictionary_subgraph()
|
|||
|
|
news_analysis_graph = build_news_analysis_subgraph()
|
|||
|
|
subgraph_nodes = create_subgraph_nodes(
|
|||
|
|
contact_graph, dictionary_graph, news_analysis_graph
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# ========== 添加节点到图 ==========
|
|||
|
|
|
|||
|
|
# 阶段 1: 记忆检索
|
|||
|
|
if retrieve_memory_node:
|
|||
|
|
graph.add_node("retrieve_memory", retrieve_memory_node)
|
|||
|
|
graph.add_node("memory_trigger", memory_trigger_node)
|
|||
|
|
|
|||
|
|
# 阶段 2: 初始化
|
|||
|
|
graph.add_node("init_state", init_state_node)
|
|||
|
|
|
|||
|
|
# 阶段 3: 混合路由(可选)
|
|||
|
|
if use_hybrid_router:
|
|||
|
|
graph.add_node("hybrid_router", hybrid_router_node)
|
|||
|
|
graph.add_node("fast_chitchat", fast_chitchat_node)
|
|||
|
|
graph.add_node("fast_rag", fast_rag_node)
|
|||
|
|
graph.add_node("fast_tool", fast_tool_node)
|
|||
|
|
|
|||
|
|
# 阶段 4: React 循环推理(始终保留)
|
|||
|
|
graph.add_node("react_reason", react_reason_node)
|
|||
|
|
graph.add_node("rag_retrieve", rag_retrieve_node)
|
|||
|
|
graph.add_node("web_search", web_search_node)
|
|||
|
|
graph.add_node("handle_error", error_handling_node)
|
|||
|
|
|
|||
|
|
if llm_node is not None:
|
|||
|
|
graph.add_node("llm_call", llm_node)
|
|||
|
|
|
|||
|
|
# 子图节点
|
|||
|
|
for node_name, node_func in subgraph_nodes.items():
|
|||
|
|
graph.add_node(node_name, node_func)
|
|||
|
|
|
|||
|
|
# 阶段 5: 完成处理
|
|||
|
|
if summarize_node:
|
|||
|
|
graph.add_node("summarize", summarize_node)
|
|||
|
|
graph.add_node("finalize", finalize_node)
|
|||
|
|
|
|||
|
|
# ========== 添加边 ==========
|
|||
|
|
|
|||
|
|
# 阶段 1: 记忆检索
|
|||
|
|
_add_memory_edges(graph, retrieve_memory_node)
|
|||
|
|
|
|||
|
|
# 阶段 2: 初始化
|
|||
|
|
graph.add_edge("memory_trigger", "init_state")
|
|||
|
|
|
|||
|
|
# 阶段 3: 路由分支
|
|||
|
|
_add_routing_edges(graph, use_hybrid_router, llm_node)
|
|||
|
|
|
|||
|
|
# 阶段 4: React 循环边
|
|||
|
|
_add_react_loop_edges(graph, subgraph_nodes)
|
|||
|
|
|
|||
|
|
# 阶段 5: 完成阶段
|
|||
|
|
_add_finalize_edges(graph, llm_node, summarize_node)
|
|||
|
|
|
|||
|
|
info(f"✅ [图构建] 整合后的完整主图构建完成(混合路由: {use_hybrid_router})")
|
|||
|
|
|
|||
|
|
return graph
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _add_memory_edges(graph: StateGraph, retrieve_memory_node) -> None:
|
|||
|
|
"""添加记忆检索阶段的边"""
|
|||
|
|
if retrieve_memory_node:
|
|||
|
|
graph.add_edge(START, "retrieve_memory")
|
|||
|
|
graph.add_edge("retrieve_memory", "memory_trigger")
|
|||
|
|
else:
|
|||
|
|
graph.add_edge(START, "memory_trigger")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _add_routing_edges(graph: StateGraph, use_hybrid_router: bool, llm_node) -> None:
|
|||
|
|
"""添加路由阶段的边"""
|
|||
|
|
if use_hybrid_router:
|
|||
|
|
graph.add_edge("init_state", "hybrid_router")
|
|||
|
|
|
|||
|
|
# 混合路由条件分支
|
|||
|
|
graph.add_conditional_edges(
|
|||
|
|
"hybrid_router",
|
|||
|
|
route_from_hybrid_decision,
|
|||
|
|
{
|
|||
|
|
"fast_chitchat": "fast_chitchat",
|
|||
|
|
"fast_rag": "fast_rag",
|
|||
|
|
"fast_tool": "fast_tool",
|
|||
|
|
"react_loop": "react_reason"
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 快速路径的完成检查(fast_rag 失败直接走 react_reason)
|
|||
|
|
for fast_node in ["fast_chitchat", "fast_rag", "fast_tool"]:
|
|||
|
|
graph.add_conditional_edges(
|
|||
|
|
fast_node,
|
|||
|
|
check_fast_path_success,
|
|||
|
|
{
|
|||
|
|
"llm_call": "llm_call",
|
|||
|
|
"escalate": "react_reason"
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
info(f"✅ [图构建] 混合路由模式已启用")
|
|||
|
|
else:
|
|||
|
|
graph.add_edge("init_state", "react_reason")
|
|||
|
|
info(f"✅ [图构建] 纯 React 模式")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _add_react_loop_edges(graph: StateGraph, subgraph_nodes: Dict[str, Any]) -> None:
|
|||
|
|
"""添加 React 循环阶段的边"""
|
|||
|
|
subgraph_names = list(subgraph_nodes.keys())
|
|||
|
|
|
|||
|
|
# React 推理的条件分支
|
|||
|
|
graph.add_conditional_edges(
|
|||
|
|
"react_reason",
|
|||
|
|
route_by_reasoning,
|
|||
|
|
{
|
|||
|
|
"rag_retrieve": "rag_retrieve",
|
|||
|
|
"web_search": "web_search",
|
|||
|
|
**{name: name for name in subgraph_names},
|
|||
|
|
"handle_error": "handle_error",
|
|||
|
|
"llm_call": "llm_call"
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# RAG 检索后回到 react_reason,由意图识别决定下一步
|
|||
|
|
graph.add_edge("rag_retrieve", "react_reason")
|
|||
|
|
|
|||
|
|
# 循环边(回到 react_reason)
|
|||
|
|
loop_back_nodes = ["web_search", "handle_error"] + subgraph_names
|
|||
|
|
for node_name in loop_back_nodes:
|
|||
|
|
graph.add_edge(node_name, "react_reason")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _add_finalize_edges(graph: StateGraph, llm_node, summarize_node) -> None:
|
|||
|
|
"""添加完成阶段的边"""
|
|||
|
|
if llm_node is not None:
|
|||
|
|
if summarize_node:
|
|||
|
|
graph.add_conditional_edges(
|
|||
|
|
"llm_call",
|
|||
|
|
should_summarize,
|
|||
|
|
{
|
|||
|
|
"summarize": "summarize",
|
|||
|
|
"finalize": "finalize"
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
graph.add_edge("summarize", "finalize")
|
|||
|
|
else:
|
|||
|
|
graph.add_edge("llm_call", "finalize")
|
|||
|
|
|
|||
|
|
graph.add_edge("finalize", END)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ========== 导出 ==========
|
|||
|
|
__all__ = [
|
|||
|
|
"build_react_main_graph",
|
|||
|
|
]
|