Files
ailine/backend/app/main_graph/nodes/retrieve_memory.py
root 615b4b6eed
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 6m39s
修复状态兼容性问题:让旧节点同时支持 dict 和 dataclass
2026-05-01 22:45:42 +08:00

89 lines
2.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
记忆检索节点模块
负责从 Mem0 检索相关长期记忆
"""
from typing import Any, Dict
# 本地模块
from app.main_graph.state import MessagesState
from app.memory.mem0_client import Mem0Client
from app.utils.logging import log_state_change
from app.logger import debug
def _get_attr(state, attr_name, default=None):
"""通用方法:兼容 dict 和 dataclass 两种状态格式"""
if isinstance(state, dict):
return state.get(attr_name, default)
else:
return getattr(state, attr_name, default)
def create_retrieve_memory_node(mem0_client: Mem0Client):
"""
工厂函数:创建记忆检索节点
Args:
mem0_client: Mem0 客户端实例
Returns:
异步节点函数
"""
from langchain_core.runnables.config import RunnableConfig
async def retrieve_memory(state, config: RunnableConfig) -> Dict[str, Any]:
"""
记忆检索节点 - 使用 Mem0
Args:
state: 当前对话状态(兼容 dict 和 dataclass
config: 运行时配置
Returns:
包含 memory_context 的状态更新
"""
log_state_change("retrieve_memory", state, "进入")
# 从 metadata 中获取 user_id
user_id = config.get("metadata", {}).get("user_id", "default_user")
# 兼容 dict 和对象两种消息格式
messages = _get_attr(state, "messages", [])
last_msg = messages[-1] if messages else None
if last_msg:
if isinstance(last_msg, dict):
query = str(last_msg.get("content", ""))
else:
query = str(last_msg.content)
else:
query = ""
memory_text_parts = []
# 确保 Mem0 已初始化(懒加载)
if not mem0_client._initialized:
await mem0_client.initialize()
if mem0_client.mem0:
try:
# 异步调用 Mem0 语义检索
facts = await mem0_client.search_memories(query, user_id=user_id, limit=5)
if facts:
memory_text_parts.append(f"【相关长期记忆】\n" + "\n".join(f"- {f}" for f in facts))
else:
debug("🔍 [记忆检索] 未找到相关记忆")
except Exception as e:
from app.logger import warning
warning(f"⚠️ Mem0 检索失败: {e}")
else:
from app.logger import warning
warning("⚠️ Mem0 未初始化,跳过记忆检索")
memory_context = "\n\n".join(memory_text_parts) if memory_text_parts else "暂无用户信息"
result = {"memory_context": memory_context}
log_state_change("retrieve_memory", {**state, **result} if isinstance(state, dict) else state, "离开")
return result
return retrieve_memory