diff --git a/backend/app/main_graph/nodes/finalize.py b/backend/app/main_graph/nodes/finalize.py index 134379a..420f8b3 100644 --- a/backend/app/main_graph/nodes/finalize.py +++ b/backend/app/main_graph/nodes/finalize.py @@ -13,14 +13,23 @@ from app.logger import info, error from langchain_core.runnables.config import RunnableConfig -async def finalize_node(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]: + +def _get_attr(state, attr_name, default=None): + """通用方法:兼容 dict 和 dataclass 两种状态格式""" + if isinstance(state, dict): + return state.get(attr_name, default) + else: + return getattr(state, attr_name, default) + + +async def finalize_node(state, config: RunnableConfig) -> Dict[str, Any]: """ 完成事件节点 - 发送完成事件,包含token使用情况和耗时信息 Args: - state: 当前对话状态 + state: 当前对话状态(兼容 dict 和 dataclass) config: 运行时配置 - + Returns: 空字典(完成节点,无状态更新) """ @@ -33,8 +42,8 @@ async def finalize_node(state: MessagesState, config: RunnableConfig) -> Dict[st "type": "custom", "data": { "type": "done", - "token_usage": state.get("last_token_usage", {}), - "elapsed_time": state.get("last_elapsed_time", 0.0) + "token_usage": _get_attr(state, "last_token_usage", {}), + "elapsed_time": _get_attr(state, "last_elapsed_time", 0.0) } }) info("🏁 [完成事件] 已发送完成事件,包含token使用情况和耗时信息") diff --git a/backend/app/main_graph/nodes/memory_trigger.py b/backend/app/main_graph/nodes/memory_trigger.py index 523e52c..c8afcf0 100644 --- a/backend/app/main_graph/nodes/memory_trigger.py +++ b/backend/app/main_graph/nodes/memory_trigger.py @@ -4,19 +4,30 @@ from app.main_graph.state import MessagesState from app.memory.mem0_client import Mem0Client from app.logger import info + +def _get_attr(state, attr_name, default=None): + """通用方法:兼容 dict 和 dataclass 两种状态格式""" + if isinstance(state, dict): + return state.get(attr_name, default) + else: + return getattr(state, attr_name, default) + + # 全局变量,在 GraphBuilder 中注入 _mem0_client: Mem0Client = None + def set_mem0_client(client: Mem0Client): global _mem0_client _mem0_client = client -async def memory_trigger_node(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]: + +async def memory_trigger_node(state, config: RunnableConfig) -> Dict[str, Any]: """检测用户消息中的记忆指令,若命中则主动调用 Mem0 存储""" if _mem0_client is None: return {} - messages = state.get("messages", []) + messages = _get_attr(state, "messages", []) if not messages: return {} diff --git a/backend/app/main_graph/nodes/retrieve_memory.py b/backend/app/main_graph/nodes/retrieve_memory.py index ec85655..3aac9fd 100644 --- a/backend/app/main_graph/nodes/retrieve_memory.py +++ b/backend/app/main_graph/nodes/retrieve_memory.py @@ -11,27 +11,36 @@ from app.memory.mem0_client import Mem0Client from app.utils.logging import log_state_change from app.logger import debug + +def _get_attr(state, attr_name, default=None): + """通用方法:兼容 dict 和 dataclass 两种状态格式""" + if isinstance(state, dict): + return state.get(attr_name, default) + else: + return getattr(state, attr_name, default) + + def create_retrieve_memory_node(mem0_client: Mem0Client): """ 工厂函数:创建记忆检索节点 Args: mem0_client: Mem0 客户端实例 - + Returns: 异步节点函数 """ from langchain_core.runnables.config import RunnableConfig - async def retrieve_memory(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]: + async def retrieve_memory(state, config: RunnableConfig) -> Dict[str, Any]: """ 记忆检索节点 - 使用 Mem0 Args: - state: 当前对话状态 + state: 当前对话状态(兼容 dict 和 dataclass) config: 运行时配置 - + Returns: 包含 memory_context 的状态更新 """ @@ -41,11 +50,15 @@ def create_retrieve_memory_node(mem0_client: Mem0Client): user_id = config.get("metadata", {}).get("user_id", "default_user") # 兼容 dict 和对象两种消息格式 - last_msg = state["messages"][-1] - if isinstance(last_msg, dict): - query = str(last_msg.get("content", "")) + messages = _get_attr(state, "messages", []) + last_msg = messages[-1] if messages else None + if last_msg: + if isinstance(last_msg, dict): + query = str(last_msg.get("content", "")) + else: + query = str(last_msg.content) else: - query = str(last_msg.content) + query = "" memory_text_parts = [] # 确保 Mem0 已初始化(懒加载) @@ -70,7 +83,7 @@ def create_retrieve_memory_node(mem0_client: Mem0Client): memory_context = "\n\n".join(memory_text_parts) if memory_text_parts else "暂无用户信息" result = {"memory_context": memory_context} - log_state_change("retrieve_memory", {**state, **result}, "离开") + log_state_change("retrieve_memory", {**state, **result} if isinstance(state, dict) else state, "离开") return result - return retrieve_memory + return retrieve_memory \ No newline at end of file diff --git a/backend/app/main_graph/nodes/summarize.py b/backend/app/main_graph/nodes/summarize.py index fc9ac02..f817266 100644 --- a/backend/app/main_graph/nodes/summarize.py +++ b/backend/app/main_graph/nodes/summarize.py @@ -11,33 +11,42 @@ from app.memory.mem0_client import Mem0Client from app.utils.logging import log_state_change from app.logger import debug, info, error, warning + +def _get_attr(state, attr_name, default=None): + """通用方法:兼容 dict 和 dataclass 两种状态格式""" + if isinstance(state, dict): + return state.get(attr_name, default) + else: + return getattr(state, attr_name, default) + + def create_summarize_node(mem0_client: Mem0Client): """ 工厂函数:创建记忆存储节点 Args: mem0_client: Mem0 客户端实例 - + Returns: 异步节点函数 """ from langchain_core.runnables.config import RunnableConfig - async def summarize_conversation(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]: + async def summarize_conversation(state, config: RunnableConfig) -> Dict[str, Any]: """ 记忆存储节点 - 使用 Mem0 Args: - state: 当前对话状态 + state: 当前对话状态(兼容 dict 和 dataclass) config: 运行时配置 - + Returns: 重置计数器的状态更新 """ log_state_change("summarize", state, "进入") - messages = state["messages"] + messages = _get_attr(state, "messages", []) if len(messages) < 4: debug("📝 [记忆添加] 对话过短,跳过") return {"turns_since_last_summary": 0} diff --git a/backend/app/utils/logging.py b/backend/app/utils/logging.py index 8228366..770c731 100644 --- a/backend/app/utils/logging.py +++ b/backend/app/utils/logging.py @@ -7,18 +7,26 @@ from app.config import ENABLE_GRAPH_TRACE from app.logger import debug, info -def log_state_change(node_name: str, state: dict, prefix: str = "进入"): +def _get_attr(state, attr_name, default=None): + """通用方法:兼容 dict 和 dataclass 两种状态格式""" + if isinstance(state, dict): + return state.get(attr_name, default) + else: + return getattr(state, attr_name, default) + + +def log_state_change(node_name: str, state, prefix: str = "进入"): """ 记录状态变化日志 Args: node_name: 节点名称 - state: 当前状态 + state: 当前状态(兼容 dict 和 dataclass) prefix: 日志前缀("进入" 或 "离开") """ from app.logger import info - messages = state.get("messages", []) + messages = _get_attr(state, "messages", []) msg_count = len(messages) last_msg = messages[-1] if messages else None last_info = ""