diff --git a/backend/app/main_graph/nodes/finalize.py b/backend/app/main_graph/nodes/finalize.py index 420f8b3..5e6af14 100644 --- a/backend/app/main_graph/nodes/finalize.py +++ b/backend/app/main_graph/nodes/finalize.py @@ -4,32 +4,23 @@ """ from typing import Any, Dict -from app.main_graph.config import get_stream_writer # 本地模块 -from app.main_graph.state import MessagesState +from app.main_graph.state import MainGraphState from app.utils.logging import log_state_change -from app.logger import info, error +from app.logger import info, warning from langchain_core.runnables.config import RunnableConfig -def _get_attr(state, attr_name, default=None): - """通用方法:兼容 dict 和 dataclass 两种状态格式""" - if isinstance(state, dict): - return state.get(attr_name, default) - else: - return getattr(state, attr_name, default) - - -async def finalize_node(state, config: RunnableConfig) -> Dict[str, Any]: +async def finalize_node(state: MainGraphState, config: RunnableConfig) -> Dict[str, Any]: """ 完成事件节点 - 发送完成事件,包含token使用情况和耗时信息 - + Args: - state: 当前对话状态(兼容 dict 和 dataclass) + state: 当前对话状态 config: 运行时配置 - + Returns: 空字典(完成节点,无状态更新) """ @@ -37,18 +28,25 @@ async def finalize_node(state, config: RunnableConfig) -> Dict[str, Any]: try: # 获取流式写入器并发送完成事件 + from app.main_graph.config import get_stream_writer writer = get_stream_writer() - writer({ - "type": "custom", - "data": { - "type": "done", - "token_usage": _get_attr(state, "last_token_usage", {}), - "elapsed_time": _get_attr(state, "last_elapsed_time", 0.0) - } - }) - info("🏁 [完成事件] 已发送完成事件,包含token使用情况和耗时信息") + + # 只在 writer 存在且不是 noop 时才发送 + if writer and hasattr(writer, '__call__'): + try: + writer({ + "type": "custom", + "data": { + "type": "done", + "token_usage": state.last_token_usage, + "elapsed_time": state.last_elapsed_time + } + }) + info("🏁 [完成事件] 已发送完成事件,包含token使用情况和耗时信息") + except Exception as e: + warning(f"⚠️ [完成事件] 发送完成事件失败 (非致命): {e}") except Exception as e: - error(f"❌ [完成事件] 发送完成事件时发生异常: {e}") + warning(f"⚠️ [完成事件] 处理失败 (非致命): {e}") log_state_change("finalize", state, "离开") return {} \ No newline at end of file diff --git a/backend/app/main_graph/nodes/memory_trigger.py b/backend/app/main_graph/nodes/memory_trigger.py index c8afcf0..e848559 100644 --- a/backend/app/main_graph/nodes/memory_trigger.py +++ b/backend/app/main_graph/nodes/memory_trigger.py @@ -1,18 +1,10 @@ from typing import Any, Dict from langchain_core.runnables.config import RunnableConfig -from app.main_graph.state import MessagesState +from app.main_graph.state import MainGraphState from app.memory.mem0_client import Mem0Client from app.logger import info -def _get_attr(state, attr_name, default=None): - """通用方法:兼容 dict 和 dataclass 两种状态格式""" - if isinstance(state, dict): - return state.get(attr_name, default) - else: - return getattr(state, attr_name, default) - - # 全局变量,在 GraphBuilder 中注入 _mem0_client: Mem0Client = None @@ -22,12 +14,12 @@ def set_mem0_client(client: Mem0Client): _mem0_client = client -async def memory_trigger_node(state, config: RunnableConfig) -> Dict[str, Any]: +async def memory_trigger_node(state: MainGraphState, config: RunnableConfig) -> Dict[str, Any]: """检测用户消息中的记忆指令,若命中则主动调用 Mem0 存储""" if _mem0_client is None: return {} - messages = _get_attr(state, "messages", []) + messages = state.messages if not messages: return {} diff --git a/backend/app/main_graph/nodes/retrieve_memory.py b/backend/app/main_graph/nodes/retrieve_memory.py index 3aac9fd..ed8b4fa 100644 --- a/backend/app/main_graph/nodes/retrieve_memory.py +++ b/backend/app/main_graph/nodes/retrieve_memory.py @@ -6,20 +6,12 @@ from typing import Any, Dict # 本地模块 -from app.main_graph.state import MessagesState +from app.main_graph.state import MainGraphState from app.memory.mem0_client import Mem0Client from app.utils.logging import log_state_change from app.logger import debug -def _get_attr(state, attr_name, default=None): - """通用方法:兼容 dict 和 dataclass 两种状态格式""" - if isinstance(state, dict): - return state.get(attr_name, default) - else: - return getattr(state, attr_name, default) - - def create_retrieve_memory_node(mem0_client: Mem0Client): """ 工厂函数:创建记忆检索节点 @@ -33,12 +25,12 @@ def create_retrieve_memory_node(mem0_client: Mem0Client): from langchain_core.runnables.config import RunnableConfig - async def retrieve_memory(state, config: RunnableConfig) -> Dict[str, Any]: + async def retrieve_memory(state: MainGraphState, config: RunnableConfig) -> Dict[str, Any]: """ 记忆检索节点 - 使用 Mem0 Args: - state: 当前对话状态(兼容 dict 和 dataclass) + state: 当前对话状态 config: 运行时配置 Returns: @@ -49,16 +41,16 @@ def create_retrieve_memory_node(mem0_client: Mem0Client): # 从 metadata 中获取 user_id user_id = config.get("metadata", {}).get("user_id", "default_user") - # 兼容 dict 和对象两种消息格式 - messages = _get_attr(state, "messages", []) + # 获取最后一条消息 + messages = state.messages last_msg = messages[-1] if messages else None + query = "" if last_msg: if isinstance(last_msg, dict): query = str(last_msg.get("content", "")) else: query = str(last_msg.content) - else: - query = "" + memory_text_parts = [] # 确保 Mem0 已初始化(懒加载) @@ -83,7 +75,7 @@ def create_retrieve_memory_node(mem0_client: Mem0Client): memory_context = "\n\n".join(memory_text_parts) if memory_text_parts else "暂无用户信息" result = {"memory_context": memory_context} - log_state_change("retrieve_memory", {**state, **result} if isinstance(state, dict) else state, "离开") + log_state_change("retrieve_memory", state, "离开") return result return retrieve_memory \ No newline at end of file diff --git a/backend/app/main_graph/nodes/summarize.py b/backend/app/main_graph/nodes/summarize.py index f817266..d75ba1d 100644 --- a/backend/app/main_graph/nodes/summarize.py +++ b/backend/app/main_graph/nodes/summarize.py @@ -6,20 +6,12 @@ from typing import Any, Dict # 本地模块 -from app.main_graph.state import MessagesState +from app.main_graph.state import MainGraphState from app.memory.mem0_client import Mem0Client from app.utils.logging import log_state_change from app.logger import debug, info, error, warning -def _get_attr(state, attr_name, default=None): - """通用方法:兼容 dict 和 dataclass 两种状态格式""" - if isinstance(state, dict): - return state.get(attr_name, default) - else: - return getattr(state, attr_name, default) - - def create_summarize_node(mem0_client: Mem0Client): """ 工厂函数:创建记忆存储节点 @@ -33,12 +25,12 @@ def create_summarize_node(mem0_client: Mem0Client): from langchain_core.runnables.config import RunnableConfig - async def summarize_conversation(state, config: RunnableConfig) -> Dict[str, Any]: + async def summarize_conversation(state: MainGraphState, config: RunnableConfig) -> Dict[str, Any]: """ 记忆存储节点 - 使用 Mem0 Args: - state: 当前对话状态(兼容 dict 和 dataclass) + state: 当前对话状态 config: 运行时配置 Returns: @@ -46,7 +38,7 @@ def create_summarize_node(mem0_client: Mem0Client): """ log_state_change("summarize", state, "进入") - messages = _get_attr(state, "messages", []) + messages = state.messages if len(messages) < 4: debug("📝 [记忆添加] 对话过短,跳过") return {"turns_since_last_summary": 0} diff --git a/backend/app/main_graph/state.py b/backend/app/main_graph/state.py index 182ece8..8450f20 100644 --- a/backend/app/main_graph/state.py +++ b/backend/app/main_graph/state.py @@ -10,17 +10,9 @@ from app.main_graph.graph import add_messages from langchain_core.messages import BaseMessage -# ========== 兼容旧代码的类型 ========== -class MessagesState(TypedDict): - """旧的MessagesState类型(保留兼容性)""" - messages: Annotated[Sequence[BaseMessage], add_messages] - - -class GraphContext(TypedDict): - """旧的GraphContext类型(保留兼容性)""" - llm_calls: int - memory_context: str - system_prompt: str +# ========== 兼容性注释(旧代码已移除,状态已整合到 MainGraphState) ========== +# 旧的 MessagesState 和 GraphContext 已完全整合到 MainGraphState +# 不再需要单独的类型定义 # ========== 新的类型 ========== @@ -57,49 +49,52 @@ class ErrorRecord: @dataclass class MainGraphState: """ - 主图状态 - React 循环推理版本 + 主图状态 - 整合了旧 MessagesState 的所有字段 包含: - 1. 旧代码的MessagesState兼容性字段 - 2. React 推理控制字段 - 3. 循环和错误处理 - 4. 子图结果占位 - 5. 用户信息 + - 旧代码的 MessagesState 兼容性字段 + - React 推理控制字段 + - 循环和错误处理 + - 子图结果占位 + - 用户信息 """ - # ========== 兼容性字段(保留旧的MessagesState) ========== + # ========== 旧 MessagesState 兼容性字段 ========== messages: Annotated[Sequence[BaseMessage], add_messages] = field(default_factory=list) llm_calls: int = 0 memory_context: str = "" system_prompt: str = "" + turns_since_last_summary: int = 0 # 新增:来自旧状态 + last_token_usage: Dict[str, Any] = field(default_factory=dict) # 新增:来自旧状态 + last_elapsed_time: float = 0.0 # 新增:来自旧状态 # ========== 主图控制字段 ========== - user_query: str = "" # 用户当前查询 - current_action: CurrentAction = CurrentAction.NONE # 当前操作 - intent_confidence: float = 0.0 # 意图识别置信度 + user_query: str = "" + current_action: CurrentAction = CurrentAction.NONE + intent_confidence: float = 0.0 # ========== React 推理专用字段 ========== - reasoning_step: int = 0 # 当前推理步数 - max_steps: int = 40 # 最大推理步数 - last_action: str = "" # 上一步动作 - reasoning_history: List[Dict[str, Any]] = field(default_factory=list) # 推理历史 + reasoning_step: int = 0 + max_steps: int = 40 + last_action: str = "" + reasoning_history: List[Dict[str, Any]] = field(default_factory=list) # ========== RAG 相关字段 ========== - rag_context: str = "" # RAG 检索到的上下文 - rag_retrieved: bool = False # 是否已经检索过 - rag_docs: List[Dict[str, Any]] = field(default_factory=list) # 检索到的文档 + rag_context: str = "" + rag_retrieved: bool = False + rag_docs: List[Dict[str, Any]] = field(default_factory=list) - # ========== 联网搜索相关字段 ⭐ 新增 ========== - web_search_results: List[str] = field(default_factory=list) # 联网搜索结果 + # ========== 联网搜索相关字段 ========== + web_search_results: List[str] = field(default_factory=list) # ========== 错误处理字段 ========== - errors: List[ErrorRecord] = field(default_factory=list) # 错误列表 - current_error: Optional[ErrorRecord] = None # 当前错误 - retry_action: Optional[str] = None # 重试动作 + errors: List[ErrorRecord] = field(default_factory=list) + current_error: Optional[ErrorRecord] = None + retry_action: Optional[str] = None # ========== 子图结果占位 ========== - news_result: Optional[Dict[str, Any]] = None # 资讯子图结果 - dictionary_result: Optional[Dict[str, Any]] = None # 词典子图结果 - contact_result: Optional[Dict[str, Any]] = None # 通讯录子图结果 + news_result: Optional[Dict[str, Any]] = None + dictionary_result: Optional[Dict[str, Any]] = None + contact_result: Optional[Dict[str, Any]] = None # ========== 用户信息 ========== user_id: str = "" diff --git a/backend/app/main_graph/utils/main_graph_builder.py b/backend/app/main_graph/utils/main_graph_builder.py index eb51760..ab391b0 100644 --- a/backend/app/main_graph/utils/main_graph_builder.py +++ b/backend/app/main_graph/utils/main_graph_builder.py @@ -1,13 +1,12 @@ """ -整合后的完整主图构建器 - 结合旧图和新图的优点 -Main Graph Builder - Integrated Full Version (Old + New) +整合后的完整主图构建器 - 所有节点都直接操作 MainGraphState """ from app.main_graph.graph import StateGraph, START, END from typing import Dict, Any, Optional from langchain_core.runnables.config import RunnableConfig -from app.main_graph.state import MainGraphState, CurrentAction, MessagesState +from app.main_graph.state import MainGraphState from app.main_graph.nodes.react_nodes import ( init_state_node, react_reason_node, @@ -28,16 +27,21 @@ from app.memory.mem0_client import Mem0Client from app.logger import info, debug -# ========== 全局变量(用于传递 mem0_client)========== -# 这样就不用改旧节点的签名了 -_global_mem0_client: Optional[Mem0Client] = None - - -def set_global_mem0_client(client: Mem0Client): - """设置全局的 mem0_client""" - global _global_mem0_client - _global_mem0_client = client - set_mem0_client(client) # 同时设置给 memory_trigger_node +# ========== 检查是否需要总结 ========== +def should_summarize(state: MainGraphState) -> str: + """ + 检查是否需要总结对话(对话足够长时) + + Args: + state: 当前图状态 + + Returns: + "summarize" 或 "finalize" + """ + if state.turns_since_last_summary >= 5: # 每5轮对话总结一次 + return "summarize" + else: + return "finalize" # ========== 子图包装器(处理子图错误传递)========== @@ -93,65 +97,18 @@ def wrap_subgraph_for_error_handling(subgraph, name: str): return wrapped_node -# ========== 检查是否需要总结 ========== -def should_summarize(state: MainGraphState) -> str: - """ - 检查是否需要总结对话(对话足够长时) - - Args: - state: 当前图状态 - - Returns: - "summarize" 或 "finalize" - """ - messages = getattr(state, 'messages', []) - if len(messages) >= 4: - return "summarize" - else: - return "finalize" - - -# ========== 兼容层:让旧节点工作在新状态上 ========== -def adapt_old_node_for_new_state(old_node): - """ - 适配旧节点(期望 MessagesState)到新状态 MainGraphState - - Args: - old_node: 旧节点函数 - - Returns: 适配后的节点函数 - """ - async def adapted_node(state: MainGraphState, config: RunnableConfig) -> Dict[str, Any]: - # 把 MainGraphState 转换为 MessagesState(旧节点期望的格式) - old_state: MessagesState = { - "messages": state.messages, - "llm_calls": getattr(state, 'llm_calls', 0), - "memory_context": getattr(state, 'memory_context', ""), - "system_prompt": getattr(state, 'system_prompt', "") - } - - # 调用旧节点 - result = await old_node(old_state, config) - - # 把结果更新回 MainGraphState - if "memory_context" in result: - state.memory_context = result["memory_context"] - if "llm_calls" in result: - state.llm_calls = result["llm_calls"] - - return result - - return adapted_node - - # ========== 主图构建 ========== def build_react_main_graph(llm=None, tools=None, mem0_client=None) -> StateGraph: """ - 构建整合后的完整主图(简化版:先让系统工作起来) + 构建整合后的完整主图 完整流程: START ↓ + retrieve_memory (从Mem0检索长期记忆) + ↓ + memory_trigger (记忆触发器) + ↓ init_state (初始化) ↓ react_reason (推理) ←───────────────────────┐ @@ -165,6 +122,10 @@ def build_react_main_graph(llm=None, tools=None, mem0_client=None) -> StateGraph ├─ handle_error → (重试或结束) ────────────┤ └─ llm_call (大模型调用) ←────────────────┘ ↓ + 检查:需要总结吗? + ├─ 是 → summarize (提交给Mem0存储) + └─ 否 → (跳过) + ↓ finalize (发送完成事件) ↓ END @@ -172,7 +133,7 @@ def build_react_main_graph(llm=None, tools=None, mem0_client=None) -> StateGraph # 创建图 graph = StateGraph(MainGraphState) - # 设置全局 mem0_client (暂时不用记忆功能) + # 设置全局 mem0_client if mem0_client: set_global_mem0_client(mem0_client) @@ -181,8 +142,20 @@ def build_react_main_graph(llm=None, tools=None, mem0_client=None) -> StateGraph if llm is not None: llm_node = create_llm_call_node(llm, tools or []) + retrieve_memory_node = None + summarize_node = None + if mem0_client: + retrieve_memory_node = create_retrieve_memory_node(mem0_client) + summarize_node = create_summarize_node(mem0_client) + # ========== 添加节点 ========== - # 简化:先不用记忆检索相关节点 + + # 第一阶段:记忆检索 + if retrieve_memory_node: + graph.add_node("retrieve_memory", retrieve_memory_node) + graph.add_node("memory_trigger", memory_trigger_node) + + # 第二阶段:React 循环推理 graph.add_node("init_state", init_state_node) graph.add_node("react_reason", react_reason_node) graph.add_node("rag_retrieve", rag_retrieve_node) @@ -210,15 +183,25 @@ def build_react_main_graph(llm=None, tools=None, mem0_client=None) -> StateGraph wrap_subgraph_for_error_handling(news_analysis_graph.compile(), "news_analysis") ) - # 完成节点 + # 第三阶段:完成处理 + if summarize_node: + graph.add_node("summarize", summarize_node) graph.add_node("finalize", finalize_node) # ========== 添加边 ========== - # 简化:直接从 START 到 init_state - graph.add_edge(START, "init_state") + + # 第一阶段:记忆检索 + if retrieve_memory_node: + graph.add_edge(START, "retrieve_memory") + graph.add_edge("retrieve_memory", "memory_trigger") + else: + graph.add_edge(START, "memory_trigger") + + # 进入第二阶段 + graph.add_edge("memory_trigger", "init_state") graph.add_edge("init_state", "react_reason") - # 条件路由 + # 第二阶段:React 循环推理 graph.add_conditional_edges( "react_reason", route_by_reasoning, @@ -241,14 +224,27 @@ def build_react_main_graph(llm=None, tools=None, mem0_client=None) -> StateGraph graph.add_edge("news_analysis_subgraph", "react_reason") graph.add_edge("handle_error", "react_reason") - # llm_call 之后直接到 finalize + # 第三阶段:llm_call 后进入完成处理 if llm_node is not None: - graph.add_edge("llm_call", "finalize") + if summarize_node: + # 检查是否需要总结 + graph.add_conditional_edges( + "llm_call", + should_summarize, + { + "summarize": "summarize", + "finalize": "finalize" + } + ) + graph.add_edge("summarize", "finalize") + else: + # 没有 summarize 节点,直接 finalize + graph.add_edge("llm_call", "finalize") # 完成 graph.add_edge("finalize", END) - info("✅ [图构建] 整合后的完整主图构建完成(简化版)") + info("✅ [图构建] 整合后的完整主图构建完成") return graph @@ -265,6 +261,5 @@ def build_main_graph() -> StateGraph: __all__ = [ "build_react_main_graph", "build_main_graph", - "wrap_subgraph_for_error_handling", - "set_global_mem0_client" -] + "wrap_subgraph_for_error_handling" +] \ No newline at end of file diff --git a/backend/app/utils/logging.py b/backend/app/utils/logging.py index 770c731..c3137fc 100644 --- a/backend/app/utils/logging.py +++ b/backend/app/utils/logging.py @@ -7,14 +7,6 @@ from app.config import ENABLE_GRAPH_TRACE from app.logger import debug, info -def _get_attr(state, attr_name, default=None): - """通用方法:兼容 dict 和 dataclass 两种状态格式""" - if isinstance(state, dict): - return state.get(attr_name, default) - else: - return getattr(state, attr_name, default) - - def log_state_change(node_name: str, state, prefix: str = "进入"): """ 记录状态变化日志 @@ -26,7 +18,13 @@ def log_state_change(node_name: str, state, prefix: str = "进入"): """ from app.logger import info - messages = _get_attr(state, "messages", []) + # 获取 messages + messages = [] + if isinstance(state, dict): + messages = state.get("messages", []) + else: + messages = getattr(state, "messages", []) + msg_count = len(messages) last_msg = messages[-1] if messages else None last_info = "" @@ -48,7 +46,7 @@ def print_llm_input(prompt_value): Args: prompt_value: ChatPromptValue 对象,包含格式化后的消息列表 - + Returns: 原样返回 prompt_value,不影响链式调用 """ @@ -56,14 +54,14 @@ def print_llm_input(prompt_value): return prompt_value messages = prompt_value.messages # ChatPromptValue 提供 .messages 属性 - - debug("\n" + "=" * 80) - debug("📤 [LLM输入] 格式化后发送给大模型的完整消息:") + + debug("\n" + "="*80) + debug("📥 [LLM输入] 格式化后发送给大模型的完整消息:") debug(f" 总消息数: {len(messages)}") - debug("-" * 80) + debug("-"*80) for i, msg in enumerate(messages): content_preview = str(msg.content) # 完整输出 debug(f" [{i}] {msg.type.upper():10s}: {content_preview}") - debug("\n" + "=" * 80 + "\n") + debug("\n" + "="*80 + "\n") - return prompt_value + return prompt_value \ No newline at end of file