""" 整合后的完整主图构建器 - 结合旧图和新图的优点 Main Graph Builder - Integrated Full Version (Old + New) """ from app.main_graph.graph import StateGraph, START, END from typing import Dict, Any, Optional from langchain_core.runnables.config import RunnableConfig from app.main_graph.state import MainGraphState, CurrentAction, MessagesState from app.main_graph.nodes.react_nodes import ( init_state_node, react_reason_node, web_search_node, error_handling_node, route_by_reasoning ) from app.main_graph.nodes.llm_call import create_llm_call_node from app.main_graph.nodes.rag_nodes import rag_retrieve_node from app.main_graph.nodes.retrieve_memory import create_retrieve_memory_node from app.main_graph.nodes.memory_trigger import memory_trigger_node, set_mem0_client from app.main_graph.nodes.summarize import create_summarize_node from app.main_graph.nodes.finalize import finalize_node from app.subgraphs.contact import build_contact_subgraph from app.subgraphs.dictionary import build_dictionary_subgraph from app.subgraphs.news_analysis import build_news_analysis_subgraph from app.memory.mem0_client import Mem0Client from app.logger import info, debug # ========== 全局变量(用于传递 mem0_client)========== # 这样就不用改旧节点的签名了 _global_mem0_client: Optional[Mem0Client] = None def set_global_mem0_client(client: Mem0Client): """设置全局的 mem0_client""" global _global_mem0_client _global_mem0_client = client set_mem0_client(client) # 同时设置给 memory_trigger_node # ========== 子图包装器(处理子图错误传递)========== def wrap_subgraph_for_error_handling(subgraph, name: str): """ 包装子图,使其错误能传递给主图 Args: subgraph: 编译好的子图 name: 子图名称(用于错误标识) Returns: 包装后的节点函数 """ def wrapped_node(state: MainGraphState) -> MainGraphState: try: # 调用子图 result = subgraph.invoke(state) # 更新主图状态 if name == "contact": state.contact_result = result elif name == "dictionary": state.dictionary_result = result elif name == "news_analysis": state.news_result = result # 标记成功 state.success = True return state except Exception as e: # 捕获子图错误,传递给主图 from app.main_graph.state import ErrorRecord, ErrorSeverity from datetime import datetime error_record = ErrorRecord( error_type=f"{name}SubgraphError", error_message=str(e), severity=ErrorSeverity.WARNING, source=f"{name}_subgraph", timestamp=datetime.now().isoformat(), retry_count=0, max_retries=1, context={"user_query": state.user_query} ) state.errors.append(error_record) state.current_error = error_record state.current_phase = "error_handling" state.success = False return state return wrapped_node # ========== 检查是否需要总结 ========== def should_summarize(state: MainGraphState) -> str: """ 检查是否需要总结对话(对话足够长时) Args: state: 当前图状态 Returns: "summarize" 或 "finalize" """ messages = getattr(state, 'messages', []) if len(messages) >= 4: return "summarize" else: return "finalize" # ========== 兼容层:让旧节点工作在新状态上 ========== def adapt_old_node_for_new_state(old_node): """ 适配旧节点(期望 MessagesState)到新状态 MainGraphState Args: old_node: 旧节点函数 Returns: 适配后的节点函数 """ async def adapted_node(state: MainGraphState, config: RunnableConfig) -> Dict[str, Any]: # 把 MainGraphState 转换为 MessagesState(旧节点期望的格式) old_state: MessagesState = { "messages": state.messages, "llm_calls": getattr(state, 'llm_calls', 0), "memory_context": getattr(state, 'memory_context', ""), "system_prompt": getattr(state, 'system_prompt', "") } # 调用旧节点 result = await old_node(old_state, config) # 把结果更新回 MainGraphState if "memory_context" in result: state.memory_context = result["memory_context"] if "llm_calls" in result: state.llm_calls = result["llm_calls"] return result return adapted_node # ========== 主图构建 ========== def build_react_main_graph(llm=None, tools=None, mem0_client=None) -> StateGraph: """ 构建整合后的完整主图(简化版:先让系统工作起来) 完整流程: START ↓ init_state (初始化) ↓ react_reason (推理) ←───────────────────────┐ ↓ │ 条件路由 │ ├─ rag_retrieve →─────────────────────────┤ ├─ contact_subgraph →─────────────────────┤ ├─ dictionary_subgraph →──────────────────┤ ├─ news_analysis_subgraph →───────────────┤ ├─ web_search →───────────────────────────┤ ├─ handle_error → (重试或结束) ────────────┤ └─ llm_call (大模型调用) ←────────────────┘ ↓ finalize (发送完成事件) ↓ END """ # 创建图 graph = StateGraph(MainGraphState) # 设置全局 mem0_client (暂时不用记忆功能) if mem0_client: set_global_mem0_client(mem0_client) # 创建节点 llm_node = None if llm is not None: llm_node = create_llm_call_node(llm, tools or []) # ========== 添加节点 ========== # 简化:先不用记忆检索相关节点 graph.add_node("init_state", init_state_node) graph.add_node("react_reason", react_reason_node) graph.add_node("rag_retrieve", rag_retrieve_node) graph.add_node("web_search", web_search_node) graph.add_node("handle_error", error_handling_node) if llm_node is not None: graph.add_node("llm_call", llm_node) # 子图节点 contact_graph = build_contact_subgraph() dictionary_graph = build_dictionary_subgraph() news_analysis_graph = build_news_analysis_subgraph() graph.add_node( "contact_subgraph", wrap_subgraph_for_error_handling(contact_graph.compile(), "contact") ) graph.add_node( "dictionary_subgraph", wrap_subgraph_for_error_handling(dictionary_graph.compile(), "dictionary") ) graph.add_node( "news_analysis_subgraph", wrap_subgraph_for_error_handling(news_analysis_graph.compile(), "news_analysis") ) # 完成节点 graph.add_node("finalize", finalize_node) # ========== 添加边 ========== # 简化:直接从 START 到 init_state graph.add_edge(START, "init_state") graph.add_edge("init_state", "react_reason") # 条件路由 graph.add_conditional_edges( "react_reason", route_by_reasoning, { "rag_retrieve": "rag_retrieve", "web_search": "web_search", "contact_subgraph": "contact_subgraph", "dictionary_subgraph": "dictionary_subgraph", "news_analysis_subgraph": "news_analysis_subgraph", "handle_error": "handle_error", "llm_call": "llm_call" } ) # 循环边 graph.add_edge("rag_retrieve", "react_reason") graph.add_edge("web_search", "react_reason") graph.add_edge("contact_subgraph", "react_reason") graph.add_edge("dictionary_subgraph", "react_reason") graph.add_edge("news_analysis_subgraph", "react_reason") graph.add_edge("handle_error", "react_reason") # llm_call 之后直接到 finalize if llm_node is not None: graph.add_edge("llm_call", "finalize") # 完成 graph.add_edge("finalize", END) info("✅ [图构建] 整合后的完整主图构建完成(简化版)") return graph # ========== 兼容性:保留旧的函数名 ========== def build_main_graph() -> StateGraph: """ 兼容性函数:旧代码调用 build_main_graph() 时返回 React 版本 """ return build_react_main_graph() # ========== 导出 ========== __all__ = [ "build_react_main_graph", "build_main_graph", "wrap_subgraph_for_error_handling", "set_global_mem0_client" ]