彻底重构状态系统:整合所有旧状态到 MainGraphState,修复所有节点
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m35s

This commit is contained in:
2026-05-01 23:20:31 +08:00
parent 9a58eb8e6d
commit 9386b9fa7a
7 changed files with 155 additions and 193 deletions

View File

@@ -4,32 +4,23 @@
"""
from typing import Any, Dict
from app.main_graph.config import get_stream_writer
# 本地模块
from app.main_graph.state import MessagesState
from app.main_graph.state import MainGraphState
from app.utils.logging import log_state_change
from app.logger import info, error
from app.logger import info, warning
from langchain_core.runnables.config import RunnableConfig
def _get_attr(state, attr_name, default=None):
"""通用方法:兼容 dict 和 dataclass 两种状态格式"""
if isinstance(state, dict):
return state.get(attr_name, default)
else:
return getattr(state, attr_name, default)
async def finalize_node(state, config: RunnableConfig) -> Dict[str, Any]:
async def finalize_node(state: MainGraphState, config: RunnableConfig) -> Dict[str, Any]:
"""
完成事件节点 - 发送完成事件包含token使用情况和耗时信息
Args:
state: 当前对话状态(兼容 dict 和 dataclass
state: 当前对话状态
config: 运行时配置
Returns:
空字典(完成节点,无状态更新)
"""
@@ -37,18 +28,25 @@ async def finalize_node(state, config: RunnableConfig) -> Dict[str, Any]:
try:
# 获取流式写入器并发送完成事件
from app.main_graph.config import get_stream_writer
writer = get_stream_writer()
writer({
"type": "custom",
"data": {
"type": "done",
"token_usage": _get_attr(state, "last_token_usage", {}),
"elapsed_time": _get_attr(state, "last_elapsed_time", 0.0)
}
})
info("🏁 [完成事件] 已发送完成事件包含token使用情况和耗时信息")
# 只在 writer 存在且不是 noop 时才发送
if writer and hasattr(writer, '__call__'):
try:
writer({
"type": "custom",
"data": {
"type": "done",
"token_usage": state.last_token_usage,
"elapsed_time": state.last_elapsed_time
}
})
info("🏁 [完成事件] 已发送完成事件包含token使用情况和耗时信息")
except Exception as e:
warning(f"⚠️ [完成事件] 发送完成事件失败 (非致命): {e}")
except Exception as e:
error(f" [完成事件] 发送完成事件时发生异常: {e}")
warning(f"⚠️ [完成事件] 处理失败 (非致命): {e}")
log_state_change("finalize", state, "离开")
return {}

View File

@@ -1,18 +1,10 @@
from typing import Any, Dict
from langchain_core.runnables.config import RunnableConfig
from app.main_graph.state import MessagesState
from app.main_graph.state import MainGraphState
from app.memory.mem0_client import Mem0Client
from app.logger import info
def _get_attr(state, attr_name, default=None):
"""通用方法:兼容 dict 和 dataclass 两种状态格式"""
if isinstance(state, dict):
return state.get(attr_name, default)
else:
return getattr(state, attr_name, default)
# 全局变量,在 GraphBuilder 中注入
_mem0_client: Mem0Client = None
@@ -22,12 +14,12 @@ def set_mem0_client(client: Mem0Client):
_mem0_client = client
async def memory_trigger_node(state, config: RunnableConfig) -> Dict[str, Any]:
async def memory_trigger_node(state: MainGraphState, config: RunnableConfig) -> Dict[str, Any]:
"""检测用户消息中的记忆指令,若命中则主动调用 Mem0 存储"""
if _mem0_client is None:
return {}
messages = _get_attr(state, "messages", [])
messages = state.messages
if not messages:
return {}

View File

@@ -6,20 +6,12 @@
from typing import Any, Dict
# 本地模块
from app.main_graph.state import MessagesState
from app.main_graph.state import MainGraphState
from app.memory.mem0_client import Mem0Client
from app.utils.logging import log_state_change
from app.logger import debug
def _get_attr(state, attr_name, default=None):
"""通用方法:兼容 dict 和 dataclass 两种状态格式"""
if isinstance(state, dict):
return state.get(attr_name, default)
else:
return getattr(state, attr_name, default)
def create_retrieve_memory_node(mem0_client: Mem0Client):
"""
工厂函数:创建记忆检索节点
@@ -33,12 +25,12 @@ def create_retrieve_memory_node(mem0_client: Mem0Client):
from langchain_core.runnables.config import RunnableConfig
async def retrieve_memory(state, config: RunnableConfig) -> Dict[str, Any]:
async def retrieve_memory(state: MainGraphState, config: RunnableConfig) -> Dict[str, Any]:
"""
记忆检索节点 - 使用 Mem0
Args:
state: 当前对话状态(兼容 dict 和 dataclass
state: 当前对话状态
config: 运行时配置
Returns:
@@ -49,16 +41,16 @@ def create_retrieve_memory_node(mem0_client: Mem0Client):
# 从 metadata 中获取 user_id
user_id = config.get("metadata", {}).get("user_id", "default_user")
# 兼容 dict 和对象两种消息格式
messages = _get_attr(state, "messages", [])
# 获取最后一条消息
messages = state.messages
last_msg = messages[-1] if messages else None
query = ""
if last_msg:
if isinstance(last_msg, dict):
query = str(last_msg.get("content", ""))
else:
query = str(last_msg.content)
else:
query = ""
memory_text_parts = []
# 确保 Mem0 已初始化(懒加载)
@@ -83,7 +75,7 @@ def create_retrieve_memory_node(mem0_client: Mem0Client):
memory_context = "\n\n".join(memory_text_parts) if memory_text_parts else "暂无用户信息"
result = {"memory_context": memory_context}
log_state_change("retrieve_memory", {**state, **result} if isinstance(state, dict) else state, "离开")
log_state_change("retrieve_memory", state, "离开")
return result
return retrieve_memory

View File

@@ -6,20 +6,12 @@
from typing import Any, Dict
# 本地模块
from app.main_graph.state import MessagesState
from app.main_graph.state import MainGraphState
from app.memory.mem0_client import Mem0Client
from app.utils.logging import log_state_change
from app.logger import debug, info, error, warning
def _get_attr(state, attr_name, default=None):
"""通用方法:兼容 dict 和 dataclass 两种状态格式"""
if isinstance(state, dict):
return state.get(attr_name, default)
else:
return getattr(state, attr_name, default)
def create_summarize_node(mem0_client: Mem0Client):
"""
工厂函数:创建记忆存储节点
@@ -33,12 +25,12 @@ def create_summarize_node(mem0_client: Mem0Client):
from langchain_core.runnables.config import RunnableConfig
async def summarize_conversation(state, config: RunnableConfig) -> Dict[str, Any]:
async def summarize_conversation(state: MainGraphState, config: RunnableConfig) -> Dict[str, Any]:
"""
记忆存储节点 - 使用 Mem0
Args:
state: 当前对话状态(兼容 dict 和 dataclass
state: 当前对话状态
config: 运行时配置
Returns:
@@ -46,7 +38,7 @@ def create_summarize_node(mem0_client: Mem0Client):
"""
log_state_change("summarize", state, "进入")
messages = _get_attr(state, "messages", [])
messages = state.messages
if len(messages) < 4:
debug("📝 [记忆添加] 对话过短,跳过")
return {"turns_since_last_summary": 0}