Files
ailine/backend/app/main_graph/utils/main_graph_builder.py

314 lines
11 KiB
Python
Raw Normal View History

"""
整合后的完整主图构建器 - 结合旧图和新图的优点
Main Graph Builder - Integrated Full Version (Old + New)
"""
from app.main_graph.graph import StateGraph, START, END
from typing import Dict, Any, Optional
from langchain_core.runnables.config import RunnableConfig
from app.main_graph.state import MainGraphState, CurrentAction, MessagesState
2026-05-01 00:36:30 +08:00
from app.main_graph.nodes.react_nodes import (
init_state_node,
react_reason_node,
web_search_node,
error_handling_node,
route_by_reasoning
)
from app.main_graph.nodes.llm_call import create_llm_call_node
2026-05-01 00:36:30 +08:00
from app.main_graph.nodes.rag_nodes import rag_retrieve_node
from app.main_graph.nodes.retrieve_memory import create_retrieve_memory_node
from app.main_graph.nodes.memory_trigger import memory_trigger_node, set_mem0_client
from app.main_graph.nodes.summarize import create_summarize_node
from app.main_graph.nodes.finalize import finalize_node
from app.subgraphs.contact import build_contact_subgraph
from app.subgraphs.dictionary import build_dictionary_subgraph
from app.subgraphs.news_analysis import build_news_analysis_subgraph
from app.memory.mem0_client import Mem0Client
from app.logger import info, debug
# ========== 全局变量(用于传递 mem0_client==========
# 这样就不用改旧节点的签名了
_global_mem0_client: Optional[Mem0Client] = None
def set_global_mem0_client(client: Mem0Client):
"""设置全局的 mem0_client"""
global _global_mem0_client
_global_mem0_client = client
set_mem0_client(client) # 同时设置给 memory_trigger_node
# ========== 子图包装器(处理子图错误传递)==========
def wrap_subgraph_for_error_handling(subgraph, name: str):
"""
包装子图使其错误能传递给主图
Args:
subgraph: 编译好的子图
name: 子图名称用于错误标识
Returns: 包装后的节点函数
"""
def wrapped_node(state: MainGraphState) -> MainGraphState:
try:
# 调用子图
result = subgraph.invoke(state)
# 更新主图状态
if name == "contact":
state.contact_result = result
elif name == "dictionary":
state.dictionary_result = result
elif name == "news_analysis":
state.news_result = result
# 标记成功
state.success = True
return state
except Exception as e:
# 捕获子图错误,传递给主图
from app.main_graph.state import ErrorRecord, ErrorSeverity
from datetime import datetime
error_record = ErrorRecord(
error_type=f"{name}SubgraphError",
error_message=str(e),
severity=ErrorSeverity.WARNING,
source=f"{name}_subgraph",
timestamp=datetime.now().isoformat(),
retry_count=0,
max_retries=1,
context={"user_query": state.user_query}
)
state.errors.append(error_record)
state.current_error = error_record
state.current_phase = "error_handling"
state.success = False
return state
return wrapped_node
# ========== 检查是否需要总结 ==========
def should_summarize(state: MainGraphState) -> str:
"""
检查是否需要总结对话对话足够长时
Args:
state: 当前图状态
Returns:
"summarize" "finalize"
"""
messages = getattr(state, 'messages', [])
if len(messages) >= 4:
return "summarize"
else:
return "finalize"
# ========== 兼容层:让旧节点工作在新状态上 ==========
def adapt_old_node_for_new_state(old_node):
"""
适配旧节点期望 MessagesState到新状态 MainGraphState
Args:
old_node: 旧节点函数
Returns: 适配后的节点函数
"""
async def adapted_node(state: MainGraphState, config: RunnableConfig) -> Dict[str, Any]:
# 把 MainGraphState 转换为 MessagesState旧节点期望的格式
old_state: MessagesState = {
"messages": state.messages,
"llm_calls": getattr(state, 'llm_calls', 0),
"memory_context": getattr(state, 'memory_context', ""),
"system_prompt": getattr(state, 'system_prompt', "")
}
# 调用旧节点
result = await old_node(old_state, config)
# 把结果更新回 MainGraphState
if "memory_context" in result:
state.memory_context = result["memory_context"]
if "llm_calls" in result:
state.llm_calls = result["llm_calls"]
return result
return adapted_node
# ========== 主图构建 ==========
def build_react_main_graph(llm=None, tools=None, mem0_client=None) -> StateGraph:
"""
构建整合后的完整主图
完整流程
START
retrieve_memory (从Mem0检索长期记忆) 来自旧图
memory_trigger (记忆触发器) 来自旧图
init_state (初始化) 来自新图
react_reason (推理)
条件路由
rag_retrieve
contact_subgraph
dictionary_subgraph
news_analysis_subgraph
web_search
handle_error (重试或结束)
llm_call (大模型调用)
检查需要总结
summarize (提交给Mem0存储) 来自旧图
(跳过)
finalize (发送完成事件) 来自旧图
END
"""
# 创建图
graph = StateGraph(MainGraphState)
# 设置全局 mem0_client
if mem0_client:
set_global_mem0_client(mem0_client)
# 创建节点
llm_node = None
if llm is not None:
llm_node = create_llm_call_node(llm, tools or [])
retrieve_memory_node = None
summarize_node = None
if mem0_client:
retrieve_memory_node = create_retrieve_memory_node(mem0_client)
summarize_node = create_summarize_node(mem0_client)
# ========== 添加节点 ==========
# 第一阶段:记忆检索(来自旧图)
if retrieve_memory_node:
graph.add_node("retrieve_memory", adapt_old_node_for_new_state(retrieve_memory_node))
graph.add_node("memory_trigger", memory_trigger_node)
# 第二阶段React 循环推理(来自新图)
graph.add_node("init_state", init_state_node)
graph.add_node("react_reason", react_reason_node)
graph.add_node("rag_retrieve", rag_retrieve_node)
graph.add_node("web_search", web_search_node)
graph.add_node("handle_error", error_handling_node)
if llm_node is not None:
graph.add_node("llm_call", llm_node)
# 子图节点
contact_graph = build_contact_subgraph()
dictionary_graph = build_dictionary_subgraph()
news_analysis_graph = build_news_analysis_subgraph()
graph.add_node(
"contact_subgraph",
wrap_subgraph_for_error_handling(contact_graph.compile(), "contact")
)
graph.add_node(
"dictionary_subgraph",
wrap_subgraph_for_error_handling(dictionary_graph.compile(), "dictionary")
)
graph.add_node(
"news_analysis_subgraph",
wrap_subgraph_for_error_handling(news_analysis_graph.compile(), "news_analysis")
)
# 第三阶段:完成处理(来自旧图)
if summarize_node:
graph.add_node("summarize", adapt_old_node_for_new_state(summarize_node))
graph.add_node("finalize", finalize_node)
# ========== 添加边 ==========
# 第一阶段:记忆检索
if retrieve_memory_node:
graph.add_edge(START, "retrieve_memory")
graph.add_edge("retrieve_memory", "memory_trigger")
else:
graph.add_edge(START, "memory_trigger")
# 进入第二阶段
graph.add_edge("memory_trigger", "init_state")
graph.add_edge("init_state", "react_reason")
# 第二阶段React 循环推理
graph.add_conditional_edges(
"react_reason",
route_by_reasoning,
{
"rag_retrieve": "rag_retrieve",
"web_search": "web_search",
"contact_subgraph": "contact_subgraph",
"dictionary_subgraph": "dictionary_subgraph",
"news_analysis_subgraph": "news_analysis_subgraph",
"handle_error": "handle_error",
"llm_call": "llm_call"
}
)
# 循环边:检索/搜索/子图/错误处理后 → 回到推理
graph.add_edge("rag_retrieve", "react_reason")
graph.add_edge("web_search", "react_reason")
graph.add_edge("contact_subgraph", "react_reason")
graph.add_edge("dictionary_subgraph", "react_reason")
graph.add_edge("news_analysis_subgraph", "react_reason")
graph.add_edge("handle_error", "react_reason")
# 第三阶段llm_call 后进入完成处理
if llm_node is not None:
if summarize_node:
# 检查是否需要总结
graph.add_conditional_edges(
"llm_call",
should_summarize,
{
"summarize": "summarize",
"finalize": "finalize"
}
)
graph.add_edge("summarize", "finalize")
else:
# 没有 summarize 节点,直接 finalize
graph.add_edge("llm_call", "finalize")
# 完成
graph.add_edge("finalize", END)
info("✅ [图构建] 整合后的完整主图构建完成")
return graph
# ========== 兼容性:保留旧的函数名 ==========
def build_main_graph() -> StateGraph:
"""
兼容性函数旧代码调用 build_main_graph() 时返回 React 版本
"""
return build_react_main_graph()
# ========== 导出 ==========
__all__ = [
"build_react_main_graph",
"build_main_graph",
"wrap_subgraph_for_error_handling",
"set_global_mem0_client"
]