89 lines
2.9 KiB
Python
89 lines
2.9 KiB
Python
"""
|
||
记忆检索节点模块
|
||
负责从 Mem0 检索相关长期记忆
|
||
"""
|
||
|
||
from typing import Any, Dict
|
||
|
||
# 本地模块
|
||
from app.main_graph.state import MessagesState
|
||
from app.memory.mem0_client import Mem0Client
|
||
from app.utils.logging import log_state_change
|
||
from app.logger import debug
|
||
|
||
|
||
def _get_attr(state, attr_name, default=None):
|
||
"""通用方法:兼容 dict 和 dataclass 两种状态格式"""
|
||
if isinstance(state, dict):
|
||
return state.get(attr_name, default)
|
||
else:
|
||
return getattr(state, attr_name, default)
|
||
|
||
|
||
def create_retrieve_memory_node(mem0_client: Mem0Client):
|
||
"""
|
||
工厂函数:创建记忆检索节点
|
||
|
||
Args:
|
||
mem0_client: Mem0 客户端实例
|
||
|
||
Returns:
|
||
异步节点函数
|
||
"""
|
||
|
||
from langchain_core.runnables.config import RunnableConfig
|
||
|
||
async def retrieve_memory(state, config: RunnableConfig) -> Dict[str, Any]:
|
||
"""
|
||
记忆检索节点 - 使用 Mem0
|
||
|
||
Args:
|
||
state: 当前对话状态(兼容 dict 和 dataclass)
|
||
config: 运行时配置
|
||
|
||
Returns:
|
||
包含 memory_context 的状态更新
|
||
"""
|
||
log_state_change("retrieve_memory", state, "进入")
|
||
|
||
# 从 metadata 中获取 user_id
|
||
user_id = config.get("metadata", {}).get("user_id", "default_user")
|
||
|
||
# 兼容 dict 和对象两种消息格式
|
||
messages = _get_attr(state, "messages", [])
|
||
last_msg = messages[-1] if messages else None
|
||
if last_msg:
|
||
if isinstance(last_msg, dict):
|
||
query = str(last_msg.get("content", ""))
|
||
else:
|
||
query = str(last_msg.content)
|
||
else:
|
||
query = ""
|
||
memory_text_parts = []
|
||
|
||
# 确保 Mem0 已初始化(懒加载)
|
||
if not mem0_client._initialized:
|
||
await mem0_client.initialize()
|
||
|
||
if mem0_client.mem0:
|
||
try:
|
||
# 异步调用 Mem0 语义检索
|
||
facts = await mem0_client.search_memories(query, user_id=user_id, limit=5)
|
||
|
||
if facts:
|
||
memory_text_parts.append(f"【相关长期记忆】\n" + "\n".join(f"- {f}" for f in facts))
|
||
else:
|
||
debug("🔍 [记忆检索] 未找到相关记忆")
|
||
except Exception as e:
|
||
from app.logger import warning
|
||
warning(f"⚠️ Mem0 检索失败: {e}")
|
||
else:
|
||
from app.logger import warning
|
||
warning("⚠️ Mem0 未初始化,跳过记忆检索")
|
||
|
||
memory_context = "\n\n".join(memory_text_parts) if memory_text_parts else "暂无用户信息"
|
||
result = {"memory_context": memory_context}
|
||
log_state_change("retrieve_memory", {**state, **result} if isinstance(state, dict) else state, "离开")
|
||
return result
|
||
|
||
return retrieve_memory |