diff --git a/backend/app/agent/agent_service.py b/backend/app/agent/agent_service.py index a1e4aa0..48c9611 100644 --- a/backend/app/agent/agent_service.py +++ b/backend/app/agent/agent_service.py @@ -13,7 +13,7 @@ from app.main_graph.config import set_stream_writer from ..model_services.chat_services import get_all_chat_services, LocalVLLMChatProvider from app.main_graph.utils.rag_initializer import init_rag_tool from app.core.intent_classifier import get_intent_classifier -from app.logger import info, warning +from app.logger import info, warning, error from app.main_graph.state import MainGraphState, CurrentAction @@ -27,8 +27,19 @@ class AIAgentService: self.intent_classifier = get_intent_classifier() # RAG 管道(可选,需要时设置) self.rag_pipeline = None + # Mem0 客户端 + self.mem0_client = None async def initialize(self): + # 0. 初始化 Mem0 客户端 + from app.memory.mem0_client import Mem0Client + # 创建一个临时的 LLM 用于 Mem0(用第一个可用的) + chat_services = get_all_chat_services() + temp_llm = None + if chat_services: + temp_llm = list(chat_services.values())[0] + self.mem0_client = Mem0Client(temp_llm) + # 1. 初始化 RAG 工具(如果需要) def create_local_llm(): provider = LocalVLLMChatProvider() @@ -42,11 +53,14 @@ class AIAgentService: set_global_rag_tool(rag_tool) # 2. 构建各模型的 Graph(使用新版 React 模式) - chat_services = get_all_chat_services() for name, llm in chat_services.items(): try: info(f"🔄 初始化模型 '{name}'...") - graph = build_react_main_graph(llm=llm, tools=self.tools).compile(checkpointer=self.checkpointer) + graph = build_react_main_graph( + llm=llm, + tools=self.tools, + mem0_client=self.mem0_client + ).compile(checkpointer=self.checkpointer) self.graphs[name] = graph info(f"✅ 模型 '{name}' 初始化成功") except Exception as e: diff --git a/backend/app/main_graph/nodes/llm_call.py b/backend/app/main_graph/nodes/llm_call.py index aee7053..5de699e 100644 --- a/backend/app/main_graph/nodes/llm_call.py +++ b/backend/app/main_graph/nodes/llm_call.py @@ -126,6 +126,9 @@ def create_llm_call_node(llm, tools: list): debug(f"📋 [LLM统计] 详细用量: {token_usage}") debug("="*80 + "\n") + # 检查是否有工具调用 + has_tool_calls = hasattr(response, 'tool_calls') and len(response.tool_calls) > 0 + result = { "messages": [response], "llm_calls": getattr(state, 'llm_calls', 0) + 1, @@ -134,7 +137,8 @@ def create_llm_call_node(llm, tools: list): "turns_since_last_summary": getattr(state, 'turns_since_last_summary', 0) + 1, "final_result": response.content, "success": True, - "current_phase": "done" + "current_phase": "done", + "has_tool_calls": has_tool_calls } log_state_change("llm_call", {**state, **result}, "离开") diff --git a/backend/app/main_graph/utils/main_graph_builder.py b/backend/app/main_graph/utils/main_graph_builder.py index 51765a5..cf30b01 100644 --- a/backend/app/main_graph/utils/main_graph_builder.py +++ b/backend/app/main_graph/utils/main_graph_builder.py @@ -1,12 +1,13 @@ """ -React 模式主图构建器 - 完整循环推理版本 -Main Graph Builder - Full React Mode with Loop Reasoning +整合后的完整主图构建器 - 结合旧图和新图的优点 +Main Graph Builder - Integrated Full Version (Old + New) """ from app.main_graph.graph import StateGraph, START, END -from typing import Dict, Any +from typing import Dict, Any, Optional +from langchain_core.runnables.config import RunnableConfig -from app.main_graph.state import MainGraphState, CurrentAction +from app.main_graph.state import MainGraphState, CurrentAction, MessagesState from app.main_graph.nodes.react_nodes import ( init_state_node, react_reason_node, @@ -16,12 +17,30 @@ from app.main_graph.nodes.react_nodes import ( ) from app.main_graph.nodes.llm_call import create_llm_call_node from app.main_graph.nodes.rag_nodes import rag_retrieve_node +from app.main_graph.nodes.retrieve_memory import create_retrieve_memory_node +from app.main_graph.nodes.memory_trigger import memory_trigger_node, set_mem0_client +from app.main_graph.nodes.summarize import create_summarize_node +from app.main_graph.nodes.finalize import finalize_node from app.subgraphs.contact import build_contact_subgraph from app.subgraphs.dictionary import build_dictionary_subgraph from app.subgraphs.news_analysis import build_news_analysis_subgraph +from app.memory.mem0_client import Mem0Client +from app.logger import info, debug -# ========== 子图包装器(处理子图错误传递) ========== +# ========== 全局变量(用于传递 mem0_client)========== +# 这样就不用改旧节点的签名了 +_global_mem0_client: Optional[Mem0Client] = None + + +def set_global_mem0_client(client: Mem0Client): + """设置全局的 mem0_client""" + global _global_mem0_client + _global_mem0_client = client + set_mem0_client(client) # 同时设置给 memory_trigger_node + + +# ========== 子图包装器(处理子图错误传递)========== def wrap_subgraph_for_error_handling(subgraph, name: str): """ 包装子图,使其错误能传递给主图 @@ -74,64 +93,126 @@ def wrap_subgraph_for_error_handling(subgraph, name: str): return wrapped_node -# ========== 主图构建 ========== -def build_react_main_graph(llm=None, tools=None) -> StateGraph: +# ========== 检查是否需要总结 ========== +def should_summarize(state: MainGraphState) -> str: """ - 构建完整的 React 模式主图 + 检查是否需要总结对话(对话足够长时) + + Args: + state: 当前图状态 + + Returns: + "summarize" 或 "finalize" + """ + messages = getattr(state, 'messages', []) + if len(messages) >= 4: + return "summarize" + else: + return "finalize" - 流程: + +# ========== 兼容层:让旧节点工作在新状态上 ========== +def adapt_old_node_for_new_state(old_node): + """ + 适配旧节点(期望 MessagesState)到新状态 MainGraphState + + Args: + old_node: 旧节点函数 + + Returns: 适配后的节点函数 + """ + async def adapted_node(state: MainGraphState, config: RunnableConfig) -> Dict[str, Any]: + # 把 MainGraphState 转换为 MessagesState(旧节点期望的格式) + old_state: MessagesState = { + "messages": state.messages, + "llm_calls": getattr(state, 'llm_calls', 0), + "memory_context": getattr(state, 'memory_context', ""), + "system_prompt": getattr(state, 'system_prompt', "") + } + + # 调用旧节点 + result = await old_node(old_state, config) + + # 把结果更新回 MainGraphState + if "memory_context" in result: + state.memory_context = result["memory_context"] + if "llm_calls" in result: + state.llm_calls = result["llm_calls"] + + return result + + return adapted_node + + +# ========== 主图构建 ========== +def build_react_main_graph(llm=None, tools=None, mem0_client=None) -> StateGraph: + """ + 构建整合后的完整主图 + + 完整流程: START ↓ - init_state (初始化) + retrieve_memory (从Mem0检索长期记忆) ← 来自旧图 ↓ - react_reason (推理) ←──────────────┐ - ↓ │ - 条件路由 │ - ├─ rag_retrieve →───────────────┤ - ├─ contact_subgraph →───────────┤ - ├─ dictionary_subgraph →────────┤ - ├─ news_analysis_subgraph →─────┤ - ├─ handle_error → (重试或结束) ─┤ - └─ llm_call (大模型调用) ←──────┘ + memory_trigger (记忆触发器) ← 来自旧图 + ↓ + init_state (初始化) ← 来自新图 + ↓ + react_reason (推理) ←──────────────────────┐ + ↓ │ + 条件路由 │ + ├─ rag_retrieve →─────────────────────────┤ + ├─ contact_subgraph →─────────────────────┤ + ├─ dictionary_subgraph →──────────────────┤ + ├─ news_analysis_subgraph →───────────────┤ + ├─ web_search →───────────────────────────┤ + ├─ handle_error → (重试或结束) ───────────┤ + └─ llm_call (大模型调用) ←────────────────┘ ↓ - 🔍 观察 (检查 tool_calls) + 检查:需要总结? + ├─ 是 → summarize (提交给Mem0存储) ← 来自旧图 + └─ 否 → (跳过) ↓ - [有工具调用?] - ├─ 是 → 执行工具 → 回到 llm_call - └─ 否 → END + finalize (发送完成事件) ← 来自旧图 + ↓ + END """ # 创建图 graph = StateGraph(MainGraphState) - # 创建 llm_call 节点 + # 设置全局 mem0_client + if mem0_client: + set_global_mem0_client(mem0_client) + + # 创建节点 llm_node = None if llm is not None: llm_node = create_llm_call_node(llm, tools or []) + retrieve_memory_node = None + summarize_node = None + if mem0_client: + retrieve_memory_node = create_retrieve_memory_node(mem0_client) + summarize_node = create_summarize_node(mem0_client) + # ========== 添加节点 ========== - # 1. 初始化节点 + # 第一阶段:记忆检索(来自旧图) + if retrieve_memory_node: + graph.add_node("retrieve_memory", adapt_old_node_for_new_state(retrieve_memory_node)) + graph.add_node("memory_trigger", memory_trigger_node) + + # 第二阶段:React 循环推理(来自新图) graph.add_node("init_state", init_state_node) - - # 2. React 推理节点 graph.add_node("react_reason", react_reason_node) - - # 3. RAG 检索节点 graph.add_node("rag_retrieve", rag_retrieve_node) - - # 4. 联网搜索节点 graph.add_node("web_search", web_search_node) - - # 5. 错误处理节点 graph.add_node("handle_error", error_handling_node) - # 6. LLM 调用节点(真正的大模型输出) if llm_node is not None: graph.add_node("llm_call", llm_node) - # ========== 添加子图节点 ========== - - # 构建并包装子图(带错误处理) + # 子图节点 contact_graph = build_contact_subgraph() dictionary_graph = build_dictionary_subgraph() news_analysis_graph = build_news_analysis_subgraph() @@ -149,39 +230,40 @@ def build_react_main_graph(llm=None, tools=None) -> StateGraph: wrap_subgraph_for_error_handling(news_analysis_graph.compile(), "news_analysis") ) + # 第三阶段:完成处理(来自旧图) + if summarize_node: + graph.add_node("summarize", adapt_old_node_for_new_state(summarize_node)) + graph.add_node("finalize", finalize_node) + # ========== 添加边 ========== - # 1. START → init_state - graph.add_edge(START, "init_state") + # 第一阶段:记忆检索 + if retrieve_memory_node: + graph.add_edge(START, "retrieve_memory") + graph.add_edge("retrieve_memory", "memory_trigger") + else: + graph.add_edge(START, "memory_trigger") - # 2. init_state → react_reason + # 进入第二阶段 + graph.add_edge("memory_trigger", "init_state") graph.add_edge("init_state", "react_reason") - # 3. 条件路由:react_reason → 各分支 + # 第二阶段:React 循环推理 graph.add_conditional_edges( "react_reason", route_by_reasoning, { - # 检索分支 → 检索后回到推理 "rag_retrieve": "rag_retrieve", - - # 联网搜索分支 "web_search": "web_search", - - # 子图分支 → 子图后回到推理 "contact_subgraph": "contact_subgraph", "dictionary_subgraph": "dictionary_subgraph", "news_analysis_subgraph": "news_analysis_subgraph", - - # 错误处理分支 "handle_error": "handle_error", - - # LLM 调用分支 → 直接输出给用户 "llm_call": "llm_call" } ) - - # 4. 循环边:检索/搜索/子图/错误处理后 → 回到推理 + + # 循环边:检索/搜索/子图/错误处理后 → 回到推理 graph.add_edge("rag_retrieve", "react_reason") graph.add_edge("web_search", "react_reason") graph.add_edge("contact_subgraph", "react_reason") @@ -189,10 +271,27 @@ def build_react_main_graph(llm=None, tools=None) -> StateGraph: graph.add_edge("news_analysis_subgraph", "react_reason") graph.add_edge("handle_error", "react_reason") - # 5. 条件路由:llm_call 后检查是否有工具调用 - # 注意:这里简化处理,先直接 END,后续再完善工具调用循环 + # 第三阶段:llm_call 后进入完成处理 if llm_node is not None: - graph.add_edge("llm_call", END) + if summarize_node: + # 检查是否需要总结 + graph.add_conditional_edges( + "llm_call", + should_summarize, + { + "summarize": "summarize", + "finalize": "finalize" + } + ) + graph.add_edge("summarize", "finalize") + else: + # 没有 summarize 节点,直接 finalize + graph.add_edge("llm_call", "finalize") + + # 完成 + graph.add_edge("finalize", END) + + info("✅ [图构建] 整合后的完整主图构建完成") return graph @@ -209,5 +308,6 @@ def build_main_graph() -> StateGraph: __all__ = [ "build_react_main_graph", "build_main_graph", - "wrap_subgraph_for_error_handling" + "wrap_subgraph_for_error_handling", + "set_global_mem0_client" ]