修复状态兼容性问题:让旧节点同时支持 dict 和 dataclass
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 6m39s
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 6m39s
This commit is contained in:
@@ -13,12 +13,21 @@ from app.logger import info, error
|
|||||||
|
|
||||||
from langchain_core.runnables.config import RunnableConfig
|
from langchain_core.runnables.config import RunnableConfig
|
||||||
|
|
||||||
async def finalize_node(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:
|
|
||||||
|
def _get_attr(state, attr_name, default=None):
|
||||||
|
"""通用方法:兼容 dict 和 dataclass 两种状态格式"""
|
||||||
|
if isinstance(state, dict):
|
||||||
|
return state.get(attr_name, default)
|
||||||
|
else:
|
||||||
|
return getattr(state, attr_name, default)
|
||||||
|
|
||||||
|
|
||||||
|
async def finalize_node(state, config: RunnableConfig) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
完成事件节点 - 发送完成事件,包含token使用情况和耗时信息
|
完成事件节点 - 发送完成事件,包含token使用情况和耗时信息
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
state: 当前对话状态
|
state: 当前对话状态(兼容 dict 和 dataclass)
|
||||||
config: 运行时配置
|
config: 运行时配置
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -33,8 +42,8 @@ async def finalize_node(state: MessagesState, config: RunnableConfig) -> Dict[st
|
|||||||
"type": "custom",
|
"type": "custom",
|
||||||
"data": {
|
"data": {
|
||||||
"type": "done",
|
"type": "done",
|
||||||
"token_usage": state.get("last_token_usage", {}),
|
"token_usage": _get_attr(state, "last_token_usage", {}),
|
||||||
"elapsed_time": state.get("last_elapsed_time", 0.0)
|
"elapsed_time": _get_attr(state, "last_elapsed_time", 0.0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
info("🏁 [完成事件] 已发送完成事件,包含token使用情况和耗时信息")
|
info("🏁 [完成事件] 已发送完成事件,包含token使用情况和耗时信息")
|
||||||
|
|||||||
@@ -4,19 +4,30 @@ from app.main_graph.state import MessagesState
|
|||||||
from app.memory.mem0_client import Mem0Client
|
from app.memory.mem0_client import Mem0Client
|
||||||
from app.logger import info
|
from app.logger import info
|
||||||
|
|
||||||
|
|
||||||
|
def _get_attr(state, attr_name, default=None):
|
||||||
|
"""通用方法:兼容 dict 和 dataclass 两种状态格式"""
|
||||||
|
if isinstance(state, dict):
|
||||||
|
return state.get(attr_name, default)
|
||||||
|
else:
|
||||||
|
return getattr(state, attr_name, default)
|
||||||
|
|
||||||
|
|
||||||
# 全局变量,在 GraphBuilder 中注入
|
# 全局变量,在 GraphBuilder 中注入
|
||||||
_mem0_client: Mem0Client = None
|
_mem0_client: Mem0Client = None
|
||||||
|
|
||||||
|
|
||||||
def set_mem0_client(client: Mem0Client):
|
def set_mem0_client(client: Mem0Client):
|
||||||
global _mem0_client
|
global _mem0_client
|
||||||
_mem0_client = client
|
_mem0_client = client
|
||||||
|
|
||||||
async def memory_trigger_node(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:
|
|
||||||
|
async def memory_trigger_node(state, config: RunnableConfig) -> Dict[str, Any]:
|
||||||
"""检测用户消息中的记忆指令,若命中则主动调用 Mem0 存储"""
|
"""检测用户消息中的记忆指令,若命中则主动调用 Mem0 存储"""
|
||||||
if _mem0_client is None:
|
if _mem0_client is None:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
messages = state.get("messages", [])
|
messages = _get_attr(state, "messages", [])
|
||||||
if not messages:
|
if not messages:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,15 @@ from app.memory.mem0_client import Mem0Client
|
|||||||
from app.utils.logging import log_state_change
|
from app.utils.logging import log_state_change
|
||||||
from app.logger import debug
|
from app.logger import debug
|
||||||
|
|
||||||
|
|
||||||
|
def _get_attr(state, attr_name, default=None):
|
||||||
|
"""通用方法:兼容 dict 和 dataclass 两种状态格式"""
|
||||||
|
if isinstance(state, dict):
|
||||||
|
return state.get(attr_name, default)
|
||||||
|
else:
|
||||||
|
return getattr(state, attr_name, default)
|
||||||
|
|
||||||
|
|
||||||
def create_retrieve_memory_node(mem0_client: Mem0Client):
|
def create_retrieve_memory_node(mem0_client: Mem0Client):
|
||||||
"""
|
"""
|
||||||
工厂函数:创建记忆检索节点
|
工厂函数:创建记忆检索节点
|
||||||
@@ -24,12 +33,12 @@ def create_retrieve_memory_node(mem0_client: Mem0Client):
|
|||||||
|
|
||||||
from langchain_core.runnables.config import RunnableConfig
|
from langchain_core.runnables.config import RunnableConfig
|
||||||
|
|
||||||
async def retrieve_memory(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:
|
async def retrieve_memory(state, config: RunnableConfig) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
记忆检索节点 - 使用 Mem0
|
记忆检索节点 - 使用 Mem0
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
state: 当前对话状态
|
state: 当前对话状态(兼容 dict 和 dataclass)
|
||||||
config: 运行时配置
|
config: 运行时配置
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -41,11 +50,15 @@ def create_retrieve_memory_node(mem0_client: Mem0Client):
|
|||||||
user_id = config.get("metadata", {}).get("user_id", "default_user")
|
user_id = config.get("metadata", {}).get("user_id", "default_user")
|
||||||
|
|
||||||
# 兼容 dict 和对象两种消息格式
|
# 兼容 dict 和对象两种消息格式
|
||||||
last_msg = state["messages"][-1]
|
messages = _get_attr(state, "messages", [])
|
||||||
if isinstance(last_msg, dict):
|
last_msg = messages[-1] if messages else None
|
||||||
query = str(last_msg.get("content", ""))
|
if last_msg:
|
||||||
|
if isinstance(last_msg, dict):
|
||||||
|
query = str(last_msg.get("content", ""))
|
||||||
|
else:
|
||||||
|
query = str(last_msg.content)
|
||||||
else:
|
else:
|
||||||
query = str(last_msg.content)
|
query = ""
|
||||||
memory_text_parts = []
|
memory_text_parts = []
|
||||||
|
|
||||||
# 确保 Mem0 已初始化(懒加载)
|
# 确保 Mem0 已初始化(懒加载)
|
||||||
@@ -70,7 +83,7 @@ def create_retrieve_memory_node(mem0_client: Mem0Client):
|
|||||||
|
|
||||||
memory_context = "\n\n".join(memory_text_parts) if memory_text_parts else "暂无用户信息"
|
memory_context = "\n\n".join(memory_text_parts) if memory_text_parts else "暂无用户信息"
|
||||||
result = {"memory_context": memory_context}
|
result = {"memory_context": memory_context}
|
||||||
log_state_change("retrieve_memory", {**state, **result}, "离开")
|
log_state_change("retrieve_memory", {**state, **result} if isinstance(state, dict) else state, "离开")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
return retrieve_memory
|
return retrieve_memory
|
||||||
@@ -11,6 +11,15 @@ from app.memory.mem0_client import Mem0Client
|
|||||||
from app.utils.logging import log_state_change
|
from app.utils.logging import log_state_change
|
||||||
from app.logger import debug, info, error, warning
|
from app.logger import debug, info, error, warning
|
||||||
|
|
||||||
|
|
||||||
|
def _get_attr(state, attr_name, default=None):
|
||||||
|
"""通用方法:兼容 dict 和 dataclass 两种状态格式"""
|
||||||
|
if isinstance(state, dict):
|
||||||
|
return state.get(attr_name, default)
|
||||||
|
else:
|
||||||
|
return getattr(state, attr_name, default)
|
||||||
|
|
||||||
|
|
||||||
def create_summarize_node(mem0_client: Mem0Client):
|
def create_summarize_node(mem0_client: Mem0Client):
|
||||||
"""
|
"""
|
||||||
工厂函数:创建记忆存储节点
|
工厂函数:创建记忆存储节点
|
||||||
@@ -24,12 +33,12 @@ def create_summarize_node(mem0_client: Mem0Client):
|
|||||||
|
|
||||||
from langchain_core.runnables.config import RunnableConfig
|
from langchain_core.runnables.config import RunnableConfig
|
||||||
|
|
||||||
async def summarize_conversation(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:
|
async def summarize_conversation(state, config: RunnableConfig) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
记忆存储节点 - 使用 Mem0
|
记忆存储节点 - 使用 Mem0
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
state: 当前对话状态
|
state: 当前对话状态(兼容 dict 和 dataclass)
|
||||||
config: 运行时配置
|
config: 运行时配置
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -37,7 +46,7 @@ def create_summarize_node(mem0_client: Mem0Client):
|
|||||||
"""
|
"""
|
||||||
log_state_change("summarize", state, "进入")
|
log_state_change("summarize", state, "进入")
|
||||||
|
|
||||||
messages = state["messages"]
|
messages = _get_attr(state, "messages", [])
|
||||||
if len(messages) < 4:
|
if len(messages) < 4:
|
||||||
debug("📝 [记忆添加] 对话过短,跳过")
|
debug("📝 [记忆添加] 对话过短,跳过")
|
||||||
return {"turns_since_last_summary": 0}
|
return {"turns_since_last_summary": 0}
|
||||||
|
|||||||
@@ -7,18 +7,26 @@ from app.config import ENABLE_GRAPH_TRACE
|
|||||||
from app.logger import debug, info
|
from app.logger import debug, info
|
||||||
|
|
||||||
|
|
||||||
def log_state_change(node_name: str, state: dict, prefix: str = "进入"):
|
def _get_attr(state, attr_name, default=None):
|
||||||
|
"""通用方法:兼容 dict 和 dataclass 两种状态格式"""
|
||||||
|
if isinstance(state, dict):
|
||||||
|
return state.get(attr_name, default)
|
||||||
|
else:
|
||||||
|
return getattr(state, attr_name, default)
|
||||||
|
|
||||||
|
|
||||||
|
def log_state_change(node_name: str, state, prefix: str = "进入"):
|
||||||
"""
|
"""
|
||||||
记录状态变化日志
|
记录状态变化日志
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
node_name: 节点名称
|
node_name: 节点名称
|
||||||
state: 当前状态
|
state: 当前状态(兼容 dict 和 dataclass)
|
||||||
prefix: 日志前缀("进入" 或 "离开")
|
prefix: 日志前缀("进入" 或 "离开")
|
||||||
"""
|
"""
|
||||||
from app.logger import info
|
from app.logger import info
|
||||||
|
|
||||||
messages = state.get("messages", [])
|
messages = _get_attr(state, "messages", [])
|
||||||
msg_count = len(messages)
|
msg_count = len(messages)
|
||||||
last_msg = messages[-1] if messages else None
|
last_msg = messages[-1] if messages else None
|
||||||
last_info = ""
|
last_info = ""
|
||||||
|
|||||||
Reference in New Issue
Block a user