refactor: 单图方案重构 + 动态模型选择 + chat_services优化
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 12m9s

## 核心改动

### 1. 单图方案重构
- 删除了多图(self.graphs),改为单图(self.graph)
- 新增 MainGraphState.current_model 字段用于运行时注入模型
- llm_call 节点改为动态选择模型(create_dynamic_llm_call_node)

### 2. chat_services 优化
- 添加 _cached_services 缓存,避免重复初始化
- 新增 get_cached_chat_services() 函数,用于单图注入
- 新增 _check_http_service_available() 统一HTTP探测逻辑
- 减少重复代码,LocalVLLMChatProvider和LocalSmallModelProvider共用探测方法

### 3. AIAgentService 重构
- initialize() 只构建一次图,传入 chat_services 字典
- 新增 _resolve_model() 模型回退逻辑
- 新增 _build_invocation() 统一构建调用参数
- process_message() 和 process_message_stream() 改为注入 current_model
- 流式处理代码拆分,增加可读性

### 4. 新增和删除文件
- 新增:backend/app/main_graph/main_graph_builder.py(图构建)
- 新增:backend/app/main_graph/subgraph_wrapper.py(子图封装)
- 新增:tools/test/test_tavily_search.py(测试)
- 删除:backend/app/main_graph/graph.py(旧图)
- 删除:backend/app/main_graph/utils/main_graph_builder.py(旧构建器)
- 删除:backend/app/main_graph/utils/__init__.py

### 5. 其他更新
- README.md:新增模型服务使用情况详解章节
- backend/app/model_services/__init__.py:新增 get_cached_chat_services 导出

## 方案优势

- 内存优化:N张图 → 1张图
- 灵活性:运行时动态选择模型,支持同会话不同模型
- 性能:模型服务缓存,初始化仅一次
- 可维护性:减少重复代码,统一HTTP探测逻辑
This commit is contained in:
2026-05-05 17:30:55 +08:00
parent 8b5fbbd395
commit b5c15ef445
25 changed files with 1225 additions and 830 deletions

View File

@@ -1,25 +1,28 @@
"""
AI Agent 服务类 - 支持多模型动态切换
AI Agent 服务类 - 单图方案 + 动态模型选择
接收外部传入的 checkpointer不负责管理连接生命周期
"""
import json
import asyncio
from typing import AsyncGenerator, Dict, Any, Optional, Tuple
# 本地模块
from ..main_graph.utils.main_graph_builder import build_react_main_graph
from ..model_services import get_cached_chat_services
from ..main_graph.main_graph_builder import build_react_main_graph
from ..main_graph.tools.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
from ..main_graph.config import set_stream_writer
from ..main_graph.utils.rag_initializer import init_rag_tool
from ..core.intent_classifier import get_intent_classifier
from ..logger import info, warning, error
from ..logger import debug, info, warning, error
from ..main_graph.state import MainGraphState, CurrentAction
class AIAgentService:
def __init__(self, checkpointer):
self.checkpointer = checkpointer
self.graphs = {}
self.graph = None # 只有一张图
self.chat_services = None # 缓存的模型字典
self.tools = AVAILABLE_TOOLS.copy()
self.tools_by_name = TOOLS_BY_NAME.copy()
# 添加:意图分类器
@@ -40,64 +43,94 @@ class AIAgentService:
self.tools.append(rag_tool)
self.tools_by_name[rag_tool.name] = rag_tool
self.rag_tool = rag_tool # 保存到实例变量,供 config 注入
# 2. 构建各模型的 Graph使用新版 React 模式)
for name, llm in chat_services.items():
try:
info(f"🔄 初始化模型 '{name}'...")
graph = build_react_main_graph(
llm=llm,
tools=self.tools,
mem0_client=self.mem0_client
).compile(checkpointer=self.checkpointer)
self.graphs[name] = graph
info(f"✅ 模型 '{name}' 初始化成功")
except Exception as e:
warning(f"⚠️ 模型 '{name}' 初始化失败: {e}")
if not self.graphs:
raise RuntimeError("没有可用的模型")
# 2. 获取缓存的模型字典
self.chat_services = get_cached_chat_services()
info(f"✅ 加载了 {len(self.chat_services)} 个可用模型: {list(self.chat_services.keys())}")
# 3. 只构建一次图(传入 chat_services 字典)
info(f"🔄 构建单图...")
graph_builder = build_react_main_graph(
chat_services=self.chat_services,
tools=self.tools,
mem0_client=self.mem0_client
)
self.graph = graph_builder.compile(checkpointer=self.checkpointer)
info(f"✅ 单图初始化完成")
return self
async def process_message(self, message: str, thread_id: str, model: str = "local", user_id: str = "default_user") -> dict:
"""处理用户消息返回包含回复、token统计和耗时的字典"""
if model not in self.graphs:
# 回退到第一个可用模型
available = list(self.graphs.keys())
if not available:
raise RuntimeError("没有可用的模型")
model = available[0]
warning(f"模型 '{model}' 不可用,已回退到 '{model}'")
def _resolve_model(self, model: str) -> str:
"""
解析并验证模型名称,不可用时回退到第一个可用模型
Args:
model: 目标模型名称
Returns:
实际使用的模型名称
"""
if not model or model not in self.chat_services:
fallback = next(iter(self.chat_services.keys()))
warning(f"模型 '{model}' 不可用,回退到 '{fallback}'")
return fallback
return model
graph = self.graphs[model]
def _build_invocation(
self, message: str, thread_id: str, model: str, user_id: str
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""
构建图调用所需的 config 和 input_state
Args:
message: 用户消息
thread_id: 会话 ID
model: 模型名称
user_id: 用户 ID
Returns:
(config, input_state) 元组
"""
config = {
"configurable": {
"thread_id": thread_id,
"rag_tool": getattr(self, "rag_tool", None), # 注入 RAG 工具
"rag_tool": getattr(self, "rag_tool", None),
},
"metadata": {"user_id": user_id}
}
# 新版状态输入:传入完整的 MainGraphState关键是 user_query
from ..main_graph.state import MainGraphState, CurrentAction
input_state = {
"user_query": message,
"messages": [{"role": "user", "content": message}],
"user_id": user_id,
"current_model": model,
"current_action": CurrentAction.NONE
}
return config, input_state
result = await graph.ainvoke(input_state, config=config)
async def process_message(
self, message: str, thread_id: str, model: str = "", user_id: str = "default_user"
) -> dict:
"""处理用户消息返回包含回复、token统计和耗时的字典"""
# 解析模型名称
resolved_model = self._resolve_model(model)
# 构建调用参数
config, input_state = self._build_invocation(message, thread_id, resolved_model, user_id)
result = await self.graph.ainvoke(input_state, config=config)
reply = result.get("final_result", "")
if not reply and result.get("messages"):
reply = result["messages"][-1].content
token_usage = result.get("debug_info", {}).get("token_usage", {})
elapsed_time = result.get("debug_info", {}).get("elapsed_time", 0.0)
token_usage = result.get("last_token_usage", {})
elapsed_time = result.get("last_elapsed_time", 0.0)
actual_model = result.get("current_model", resolved_model)
return {
"reply": reply,
"token_usage": token_usage,
"elapsed_time": elapsed_time
"elapsed_time": elapsed_time,
"model_used": actual_model
}
def _serialize_value(self, value):
@@ -121,31 +154,169 @@ class AIAgentService:
except (TypeError, ValueError):
return str(value)
async def process_message_stream(self, message: str, thread_id: str, model_name: str, user_id: str = "default_user"):
"""流式处理消息,返回异步生成器(全部走 React 模式)"""
graph = self.graphs.get(model_name)
if not graph:
raise ValueError(f"模型 '{model_name}' 未找到或未初始化")
async def _handle_message_chunk(
self, chunk: Dict[str, Any], current_node: Optional[str], tool_calls_in_progress: Dict[str, Any]
) -> AsyncGenerator[Dict[str, Any], None]:
"""处理 messages 类型的 chunk"""
message_chunk, metadata = chunk["data"]
node_name = metadata.get("langgraph_node", "unknown")
new_current_node = current_node
config = {
"configurable": {
"thread_id": thread_id,
"rag_tool": getattr(self, "rag_tool", None), # 注入 RAG 工具
},
"metadata": {"user_id": user_id}
}
input_state = {
"user_query": message,
"messages": [{"role": "user", "content": message}],
"user_id": user_id,
"current_action": CurrentAction.NONE
}
# 检测节点变化,发送节点开始事件
if node_name != current_node:
if current_node:
yield {"type": "node_end", "node": current_node}
yield {"type": "node_start", "node": node_name}
new_current_node = node_name
# ========== 意图识别(保留用于日志)==========
# 处理消息内容
token_content = getattr(message_chunk, 'content', str(message_chunk))
reasoning_token = ""
if hasattr(message_chunk, 'additional_kwargs'):
reasoning_token = message_chunk.additional_kwargs.get("reasoning_content", "")
# 处理思考过程
if reasoning_token:
yield {
"type": "llm_token",
"node": node_name,
"reasoning_token": reasoning_token
}
# 处理工具调用
elif hasattr(message_chunk, 'tool_calls') and message_chunk.tool_calls:
for tool_call in message_chunk.tool_calls:
tool_call_id = tool_call.get("id", "")
tool_name = tool_call.get("name", "")
tool_args = tool_call.get("args", {})
# 记录工具调用开始,避免重复
if tool_call_id and tool_call_id not in tool_calls_in_progress:
tool_calls_in_progress[tool_call_id] = {
"name": tool_name,
"args": tool_args
}
yield {
"type": "tool_call_start",
"tool": tool_name,
"args": tool_args,
"id": tool_call_id
}
# 处理普通 token
elif token_content:
yield {
"type": "llm_token",
"node": node_name,
"token": token_content,
"reasoning_token": reasoning_token
}
# 返回更新后的 current_node
yield {"type": "_update_state", "current_node": new_current_node}
async def _handle_updates_chunk(
self, chunk: Dict[str, Any], tool_calls_in_progress: Dict[str, Any], actual_model_used: str
) -> AsyncGenerator[Dict[str, Any], None]:
"""处理 updates 类型的 chunk"""
updates_data = chunk["data"]
new_actual_model = actual_model_used
debug(f"[Stream] updates 数据: {list(updates_data.keys()) if isinstance(updates_data, dict) else type(updates_data)}")
# 特别检查 final_result 和 current_model
if isinstance(updates_data, dict):
if "final_result" in updates_data:
debug(f"[Stream] 收到 final_result: {str(updates_data['final_result'])[:100]}...")
if "current_model" in updates_data:
new_actual_model = updates_data["current_model"]
info(f"[Stream] 实际使用模型: {new_actual_model}")
serialized_data = self._serialize_value(updates_data)
# 检查是否有人工审核请求
if "review_pending" in serialized_data and serialized_data["review_pending"]:
review_id = serialized_data.get("review_id", "")
content_to_review = serialized_data.get("content_to_review", "")
yield {
"type": "human_review_request",
"review_id": review_id,
"content": content_to_review
}
# 检查是否有工具结果
if "messages" in serialized_data:
for msg in serialized_data["messages"]:
# 检测工具结果消息
if msg.get("role") == "tool":
tool_call_id = msg.get("tool_call_id", "")
tool_name = msg.get("name", "")
tool_result = msg.get("content", "")
if tool_call_id and tool_call_id in tool_calls_in_progress:
yield {
"type": "tool_call_end",
"tool": tool_name,
"id": tool_call_id,
"result": tool_result
}
del tool_calls_in_progress[tool_call_id]
yield {
"type": "state_update",
"data": serialized_data
}
# 返回更新后的模型
yield {"type": "_update_state", "actual_model_used": new_actual_model}
async def _handle_custom_chunk(self, chunk: Dict[str, Any]) -> AsyncGenerator[Dict[str, Any], None]:
"""处理 custom 类型的 chunk"""
custom_data = chunk["data"]
# 处理我们从 react_reason_node 发送的自定义推理事件
if isinstance(custom_data, dict):
# 检查是否是我们的推理事件
if "action" in custom_data and "reasoning" in custom_data:
yield {
"type": "react_reasoning",
"step": custom_data.get("step", 1),
"action": custom_data.get("action", "unknown"),
"confidence": custom_data.get("confidence", 0),
"reasoning": custom_data.get("reasoning", "")
}
else:
# 处理其他自定义事件
serialized_data = self._serialize_value(custom_data)
yield {
"type": "custom",
"data": serialized_data
}
else:
# 处理其他自定义事件
serialized_data = self._serialize_value(custom_data)
yield {
"type": "custom",
"data": serialized_data
}
async def process_message_stream(
self, message: str, thread_id: str, model: str = "", user_id: str = "default_user"
) -> AsyncGenerator[Dict[str, Any], None]:
"""流式处理消息,返回异步生成器"""
# 解析模型名称
resolved_model = self._resolve_model(model)
# 构建调用参数
config, input_state = self._build_invocation(message, thread_id, resolved_model, user_id)
# ========== 意图识别(保留用于日志和后续路由)==========
intent_result = await self.intent_classifier.classify(message)
info(f"🧠 意图识别: {intent_result.intent_type} (置信度: {intent_result.confidence:.2f})")
info(f"📝 推理: {intent_result.reasoning}")
# 注入意图到状态(让 hybrid_router 可以利用)
input_state["intent_type"] = intent_result.intent_type.value
input_state["intent_confidence"] = intent_result.confidence
# 发送意图分类事件
yield {
"type": "intent_classified",
@@ -154,25 +325,26 @@ class AIAgentService:
"reasoning": intent_result.reasoning
}
# 发送路径决策事件(现在都是 react_loop
# 发送路径决策事件(目前硬编码,但状态中有意图信息供后续使用
yield {
"type": "path_decision",
"path": "react_loop",
"intent": intent_result.intent_type.value
}
# ========================================
# =============================================
# ========== React 循环路径 ==========
info(f"🚀 开始执行 React 图,模型: {model_name}")
info(f"🚀 开始执行单图,指定模型: {resolved_model}")
current_node = None
tool_calls_in_progress = {}
tool_calls_in_progress: Dict[str, Any] = {}
actual_model_used = resolved_model
chunk_count = 0
full_message_content = ""
try:
info(f"📡 开始调用 graph.astream()...")
chunk_count = 0
full_message_content = "" # 收集完整消息内容
async for chunk in graph.astream(
async for chunk in self.graph.astream(
input_state,
config=config,
stream_mode=["messages", "updates", "custom"],
@@ -181,156 +353,58 @@ class AIAgentService:
):
chunk_count += 1
chunk_type = chunk["type"]
processed_event = {}
if chunk_type == "messages":
message_chunk, metadata = chunk["data"]
node_name = metadata.get("langgraph_node", "unknown")
# 检测节点变化,发送节点开始事件
if node_name != current_node:
if current_node:
yield {
"type": "node_end",
"node": current_node
}
yield {
"type": "node_start",
"node": node_name
}
current_node = node_name
# 处理消息内容
token_content = getattr(message_chunk, 'content', str(message_chunk))
reasoning_token = ""
if hasattr(message_chunk, 'additional_kwargs'):
reasoning_token = message_chunk.additional_kwargs.get("reasoning_content", "")
# 处理思考过程
if reasoning_token:
processed_event = {
"type": "llm_token",
"node": node_name,
"reasoning_token": reasoning_token
}
# 处理工具调用
elif hasattr(message_chunk, 'tool_calls') and message_chunk.tool_calls:
for tool_call in message_chunk.tool_calls:
tool_call_id = tool_call.get("id", "")
tool_name = tool_call.get("name", "")
tool_args = tool_call.get("args", {})
# 记录工具调用开始
if tool_call_id not in tool_calls_in_progress:
tool_calls_in_progress[tool_call_id] = {
"name": tool_name,
"args": tool_args
}
yield {
"type": "tool_call_start",
"tool": tool_name,
"args": tool_args,
"id": tool_call_id
}
# 处理普通 token - 只收集,不打印单个 token
elif token_content:
processed_event = {
"type": "llm_token",
"node": node_name,
"token": token_content,
"reasoning_token": reasoning_token
}
if node_name == "llm_call":
full_message_content += token_content
async for event in self._handle_message_chunk(
chunk, current_node, tool_calls_in_progress
):
if event.get("type") == "_update_state":
current_node = event.get("current_node", current_node)
else:
# 如果是 llm_call 节点的 token收集完整消息
if (
event.get("type") == "llm_token"
and event.get("node") == "llm_call"
and "token" in event
):
full_message_content += event["token"]
yield event
elif chunk_type == "updates":
updates_data = chunk["data"]
info(f"[Stream] updates 数据: {list(updates_data.keys()) if isinstance(updates_data, dict) else type(updates_data)}")
# 特别检查 final_result
if isinstance(updates_data, dict) and "final_result" in updates_data:
info(f"[Stream] 收到 final_result: {str(updates_data['final_result'])[:100]}...")
serialized_data = self._serialize_value(updates_data)
if "review_pending" in serialized_data and serialized_data["review_pending"]:
review_id = serialized_data.get("review_id", "")
content_to_review = serialized_data.get("content_to_review", "")
yield {
"type": "human_review_request",
"review_id": review_id,
"content": content_to_review
}
# 检查是否有工具结果
if "messages" in serialized_data:
for msg in serialized_data["messages"]:
# 检测工具结果消息
if msg.get("role") == "tool":
tool_call_id = msg.get("tool_call_id", "")
tool_name = msg.get("name", "")
tool_output = msg.get("content", "")
if tool_call_id in tool_calls_in_progress:
yield {
"type": "tool_call_end",
"tool": tool_name,
"id": tool_call_id,
"result": tool_output
}
del tool_calls_in_progress[tool_call_id]
processed_event = {
"type": "state_update",
"data": serialized_data
}
async for event in self._handle_updates_chunk(
chunk, tool_calls_in_progress, actual_model_used
):
if event.get("type") == "_update_state":
actual_model_used = event.get("actual_model_used", actual_model_used)
else:
yield event
elif chunk_type == "custom":
custom_data = chunk["data"]
# 处理我们从 react_reason_node 发送的自定义推理事件
if isinstance(custom_data, dict):
# 检查是否是我们的推理事件
if "action" in custom_data and "reasoning" in custom_data:
yield {
"type": "react_reasoning",
"step": custom_data.get("step", 1),
"action": custom_data.get("action", "unknown"),
"confidence": custom_data.get("confidence", 0),
"reasoning": custom_data.get("reasoning", "")
}
else:
# 处理其他自定义事件
serialized_data = self._serialize_value(custom_data)
processed_event = {
"type": "custom",
"data": serialized_data
}
else:
# 处理其他自定义事件
serialized_data = self._serialize_value(custom_data)
processed_event = {
"type": "custom",
"data": serialized_data
}
if processed_event:
yield processed_event
async for event in self._handle_custom_chunk(chunk):
yield event
# 完整消息集合完成后,一次性打印
info(f"✅ graph.astream() 完成,共 {chunk_count} 个chunks")
info(f"✅ graph.astream() 完成,共 {chunk_count} chunks")
if full_message_content:
info(f"📄 完整消息内容: {repr(full_message_content)}")
info(f"🤖 实际使用模型: {actual_model_used}")
except Exception as e:
error(f"❌ 执行 React 图时出错: {e}")
error(f"❌ 执行图时出错: {e}")
import traceback
error(f"📋 堆栈: {traceback.format_exc()}")
raise
# 发送结束事件
if current_node:
yield {
"type": "node_end",
"node": current_node
"type": "error",
"message": str(e)
}
finally:
# 无论成功或失败,都发送结束事件,保证前端平稳关闭
if current_node:
yield {
"type": "node_end",
"node": current_node
}
yield {
"type": "done",
"model_used": actual_model_used
}
yield {
"type": "done"
}

View File

@@ -22,9 +22,8 @@ def create_system_prompt(tools: list = None) -> ChatPromptTemplate:
"3. 📇 通讯录子系统 - 查询联系人、添加联系人、管理通讯录\n"
"4. 🔍 RAG检索 - 从知识库中检索相关信息回答问题\n\n"
"【用户背景信息】\n"
"以下是对当前用户的已知信息和长期记忆,你必须优先采纳并在回答中体现\n"
"以下是对当前用户的已知信息和长期记忆,你必须优先采纳:\n"
"{memory_context}\n"
"若包含姓名、偏好等个人信息,请自然融入回应(例如称呼名字、提及偏好)。\n\n"
"【可用工具与使用规则】\n"
f"{tools_section}\n"
"工具调用时请直接返回所需参数,无需额外说明。\n\n"

View File

@@ -127,6 +127,13 @@ BACKEND_PORT = _get_int("BACKEND_PORT")
MEMORY_SUMMARIZE_INTERVAL = _get_int("MEMORY_SUMMARIZE_INTERVAL")
# ========== Tavily 搜索配置 ==========
# Tavily APIhttps://app.tavily.com
# 免费额度1000次/天
TAVILY_API_KEY = _get_str("TAVILY_API_KEY")
TAVILY_MAX_RESULTS = _get_int("TAVILY_MAX_RESULTS") or 5
# ========== Graph 执行追踪配置 ==========
# 是否启用 Graph 流转追踪(通过环境变量控制)
ENABLE_GRAPH_TRACE = _get_bool("ENABLE_GRAPH_TRACE")

View File

@@ -33,18 +33,29 @@ class WebSearchTool:
def search(self, query: str, max_results: Optional[int] = None) -> List[SearchResult]:
"""
使用多种方式搜索
使用多种方式搜索,按优先级尝试
Args:
query: 搜索关键词
max_results: 返回结果数量,默认使用初始化时的设置
Returns:
搜索结果列表
"""
num_results = max_results or self.max_results
# 方式 1: 尝试用 ddgs 包
# 方式 1: Tavily (需要 API Key质量最高)
try:
return self._search_tavily(query, num_results)
except ImportError:
print("[WebSearch] tavily 未安装,尝试其他搜索方式")
except Exception as e:
if "API_KEY" in str(e) or "未配置" in str(e):
print(f"[WebSearch] Tavily API Key 未配置: {e}")
else:
print(f"[WebSearch] Tavily 搜索失败: {e}")
# 方式 2: 尝试用 ddgs 包
try:
from ddgs import DDGS
print(f"[WebSearch] 使用 ddgs 搜索: {query}")
@@ -65,29 +76,7 @@ class WebSearchTool:
print("[WebSearch] ddgs 未安装,尝试 duckduckgo-search")
except Exception as e:
print(f"[WebSearch] ddgs 搜索失败: {e}")
# 方式 2: 尝试用旧的 duckduckgo-search 包
try:
from duckduckgo_search import DDGS
print(f"[WebSearch] 使用 duckduckgo-search 搜索: {query}")
with DDGS() as ddgs:
results = list(ddgs.text(query, max_results=num_results))
if results:
search_results = []
for r in results:
search_results.append(SearchResult(
title=r.get("title", ""),
url=r.get("href", ""),
snippet=r.get("body", ""),
source="DuckDuckGo"
))
print(f"[WebSearch] duckduckgo-search 返回 {len(search_results)} 条结果")
return search_results
except ImportError:
print("[WebSearch] duckduckgo-search 未安装")
except Exception as e:
print(f"[WebSearch] duckduckgo-search 搜索失败: {e}")
# 方式 3: 尝试用简单 HTTP 请求
try:
return self._search_http(query, num_results)
@@ -97,6 +86,34 @@ class WebSearchTool:
# 方式 4: 返回模拟数据作为最后兜底
return self._search_mock(query, num_results)
def _search_tavily(self, query: str, max_results: int) -> List[SearchResult]:
"""使用 Tavily API 搜索"""
from tavily import TavilyClient
from app.config import TAVILY_API_KEY, TAVILY_MAX_RESULTS
if not TAVILY_API_KEY:
raise ValueError("TAVILY_API_KEY 未配置")
client = TavilyClient(api_key=TAVILY_API_KEY)
response = client.search(
query=query,
max_results=min(max_results, TAVILY_MAX_RESULTS or 5),
include_answer=True,
include_raw_content=False
)
results = []
for item in response.get("results", []):
results.append(SearchResult(
title=item.get("title", ""),
url=item.get("url", ""),
snippet=item.get("content", ""),
source="Tavily"
))
print(f"[WebSearch] Tavily 返回 {len(results)} 条结果")
return results
def _search_http(self, query: str, max_results: int) -> List[SearchResult]:
"""用简单 HTTP 请求搜索(备用方案)- 尝试多个国内源"""
print(f"[WebSearch] 尝试 HTTP 搜索")

View File

@@ -1,8 +0,0 @@
"""
LangGraph 核心组件重新导出
统一导入入口,避免直接依赖 langgraph
"""
from langgraph.graph import StateGraph, START, END, add_messages
__all__ = ["StateGraph", "START", "END", "add_messages"]

View File

@@ -0,0 +1,229 @@
"""
主图构建器 - 构建整合后的完整主图
"""
from langgraph.graph import StateGraph, START, END
from typing import Dict, Any
from .state import MainGraphState
from .nodes.reasoning import react_reason_node
from .nodes.web_search import web_search_node
from .nodes.error_handling import error_handling_node
from .nodes.routing import init_state_node, route_by_reasoning, should_summarize
from .nodes.hybrid_router import (
hybrid_router_node,
route_from_hybrid_decision,
check_fast_path_success,
)
from .nodes.fast_paths import (
fast_chitchat_node,
fast_rag_node,
fast_tool_node,
)
from .nodes.llm_call import create_dynamic_llm_call_node
from .nodes.rag_nodes import rag_retrieve_node
from .nodes.retrieve_memory import create_retrieve_memory_node
from .nodes.memory_trigger import memory_trigger_node, set_mem0_client
from .nodes.summarize import create_summarize_node
from .nodes.finalize import finalize_node
from ..subgraphs.contact import build_contact_subgraph
from ..subgraphs.dictionary import build_dictionary_subgraph
from ..subgraphs.news_analysis import build_news_analysis_subgraph
from ..logger import info
from .subgraph_wrapper import create_subgraph_nodes
# ========== 主图构建 ==========
def build_react_main_graph(
chat_services: dict,
tools=None,
mem0_client=None,
use_hybrid_router: bool = True
) -> StateGraph:
"""
构建整合后的完整主图(支持混合路由 + 动态模型选择)
Args:
chat_services: 模型名称 -> ChatModel 实例 的字典
tools: 工具列表
mem0_client: Mem0 客户端实例
use_hybrid_router: 是否使用混合路由(快速路径 + React 循环)
Returns:
StateGraph: 构建好的图
"""
# 创建图
graph = StateGraph(MainGraphState)
# 设置全局 mem0_client
if mem0_client:
set_mem0_client(mem0_client)
# ========== 创建节点 ==========
# LLM 调用节点
llm_node = create_dynamic_llm_call_node(chat_services, tools or [])
# 记忆节点
retrieve_memory_node = None
summarize_node = None
if mem0_client:
retrieve_memory_node = create_retrieve_memory_node(mem0_client)
summarize_node = create_summarize_node(mem0_client)
# 子图节点
contact_graph = build_contact_subgraph()
dictionary_graph = build_dictionary_subgraph()
news_analysis_graph = build_news_analysis_subgraph()
subgraph_nodes = create_subgraph_nodes(
contact_graph, dictionary_graph, news_analysis_graph
)
# ========== 添加节点到图 ==========
# 阶段 1: 记忆检索
if retrieve_memory_node:
graph.add_node("retrieve_memory", retrieve_memory_node)
graph.add_node("memory_trigger", memory_trigger_node)
# 阶段 2: 初始化
graph.add_node("init_state", init_state_node)
# 阶段 3: 混合路由(可选)
if use_hybrid_router:
graph.add_node("hybrid_router", hybrid_router_node)
graph.add_node("fast_chitchat", fast_chitchat_node)
graph.add_node("fast_rag", fast_rag_node)
graph.add_node("fast_tool", fast_tool_node)
# 阶段 4: React 循环推理(始终保留)
graph.add_node("react_reason", react_reason_node)
graph.add_node("rag_retrieve", rag_retrieve_node)
graph.add_node("web_search", web_search_node)
graph.add_node("handle_error", error_handling_node)
if llm_node is not None:
graph.add_node("llm_call", llm_node)
# 子图节点
for node_name, node_func in subgraph_nodes.items():
graph.add_node(node_name, node_func)
# 阶段 5: 完成处理
if summarize_node:
graph.add_node("summarize", summarize_node)
graph.add_node("finalize", finalize_node)
# ========== 添加边 ==========
# 阶段 1: 记忆检索
_add_memory_edges(graph, retrieve_memory_node)
# 阶段 2: 初始化
graph.add_edge("memory_trigger", "init_state")
# 阶段 3: 路由分支
_add_routing_edges(graph, use_hybrid_router, llm_node)
# 阶段 4: React 循环边
_add_react_loop_edges(graph, subgraph_nodes)
# 阶段 5: 完成阶段
_add_finalize_edges(graph, llm_node, summarize_node)
info(f"✅ [图构建] 整合后的完整主图构建完成(混合路由: {use_hybrid_router}")
return graph
def _add_memory_edges(graph: StateGraph, retrieve_memory_node) -> None:
"""添加记忆检索阶段的边"""
if retrieve_memory_node:
graph.add_edge(START, "retrieve_memory")
graph.add_edge("retrieve_memory", "memory_trigger")
else:
graph.add_edge(START, "memory_trigger")
def _add_routing_edges(graph: StateGraph, use_hybrid_router: bool, llm_node) -> None:
"""添加路由阶段的边"""
if use_hybrid_router:
graph.add_edge("init_state", "hybrid_router")
# 混合路由条件分支
graph.add_conditional_edges(
"hybrid_router",
route_from_hybrid_decision,
{
"fast_chitchat": "fast_chitchat",
"fast_rag": "fast_rag",
"fast_tool": "fast_tool",
"react_loop": "react_reason"
}
)
# 快速路径的完成检查
for fast_node in ["fast_chitchat", "fast_rag", "fast_tool"]:
graph.add_conditional_edges(
fast_node,
check_fast_path_success,
{
"llm_call": "llm_call",
"escalate": "react_reason"
}
)
info(f"✅ [图构建] 混合路由模式已启用")
else:
graph.add_edge("init_state", "react_reason")
info(f"✅ [图构建] 纯 React 模式")
def _add_react_loop_edges(graph: StateGraph, subgraph_nodes: Dict[str, Any]) -> None:
"""添加 React 循环阶段的边"""
subgraph_names = list(subgraph_nodes.keys())
# React 推理的条件分支
graph.add_conditional_edges(
"react_reason",
route_by_reasoning,
{
"rag_retrieve": "rag_retrieve",
"web_search": "web_search",
**{name: name for name in subgraph_names},
"handle_error": "handle_error",
"llm_call": "llm_call"
}
)
# 循环边(回到 react_reason
loop_back_nodes = ["rag_retrieve", "web_search", "handle_error"] + subgraph_names
for node_name in loop_back_nodes:
graph.add_edge(node_name, "react_reason")
def _add_finalize_edges(graph: StateGraph, llm_node, summarize_node) -> None:
"""添加完成阶段的边"""
if llm_node is not None:
if summarize_node:
graph.add_conditional_edges(
"llm_call",
should_summarize,
{
"summarize": "summarize",
"finalize": "finalize"
}
)
graph.add_edge("summarize", "finalize")
else:
graph.add_edge("llm_call", "finalize")
graph.add_edge("finalize", END)
# ========== 导出 ==========
__all__ = [
"build_react_main_graph",
]

View File

@@ -6,8 +6,8 @@
from .reasoning import react_reason_node
from .web_search import web_search_node
from .error_handling import error_handling_node
from .routing import init_state_node, route_by_reasoning
from .llm_call import create_llm_call_node
from .routing import init_state_node, route_by_reasoning, should_summarize
from .llm_call import create_dynamic_llm_call_node
from .rag_nodes import rag_retrieve_node, rag_re_retrieve_node
# 记忆节点
@@ -38,7 +38,8 @@ __all__ = [
"web_search_node",
"error_handling_node",
"route_by_reasoning",
"create_llm_call_node",
"should_summarize",
"create_dynamic_llm_call_node",
"rag_retrieve_node",
"rag_re_retrieve_node",
# 记忆节点

View File

@@ -5,7 +5,7 @@ LLM 调用节点模块
import time
from typing import Any, Dict
from langchain_core.language_models import BaseLLM
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage
# 本地模块
@@ -14,29 +14,34 @@ from app.agent.prompts import create_system_prompt
from app.utils.logging import log_state_change
from app.logger import debug, info, error
def create_llm_call_node(llm, tools: list):
def create_dynamic_llm_call_node(chat_services: Dict[str, BaseChatModel], tools: list):
"""
工厂函数:创建 LLM 调用节点
工厂函数:创建动态 LLM 调用节点(根据 state.current_model 选择模型)
Args:
llm: LangChain LLM 实例
chat_services: 模型名称 -> ChatModel 实例 的字典
tools: 工具列表
Returns:
异步节点函数
"""
# 构建调用链
# 构建所有模型的 tools 绑定(避免每次调用都 bind
bound_models: Dict[str, Any] = {}
for name, llm in chat_services.items():
if tools:
bound_models[name] = llm.bind_tools(tools)
else:
bound_models[name] = llm
# 预构建 prompt
prompt = create_system_prompt(tools)
llm_with_tools = llm.bind_tools(tools)
# 恢复带 RunnableLambda 的链,并在下方使用 astream 遍历
chain = prompt | llm_with_tools
from langchain_core.runnables.config import RunnableConfig
async def call_llm(state: MainGraphState, config: RunnableConfig) -> Dict[str, Any]:
"""
LLM 调用节点(异步方法
LLM 调用节点(动态选择模型
Args:
state: 当前对话状态
@@ -46,7 +51,7 @@ def create_llm_call_node(llm, tools: list):
更新后的状态字典
"""
log_state_change("llm_call", state, "进入")
memory_context = getattr(state, "memory_context", "暂无用户信息")
start_time = time.time()
@@ -62,9 +67,20 @@ def create_llm_call_node(llm, tools: list):
"last_elapsed_time": elapsed_time,
"turns_since_last_summary": getattr(state, 'turns_since_last_summary', 0) + 1,
}
# 动态选择模型
model_name = getattr(state, "current_model", "")
if not model_name or model_name not in bound_models:
# 回退到第一个可用模型
fallback_name = next(iter(bound_models.keys()))
info(f"[llm_call] 模型 '{model_name}' 不可用,回退到 '{fallback_name}'")
model_name = fallback_name
llm_with_tools = bound_models[model_name]
info(f"[llm_call] 使用模型: {model_name}")
try:
# 添加 RAG 上下文到消息
# 添加上下文到消息
messages_with_context = list(state.messages)
if state.rag_context:
from langchain_core.messages import SystemMessage
@@ -77,9 +93,10 @@ def create_llm_call_node(llm, tools: list):
break
if not inserted:
messages_with_context.insert(0, rag_system_msg)
# 恢复为:手动进行 astream并将所有的 chunk 拼接成最终的 response 返回。
# LangGraph 会自动监听这期间产生的所有 token。
chain = prompt | llm_with_tools
chunks = []
async for chunk in chain.astream(
{
@@ -89,7 +106,7 @@ def create_llm_call_node(llm, tools: list):
config=config
):
chunks.append(chunk)
# 将所有 chunk 合并成最终的 AIMessage
if chunks:
response = chunks[0]
@@ -97,14 +114,14 @@ def create_llm_call_node(llm, tools: list):
response = response + chunk
else:
response = AIMessage(content="")
elapsed_time = time.time() - start_time
# 提取 token 用量(兼容不同 LLM 提供商的元数据格式)
token_usage = {}
input_tokens = 0
output_tokens = 0
# 尝试从 response_metadata 中提取
if hasattr(response, 'response_metadata') and response.response_metadata:
meta = response.response_metadata
@@ -112,33 +129,33 @@ def create_llm_call_node(llm, tools: list):
token_usage = meta['token_usage']
elif 'usage' in meta:
token_usage = meta['usage']
# 尝试从 additional_kwargs 中提取
if not token_usage and hasattr(response, 'additional_kwargs'):
add_kwargs = response.additional_kwargs
if 'llm_output' in add_kwargs and 'token_usage' in add_kwargs['llm_output']:
token_usage = add_kwargs['llm_output']['token_usage']
# 提取具体的 token 数值
if token_usage:
input_tokens = token_usage.get('prompt_tokens', token_usage.get('input_tokens', 0))
output_tokens = token_usage.get('completion_tokens', token_usage.get('output_tokens', 0))
# 打印 LLM 的完整输出
debug("\n" + "="*80)
debug("📥 [LLM输出] 模型返回的完整响应:")
debug(f"📥 [LLM输出] 模型: {model_name} 返回的完整响应:")
debug(f" 消息类型: {response.type.upper()}")
debug(f" 内容长度: {len(str(response.content))} 字符")
debug("-"*80)
debug(f"{response.content}")
# 打印响应统计信息
info(f"⏱️ [LLM统计] 调用耗时: {elapsed_time:.2f}")
info(f"📊 [LLM统计] Token用量: 输入={input_tokens}, 输出={output_tokens}, 总计={input_tokens + output_tokens}")
if token_usage:
debug(f"📋 [LLM统计] 详细用量: {token_usage}")
debug("="*80 + "\n")
# 检查是否有工具调用
has_tool_calls = hasattr(response, 'tool_calls') and len(response.tool_calls) > 0
@@ -151,21 +168,22 @@ def create_llm_call_node(llm, tools: list):
"final_result": response.content,
"success": True,
"current_phase": "done",
"has_tool_calls": has_tool_calls
"has_tool_calls": has_tool_calls,
"current_model": model_name # 记录实际使用的模型
}
log_state_change("llm_call", state, "离开")
return result
except Exception as e:
elapsed_time = time.time() - start_time
error(f"\n❌ [LLM错误] 调用失败 (耗时: {elapsed_time:.2f}秒)")
error(f"\n❌ [LLM错误] 模型 {model_name} 调用失败 (耗时: {elapsed_time:.2f}秒)")
error(f" 错误类型: {type(e).__name__}")
error(f" 错误信息: {str(e)}")
import traceback
traceback.print_exc()
error(f"📋 堆栈: {traceback.format_exc()}")
debug("="*80 + "\n")
# 返回一个友好的错误消息
error_response = AIMessage(
content="抱歉,模型暂时无法响应,可能是网络超时或服务繁忙,请稍后再试。"
@@ -178,10 +196,11 @@ def create_llm_call_node(llm, tools: list):
"turns_since_last_summary": getattr(state, 'turns_since_last_summary', 0) + 1,
"final_result": "抱歉,模型暂时无法响应,可能是网络超时或服务繁忙,请稍后再试。",
"success": False,
"current_phase": "done"
"current_phase": "done",
"current_model": model_name
}
log_state_change("llm_call", state, "离开(异常)")
return error_result
return call_llm
return call_llm

View File

@@ -118,3 +118,21 @@ def route_by_reasoning(state: MainGraphState) -> str:
info(f"[条件路由] 动作={latest_action}, 目标={target}")
return target
# ========== 完成阶段条件路由函数 ==========
def should_summarize(state: MainGraphState) -> str:
"""
检查是否需要总结对话(对话足够长时)
Args:
state: 当前图状态
Returns:
"summarize""finalize"
"""
if state.turns_since_last_summary >= 5: # 每5轮对话总结一次
return "summarize"
else:
return "finalize"

View File

@@ -6,7 +6,7 @@ Main Graph State Definition - React Mode Enhanced
from enum import Enum, auto
from typing import Optional, Dict, Any, Annotated, Sequence, TypedDict, List
from dataclasses import dataclass, field
from app.main_graph.graph import add_messages
from langgraph.graph import add_messages
from langchain_core.messages import BaseMessage
@@ -62,6 +62,7 @@ class MainGraphState:
# ========== 主图控制字段 ==========
user_query: str = ""
current_action: CurrentAction = CurrentAction.NONE
current_model: str = "" # 新增:本次请求使用的模型
intent_confidence: float = 0.0
# ========== React 推理专用字段 ==========

View File

@@ -0,0 +1,159 @@
"""
子图包装器 - 为子图添加错误处理和事件追踪
"""
from typing import Dict, Any, Optional
from datetime import datetime
from .state import MainGraphState, ErrorRecord, ErrorSeverity
from ..logger import info
def wrap_subgraph_for_error_handling(subgraph, name: str):
"""
包装子图,使其错误能传递给主图
Args:
subgraph: 编译好的子图
name: 子图名称(用于错误标识)
Returns: 包装后的节点函数
"""
async def wrapped_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState:
# 发送子图开始事件
if config:
try:
from langchain_core.callbacks.manager import adispatch_custom_event
callbacks = config.get("callbacks")
if callbacks:
await adispatch_custom_event(
"react_reasoning",
{
"step": state.reasoning_step,
"action": f"{name}_subgraph_start",
"confidence": 1.0,
"reasoning": f"开始执行 {name} 子图..."
},
callbacks=callbacks
)
except Exception as e:
info(f"[{name}_subgraph] 无法发送开始事件: {e}")
try:
# 调用子图
result = subgraph.invoke(state)
# 更新主图状态
subgraph_result = None
if name == "contact":
state.contact_result = result
subgraph_result = result.get("final_result", "")
elif name == "dictionary":
state.dictionary_result = result
subgraph_result = result.get("final_result", "")
elif name == "news_analysis":
state.news_result = result
subgraph_result = result.get("final_result", "")
# 设置最终结果
if subgraph_result:
state.final_result = subgraph_result
else:
state.final_result = "子图执行完成"
# 标记成功
state.success = True
state.current_phase = "done"
state.reasoning_history.append({
"step": state.reasoning_step,
"action": "subgraph_completed",
"confidence": 1.0,
"reasoning": f"{name}子图执行完成",
"timestamp": datetime.now().isoformat()
})
# 发送子图完成事件
if config:
try:
from langchain_core.callbacks.manager import adispatch_custom_event
callbacks = config.get("callbacks")
if callbacks:
await adispatch_custom_event(
"react_reasoning",
{
"step": state.reasoning_step,
"action": f"{name}_subgraph_complete",
"confidence": 1.0,
"reasoning": f"{name} 子图执行完成"
},
callbacks=callbacks
)
except Exception as e:
info(f"[{name}_subgraph] 无法发送完成事件: {e}")
return state
except Exception as e:
# 捕获子图错误,传递给主图
error_record = ErrorRecord(
error_type=f"{name}SubgraphError",
error_message=str(e),
severity=ErrorSeverity.WARNING,
source=f"{name}_subgraph",
timestamp=datetime.now().isoformat(),
retry_count=0,
max_retries=1,
context={"user_query": state.user_query}
)
state.errors.append(error_record)
state.current_error = error_record
state.current_phase = "error_handling"
state.success = False
# 发送子图错误事件
if config:
try:
from langchain_core.callbacks.manager import adispatch_custom_event
callbacks = config.get("callbacks")
if callbacks:
await adispatch_custom_event(
"react_reasoning",
{
"step": state.reasoning_step,
"action": f"{name}_subgraph_error",
"confidence": 1.0,
"reasoning": f"{name} 子图执行失败: {str(e)}"
},
callbacks=callbacks
)
except Exception as e:
info(f"[{name}_subgraph] 无法发送错误事件: {e}")
return state
return wrapped_node
def create_subgraph_nodes(contact_graph, dictionary_graph, news_analysis_graph) -> Dict[str, Any]:
"""
创建所有子图节点的字典
Args:
contact_graph: 联系人子图
dictionary_graph: 词典子图
news_analysis_graph: 新闻分析子图
Returns:
子图节点字典 {name: wrapped_node}
"""
return {
"contact_subgraph": wrap_subgraph_for_error_handling(
contact_graph.compile(), "contact"
),
"dictionary_subgraph": wrap_subgraph_for_error_handling(
dictionary_graph.compile(), "dictionary"
),
"news_analysis_subgraph": wrap_subgraph_for_error_handling(
news_analysis_graph.compile(), "news_analysis"
),
}

View File

@@ -1 +0,0 @@
"""主图工具函数"""

View File

@@ -1,371 +0,0 @@
"""
整合后的完整主图构建器 - 所有节点都直接操作 MainGraphState
"""
from ..graph import StateGraph, START, END
from typing import Dict, Any, Optional
from langchain_core.runnables.config import RunnableConfig
from ..state import MainGraphState
from ..nodes.reasoning import react_reason_node
from ..nodes.web_search import web_search_node
from ..nodes.error_handling import error_handling_node
from ..nodes.routing import init_state_node, route_by_reasoning
from ..nodes.hybrid_router import (
hybrid_router_node,
route_from_hybrid_decision,
check_fast_path_success,
)
from ..nodes.fast_paths import (
fast_chitchat_node,
fast_rag_node,
fast_tool_node,
)
from ..nodes.llm_call import create_llm_call_node
from ..nodes.rag_nodes import rag_retrieve_node
from ..nodes.retrieve_memory import create_retrieve_memory_node
from ..nodes.memory_trigger import memory_trigger_node, set_mem0_client
from ..nodes.summarize import create_summarize_node
from ..nodes.finalize import finalize_node
from ...subgraphs.contact import build_contact_subgraph
from ...subgraphs.dictionary import build_dictionary_subgraph
from ...subgraphs.news_analysis import build_news_analysis_subgraph
from ...memory.mem0_client import Mem0Client
from ...logger import info, debug
# ========== 检查是否需要总结 ==========
def should_summarize(state: MainGraphState) -> str:
"""
检查是否需要总结对话(对话足够长时)
Args:
state: 当前图状态
Returns:
"summarize""finalize"
"""
if state.turns_since_last_summary >= 5: # 每5轮对话总结一次
return "summarize"
else:
return "finalize"
# ========== 子图包装器(处理子图错误传递)==========
def wrap_subgraph_for_error_handling(subgraph, name: str):
"""
包装子图,使其错误能传递给主图
Args:
subgraph: 编译好的子图
name: 子图名称(用于错误标识)
Returns: 包装后的节点函数
"""
async def wrapped_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState:
# 发送子图开始事件
if config:
try:
from langchain_core.callbacks.manager import adispatch_custom_event
callbacks = config.get("callbacks")
if callbacks:
await adispatch_custom_event(
"react_reasoning",
{
"step": state.reasoning_step,
"action": f"{name}_subgraph_start",
"confidence": 1.0,
"reasoning": f"开始执行 {name} 子图..."
},
callbacks=callbacks
)
except Exception as e:
info(f"[{name}_subgraph] 无法发送开始事件: {e}")
try:
# 调用子图
result = subgraph.invoke(state)
# 更新主图状态
subgraph_result = None
if name == "contact":
state.contact_result = result
subgraph_result = result.get("final_result", "")
elif name == "dictionary":
state.dictionary_result = result
subgraph_result = result.get("final_result", "")
elif name == "news_analysis":
state.news_result = result
subgraph_result = result.get("final_result", "")
# 关键:设置最终结果
if subgraph_result:
state.final_result = subgraph_result
else:
state.final_result = "子图执行完成"
# 标记成功
state.success = True
state.current_phase = "done"
# 标记不再需要推理,避免循环
state.reasoning_history.append({
"step": state.reasoning_step,
"action": "subgraph_completed",
"confidence": 1.0,
"reasoning": f"{name}子图执行完成",
"timestamp": datetime.now().isoformat()
})
# 发送子图完成事件
if config:
try:
from langchain_core.callbacks.manager import adispatch_custom_event
callbacks = config.get("callbacks")
if callbacks:
await adispatch_custom_event(
"react_reasoning",
{
"step": state.reasoning_step,
"action": f"{name}_subgraph_complete",
"confidence": 1.0,
"reasoning": f"{name} 子图执行完成"
},
callbacks=callbacks
)
except Exception as e:
info(f"[{name}_subgraph] 无法发送完成事件: {e}")
return state
except Exception as e:
# 捕获子图错误,传递给主图
from ..state import ErrorRecord, ErrorSeverity
from datetime import datetime
error_record = ErrorRecord(
error_type=f"{name}SubgraphError",
error_message=str(e),
severity=ErrorSeverity.WARNING,
source=f"{name}_subgraph",
timestamp=datetime.now().isoformat(),
retry_count=0,
max_retries=1,
context={"user_query": state.user_query}
)
state.errors.append(error_record)
state.current_error = error_record
state.current_phase = "error_handling"
state.success = False
# 发送子图错误事件
if config:
try:
from langchain_core.callbacks.manager import adispatch_custom_event
callbacks = config.get("callbacks")
if callbacks:
await adispatch_custom_event(
"react_reasoning",
{
"step": state.reasoning_step,
"action": f"{name}_subgraph_error",
"confidence": 1.0,
"reasoning": f"{name} 子图执行失败: {str(e)}"
},
callbacks=callbacks
)
except Exception as e:
info(f"[{name}_subgraph] 无法发送错误事件: {e}")
return state
return wrapped_node
# ========== 主图构建 ==========
def build_react_main_graph(llm=None, tools=None, mem0_client=None, use_hybrid_router: bool = True) -> StateGraph:
"""
构建整合后的完整主图(支持混合路由)
Args:
llm: LangChain ChatModel 实例
tools: 工具列表
mem0_client: Mem0 客户端实例
use_hybrid_router: 是否使用混合路由(快速路径 + React 循环)
Returns:
StateGraph: 构建好的图
"""
# 创建图
graph = StateGraph(MainGraphState)
# 设置全局 mem0_client
if mem0_client:
set_mem0_client(mem0_client)
# 创建节点
llm_node = None
if llm is not None:
llm_node = create_llm_call_node(llm, tools or [])
retrieve_memory_node = None
summarize_node = None
if mem0_client:
retrieve_memory_node = create_retrieve_memory_node(mem0_client)
summarize_node = create_summarize_node(mem0_client)
# ========== 添加节点 ==========
# 第一阶段:记忆检索
if retrieve_memory_node:
graph.add_node("retrieve_memory", retrieve_memory_node)
graph.add_node("memory_trigger", memory_trigger_node)
# 第二阶段:初始化
graph.add_node("init_state", init_state_node)
# ========== 混合路由节点(如果启用) ==========
if use_hybrid_router:
graph.add_node("hybrid_router", hybrid_router_node)
graph.add_node("fast_chitchat", fast_chitchat_node)
graph.add_node("fast_rag", fast_rag_node)
graph.add_node("fast_tool", fast_tool_node)
# 第三阶段React 循环推理(始终保留)
graph.add_node("react_reason", react_reason_node)
graph.add_node("rag_retrieve", rag_retrieve_node)
graph.add_node("web_search", web_search_node)
graph.add_node("handle_error", error_handling_node)
if llm_node is not None:
graph.add_node("llm_call", llm_node)
# 子图节点
contact_graph = build_contact_subgraph()
dictionary_graph = build_dictionary_subgraph()
news_analysis_graph = build_news_analysis_subgraph()
graph.add_node(
"contact_subgraph",
wrap_subgraph_for_error_handling(contact_graph.compile(), "contact")
)
graph.add_node(
"dictionary_subgraph",
wrap_subgraph_for_error_handling(dictionary_graph.compile(), "dictionary")
)
graph.add_node(
"news_analysis_subgraph",
wrap_subgraph_for_error_handling(news_analysis_graph.compile(), "news_analysis")
)
# 第四阶段:完成处理
if summarize_node:
graph.add_node("summarize", summarize_node)
graph.add_node("finalize", finalize_node)
# ========== 添加边 ==========
# 第一阶段:记忆检索
if retrieve_memory_node:
graph.add_edge(START, "retrieve_memory")
graph.add_edge("retrieve_memory", "memory_trigger")
else:
graph.add_edge(START, "memory_trigger")
# 进入初始化
graph.add_edge("memory_trigger", "init_state")
# ========== 混合路由分支(如果启用) ==========
if use_hybrid_router:
graph.add_edge("init_state", "hybrid_router")
# 从 hybrid_router 条件分支
graph.add_conditional_edges(
"hybrid_router",
route_from_hybrid_decision,
{
"fast_chitchat": "fast_chitchat",
"fast_rag": "fast_rag",
"fast_tool": "fast_tool",
"react_loop": "react_reason"
}
)
# 快速路径的完成检查
for fast_node in ["fast_chitchat", "fast_rag", "fast_tool"]:
graph.add_conditional_edges(
fast_node,
check_fast_path_success,
{
"llm_call": "llm_call",
"escalate": "react_reason"
}
)
info(f"✅ [图构建] 混合路由模式已启用")
else:
# 无混合路由,直接到 react_reason
graph.add_edge("init_state", "react_reason")
info(f"✅ [图构建] 纯 React 模式")
# ========== React 循环边(始终保留) ==========
graph.add_conditional_edges(
"react_reason",
route_by_reasoning,
{
"rag_retrieve": "rag_retrieve",
"web_search": "web_search",
"contact_subgraph": "contact_subgraph",
"dictionary_subgraph": "dictionary_subgraph",
"news_analysis_subgraph": "news_analysis_subgraph",
"handle_error": "handle_error",
"llm_call": "llm_call"
}
)
# 循环边rag、web_search、子图、error都回到 reason
graph.add_edge("rag_retrieve", "react_reason")
graph.add_edge("web_search", "react_reason")
graph.add_edge("contact_subgraph", "react_reason")
graph.add_edge("dictionary_subgraph", "react_reason")
graph.add_edge("news_analysis_subgraph", "react_reason")
graph.add_edge("handle_error", "react_reason")
# ========== 最终完成阶段 ==========
if llm_node is not None:
if summarize_node:
# 检查是否需要总结
graph.add_conditional_edges(
"llm_call",
should_summarize,
{
"summarize": "summarize",
"finalize": "finalize"
}
)
graph.add_edge("summarize", "finalize")
else:
# 没有 summarize 节点,直接 finalize
graph.add_edge("llm_call", "finalize")
# 完成
graph.add_edge("finalize", END)
info(f"✅ [图构建] 整合后的完整主图构建完成(混合路由: {use_hybrid_router}")
return graph
# ========== 兼容性:保留旧的函数名 ==========
def build_main_graph() -> StateGraph:
"""
兼容性函数:旧代码调用 build_main_graph() 时返回 React 版本
"""
return build_react_main_graph()
# ========== 导出 ==========
__all__ = [
"build_react_main_graph",
"build_main_graph",
"wrap_subgraph_for_error_handling"
]

View File

@@ -6,11 +6,17 @@
from .embedding_services import get_embedding_service
from .rerank_services import get_rerank_service, BaseRerankService
from .chat_services import get_small_llm_service
from .chat_services import (
get_small_llm_service,
get_cached_chat_services,
get_all_chat_services
)
__all__ = [
"get_embedding_service",
"get_rerank_service",
"get_small_llm_service",
"get_cached_chat_services",
"get_all_chat_services",
"BaseRerankService"
]

View File

@@ -33,6 +33,21 @@ from app.config import (
logger = logging.getLogger(__name__)
# 缓存已初始化的模型字典
_cached_services: Dict[str, BaseChatModel] | None = None
def _check_http_service_available(base_url: str, api_key: str = "", timeout: float = 2.0) -> bool:
"""通过探测 /models 端点检查 HTTP API 是否可用(内部工具函数)"""
try:
import httpx
client = httpx.Client(base_url=base_url.rstrip('/'), timeout=timeout)
headers = {"Authorization": f"Bearer {api_key}"} if api_key else {}
resp = client.get("/models", headers=headers)
return resp.status_code == 200
except Exception:
return False
class LocalVLLMChatProvider(BaseServiceProvider[BaseChatModel]):
"""
@@ -54,46 +69,8 @@ class LocalVLLMChatProvider(BaseServiceProvider[BaseChatModel]):
logger.warning("VLLM_BASE_URL 未配置")
return False
try:
# 先测试主机名能否解析
import httpx
from urllib.parse import urlparse
parsed_url = urlparse(VLLM_BASE_URL)
host = parsed_url.hostname
port = parsed_url.port or (80 if parsed_url.scheme == 'http' else 443)
# 测试能否建立 TCP 连接(快速失败)
import socket
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(2.0)
try:
sock.connect((host, port))
sock.close()
except Exception as e:
logger.warning(f"本地 VLLM 服务无法连接: {host}:{port} - {e}")
return False
# 再尝试调用简单的 API比如 models 接口)
client = httpx.Client(base_url=VLLM_BASE_URL.rstrip('/'), timeout=5.0)
headers = {}
if LLM_API_KEY:
headers["Authorization"] = f"Bearer {LLM_API_KEY}"
try:
response = client.get("/models", headers=headers)
if response.status_code == 200:
logger.info(f"本地 VLLM 服务可用: {self._model}")
return True
except Exception:
pass
# 如果 /v1/models 失败,也认为服务不可用
logger.warning(f"本地 VLLM 服务响应异常")
return False
except Exception as e:
logger.warning(f"本地 VLLM 服务不可用: {e}")
return False
# 使用统一的 HTTP 探测方法
return _check_http_service_available(VLLM_BASE_URL, LLM_API_KEY, timeout=2.0)
def get_service(self) -> BaseChatModel:
"""
@@ -238,45 +215,8 @@ class LocalSmallModelProvider(BaseServiceProvider[BaseChatModel]):
logger.warning("SMALL_VLLM_BASE_URL 未配置,本地小模型不可用")
return False
try:
# 先测试主机名能否解析
import httpx
from urllib.parse import urlparse
parsed_url = urlparse(self._base_url)
host = parsed_url.hostname
port = parsed_url.port or (80 if parsed_url.scheme == 'http' else 443)
# 测试能否建立 TCP 连接(快速失败)
import socket
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(2.0)
try:
sock.connect((host, port))
sock.close()
except Exception as e:
logger.warning(f"本地小模型服务无法连接: {host}:{port} - {e}")
return False
# 再尝试调用简单的 API
client = httpx.Client(base_url=self._base_url.rstrip('/'), timeout=5.0)
headers = {}
if self._api_key:
headers["Authorization"] = f"Bearer {self._api_key}"
try:
response = client.get("/models", headers=headers)
if response.status_code == 200:
logger.info(f"本地小模型服务可用: {self._model}")
return True
except Exception:
pass
logger.warning(f"本地小模型服务响应异常")
return False
except Exception as e:
logger.warning(f"本地小模型服务不可用: {e}")
return False
# 使用统一的 HTTP 探测方法
return _check_http_service_available(self._base_url, self._api_key, timeout=2.0)
def get_service(self) -> BaseChatModel:
"""获取本地小模型服务"""
@@ -358,25 +298,18 @@ def get_chat_service() -> BaseChatModel:
return chain.get_available_service()
def get_all_chat_services() -> Dict[str, BaseChatModel]:
"""
获取所有可用的生成式大模型服务(用于多模型切换)
Returns:
Dict[str, BaseChatModel]: 模型名称 -> ChatModel 实例 的字典
"""
def _init_chat_services() -> Dict[str, BaseChatModel]:
"""实际初始化所有可用模型(仅在首次调用)"""
services = {}
for name, provider_factory in CHAT_PROVIDERS.items():
try:
provider = provider_factory()
if provider.is_available():
logger.info(f"模型 '{name}' 可用")
services[name] = provider.get_service()
else:
logger.warning(f"模型 '{name}' 不可用,跳过")
logger.info(f"已加载模型: {name}")
except Exception as e:
logger.warning(f"初始化模型 '{name}' 失败: {e}")
logger.warning(f"模型 {name} 初始化失败: {e}")
if not services:
raise RuntimeError(f"没有可用的生成式大模型,尝试了: {list(CHAT_PROVIDERS.keys())}")
@@ -384,6 +317,25 @@ def get_all_chat_services() -> Dict[str, BaseChatModel]:
return services
def get_cached_chat_services() -> Dict[str, BaseChatModel]:
"""获取缓存的可用模型字典(用于单图动态注入)"""
global _cached_services
if _cached_services is None:
_cached_services = _init_chat_services()
return _cached_services
def get_all_chat_services() -> Dict[str, BaseChatModel]:
"""
获取所有可用的生成式大模型服务(用于多模型切换,保留兼容性)
新代码请使用 get_cached_chat_services() 获取缓存版本
Returns:
Dict[str, BaseChatModel]: 模型名称 -> ChatModel 实例 的字典
"""
return get_cached_chat_services()
def get_small_llm_service() -> BaseChatModel:
"""
获取轻量级大模型服务(用于查询改写、意图分类等简单任务)

View File

@@ -4,7 +4,7 @@ Contact Subgraph Builder
支持 API 注入的工厂模式
"""
from app.main_graph.graph import StateGraph, START, END
from langgraph.graph import StateGraph, START, END
from .state import ContactState
from .nodes import create_contact_nodes

View File

@@ -3,7 +3,7 @@
Dictionary Subgraph Builder - Complete
"""
from app.main_graph.graph import StateGraph, START, END
from langgraph.graph import StateGraph, START, END
from .state import DictionaryState
from .nodes import (

View File

@@ -3,7 +3,7 @@
News Analysis Subgraph Builder
"""
from app.main_graph.graph import StateGraph, START, END
from langgraph.graph import StateGraph, START, END
from .state import NewsAnalysisState
from .nodes import (

View File

@@ -42,6 +42,7 @@ PyYAML>=6.0.3
numpy>=1.26.2
pyjwt>=2.8.0
ddgs>=6.0.0 # 免费联网搜索(原 duckduckgo-search 已重命名)
tavily-python>=0.5.0 # Tavily 搜索 API需要 API Key质量更高
matplotlib>=3.9.0 # 可视化图表
# Document Processing