修复状态兼容性问题:让旧节点同时支持 dict 和 dataclass
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 6m39s
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 6m39s
This commit is contained in:
@@ -11,27 +11,36 @@ from app.memory.mem0_client import Mem0Client
|
||||
from app.utils.logging import log_state_change
|
||||
from app.logger import debug
|
||||
|
||||
|
||||
def _get_attr(state, attr_name, default=None):
|
||||
"""通用方法:兼容 dict 和 dataclass 两种状态格式"""
|
||||
if isinstance(state, dict):
|
||||
return state.get(attr_name, default)
|
||||
else:
|
||||
return getattr(state, attr_name, default)
|
||||
|
||||
|
||||
def create_retrieve_memory_node(mem0_client: Mem0Client):
|
||||
"""
|
||||
工厂函数:创建记忆检索节点
|
||||
|
||||
Args:
|
||||
mem0_client: Mem0 客户端实例
|
||||
|
||||
|
||||
Returns:
|
||||
异步节点函数
|
||||
"""
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
async def retrieve_memory(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:
|
||||
async def retrieve_memory(state, config: RunnableConfig) -> Dict[str, Any]:
|
||||
"""
|
||||
记忆检索节点 - 使用 Mem0
|
||||
|
||||
Args:
|
||||
state: 当前对话状态
|
||||
state: 当前对话状态(兼容 dict 和 dataclass)
|
||||
config: 运行时配置
|
||||
|
||||
|
||||
Returns:
|
||||
包含 memory_context 的状态更新
|
||||
"""
|
||||
@@ -41,11 +50,15 @@ def create_retrieve_memory_node(mem0_client: Mem0Client):
|
||||
user_id = config.get("metadata", {}).get("user_id", "default_user")
|
||||
|
||||
# 兼容 dict 和对象两种消息格式
|
||||
last_msg = state["messages"][-1]
|
||||
if isinstance(last_msg, dict):
|
||||
query = str(last_msg.get("content", ""))
|
||||
messages = _get_attr(state, "messages", [])
|
||||
last_msg = messages[-1] if messages else None
|
||||
if last_msg:
|
||||
if isinstance(last_msg, dict):
|
||||
query = str(last_msg.get("content", ""))
|
||||
else:
|
||||
query = str(last_msg.content)
|
||||
else:
|
||||
query = str(last_msg.content)
|
||||
query = ""
|
||||
memory_text_parts = []
|
||||
|
||||
# 确保 Mem0 已初始化(懒加载)
|
||||
@@ -70,7 +83,7 @@ def create_retrieve_memory_node(mem0_client: Mem0Client):
|
||||
|
||||
memory_context = "\n\n".join(memory_text_parts) if memory_text_parts else "暂无用户信息"
|
||||
result = {"memory_context": memory_context}
|
||||
log_state_change("retrieve_memory", {**state, **result}, "离开")
|
||||
log_state_change("retrieve_memory", {**state, **result} if isinstance(state, dict) else state, "离开")
|
||||
return result
|
||||
|
||||
return retrieve_memory
|
||||
return retrieve_memory
|
||||
Reference in New Issue
Block a user