修复状态兼容性问题:让旧节点同时支持 dict 和 dataclass
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 6m39s

This commit is contained in:
2026-05-01 22:45:42 +08:00
parent 1f177f7dfd
commit 615b4b6eed
5 changed files with 75 additions and 25 deletions

View File

@@ -13,14 +13,23 @@ from app.logger import info, error
from langchain_core.runnables.config import RunnableConfig from langchain_core.runnables.config import RunnableConfig
async def finalize_node(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:
def _get_attr(state, attr_name, default=None):
"""通用方法:兼容 dict 和 dataclass 两种状态格式"""
if isinstance(state, dict):
return state.get(attr_name, default)
else:
return getattr(state, attr_name, default)
async def finalize_node(state, config: RunnableConfig) -> Dict[str, Any]:
""" """
完成事件节点 - 发送完成事件包含token使用情况和耗时信息 完成事件节点 - 发送完成事件包含token使用情况和耗时信息
Args: Args:
state: 当前对话状态 state: 当前对话状态(兼容 dict 和 dataclass
config: 运行时配置 config: 运行时配置
Returns: Returns:
空字典(完成节点,无状态更新) 空字典(完成节点,无状态更新)
""" """
@@ -33,8 +42,8 @@ async def finalize_node(state: MessagesState, config: RunnableConfig) -> Dict[st
"type": "custom", "type": "custom",
"data": { "data": {
"type": "done", "type": "done",
"token_usage": state.get("last_token_usage", {}), "token_usage": _get_attr(state, "last_token_usage", {}),
"elapsed_time": state.get("last_elapsed_time", 0.0) "elapsed_time": _get_attr(state, "last_elapsed_time", 0.0)
} }
}) })
info("🏁 [完成事件] 已发送完成事件包含token使用情况和耗时信息") info("🏁 [完成事件] 已发送完成事件包含token使用情况和耗时信息")

View File

@@ -4,19 +4,30 @@ from app.main_graph.state import MessagesState
from app.memory.mem0_client import Mem0Client from app.memory.mem0_client import Mem0Client
from app.logger import info from app.logger import info
def _get_attr(state, attr_name, default=None):
"""通用方法:兼容 dict 和 dataclass 两种状态格式"""
if isinstance(state, dict):
return state.get(attr_name, default)
else:
return getattr(state, attr_name, default)
# 全局变量,在 GraphBuilder 中注入 # 全局变量,在 GraphBuilder 中注入
_mem0_client: Mem0Client = None _mem0_client: Mem0Client = None
def set_mem0_client(client: Mem0Client): def set_mem0_client(client: Mem0Client):
global _mem0_client global _mem0_client
_mem0_client = client _mem0_client = client
async def memory_trigger_node(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:
async def memory_trigger_node(state, config: RunnableConfig) -> Dict[str, Any]:
"""检测用户消息中的记忆指令,若命中则主动调用 Mem0 存储""" """检测用户消息中的记忆指令,若命中则主动调用 Mem0 存储"""
if _mem0_client is None: if _mem0_client is None:
return {} return {}
messages = state.get("messages", []) messages = _get_attr(state, "messages", [])
if not messages: if not messages:
return {} return {}

View File

@@ -11,27 +11,36 @@ from app.memory.mem0_client import Mem0Client
from app.utils.logging import log_state_change from app.utils.logging import log_state_change
from app.logger import debug from app.logger import debug
def _get_attr(state, attr_name, default=None):
"""通用方法:兼容 dict 和 dataclass 两种状态格式"""
if isinstance(state, dict):
return state.get(attr_name, default)
else:
return getattr(state, attr_name, default)
def create_retrieve_memory_node(mem0_client: Mem0Client): def create_retrieve_memory_node(mem0_client: Mem0Client):
""" """
工厂函数:创建记忆检索节点 工厂函数:创建记忆检索节点
Args: Args:
mem0_client: Mem0 客户端实例 mem0_client: Mem0 客户端实例
Returns: Returns:
异步节点函数 异步节点函数
""" """
from langchain_core.runnables.config import RunnableConfig from langchain_core.runnables.config import RunnableConfig
async def retrieve_memory(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]: async def retrieve_memory(state, config: RunnableConfig) -> Dict[str, Any]:
""" """
记忆检索节点 - 使用 Mem0 记忆检索节点 - 使用 Mem0
Args: Args:
state: 当前对话状态 state: 当前对话状态(兼容 dict 和 dataclass
config: 运行时配置 config: 运行时配置
Returns: Returns:
包含 memory_context 的状态更新 包含 memory_context 的状态更新
""" """
@@ -41,11 +50,15 @@ def create_retrieve_memory_node(mem0_client: Mem0Client):
user_id = config.get("metadata", {}).get("user_id", "default_user") user_id = config.get("metadata", {}).get("user_id", "default_user")
# 兼容 dict 和对象两种消息格式 # 兼容 dict 和对象两种消息格式
last_msg = state["messages"][-1] messages = _get_attr(state, "messages", [])
if isinstance(last_msg, dict): last_msg = messages[-1] if messages else None
query = str(last_msg.get("content", "")) if last_msg:
if isinstance(last_msg, dict):
query = str(last_msg.get("content", ""))
else:
query = str(last_msg.content)
else: else:
query = str(last_msg.content) query = ""
memory_text_parts = [] memory_text_parts = []
# 确保 Mem0 已初始化(懒加载) # 确保 Mem0 已初始化(懒加载)
@@ -70,7 +83,7 @@ def create_retrieve_memory_node(mem0_client: Mem0Client):
memory_context = "\n\n".join(memory_text_parts) if memory_text_parts else "暂无用户信息" memory_context = "\n\n".join(memory_text_parts) if memory_text_parts else "暂无用户信息"
result = {"memory_context": memory_context} result = {"memory_context": memory_context}
log_state_change("retrieve_memory", {**state, **result}, "离开") log_state_change("retrieve_memory", {**state, **result} if isinstance(state, dict) else state, "离开")
return result return result
return retrieve_memory return retrieve_memory

View File

@@ -11,33 +11,42 @@ from app.memory.mem0_client import Mem0Client
from app.utils.logging import log_state_change from app.utils.logging import log_state_change
from app.logger import debug, info, error, warning from app.logger import debug, info, error, warning
def _get_attr(state, attr_name, default=None):
"""通用方法:兼容 dict 和 dataclass 两种状态格式"""
if isinstance(state, dict):
return state.get(attr_name, default)
else:
return getattr(state, attr_name, default)
def create_summarize_node(mem0_client: Mem0Client): def create_summarize_node(mem0_client: Mem0Client):
""" """
工厂函数:创建记忆存储节点 工厂函数:创建记忆存储节点
Args: Args:
mem0_client: Mem0 客户端实例 mem0_client: Mem0 客户端实例
Returns: Returns:
异步节点函数 异步节点函数
""" """
from langchain_core.runnables.config import RunnableConfig from langchain_core.runnables.config import RunnableConfig
async def summarize_conversation(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]: async def summarize_conversation(state, config: RunnableConfig) -> Dict[str, Any]:
""" """
记忆存储节点 - 使用 Mem0 记忆存储节点 - 使用 Mem0
Args: Args:
state: 当前对话状态 state: 当前对话状态(兼容 dict 和 dataclass
config: 运行时配置 config: 运行时配置
Returns: Returns:
重置计数器的状态更新 重置计数器的状态更新
""" """
log_state_change("summarize", state, "进入") log_state_change("summarize", state, "进入")
messages = state["messages"] messages = _get_attr(state, "messages", [])
if len(messages) < 4: if len(messages) < 4:
debug("📝 [记忆添加] 对话过短,跳过") debug("📝 [记忆添加] 对话过短,跳过")
return {"turns_since_last_summary": 0} return {"turns_since_last_summary": 0}

View File

@@ -7,18 +7,26 @@ from app.config import ENABLE_GRAPH_TRACE
from app.logger import debug, info from app.logger import debug, info
def log_state_change(node_name: str, state: dict, prefix: str = "进入"): def _get_attr(state, attr_name, default=None):
"""通用方法:兼容 dict 和 dataclass 两种状态格式"""
if isinstance(state, dict):
return state.get(attr_name, default)
else:
return getattr(state, attr_name, default)
def log_state_change(node_name: str, state, prefix: str = "进入"):
""" """
记录状态变化日志 记录状态变化日志
Args: Args:
node_name: 节点名称 node_name: 节点名称
state: 当前状态 state: 当前状态(兼容 dict 和 dataclass
prefix: 日志前缀("进入""离开" prefix: 日志前缀("进入""离开"
""" """
from app.logger import info from app.logger import info
messages = state.get("messages", []) messages = _get_attr(state, "messages", [])
msg_count = len(messages) msg_count = len(messages)
last_msg = messages[-1] if messages else None last_msg = messages[-1] if messages else None
last_info = "" last_info = ""