整合旧图和新图:添加完整的记忆检索、总结和完成流程
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 5m42s
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 5m42s
This commit is contained in:
@@ -13,7 +13,7 @@ from app.main_graph.config import set_stream_writer
|
|||||||
from ..model_services.chat_services import get_all_chat_services, LocalVLLMChatProvider
|
from ..model_services.chat_services import get_all_chat_services, LocalVLLMChatProvider
|
||||||
from app.main_graph.utils.rag_initializer import init_rag_tool
|
from app.main_graph.utils.rag_initializer import init_rag_tool
|
||||||
from app.core.intent_classifier import get_intent_classifier
|
from app.core.intent_classifier import get_intent_classifier
|
||||||
from app.logger import info, warning
|
from app.logger import info, warning, error
|
||||||
from app.main_graph.state import MainGraphState, CurrentAction
|
from app.main_graph.state import MainGraphState, CurrentAction
|
||||||
|
|
||||||
|
|
||||||
@@ -27,8 +27,19 @@ class AIAgentService:
|
|||||||
self.intent_classifier = get_intent_classifier()
|
self.intent_classifier = get_intent_classifier()
|
||||||
# RAG 管道(可选,需要时设置)
|
# RAG 管道(可选,需要时设置)
|
||||||
self.rag_pipeline = None
|
self.rag_pipeline = None
|
||||||
|
# Mem0 客户端
|
||||||
|
self.mem0_client = None
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
|
# 0. 初始化 Mem0 客户端
|
||||||
|
from app.memory.mem0_client import Mem0Client
|
||||||
|
# 创建一个临时的 LLM 用于 Mem0(用第一个可用的)
|
||||||
|
chat_services = get_all_chat_services()
|
||||||
|
temp_llm = None
|
||||||
|
if chat_services:
|
||||||
|
temp_llm = list(chat_services.values())[0]
|
||||||
|
self.mem0_client = Mem0Client(temp_llm)
|
||||||
|
|
||||||
# 1. 初始化 RAG 工具(如果需要)
|
# 1. 初始化 RAG 工具(如果需要)
|
||||||
def create_local_llm():
|
def create_local_llm():
|
||||||
provider = LocalVLLMChatProvider()
|
provider = LocalVLLMChatProvider()
|
||||||
@@ -42,11 +53,14 @@ class AIAgentService:
|
|||||||
set_global_rag_tool(rag_tool)
|
set_global_rag_tool(rag_tool)
|
||||||
|
|
||||||
# 2. 构建各模型的 Graph(使用新版 React 模式)
|
# 2. 构建各模型的 Graph(使用新版 React 模式)
|
||||||
chat_services = get_all_chat_services()
|
|
||||||
for name, llm in chat_services.items():
|
for name, llm in chat_services.items():
|
||||||
try:
|
try:
|
||||||
info(f"🔄 初始化模型 '{name}'...")
|
info(f"🔄 初始化模型 '{name}'...")
|
||||||
graph = build_react_main_graph(llm=llm, tools=self.tools).compile(checkpointer=self.checkpointer)
|
graph = build_react_main_graph(
|
||||||
|
llm=llm,
|
||||||
|
tools=self.tools,
|
||||||
|
mem0_client=self.mem0_client
|
||||||
|
).compile(checkpointer=self.checkpointer)
|
||||||
self.graphs[name] = graph
|
self.graphs[name] = graph
|
||||||
info(f"✅ 模型 '{name}' 初始化成功")
|
info(f"✅ 模型 '{name}' 初始化成功")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -126,6 +126,9 @@ def create_llm_call_node(llm, tools: list):
|
|||||||
debug(f"📋 [LLM统计] 详细用量: {token_usage}")
|
debug(f"📋 [LLM统计] 详细用量: {token_usage}")
|
||||||
debug("="*80 + "\n")
|
debug("="*80 + "\n")
|
||||||
|
|
||||||
|
# 检查是否有工具调用
|
||||||
|
has_tool_calls = hasattr(response, 'tool_calls') and len(response.tool_calls) > 0
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
"messages": [response],
|
"messages": [response],
|
||||||
"llm_calls": getattr(state, 'llm_calls', 0) + 1,
|
"llm_calls": getattr(state, 'llm_calls', 0) + 1,
|
||||||
@@ -134,7 +137,8 @@ def create_llm_call_node(llm, tools: list):
|
|||||||
"turns_since_last_summary": getattr(state, 'turns_since_last_summary', 0) + 1,
|
"turns_since_last_summary": getattr(state, 'turns_since_last_summary', 0) + 1,
|
||||||
"final_result": response.content,
|
"final_result": response.content,
|
||||||
"success": True,
|
"success": True,
|
||||||
"current_phase": "done"
|
"current_phase": "done",
|
||||||
|
"has_tool_calls": has_tool_calls
|
||||||
}
|
}
|
||||||
|
|
||||||
log_state_change("llm_call", {**state, **result}, "离开")
|
log_state_change("llm_call", {**state, **result}, "离开")
|
||||||
|
|||||||
@@ -1,12 +1,13 @@
|
|||||||
"""
|
"""
|
||||||
React 模式主图构建器 - 完整循环推理版本
|
整合后的完整主图构建器 - 结合旧图和新图的优点
|
||||||
Main Graph Builder - Full React Mode with Loop Reasoning
|
Main Graph Builder - Integrated Full Version (Old + New)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from app.main_graph.graph import StateGraph, START, END
|
from app.main_graph.graph import StateGraph, START, END
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any, Optional
|
||||||
|
from langchain_core.runnables.config import RunnableConfig
|
||||||
|
|
||||||
from app.main_graph.state import MainGraphState, CurrentAction
|
from app.main_graph.state import MainGraphState, CurrentAction, MessagesState
|
||||||
from app.main_graph.nodes.react_nodes import (
|
from app.main_graph.nodes.react_nodes import (
|
||||||
init_state_node,
|
init_state_node,
|
||||||
react_reason_node,
|
react_reason_node,
|
||||||
@@ -16,12 +17,30 @@ from app.main_graph.nodes.react_nodes import (
|
|||||||
)
|
)
|
||||||
from app.main_graph.nodes.llm_call import create_llm_call_node
|
from app.main_graph.nodes.llm_call import create_llm_call_node
|
||||||
from app.main_graph.nodes.rag_nodes import rag_retrieve_node
|
from app.main_graph.nodes.rag_nodes import rag_retrieve_node
|
||||||
|
from app.main_graph.nodes.retrieve_memory import create_retrieve_memory_node
|
||||||
|
from app.main_graph.nodes.memory_trigger import memory_trigger_node, set_mem0_client
|
||||||
|
from app.main_graph.nodes.summarize import create_summarize_node
|
||||||
|
from app.main_graph.nodes.finalize import finalize_node
|
||||||
from app.subgraphs.contact import build_contact_subgraph
|
from app.subgraphs.contact import build_contact_subgraph
|
||||||
from app.subgraphs.dictionary import build_dictionary_subgraph
|
from app.subgraphs.dictionary import build_dictionary_subgraph
|
||||||
from app.subgraphs.news_analysis import build_news_analysis_subgraph
|
from app.subgraphs.news_analysis import build_news_analysis_subgraph
|
||||||
|
from app.memory.mem0_client import Mem0Client
|
||||||
|
from app.logger import info, debug
|
||||||
|
|
||||||
|
|
||||||
# ========== 子图包装器(处理子图错误传递) ==========
|
# ========== 全局变量(用于传递 mem0_client)==========
|
||||||
|
# 这样就不用改旧节点的签名了
|
||||||
|
_global_mem0_client: Optional[Mem0Client] = None
|
||||||
|
|
||||||
|
|
||||||
|
def set_global_mem0_client(client: Mem0Client):
|
||||||
|
"""设置全局的 mem0_client"""
|
||||||
|
global _global_mem0_client
|
||||||
|
_global_mem0_client = client
|
||||||
|
set_mem0_client(client) # 同时设置给 memory_trigger_node
|
||||||
|
|
||||||
|
|
||||||
|
# ========== 子图包装器(处理子图错误传递)==========
|
||||||
def wrap_subgraph_for_error_handling(subgraph, name: str):
|
def wrap_subgraph_for_error_handling(subgraph, name: str):
|
||||||
"""
|
"""
|
||||||
包装子图,使其错误能传递给主图
|
包装子图,使其错误能传递给主图
|
||||||
@@ -74,64 +93,126 @@ def wrap_subgraph_for_error_handling(subgraph, name: str):
|
|||||||
return wrapped_node
|
return wrapped_node
|
||||||
|
|
||||||
|
|
||||||
# ========== 主图构建 ==========
|
# ========== 检查是否需要总结 ==========
|
||||||
def build_react_main_graph(llm=None, tools=None) -> StateGraph:
|
def should_summarize(state: MainGraphState) -> str:
|
||||||
"""
|
"""
|
||||||
构建完整的 React 模式主图
|
检查是否需要总结对话(对话足够长时)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: 当前图状态
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
"summarize" 或 "finalize"
|
||||||
|
"""
|
||||||
|
messages = getattr(state, 'messages', [])
|
||||||
|
if len(messages) >= 4:
|
||||||
|
return "summarize"
|
||||||
|
else:
|
||||||
|
return "finalize"
|
||||||
|
|
||||||
流程:
|
|
||||||
|
# ========== 兼容层:让旧节点工作在新状态上 ==========
|
||||||
|
def adapt_old_node_for_new_state(old_node):
|
||||||
|
"""
|
||||||
|
适配旧节点(期望 MessagesState)到新状态 MainGraphState
|
||||||
|
|
||||||
|
Args:
|
||||||
|
old_node: 旧节点函数
|
||||||
|
|
||||||
|
Returns: 适配后的节点函数
|
||||||
|
"""
|
||||||
|
async def adapted_node(state: MainGraphState, config: RunnableConfig) -> Dict[str, Any]:
|
||||||
|
# 把 MainGraphState 转换为 MessagesState(旧节点期望的格式)
|
||||||
|
old_state: MessagesState = {
|
||||||
|
"messages": state.messages,
|
||||||
|
"llm_calls": getattr(state, 'llm_calls', 0),
|
||||||
|
"memory_context": getattr(state, 'memory_context', ""),
|
||||||
|
"system_prompt": getattr(state, 'system_prompt', "")
|
||||||
|
}
|
||||||
|
|
||||||
|
# 调用旧节点
|
||||||
|
result = await old_node(old_state, config)
|
||||||
|
|
||||||
|
# 把结果更新回 MainGraphState
|
||||||
|
if "memory_context" in result:
|
||||||
|
state.memory_context = result["memory_context"]
|
||||||
|
if "llm_calls" in result:
|
||||||
|
state.llm_calls = result["llm_calls"]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
return adapted_node
|
||||||
|
|
||||||
|
|
||||||
|
# ========== 主图构建 ==========
|
||||||
|
def build_react_main_graph(llm=None, tools=None, mem0_client=None) -> StateGraph:
|
||||||
|
"""
|
||||||
|
构建整合后的完整主图
|
||||||
|
|
||||||
|
完整流程:
|
||||||
START
|
START
|
||||||
↓
|
↓
|
||||||
init_state (初始化)
|
retrieve_memory (从Mem0检索长期记忆) ← 来自旧图
|
||||||
↓
|
↓
|
||||||
react_reason (推理) ←──────────────┐
|
memory_trigger (记忆触发器) ← 来自旧图
|
||||||
↓ │
|
↓
|
||||||
条件路由 │
|
init_state (初始化) ← 来自新图
|
||||||
├─ rag_retrieve →───────────────┤
|
↓
|
||||||
├─ contact_subgraph →───────────┤
|
react_reason (推理) ←──────────────────────┐
|
||||||
├─ dictionary_subgraph →────────┤
|
↓ │
|
||||||
├─ news_analysis_subgraph →─────┤
|
条件路由 │
|
||||||
├─ handle_error → (重试或结束) ─┤
|
├─ rag_retrieve →─────────────────────────┤
|
||||||
└─ llm_call (大模型调用) ←──────┘
|
├─ contact_subgraph →─────────────────────┤
|
||||||
|
├─ dictionary_subgraph →──────────────────┤
|
||||||
|
├─ news_analysis_subgraph →───────────────┤
|
||||||
|
├─ web_search →───────────────────────────┤
|
||||||
|
├─ handle_error → (重试或结束) ───────────┤
|
||||||
|
└─ llm_call (大模型调用) ←────────────────┘
|
||||||
↓
|
↓
|
||||||
🔍 观察 (检查 tool_calls)
|
检查:需要总结?
|
||||||
|
├─ 是 → summarize (提交给Mem0存储) ← 来自旧图
|
||||||
|
└─ 否 → (跳过)
|
||||||
↓
|
↓
|
||||||
[有工具调用?]
|
finalize (发送完成事件) ← 来自旧图
|
||||||
├─ 是 → 执行工具 → 回到 llm_call
|
↓
|
||||||
└─ 否 → END
|
END
|
||||||
"""
|
"""
|
||||||
# 创建图
|
# 创建图
|
||||||
graph = StateGraph(MainGraphState)
|
graph = StateGraph(MainGraphState)
|
||||||
|
|
||||||
# 创建 llm_call 节点
|
# 设置全局 mem0_client
|
||||||
|
if mem0_client:
|
||||||
|
set_global_mem0_client(mem0_client)
|
||||||
|
|
||||||
|
# 创建节点
|
||||||
llm_node = None
|
llm_node = None
|
||||||
if llm is not None:
|
if llm is not None:
|
||||||
llm_node = create_llm_call_node(llm, tools or [])
|
llm_node = create_llm_call_node(llm, tools or [])
|
||||||
|
|
||||||
|
retrieve_memory_node = None
|
||||||
|
summarize_node = None
|
||||||
|
if mem0_client:
|
||||||
|
retrieve_memory_node = create_retrieve_memory_node(mem0_client)
|
||||||
|
summarize_node = create_summarize_node(mem0_client)
|
||||||
|
|
||||||
# ========== 添加节点 ==========
|
# ========== 添加节点 ==========
|
||||||
|
|
||||||
# 1. 初始化节点
|
# 第一阶段:记忆检索(来自旧图)
|
||||||
|
if retrieve_memory_node:
|
||||||
|
graph.add_node("retrieve_memory", adapt_old_node_for_new_state(retrieve_memory_node))
|
||||||
|
graph.add_node("memory_trigger", memory_trigger_node)
|
||||||
|
|
||||||
|
# 第二阶段:React 循环推理(来自新图)
|
||||||
graph.add_node("init_state", init_state_node)
|
graph.add_node("init_state", init_state_node)
|
||||||
|
|
||||||
# 2. React 推理节点
|
|
||||||
graph.add_node("react_reason", react_reason_node)
|
graph.add_node("react_reason", react_reason_node)
|
||||||
|
|
||||||
# 3. RAG 检索节点
|
|
||||||
graph.add_node("rag_retrieve", rag_retrieve_node)
|
graph.add_node("rag_retrieve", rag_retrieve_node)
|
||||||
|
|
||||||
# 4. 联网搜索节点
|
|
||||||
graph.add_node("web_search", web_search_node)
|
graph.add_node("web_search", web_search_node)
|
||||||
|
|
||||||
# 5. 错误处理节点
|
|
||||||
graph.add_node("handle_error", error_handling_node)
|
graph.add_node("handle_error", error_handling_node)
|
||||||
|
|
||||||
# 6. LLM 调用节点(真正的大模型输出)
|
|
||||||
if llm_node is not None:
|
if llm_node is not None:
|
||||||
graph.add_node("llm_call", llm_node)
|
graph.add_node("llm_call", llm_node)
|
||||||
|
|
||||||
# ========== 添加子图节点 ==========
|
# 子图节点
|
||||||
|
|
||||||
# 构建并包装子图(带错误处理)
|
|
||||||
contact_graph = build_contact_subgraph()
|
contact_graph = build_contact_subgraph()
|
||||||
dictionary_graph = build_dictionary_subgraph()
|
dictionary_graph = build_dictionary_subgraph()
|
||||||
news_analysis_graph = build_news_analysis_subgraph()
|
news_analysis_graph = build_news_analysis_subgraph()
|
||||||
@@ -149,39 +230,40 @@ def build_react_main_graph(llm=None, tools=None) -> StateGraph:
|
|||||||
wrap_subgraph_for_error_handling(news_analysis_graph.compile(), "news_analysis")
|
wrap_subgraph_for_error_handling(news_analysis_graph.compile(), "news_analysis")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 第三阶段:完成处理(来自旧图)
|
||||||
|
if summarize_node:
|
||||||
|
graph.add_node("summarize", adapt_old_node_for_new_state(summarize_node))
|
||||||
|
graph.add_node("finalize", finalize_node)
|
||||||
|
|
||||||
# ========== 添加边 ==========
|
# ========== 添加边 ==========
|
||||||
|
|
||||||
# 1. START → init_state
|
# 第一阶段:记忆检索
|
||||||
graph.add_edge(START, "init_state")
|
if retrieve_memory_node:
|
||||||
|
graph.add_edge(START, "retrieve_memory")
|
||||||
|
graph.add_edge("retrieve_memory", "memory_trigger")
|
||||||
|
else:
|
||||||
|
graph.add_edge(START, "memory_trigger")
|
||||||
|
|
||||||
# 2. init_state → react_reason
|
# 进入第二阶段
|
||||||
|
graph.add_edge("memory_trigger", "init_state")
|
||||||
graph.add_edge("init_state", "react_reason")
|
graph.add_edge("init_state", "react_reason")
|
||||||
|
|
||||||
# 3. 条件路由:react_reason → 各分支
|
# 第二阶段:React 循环推理
|
||||||
graph.add_conditional_edges(
|
graph.add_conditional_edges(
|
||||||
"react_reason",
|
"react_reason",
|
||||||
route_by_reasoning,
|
route_by_reasoning,
|
||||||
{
|
{
|
||||||
# 检索分支 → 检索后回到推理
|
|
||||||
"rag_retrieve": "rag_retrieve",
|
"rag_retrieve": "rag_retrieve",
|
||||||
|
|
||||||
# 联网搜索分支
|
|
||||||
"web_search": "web_search",
|
"web_search": "web_search",
|
||||||
|
|
||||||
# 子图分支 → 子图后回到推理
|
|
||||||
"contact_subgraph": "contact_subgraph",
|
"contact_subgraph": "contact_subgraph",
|
||||||
"dictionary_subgraph": "dictionary_subgraph",
|
"dictionary_subgraph": "dictionary_subgraph",
|
||||||
"news_analysis_subgraph": "news_analysis_subgraph",
|
"news_analysis_subgraph": "news_analysis_subgraph",
|
||||||
|
|
||||||
# 错误处理分支
|
|
||||||
"handle_error": "handle_error",
|
"handle_error": "handle_error",
|
||||||
|
|
||||||
# LLM 调用分支 → 直接输出给用户
|
|
||||||
"llm_call": "llm_call"
|
"llm_call": "llm_call"
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# 4. 循环边:检索/搜索/子图/错误处理后 → 回到推理
|
# 循环边:检索/搜索/子图/错误处理后 → 回到推理
|
||||||
graph.add_edge("rag_retrieve", "react_reason")
|
graph.add_edge("rag_retrieve", "react_reason")
|
||||||
graph.add_edge("web_search", "react_reason")
|
graph.add_edge("web_search", "react_reason")
|
||||||
graph.add_edge("contact_subgraph", "react_reason")
|
graph.add_edge("contact_subgraph", "react_reason")
|
||||||
@@ -189,10 +271,27 @@ def build_react_main_graph(llm=None, tools=None) -> StateGraph:
|
|||||||
graph.add_edge("news_analysis_subgraph", "react_reason")
|
graph.add_edge("news_analysis_subgraph", "react_reason")
|
||||||
graph.add_edge("handle_error", "react_reason")
|
graph.add_edge("handle_error", "react_reason")
|
||||||
|
|
||||||
# 5. 条件路由:llm_call 后检查是否有工具调用
|
# 第三阶段:llm_call 后进入完成处理
|
||||||
# 注意:这里简化处理,先直接 END,后续再完善工具调用循环
|
|
||||||
if llm_node is not None:
|
if llm_node is not None:
|
||||||
graph.add_edge("llm_call", END)
|
if summarize_node:
|
||||||
|
# 检查是否需要总结
|
||||||
|
graph.add_conditional_edges(
|
||||||
|
"llm_call",
|
||||||
|
should_summarize,
|
||||||
|
{
|
||||||
|
"summarize": "summarize",
|
||||||
|
"finalize": "finalize"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
graph.add_edge("summarize", "finalize")
|
||||||
|
else:
|
||||||
|
# 没有 summarize 节点,直接 finalize
|
||||||
|
graph.add_edge("llm_call", "finalize")
|
||||||
|
|
||||||
|
# 完成
|
||||||
|
graph.add_edge("finalize", END)
|
||||||
|
|
||||||
|
info("✅ [图构建] 整合后的完整主图构建完成")
|
||||||
|
|
||||||
return graph
|
return graph
|
||||||
|
|
||||||
@@ -209,5 +308,6 @@ def build_main_graph() -> StateGraph:
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"build_react_main_graph",
|
"build_react_main_graph",
|
||||||
"build_main_graph",
|
"build_main_graph",
|
||||||
"wrap_subgraph_for_error_handling"
|
"wrap_subgraph_for_error_handling",
|
||||||
|
"set_global_mem0_client"
|
||||||
]
|
]
|
||||||
|
|||||||
Reference in New Issue
Block a user