彻底重构状态系统:整合所有旧状态到 MainGraphState,修复所有节点
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m35s
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m35s
This commit is contained in:
@@ -4,30 +4,21 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
from app.main_graph.config import get_stream_writer
|
|
||||||
|
|
||||||
# 本地模块
|
# 本地模块
|
||||||
from app.main_graph.state import MessagesState
|
from app.main_graph.state import MainGraphState
|
||||||
from app.utils.logging import log_state_change
|
from app.utils.logging import log_state_change
|
||||||
from app.logger import info, error
|
from app.logger import info, warning
|
||||||
|
|
||||||
from langchain_core.runnables.config import RunnableConfig
|
from langchain_core.runnables.config import RunnableConfig
|
||||||
|
|
||||||
|
|
||||||
def _get_attr(state, attr_name, default=None):
|
async def finalize_node(state: MainGraphState, config: RunnableConfig) -> Dict[str, Any]:
|
||||||
"""通用方法:兼容 dict 和 dataclass 两种状态格式"""
|
|
||||||
if isinstance(state, dict):
|
|
||||||
return state.get(attr_name, default)
|
|
||||||
else:
|
|
||||||
return getattr(state, attr_name, default)
|
|
||||||
|
|
||||||
|
|
||||||
async def finalize_node(state, config: RunnableConfig) -> Dict[str, Any]:
|
|
||||||
"""
|
"""
|
||||||
完成事件节点 - 发送完成事件,包含token使用情况和耗时信息
|
完成事件节点 - 发送完成事件,包含token使用情况和耗时信息
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
state: 当前对话状态(兼容 dict 和 dataclass)
|
state: 当前对话状态
|
||||||
config: 运行时配置
|
config: 运行时配置
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -37,18 +28,25 @@ async def finalize_node(state, config: RunnableConfig) -> Dict[str, Any]:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# 获取流式写入器并发送完成事件
|
# 获取流式写入器并发送完成事件
|
||||||
|
from app.main_graph.config import get_stream_writer
|
||||||
writer = get_stream_writer()
|
writer = get_stream_writer()
|
||||||
|
|
||||||
|
# 只在 writer 存在且不是 noop 时才发送
|
||||||
|
if writer and hasattr(writer, '__call__'):
|
||||||
|
try:
|
||||||
writer({
|
writer({
|
||||||
"type": "custom",
|
"type": "custom",
|
||||||
"data": {
|
"data": {
|
||||||
"type": "done",
|
"type": "done",
|
||||||
"token_usage": _get_attr(state, "last_token_usage", {}),
|
"token_usage": state.last_token_usage,
|
||||||
"elapsed_time": _get_attr(state, "last_elapsed_time", 0.0)
|
"elapsed_time": state.last_elapsed_time
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
info("🏁 [完成事件] 已发送完成事件,包含token使用情况和耗时信息")
|
info("🏁 [完成事件] 已发送完成事件,包含token使用情况和耗时信息")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error(f"❌ [完成事件] 发送完成事件时发生异常: {e}")
|
warning(f"⚠️ [完成事件] 发送完成事件失败 (非致命): {e}")
|
||||||
|
except Exception as e:
|
||||||
|
warning(f"⚠️ [完成事件] 处理失败 (非致命): {e}")
|
||||||
|
|
||||||
log_state_change("finalize", state, "离开")
|
log_state_change("finalize", state, "离开")
|
||||||
return {}
|
return {}
|
||||||
@@ -1,18 +1,10 @@
|
|||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
from langchain_core.runnables.config import RunnableConfig
|
from langchain_core.runnables.config import RunnableConfig
|
||||||
from app.main_graph.state import MessagesState
|
from app.main_graph.state import MainGraphState
|
||||||
from app.memory.mem0_client import Mem0Client
|
from app.memory.mem0_client import Mem0Client
|
||||||
from app.logger import info
|
from app.logger import info
|
||||||
|
|
||||||
|
|
||||||
def _get_attr(state, attr_name, default=None):
|
|
||||||
"""通用方法:兼容 dict 和 dataclass 两种状态格式"""
|
|
||||||
if isinstance(state, dict):
|
|
||||||
return state.get(attr_name, default)
|
|
||||||
else:
|
|
||||||
return getattr(state, attr_name, default)
|
|
||||||
|
|
||||||
|
|
||||||
# 全局变量,在 GraphBuilder 中注入
|
# 全局变量,在 GraphBuilder 中注入
|
||||||
_mem0_client: Mem0Client = None
|
_mem0_client: Mem0Client = None
|
||||||
|
|
||||||
@@ -22,12 +14,12 @@ def set_mem0_client(client: Mem0Client):
|
|||||||
_mem0_client = client
|
_mem0_client = client
|
||||||
|
|
||||||
|
|
||||||
async def memory_trigger_node(state, config: RunnableConfig) -> Dict[str, Any]:
|
async def memory_trigger_node(state: MainGraphState, config: RunnableConfig) -> Dict[str, Any]:
|
||||||
"""检测用户消息中的记忆指令,若命中则主动调用 Mem0 存储"""
|
"""检测用户消息中的记忆指令,若命中则主动调用 Mem0 存储"""
|
||||||
if _mem0_client is None:
|
if _mem0_client is None:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
messages = _get_attr(state, "messages", [])
|
messages = state.messages
|
||||||
if not messages:
|
if not messages:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|||||||
@@ -6,20 +6,12 @@
|
|||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
# 本地模块
|
# 本地模块
|
||||||
from app.main_graph.state import MessagesState
|
from app.main_graph.state import MainGraphState
|
||||||
from app.memory.mem0_client import Mem0Client
|
from app.memory.mem0_client import Mem0Client
|
||||||
from app.utils.logging import log_state_change
|
from app.utils.logging import log_state_change
|
||||||
from app.logger import debug
|
from app.logger import debug
|
||||||
|
|
||||||
|
|
||||||
def _get_attr(state, attr_name, default=None):
|
|
||||||
"""通用方法:兼容 dict 和 dataclass 两种状态格式"""
|
|
||||||
if isinstance(state, dict):
|
|
||||||
return state.get(attr_name, default)
|
|
||||||
else:
|
|
||||||
return getattr(state, attr_name, default)
|
|
||||||
|
|
||||||
|
|
||||||
def create_retrieve_memory_node(mem0_client: Mem0Client):
|
def create_retrieve_memory_node(mem0_client: Mem0Client):
|
||||||
"""
|
"""
|
||||||
工厂函数:创建记忆检索节点
|
工厂函数:创建记忆检索节点
|
||||||
@@ -33,12 +25,12 @@ def create_retrieve_memory_node(mem0_client: Mem0Client):
|
|||||||
|
|
||||||
from langchain_core.runnables.config import RunnableConfig
|
from langchain_core.runnables.config import RunnableConfig
|
||||||
|
|
||||||
async def retrieve_memory(state, config: RunnableConfig) -> Dict[str, Any]:
|
async def retrieve_memory(state: MainGraphState, config: RunnableConfig) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
记忆检索节点 - 使用 Mem0
|
记忆检索节点 - 使用 Mem0
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
state: 当前对话状态(兼容 dict 和 dataclass)
|
state: 当前对话状态
|
||||||
config: 运行时配置
|
config: 运行时配置
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -49,16 +41,16 @@ def create_retrieve_memory_node(mem0_client: Mem0Client):
|
|||||||
# 从 metadata 中获取 user_id
|
# 从 metadata 中获取 user_id
|
||||||
user_id = config.get("metadata", {}).get("user_id", "default_user")
|
user_id = config.get("metadata", {}).get("user_id", "default_user")
|
||||||
|
|
||||||
# 兼容 dict 和对象两种消息格式
|
# 获取最后一条消息
|
||||||
messages = _get_attr(state, "messages", [])
|
messages = state.messages
|
||||||
last_msg = messages[-1] if messages else None
|
last_msg = messages[-1] if messages else None
|
||||||
|
query = ""
|
||||||
if last_msg:
|
if last_msg:
|
||||||
if isinstance(last_msg, dict):
|
if isinstance(last_msg, dict):
|
||||||
query = str(last_msg.get("content", ""))
|
query = str(last_msg.get("content", ""))
|
||||||
else:
|
else:
|
||||||
query = str(last_msg.content)
|
query = str(last_msg.content)
|
||||||
else:
|
|
||||||
query = ""
|
|
||||||
memory_text_parts = []
|
memory_text_parts = []
|
||||||
|
|
||||||
# 确保 Mem0 已初始化(懒加载)
|
# 确保 Mem0 已初始化(懒加载)
|
||||||
@@ -83,7 +75,7 @@ def create_retrieve_memory_node(mem0_client: Mem0Client):
|
|||||||
|
|
||||||
memory_context = "\n\n".join(memory_text_parts) if memory_text_parts else "暂无用户信息"
|
memory_context = "\n\n".join(memory_text_parts) if memory_text_parts else "暂无用户信息"
|
||||||
result = {"memory_context": memory_context}
|
result = {"memory_context": memory_context}
|
||||||
log_state_change("retrieve_memory", {**state, **result} if isinstance(state, dict) else state, "离开")
|
log_state_change("retrieve_memory", state, "离开")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
return retrieve_memory
|
return retrieve_memory
|
||||||
@@ -6,20 +6,12 @@
|
|||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
# 本地模块
|
# 本地模块
|
||||||
from app.main_graph.state import MessagesState
|
from app.main_graph.state import MainGraphState
|
||||||
from app.memory.mem0_client import Mem0Client
|
from app.memory.mem0_client import Mem0Client
|
||||||
from app.utils.logging import log_state_change
|
from app.utils.logging import log_state_change
|
||||||
from app.logger import debug, info, error, warning
|
from app.logger import debug, info, error, warning
|
||||||
|
|
||||||
|
|
||||||
def _get_attr(state, attr_name, default=None):
|
|
||||||
"""通用方法:兼容 dict 和 dataclass 两种状态格式"""
|
|
||||||
if isinstance(state, dict):
|
|
||||||
return state.get(attr_name, default)
|
|
||||||
else:
|
|
||||||
return getattr(state, attr_name, default)
|
|
||||||
|
|
||||||
|
|
||||||
def create_summarize_node(mem0_client: Mem0Client):
|
def create_summarize_node(mem0_client: Mem0Client):
|
||||||
"""
|
"""
|
||||||
工厂函数:创建记忆存储节点
|
工厂函数:创建记忆存储节点
|
||||||
@@ -33,12 +25,12 @@ def create_summarize_node(mem0_client: Mem0Client):
|
|||||||
|
|
||||||
from langchain_core.runnables.config import RunnableConfig
|
from langchain_core.runnables.config import RunnableConfig
|
||||||
|
|
||||||
async def summarize_conversation(state, config: RunnableConfig) -> Dict[str, Any]:
|
async def summarize_conversation(state: MainGraphState, config: RunnableConfig) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
记忆存储节点 - 使用 Mem0
|
记忆存储节点 - 使用 Mem0
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
state: 当前对话状态(兼容 dict 和 dataclass)
|
state: 当前对话状态
|
||||||
config: 运行时配置
|
config: 运行时配置
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -46,7 +38,7 @@ def create_summarize_node(mem0_client: Mem0Client):
|
|||||||
"""
|
"""
|
||||||
log_state_change("summarize", state, "进入")
|
log_state_change("summarize", state, "进入")
|
||||||
|
|
||||||
messages = _get_attr(state, "messages", [])
|
messages = state.messages
|
||||||
if len(messages) < 4:
|
if len(messages) < 4:
|
||||||
debug("📝 [记忆添加] 对话过短,跳过")
|
debug("📝 [记忆添加] 对话过短,跳过")
|
||||||
return {"turns_since_last_summary": 0}
|
return {"turns_since_last_summary": 0}
|
||||||
|
|||||||
@@ -10,17 +10,9 @@ from app.main_graph.graph import add_messages
|
|||||||
from langchain_core.messages import BaseMessage
|
from langchain_core.messages import BaseMessage
|
||||||
|
|
||||||
|
|
||||||
# ========== 兼容旧代码的类型 ==========
|
# ========== 兼容性注释(旧代码已移除,状态已整合到 MainGraphState) ==========
|
||||||
class MessagesState(TypedDict):
|
# 旧的 MessagesState 和 GraphContext 已完全整合到 MainGraphState
|
||||||
"""旧的MessagesState类型(保留兼容性)"""
|
# 不再需要单独的类型定义
|
||||||
messages: Annotated[Sequence[BaseMessage], add_messages]
|
|
||||||
|
|
||||||
|
|
||||||
class GraphContext(TypedDict):
|
|
||||||
"""旧的GraphContext类型(保留兼容性)"""
|
|
||||||
llm_calls: int
|
|
||||||
memory_context: str
|
|
||||||
system_prompt: str
|
|
||||||
|
|
||||||
|
|
||||||
# ========== 新的类型 ==========
|
# ========== 新的类型 ==========
|
||||||
@@ -57,49 +49,52 @@ class ErrorRecord:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class MainGraphState:
|
class MainGraphState:
|
||||||
"""
|
"""
|
||||||
主图状态 - React 循环推理版本
|
主图状态 - 整合了旧 MessagesState 的所有字段
|
||||||
|
|
||||||
包含:
|
包含:
|
||||||
1. 旧代码的MessagesState兼容性字段
|
- 旧代码的 MessagesState 兼容性字段
|
||||||
2. React 推理控制字段
|
- React 推理控制字段
|
||||||
3. 循环和错误处理
|
- 循环和错误处理
|
||||||
4. 子图结果占位
|
- 子图结果占位
|
||||||
5. 用户信息
|
- 用户信息
|
||||||
"""
|
"""
|
||||||
# ========== 兼容性字段(保留旧的MessagesState) ==========
|
# ========== 旧 MessagesState 兼容性字段 ==========
|
||||||
messages: Annotated[Sequence[BaseMessage], add_messages] = field(default_factory=list)
|
messages: Annotated[Sequence[BaseMessage], add_messages] = field(default_factory=list)
|
||||||
llm_calls: int = 0
|
llm_calls: int = 0
|
||||||
memory_context: str = ""
|
memory_context: str = ""
|
||||||
system_prompt: str = ""
|
system_prompt: str = ""
|
||||||
|
turns_since_last_summary: int = 0 # 新增:来自旧状态
|
||||||
|
last_token_usage: Dict[str, Any] = field(default_factory=dict) # 新增:来自旧状态
|
||||||
|
last_elapsed_time: float = 0.0 # 新增:来自旧状态
|
||||||
|
|
||||||
# ========== 主图控制字段 ==========
|
# ========== 主图控制字段 ==========
|
||||||
user_query: str = "" # 用户当前查询
|
user_query: str = ""
|
||||||
current_action: CurrentAction = CurrentAction.NONE # 当前操作
|
current_action: CurrentAction = CurrentAction.NONE
|
||||||
intent_confidence: float = 0.0 # 意图识别置信度
|
intent_confidence: float = 0.0
|
||||||
|
|
||||||
# ========== React 推理专用字段 ==========
|
# ========== React 推理专用字段 ==========
|
||||||
reasoning_step: int = 0 # 当前推理步数
|
reasoning_step: int = 0
|
||||||
max_steps: int = 40 # 最大推理步数
|
max_steps: int = 40
|
||||||
last_action: str = "" # 上一步动作
|
last_action: str = ""
|
||||||
reasoning_history: List[Dict[str, Any]] = field(default_factory=list) # 推理历史
|
reasoning_history: List[Dict[str, Any]] = field(default_factory=list)
|
||||||
|
|
||||||
# ========== RAG 相关字段 ==========
|
# ========== RAG 相关字段 ==========
|
||||||
rag_context: str = "" # RAG 检索到的上下文
|
rag_context: str = ""
|
||||||
rag_retrieved: bool = False # 是否已经检索过
|
rag_retrieved: bool = False
|
||||||
rag_docs: List[Dict[str, Any]] = field(default_factory=list) # 检索到的文档
|
rag_docs: List[Dict[str, Any]] = field(default_factory=list)
|
||||||
|
|
||||||
# ========== 联网搜索相关字段 ⭐ 新增 ==========
|
# ========== 联网搜索相关字段 ==========
|
||||||
web_search_results: List[str] = field(default_factory=list) # 联网搜索结果
|
web_search_results: List[str] = field(default_factory=list)
|
||||||
|
|
||||||
# ========== 错误处理字段 ==========
|
# ========== 错误处理字段 ==========
|
||||||
errors: List[ErrorRecord] = field(default_factory=list) # 错误列表
|
errors: List[ErrorRecord] = field(default_factory=list)
|
||||||
current_error: Optional[ErrorRecord] = None # 当前错误
|
current_error: Optional[ErrorRecord] = None
|
||||||
retry_action: Optional[str] = None # 重试动作
|
retry_action: Optional[str] = None
|
||||||
|
|
||||||
# ========== 子图结果占位 ==========
|
# ========== 子图结果占位 ==========
|
||||||
news_result: Optional[Dict[str, Any]] = None # 资讯子图结果
|
news_result: Optional[Dict[str, Any]] = None
|
||||||
dictionary_result: Optional[Dict[str, Any]] = None # 词典子图结果
|
dictionary_result: Optional[Dict[str, Any]] = None
|
||||||
contact_result: Optional[Dict[str, Any]] = None # 通讯录子图结果
|
contact_result: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
# ========== 用户信息 ==========
|
# ========== 用户信息 ==========
|
||||||
user_id: str = ""
|
user_id: str = ""
|
||||||
|
|||||||
@@ -1,13 +1,12 @@
|
|||||||
"""
|
"""
|
||||||
整合后的完整主图构建器 - 结合旧图和新图的优点
|
整合后的完整主图构建器 - 所有节点都直接操作 MainGraphState
|
||||||
Main Graph Builder - Integrated Full Version (Old + New)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from app.main_graph.graph import StateGraph, START, END
|
from app.main_graph.graph import StateGraph, START, END
|
||||||
from typing import Dict, Any, Optional
|
from typing import Dict, Any, Optional
|
||||||
from langchain_core.runnables.config import RunnableConfig
|
from langchain_core.runnables.config import RunnableConfig
|
||||||
|
|
||||||
from app.main_graph.state import MainGraphState, CurrentAction, MessagesState
|
from app.main_graph.state import MainGraphState
|
||||||
from app.main_graph.nodes.react_nodes import (
|
from app.main_graph.nodes.react_nodes import (
|
||||||
init_state_node,
|
init_state_node,
|
||||||
react_reason_node,
|
react_reason_node,
|
||||||
@@ -28,16 +27,21 @@ from app.memory.mem0_client import Mem0Client
|
|||||||
from app.logger import info, debug
|
from app.logger import info, debug
|
||||||
|
|
||||||
|
|
||||||
# ========== 全局变量(用于传递 mem0_client)==========
|
# ========== 检查是否需要总结 ==========
|
||||||
# 这样就不用改旧节点的签名了
|
def should_summarize(state: MainGraphState) -> str:
|
||||||
_global_mem0_client: Optional[Mem0Client] = None
|
"""
|
||||||
|
检查是否需要总结对话(对话足够长时)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: 当前图状态
|
||||||
|
|
||||||
def set_global_mem0_client(client: Mem0Client):
|
Returns:
|
||||||
"""设置全局的 mem0_client"""
|
"summarize" 或 "finalize"
|
||||||
global _global_mem0_client
|
"""
|
||||||
_global_mem0_client = client
|
if state.turns_since_last_summary >= 5: # 每5轮对话总结一次
|
||||||
set_mem0_client(client) # 同时设置给 memory_trigger_node
|
return "summarize"
|
||||||
|
else:
|
||||||
|
return "finalize"
|
||||||
|
|
||||||
|
|
||||||
# ========== 子图包装器(处理子图错误传递)==========
|
# ========== 子图包装器(处理子图错误传递)==========
|
||||||
@@ -93,65 +97,18 @@ def wrap_subgraph_for_error_handling(subgraph, name: str):
|
|||||||
return wrapped_node
|
return wrapped_node
|
||||||
|
|
||||||
|
|
||||||
# ========== 检查是否需要总结 ==========
|
|
||||||
def should_summarize(state: MainGraphState) -> str:
|
|
||||||
"""
|
|
||||||
检查是否需要总结对话(对话足够长时)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state: 当前图状态
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
"summarize" 或 "finalize"
|
|
||||||
"""
|
|
||||||
messages = getattr(state, 'messages', [])
|
|
||||||
if len(messages) >= 4:
|
|
||||||
return "summarize"
|
|
||||||
else:
|
|
||||||
return "finalize"
|
|
||||||
|
|
||||||
|
|
||||||
# ========== 兼容层:让旧节点工作在新状态上 ==========
|
|
||||||
def adapt_old_node_for_new_state(old_node):
|
|
||||||
"""
|
|
||||||
适配旧节点(期望 MessagesState)到新状态 MainGraphState
|
|
||||||
|
|
||||||
Args:
|
|
||||||
old_node: 旧节点函数
|
|
||||||
|
|
||||||
Returns: 适配后的节点函数
|
|
||||||
"""
|
|
||||||
async def adapted_node(state: MainGraphState, config: RunnableConfig) -> Dict[str, Any]:
|
|
||||||
# 把 MainGraphState 转换为 MessagesState(旧节点期望的格式)
|
|
||||||
old_state: MessagesState = {
|
|
||||||
"messages": state.messages,
|
|
||||||
"llm_calls": getattr(state, 'llm_calls', 0),
|
|
||||||
"memory_context": getattr(state, 'memory_context', ""),
|
|
||||||
"system_prompt": getattr(state, 'system_prompt', "")
|
|
||||||
}
|
|
||||||
|
|
||||||
# 调用旧节点
|
|
||||||
result = await old_node(old_state, config)
|
|
||||||
|
|
||||||
# 把结果更新回 MainGraphState
|
|
||||||
if "memory_context" in result:
|
|
||||||
state.memory_context = result["memory_context"]
|
|
||||||
if "llm_calls" in result:
|
|
||||||
state.llm_calls = result["llm_calls"]
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
return adapted_node
|
|
||||||
|
|
||||||
|
|
||||||
# ========== 主图构建 ==========
|
# ========== 主图构建 ==========
|
||||||
def build_react_main_graph(llm=None, tools=None, mem0_client=None) -> StateGraph:
|
def build_react_main_graph(llm=None, tools=None, mem0_client=None) -> StateGraph:
|
||||||
"""
|
"""
|
||||||
构建整合后的完整主图(简化版:先让系统工作起来)
|
构建整合后的完整主图
|
||||||
|
|
||||||
完整流程:
|
完整流程:
|
||||||
START
|
START
|
||||||
↓
|
↓
|
||||||
|
retrieve_memory (从Mem0检索长期记忆)
|
||||||
|
↓
|
||||||
|
memory_trigger (记忆触发器)
|
||||||
|
↓
|
||||||
init_state (初始化)
|
init_state (初始化)
|
||||||
↓
|
↓
|
||||||
react_reason (推理) ←───────────────────────┐
|
react_reason (推理) ←───────────────────────┐
|
||||||
@@ -165,6 +122,10 @@ def build_react_main_graph(llm=None, tools=None, mem0_client=None) -> StateGraph
|
|||||||
├─ handle_error → (重试或结束) ────────────┤
|
├─ handle_error → (重试或结束) ────────────┤
|
||||||
└─ llm_call (大模型调用) ←────────────────┘
|
└─ llm_call (大模型调用) ←────────────────┘
|
||||||
↓
|
↓
|
||||||
|
检查:需要总结吗?
|
||||||
|
├─ 是 → summarize (提交给Mem0存储)
|
||||||
|
└─ 否 → (跳过)
|
||||||
|
↓
|
||||||
finalize (发送完成事件)
|
finalize (发送完成事件)
|
||||||
↓
|
↓
|
||||||
END
|
END
|
||||||
@@ -172,7 +133,7 @@ def build_react_main_graph(llm=None, tools=None, mem0_client=None) -> StateGraph
|
|||||||
# 创建图
|
# 创建图
|
||||||
graph = StateGraph(MainGraphState)
|
graph = StateGraph(MainGraphState)
|
||||||
|
|
||||||
# 设置全局 mem0_client (暂时不用记忆功能)
|
# 设置全局 mem0_client
|
||||||
if mem0_client:
|
if mem0_client:
|
||||||
set_global_mem0_client(mem0_client)
|
set_global_mem0_client(mem0_client)
|
||||||
|
|
||||||
@@ -181,8 +142,20 @@ def build_react_main_graph(llm=None, tools=None, mem0_client=None) -> StateGraph
|
|||||||
if llm is not None:
|
if llm is not None:
|
||||||
llm_node = create_llm_call_node(llm, tools or [])
|
llm_node = create_llm_call_node(llm, tools or [])
|
||||||
|
|
||||||
|
retrieve_memory_node = None
|
||||||
|
summarize_node = None
|
||||||
|
if mem0_client:
|
||||||
|
retrieve_memory_node = create_retrieve_memory_node(mem0_client)
|
||||||
|
summarize_node = create_summarize_node(mem0_client)
|
||||||
|
|
||||||
# ========== 添加节点 ==========
|
# ========== 添加节点 ==========
|
||||||
# 简化:先不用记忆检索相关节点
|
|
||||||
|
# 第一阶段:记忆检索
|
||||||
|
if retrieve_memory_node:
|
||||||
|
graph.add_node("retrieve_memory", retrieve_memory_node)
|
||||||
|
graph.add_node("memory_trigger", memory_trigger_node)
|
||||||
|
|
||||||
|
# 第二阶段:React 循环推理
|
||||||
graph.add_node("init_state", init_state_node)
|
graph.add_node("init_state", init_state_node)
|
||||||
graph.add_node("react_reason", react_reason_node)
|
graph.add_node("react_reason", react_reason_node)
|
||||||
graph.add_node("rag_retrieve", rag_retrieve_node)
|
graph.add_node("rag_retrieve", rag_retrieve_node)
|
||||||
@@ -210,15 +183,25 @@ def build_react_main_graph(llm=None, tools=None, mem0_client=None) -> StateGraph
|
|||||||
wrap_subgraph_for_error_handling(news_analysis_graph.compile(), "news_analysis")
|
wrap_subgraph_for_error_handling(news_analysis_graph.compile(), "news_analysis")
|
||||||
)
|
)
|
||||||
|
|
||||||
# 完成节点
|
# 第三阶段:完成处理
|
||||||
|
if summarize_node:
|
||||||
|
graph.add_node("summarize", summarize_node)
|
||||||
graph.add_node("finalize", finalize_node)
|
graph.add_node("finalize", finalize_node)
|
||||||
|
|
||||||
# ========== 添加边 ==========
|
# ========== 添加边 ==========
|
||||||
# 简化:直接从 START 到 init_state
|
|
||||||
graph.add_edge(START, "init_state")
|
# 第一阶段:记忆检索
|
||||||
|
if retrieve_memory_node:
|
||||||
|
graph.add_edge(START, "retrieve_memory")
|
||||||
|
graph.add_edge("retrieve_memory", "memory_trigger")
|
||||||
|
else:
|
||||||
|
graph.add_edge(START, "memory_trigger")
|
||||||
|
|
||||||
|
# 进入第二阶段
|
||||||
|
graph.add_edge("memory_trigger", "init_state")
|
||||||
graph.add_edge("init_state", "react_reason")
|
graph.add_edge("init_state", "react_reason")
|
||||||
|
|
||||||
# 条件路由
|
# 第二阶段:React 循环推理
|
||||||
graph.add_conditional_edges(
|
graph.add_conditional_edges(
|
||||||
"react_reason",
|
"react_reason",
|
||||||
route_by_reasoning,
|
route_by_reasoning,
|
||||||
@@ -241,14 +224,27 @@ def build_react_main_graph(llm=None, tools=None, mem0_client=None) -> StateGraph
|
|||||||
graph.add_edge("news_analysis_subgraph", "react_reason")
|
graph.add_edge("news_analysis_subgraph", "react_reason")
|
||||||
graph.add_edge("handle_error", "react_reason")
|
graph.add_edge("handle_error", "react_reason")
|
||||||
|
|
||||||
# llm_call 之后直接到 finalize
|
# 第三阶段:llm_call 后进入完成处理
|
||||||
if llm_node is not None:
|
if llm_node is not None:
|
||||||
|
if summarize_node:
|
||||||
|
# 检查是否需要总结
|
||||||
|
graph.add_conditional_edges(
|
||||||
|
"llm_call",
|
||||||
|
should_summarize,
|
||||||
|
{
|
||||||
|
"summarize": "summarize",
|
||||||
|
"finalize": "finalize"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
graph.add_edge("summarize", "finalize")
|
||||||
|
else:
|
||||||
|
# 没有 summarize 节点,直接 finalize
|
||||||
graph.add_edge("llm_call", "finalize")
|
graph.add_edge("llm_call", "finalize")
|
||||||
|
|
||||||
# 完成
|
# 完成
|
||||||
graph.add_edge("finalize", END)
|
graph.add_edge("finalize", END)
|
||||||
|
|
||||||
info("✅ [图构建] 整合后的完整主图构建完成(简化版)")
|
info("✅ [图构建] 整合后的完整主图构建完成")
|
||||||
|
|
||||||
return graph
|
return graph
|
||||||
|
|
||||||
@@ -265,6 +261,5 @@ def build_main_graph() -> StateGraph:
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"build_react_main_graph",
|
"build_react_main_graph",
|
||||||
"build_main_graph",
|
"build_main_graph",
|
||||||
"wrap_subgraph_for_error_handling",
|
"wrap_subgraph_for_error_handling"
|
||||||
"set_global_mem0_client"
|
|
||||||
]
|
]
|
||||||
@@ -7,14 +7,6 @@ from app.config import ENABLE_GRAPH_TRACE
|
|||||||
from app.logger import debug, info
|
from app.logger import debug, info
|
||||||
|
|
||||||
|
|
||||||
def _get_attr(state, attr_name, default=None):
|
|
||||||
"""通用方法:兼容 dict 和 dataclass 两种状态格式"""
|
|
||||||
if isinstance(state, dict):
|
|
||||||
return state.get(attr_name, default)
|
|
||||||
else:
|
|
||||||
return getattr(state, attr_name, default)
|
|
||||||
|
|
||||||
|
|
||||||
def log_state_change(node_name: str, state, prefix: str = "进入"):
|
def log_state_change(node_name: str, state, prefix: str = "进入"):
|
||||||
"""
|
"""
|
||||||
记录状态变化日志
|
记录状态变化日志
|
||||||
@@ -26,7 +18,13 @@ def log_state_change(node_name: str, state, prefix: str = "进入"):
|
|||||||
"""
|
"""
|
||||||
from app.logger import info
|
from app.logger import info
|
||||||
|
|
||||||
messages = _get_attr(state, "messages", [])
|
# 获取 messages
|
||||||
|
messages = []
|
||||||
|
if isinstance(state, dict):
|
||||||
|
messages = state.get("messages", [])
|
||||||
|
else:
|
||||||
|
messages = getattr(state, "messages", [])
|
||||||
|
|
||||||
msg_count = len(messages)
|
msg_count = len(messages)
|
||||||
last_msg = messages[-1] if messages else None
|
last_msg = messages[-1] if messages else None
|
||||||
last_info = ""
|
last_info = ""
|
||||||
@@ -57,13 +55,13 @@ def print_llm_input(prompt_value):
|
|||||||
|
|
||||||
messages = prompt_value.messages # ChatPromptValue 提供 .messages 属性
|
messages = prompt_value.messages # ChatPromptValue 提供 .messages 属性
|
||||||
|
|
||||||
debug("\n" + "=" * 80)
|
debug("\n" + "="*80)
|
||||||
debug("📤 [LLM输入] 格式化后发送给大模型的完整消息:")
|
debug("📥 [LLM输入] 格式化后发送给大模型的完整消息:")
|
||||||
debug(f" 总消息数: {len(messages)}")
|
debug(f" 总消息数: {len(messages)}")
|
||||||
debug("-" * 80)
|
debug("-"*80)
|
||||||
for i, msg in enumerate(messages):
|
for i, msg in enumerate(messages):
|
||||||
content_preview = str(msg.content) # 完整输出
|
content_preview = str(msg.content) # 完整输出
|
||||||
debug(f" [{i}] {msg.type.upper():10s}: {content_preview}")
|
debug(f" [{i}] {msg.type.upper():10s}: {content_preview}")
|
||||||
debug("\n" + "=" * 80 + "\n")
|
debug("\n" + "="*80 + "\n")
|
||||||
|
|
||||||
return prompt_value
|
return prompt_value
|
||||||
Reference in New Issue
Block a user