refactor: 单图方案重构 + 动态模型选择 + chat_services优化
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 12m9s
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 12m9s
## 核心改动 ### 1. 单图方案重构 - 删除了多图(self.graphs),改为单图(self.graph) - 新增 MainGraphState.current_model 字段用于运行时注入模型 - llm_call 节点改为动态选择模型(create_dynamic_llm_call_node) ### 2. chat_services 优化 - 添加 _cached_services 缓存,避免重复初始化 - 新增 get_cached_chat_services() 函数,用于单图注入 - 新增 _check_http_service_available() 统一HTTP探测逻辑 - 减少重复代码,LocalVLLMChatProvider和LocalSmallModelProvider共用探测方法 ### 3. AIAgentService 重构 - initialize() 只构建一次图,传入 chat_services 字典 - 新增 _resolve_model() 模型回退逻辑 - 新增 _build_invocation() 统一构建调用参数 - process_message() 和 process_message_stream() 改为注入 current_model - 流式处理代码拆分,增加可读性 ### 4. 新增和删除文件 - 新增:backend/app/main_graph/main_graph_builder.py(图构建) - 新增:backend/app/main_graph/subgraph_wrapper.py(子图封装) - 新增:tools/test/test_tavily_search.py(测试) - 删除:backend/app/main_graph/graph.py(旧图) - 删除:backend/app/main_graph/utils/main_graph_builder.py(旧构建器) - 删除:backend/app/main_graph/utils/__init__.py ### 5. 其他更新 - README.md:新增模型服务使用情况详解章节 - backend/app/model_services/__init__.py:新增 get_cached_chat_services 导出 ## 方案优势 - 内存优化:N张图 → 1张图 - 灵活性:运行时动态选择模型,支持同会话不同模型 - 性能:模型服务缓存,初始化仅一次 - 可维护性:减少重复代码,统一HTTP探测逻辑
This commit is contained in:
@@ -60,6 +60,13 @@ MEMORY_SUMMARIZE_INTERVAL=10
|
||||
# 是否启用 Graph 执行追踪(调试用)
|
||||
ENABLE_GRAPH_TRACE=true
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Tavily 搜索配置
|
||||
# 免费额度:1000次/天,官网:https://app.tavily.com
|
||||
# -----------------------------------------------------------------------------
|
||||
TAVILY_API_KEY=你的Tavily_API密钥
|
||||
TAVILY_MAX_RESULTS=5
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# 稀疏模型配置
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
205
README.md
205
README.md
@@ -7,15 +7,15 @@
|
||||
## 📑 目录导航
|
||||
|
||||
- [核心功能](#-核心功能) - 面向用户的功能和技术特性
|
||||
- [使用指南](#-使用指南) - 基础对话、工具调用、多模型切换
|
||||
- [技术架构](#️-技术架构) - 技术栈、系统架构图、工作流流程图
|
||||
- [模型服务使用情况](#55-模型服务使用情况详解) - 模型选型、Token估算、成本分析
|
||||
- [核心算法与实现原理](#-核心算法与实现原理) - LangGraph 工作流、多模型路由、SSE 流式响应
|
||||
- [快速开始](#-快速开始) - Docker 和本地部署指南
|
||||
- [使用指南](#-使用指南) - 基础对话、工具调用、多模型切换
|
||||
- [开发指南](#-开发指南) - 添加工具、添加模型、Docker 部署
|
||||
- [实现指南与最佳实践](#️-实现指南与最佳实践) - 性能优化、扩展开发、部署实践
|
||||
- [环境配置](#️-环境配置) - 配置文件、环境变量
|
||||
- [故障排查](#-故障排查) - 常见问题
|
||||
|
||||
---
|
||||
|
||||
## 🎯 核心功能
|
||||
@@ -50,6 +50,48 @@
|
||||
|
||||
---
|
||||
|
||||
|
||||
---
|
||||
|
||||
## 📖 使用指南
|
||||
|
||||
### 基础对话
|
||||
|
||||
直接在聊天框输入问题即可:
|
||||
|
||||
```
|
||||
你好,请介绍一下自己
|
||||
帮我写一个 Python 脚本
|
||||
```
|
||||
|
||||
### 主要功能
|
||||
|
||||
| 功能 | 说明 | 示例提问 |
|
||||
|------|------|---------|
|
||||
| 🧠 混合路由智能分流 | 自动判断任务类型,选择最佳路径 | 自然对话即可 |
|
||||
| ⚡ 快速路径 | 闲聊、RAG查询、工具调用可走快速路径 | "你好"、"什么是 RAG" |
|
||||
| 🔄 React 推理循环 | 复杂任务走完整的思考-行动-观察循环 | "帮我分析一下这个文档" |
|
||||
| 🌐 联网搜索 | 免费 DuckDuckGo 搜索 | "今天北京天气怎么样?" |
|
||||
| 📚 RAG 知识库检索 | 检索本地知识库 | "如何配置系统?" |
|
||||
| 📇 通讯录管理 | 联系人 CRUD、邮件处理 | "帮我查看一下张三的联系方式" |
|
||||
| 📖 智能词典 | 翻译、生词本、专业术语提取 | "帮我翻译这句话" |
|
||||
| 📰 资讯分析 | 资讯获取、内容分析 | "帮我分析一下这篇新闻" |
|
||||
| 📊 可视化图表 | 支持 Mermaid 图表生成 | "帮我画一个流程图" |
|
||||
|
||||
### 多模型切换
|
||||
|
||||
1. 在左侧边栏选择模型:
|
||||
- **智谱 GLM-4**:在线服务,速度快
|
||||
- **DeepSeek V3**:深度推理模型
|
||||
- **OpenAI GPT-4o-mini**:通用对话模型
|
||||
- **本地 Qwen3.5-9B**:本地部署,隐私性好
|
||||
|
||||
2. 可随时切换,甚至在同一会话中
|
||||
|
||||
3. 点击 "🔄 新会话" 清空当前对话
|
||||
|
||||
---
|
||||
|
||||
## 🏗️ 技术架构
|
||||
|
||||
### 1. 技术栈总览
|
||||
@@ -385,6 +427,126 @@ flowchart TB
|
||||
|
||||
---
|
||||
|
||||
|
||||
---
|
||||
|
||||
### 5.5. 模型服务使用情况详解
|
||||
|
||||
#### 5.5.1. 模型服务架构总览
|
||||
|
||||
本项目采用**分层模型策略**,根据任务复杂度选择不同能力和成本的模型:
|
||||
|
||||
| 模型类型 | 用途 | 主要来源 | 成本考量 |
|
||||
|---------|------|---------|---------|
|
||||
| 小模型 (Small LLM) | 意图分类、路由决策、查询改写 | 本地模型 / DeepSeek小模型 | 低成本,高频率 |
|
||||
| 大模型 (Main LLM) | 对话生成、推理、工具调用 | 智谱 / DeepSeek / 本地 | 高能力,低频率 |
|
||||
| Embedding模型 | 文本向量化、语义检索 | 本地 llama.cpp / 智谱API | 批量处理 |
|
||||
| Rerank模型 | 检索结果重排序 | 硅基流动 / 智谱API | 精准排序 |
|
||||
| Sparse模型 | BM25稀疏检索 | FastEmbed本地 | 关键词匹配 |
|
||||
|
||||
#### 5.5.2. 小模型使用场景及Token估算
|
||||
|
||||
**小模型**主要用于高频率、低复杂度的任务:
|
||||
|
||||
| 使用场景 | 位置文件 | 用途描述 | 单次Token估算 | 调用频率 |
|
||||
|---------|---------|---------|-------------|---------|
|
||||
| 意图分类 (1) | `app/core/intent_classifier.py` | 判断用户意图类型 | ~300输入 + ~50输出 | 每轮对话1次 |
|
||||
| 意图分类 (2) | `app/main_graph/nodes/hybrid_router.py` | 混合路由决策 | ~200输入 + ~50输出 | 每轮对话1次 |
|
||||
| 闲聊回复 | `app/main_graph/nodes/fast_paths.py` | 快速回复问候语 | ~50输入 + ~30输出 | 按需调用 |
|
||||
|
||||
**Token估算说明**:
|
||||
- 单次意图分类:总计 ~350-600 tokens
|
||||
- 小模型成本通常是大模型的 1/10 - 1/100
|
||||
- 每日1000次对话,小模型仅消耗 ~350k-600k tokens
|
||||
|
||||
#### 5.5.3. 大模型使用场景及Token估算
|
||||
|
||||
**大模型**用于核心对话生成和复杂推理:
|
||||
|
||||
| 使用场景 | 位置文件 | 用途描述 | 单次Token估算 | 调用频率 |
|
||||
|---------|---------|---------|-------------|---------|
|
||||
| RAG查询改写 | `app/main_graph/utils/rag_initializer.py` | 生成多角度查询 | ~100输入 + ~150输出 | RAG调用时 |
|
||||
| 主对话生成 | `app/main_graph/nodes/llm_call.py` | 用户查询响应 | ~500-2000输入 + ~200-1000输出 | 每轮对话1次 |
|
||||
| React推理 | `app/main_graph/nodes/reasoning.py` | 任务分解与规划 | ~300-1000输入 + ~100-500输出 | 复杂任务多次 |
|
||||
| 记忆摘要 | `app/memory/mem0_client.py` | 长期记忆压缩 | ~500-2000输入 + ~200-500输出 | 每N轮对话1次 |
|
||||
|
||||
**Token估算说明**:
|
||||
- 普通对话:总计 ~1000-3000 tokens
|
||||
- RAG查询:额外 ~250 tokens
|
||||
- 复杂多步推理:可能额外增加 500-3000 tokens
|
||||
- 每日1000次对话,大模型预计消耗 1M-3M tokens
|
||||
|
||||
#### 5.5.4. Embedding模型使用场景
|
||||
|
||||
**Embedding模型**用于语义检索和向量存储:
|
||||
|
||||
| 使用场景 | 位置文件 | 用途描述 | 估算 |
|
||||
|---------|---------|---------|------|
|
||||
| RAG文档索引 | `rag_indexer/index_builder.py` | 文档分片向量化 | 每个文档片段1次 |
|
||||
| 在线检索 | `app/rag/retriever.py` | 查询向量化 + 相似度检索 | 每次检索1次 |
|
||||
| 记忆向量化 | `app/memory/mem0_client.py` | 记忆内容向量化存储 | 每次记忆更新1次 |
|
||||
|
||||
**Embedding说明**:
|
||||
- 向量维度:1024 (Qwen3-Embedding-0.6B) 或 2048 (智谱 embedding-3)
|
||||
- 批量处理:建议使用 batch_size=10-20 提高效率
|
||||
- 本地优先:优先使用 llama.cpp 服务,降低API调用成本
|
||||
|
||||
#### 5.5.5. Rerank模型使用场景
|
||||
|
||||
**Rerank模型**用于检索结果精细化排序:
|
||||
|
||||
| 使用场景 | 位置文件 | 用途描述 | 估算 |
|
||||
|---------|---------|---------|------|
|
||||
| RAG结果重排 | `app/rag/rerank.py` | 提升检索相关性 | 每次检索调用 |
|
||||
| 混合检索重排 | `app/rag/retriever.py` | 稀疏+稠密结果融合排序 | 每次检索调用 |
|
||||
|
||||
**Rerank说明**:
|
||||
- 通常在 RRF 融合后使用,进一步提升精准度
|
||||
- 重排数量建议:rerank_top_n=3-10
|
||||
- 成本权衡:rerank 会增加额外调用成本,但精度提升明显
|
||||
|
||||
#### 5.5.6. 模型服务选型参考对比
|
||||
|
||||
为方便不同部署场景选择,提供以下模型选型参考:
|
||||
|
||||
| 维度 | 本地优先方案 | 云端优先方案 | 混合方案 |
|
||||
|------|------------|------------|---------|
|
||||
| **小模型** | Qwen3.5-9B (本地) | DeepSeek-Chat (API) | 本地+DeepSeek降级 |
|
||||
| **大模型** | Qwen3.5-9B (本地) | 智谱 GLM-4 / DeepSeek | 本地+云端降级链 |
|
||||
| **Embedding** | Qwen3-Embedding-0.6B (本地llama.cpp) | 智谱 embedding-2 | 本地优先,智谱降级 |
|
||||
| **Rerank** | (可选本地) | 硅基流动 bge-reranker-v2-m3 | 硅基流动API |
|
||||
| **Sparse** | FastEmbed BM25 (本地) | FastEmbed BM25 (本地) | 本地 |
|
||||
|
||||
**成本参考对比**(每1M tokens,仅作示例):
|
||||
|
||||
| 模型 | 输入成本 | 输出成本 | 适用场景 |
|
||||
|------|---------|---------|---------|
|
||||
| **本地模型** | ~0元 | ~0元 | 有GPU机器,隐私敏感 |
|
||||
| **DeepSeek-Chat** | ~¥0.5 | ~¥1.0 | 通用推理,成本适中 |
|
||||
| **智谱 GLM-4** | ~¥1.0 | ~¥2.0 | 高质量对话 |
|
||||
| **智谱 embedding-2** | ~¥0.2 | - | 向量嵌入 |
|
||||
| **硅基流动 Rerank** | ~¥0.3 | - | 精准重排 |
|
||||
|
||||
**部署建议**:
|
||||
- **个人/测试**:全云端方案,快速上手
|
||||
- **小团队**:小模型本地,大模型云端降级
|
||||
- **企业/隐私敏感**:全本地部署,或使用私有API
|
||||
- **生产环境**:核心能力本地+云端降级链,保证高可用
|
||||
|
||||
#### 5.5.7. 模型服务降级链路设计
|
||||
|
||||
本项目所有模型服务都设计了**自动降级链路**,保证服务高可用:
|
||||
|
||||
| 服务类型 | 主服务 | 降级服务 |
|
||||
|---------|--------|---------|
|
||||
| 对话生成 | 本地模型 → 智谱GLM-4 → DeepSeek |
|
||||
| Embedding | 本地llama.cpp → 智谱embedding-2 |
|
||||
| Rerank | 硅基流动 → 智谱rerank-2 |
|
||||
|
||||
降级逻辑实现在 `app/model_services/base.py: FallbackServiceChain`,对上层业务透明。
|
||||
|
||||
---
|
||||
|
||||
### 6. 模型服务层
|
||||
|
||||
#### 6.1 多模型降级链
|
||||
@@ -1582,45 +1744,6 @@ streamlit run frontend/src/frontend_main.py
|
||||
|
||||
---
|
||||
|
||||
## 📖 使用指南
|
||||
|
||||
### 基础对话
|
||||
|
||||
直接在聊天框输入问题即可:
|
||||
|
||||
```
|
||||
你好,请介绍一下自己
|
||||
帮我写一个 Python 脚本
|
||||
```
|
||||
|
||||
### 主要功能
|
||||
|
||||
| 功能 | 说明 | 示例提问 |
|
||||
|------|------|---------|
|
||||
| 🧠 混合路由智能分流 | 自动判断任务类型,选择最佳路径 | 自然对话即可 |
|
||||
| ⚡ 快速路径 | 闲聊、RAG查询、工具调用可走快速路径 | "你好"、"什么是 RAG" |
|
||||
| 🔄 React 推理循环 | 复杂任务走完整的思考-行动-观察循环 | "帮我分析一下这个文档" |
|
||||
| 🌐 联网搜索 | 免费 DuckDuckGo 搜索 | "今天北京天气怎么样?" |
|
||||
| 📚 RAG 知识库检索 | 检索本地知识库 | "如何配置系统?" |
|
||||
| 📇 通讯录管理 | 联系人 CRUD、邮件处理 | "帮我查看一下张三的联系方式" |
|
||||
| 📖 智能词典 | 翻译、生词本、专业术语提取 | "帮我翻译这句话" |
|
||||
| 📰 资讯分析 | 资讯获取、内容分析 | "帮我分析一下这篇新闻" |
|
||||
| 📊 可视化图表 | 支持 Mermaid 图表生成 | "帮我画一个流程图" |
|
||||
|
||||
### 多模型切换
|
||||
|
||||
1. 在左侧边栏选择模型:
|
||||
- **智谱 GLM-4**:在线服务,速度快
|
||||
- **DeepSeek V3**:深度推理模型
|
||||
- **OpenAI GPT-4o-mini**:通用对话模型
|
||||
- **本地 Qwen3.5-9B**:本地部署,隐私性好
|
||||
|
||||
2. 可随时切换,甚至在同一会话中
|
||||
|
||||
3. 点击 "🔄 新会话" 清空当前对话
|
||||
|
||||
---
|
||||
|
||||
## 🔧 开发指南
|
||||
|
||||
### 添加新工具
|
||||
|
||||
@@ -1,25 +1,28 @@
|
||||
"""
|
||||
AI Agent 服务类 - 支持多模型动态切换
|
||||
AI Agent 服务类 - 单图方案 + 动态模型选择
|
||||
接收外部传入的 checkpointer,不负责管理连接生命周期
|
||||
"""
|
||||
|
||||
import json
|
||||
import asyncio
|
||||
from typing import AsyncGenerator, Dict, Any, Optional, Tuple
|
||||
|
||||
# 本地模块
|
||||
from ..main_graph.utils.main_graph_builder import build_react_main_graph
|
||||
from ..model_services import get_cached_chat_services
|
||||
from ..main_graph.main_graph_builder import build_react_main_graph
|
||||
from ..main_graph.tools.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
|
||||
from ..main_graph.config import set_stream_writer
|
||||
from ..main_graph.utils.rag_initializer import init_rag_tool
|
||||
from ..core.intent_classifier import get_intent_classifier
|
||||
from ..logger import info, warning, error
|
||||
from ..logger import debug, info, warning, error
|
||||
from ..main_graph.state import MainGraphState, CurrentAction
|
||||
|
||||
|
||||
class AIAgentService:
|
||||
def __init__(self, checkpointer):
|
||||
self.checkpointer = checkpointer
|
||||
self.graphs = {}
|
||||
self.graph = None # 只有一张图
|
||||
self.chat_services = None # 缓存的模型字典
|
||||
self.tools = AVAILABLE_TOOLS.copy()
|
||||
self.tools_by_name = TOOLS_BY_NAME.copy()
|
||||
# 添加:意图分类器
|
||||
@@ -40,64 +43,94 @@ class AIAgentService:
|
||||
self.tools.append(rag_tool)
|
||||
self.tools_by_name[rag_tool.name] = rag_tool
|
||||
self.rag_tool = rag_tool # 保存到实例变量,供 config 注入
|
||||
|
||||
# 2. 构建各模型的 Graph(使用新版 React 模式)
|
||||
for name, llm in chat_services.items():
|
||||
try:
|
||||
info(f"🔄 初始化模型 '{name}'...")
|
||||
graph = build_react_main_graph(
|
||||
llm=llm,
|
||||
tools=self.tools,
|
||||
mem0_client=self.mem0_client
|
||||
).compile(checkpointer=self.checkpointer)
|
||||
self.graphs[name] = graph
|
||||
info(f"✅ 模型 '{name}' 初始化成功")
|
||||
except Exception as e:
|
||||
warning(f"⚠️ 模型 '{name}' 初始化失败: {e}")
|
||||
|
||||
if not self.graphs:
|
||||
raise RuntimeError("没有可用的模型")
|
||||
|
||||
# 2. 获取缓存的模型字典
|
||||
self.chat_services = get_cached_chat_services()
|
||||
info(f"✅ 加载了 {len(self.chat_services)} 个可用模型: {list(self.chat_services.keys())}")
|
||||
|
||||
# 3. 只构建一次图(传入 chat_services 字典)
|
||||
info(f"🔄 构建单图...")
|
||||
graph_builder = build_react_main_graph(
|
||||
chat_services=self.chat_services,
|
||||
tools=self.tools,
|
||||
mem0_client=self.mem0_client
|
||||
)
|
||||
self.graph = graph_builder.compile(checkpointer=self.checkpointer)
|
||||
info(f"✅ 单图初始化完成")
|
||||
|
||||
return self
|
||||
|
||||
async def process_message(self, message: str, thread_id: str, model: str = "local", user_id: str = "default_user") -> dict:
|
||||
"""处理用户消息,返回包含回复、token统计和耗时的字典"""
|
||||
if model not in self.graphs:
|
||||
# 回退到第一个可用模型
|
||||
available = list(self.graphs.keys())
|
||||
if not available:
|
||||
raise RuntimeError("没有可用的模型")
|
||||
model = available[0]
|
||||
warning(f"模型 '{model}' 不可用,已回退到 '{model}'")
|
||||
def _resolve_model(self, model: str) -> str:
|
||||
"""
|
||||
解析并验证模型名称,不可用时回退到第一个可用模型
|
||||
|
||||
Args:
|
||||
model: 目标模型名称
|
||||
|
||||
Returns:
|
||||
实际使用的模型名称
|
||||
"""
|
||||
if not model or model not in self.chat_services:
|
||||
fallback = next(iter(self.chat_services.keys()))
|
||||
warning(f"模型 '{model}' 不可用,回退到 '{fallback}'")
|
||||
return fallback
|
||||
return model
|
||||
|
||||
graph = self.graphs[model]
|
||||
def _build_invocation(
|
||||
self, message: str, thread_id: str, model: str, user_id: str
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
"""
|
||||
构建图调用所需的 config 和 input_state
|
||||
|
||||
Args:
|
||||
message: 用户消息
|
||||
thread_id: 会话 ID
|
||||
model: 模型名称
|
||||
user_id: 用户 ID
|
||||
|
||||
Returns:
|
||||
(config, input_state) 元组
|
||||
"""
|
||||
config = {
|
||||
"configurable": {
|
||||
"thread_id": thread_id,
|
||||
"rag_tool": getattr(self, "rag_tool", None), # 注入 RAG 工具
|
||||
"rag_tool": getattr(self, "rag_tool", None),
|
||||
},
|
||||
"metadata": {"user_id": user_id}
|
||||
}
|
||||
# 新版状态输入:传入完整的 MainGraphState,关键是 user_query
|
||||
from ..main_graph.state import MainGraphState, CurrentAction
|
||||
input_state = {
|
||||
"user_query": message,
|
||||
"messages": [{"role": "user", "content": message}],
|
||||
"user_id": user_id,
|
||||
"current_model": model,
|
||||
"current_action": CurrentAction.NONE
|
||||
}
|
||||
return config, input_state
|
||||
|
||||
result = await graph.ainvoke(input_state, config=config)
|
||||
async def process_message(
|
||||
self, message: str, thread_id: str, model: str = "", user_id: str = "default_user"
|
||||
) -> dict:
|
||||
"""处理用户消息,返回包含回复、token统计和耗时的字典"""
|
||||
# 解析模型名称
|
||||
resolved_model = self._resolve_model(model)
|
||||
|
||||
# 构建调用参数
|
||||
config, input_state = self._build_invocation(message, thread_id, resolved_model, user_id)
|
||||
|
||||
result = await self.graph.ainvoke(input_state, config=config)
|
||||
|
||||
reply = result.get("final_result", "")
|
||||
if not reply and result.get("messages"):
|
||||
reply = result["messages"][-1].content
|
||||
token_usage = result.get("debug_info", {}).get("token_usage", {})
|
||||
elapsed_time = result.get("debug_info", {}).get("elapsed_time", 0.0)
|
||||
token_usage = result.get("last_token_usage", {})
|
||||
elapsed_time = result.get("last_elapsed_time", 0.0)
|
||||
actual_model = result.get("current_model", resolved_model)
|
||||
|
||||
return {
|
||||
"reply": reply,
|
||||
"token_usage": token_usage,
|
||||
"elapsed_time": elapsed_time
|
||||
"elapsed_time": elapsed_time,
|
||||
"model_used": actual_model
|
||||
}
|
||||
|
||||
def _serialize_value(self, value):
|
||||
@@ -121,31 +154,169 @@ class AIAgentService:
|
||||
except (TypeError, ValueError):
|
||||
return str(value)
|
||||
|
||||
async def process_message_stream(self, message: str, thread_id: str, model_name: str, user_id: str = "default_user"):
|
||||
"""流式处理消息,返回异步生成器(全部走 React 模式)"""
|
||||
graph = self.graphs.get(model_name)
|
||||
if not graph:
|
||||
raise ValueError(f"模型 '{model_name}' 未找到或未初始化")
|
||||
async def _handle_message_chunk(
|
||||
self, chunk: Dict[str, Any], current_node: Optional[str], tool_calls_in_progress: Dict[str, Any]
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""处理 messages 类型的 chunk"""
|
||||
message_chunk, metadata = chunk["data"]
|
||||
node_name = metadata.get("langgraph_node", "unknown")
|
||||
new_current_node = current_node
|
||||
|
||||
config = {
|
||||
"configurable": {
|
||||
"thread_id": thread_id,
|
||||
"rag_tool": getattr(self, "rag_tool", None), # 注入 RAG 工具
|
||||
},
|
||||
"metadata": {"user_id": user_id}
|
||||
}
|
||||
input_state = {
|
||||
"user_query": message,
|
||||
"messages": [{"role": "user", "content": message}],
|
||||
"user_id": user_id,
|
||||
"current_action": CurrentAction.NONE
|
||||
}
|
||||
# 检测节点变化,发送节点开始事件
|
||||
if node_name != current_node:
|
||||
if current_node:
|
||||
yield {"type": "node_end", "node": current_node}
|
||||
yield {"type": "node_start", "node": node_name}
|
||||
new_current_node = node_name
|
||||
|
||||
# ========== 意图识别(保留用于日志)==========
|
||||
# 处理消息内容
|
||||
token_content = getattr(message_chunk, 'content', str(message_chunk))
|
||||
reasoning_token = ""
|
||||
if hasattr(message_chunk, 'additional_kwargs'):
|
||||
reasoning_token = message_chunk.additional_kwargs.get("reasoning_content", "")
|
||||
|
||||
# 处理思考过程
|
||||
if reasoning_token:
|
||||
yield {
|
||||
"type": "llm_token",
|
||||
"node": node_name,
|
||||
"reasoning_token": reasoning_token
|
||||
}
|
||||
# 处理工具调用
|
||||
elif hasattr(message_chunk, 'tool_calls') and message_chunk.tool_calls:
|
||||
for tool_call in message_chunk.tool_calls:
|
||||
tool_call_id = tool_call.get("id", "")
|
||||
tool_name = tool_call.get("name", "")
|
||||
tool_args = tool_call.get("args", {})
|
||||
|
||||
# 记录工具调用开始,避免重复
|
||||
if tool_call_id and tool_call_id not in tool_calls_in_progress:
|
||||
tool_calls_in_progress[tool_call_id] = {
|
||||
"name": tool_name,
|
||||
"args": tool_args
|
||||
}
|
||||
yield {
|
||||
"type": "tool_call_start",
|
||||
"tool": tool_name,
|
||||
"args": tool_args,
|
||||
"id": tool_call_id
|
||||
}
|
||||
# 处理普通 token
|
||||
elif token_content:
|
||||
yield {
|
||||
"type": "llm_token",
|
||||
"node": node_name,
|
||||
"token": token_content,
|
||||
"reasoning_token": reasoning_token
|
||||
}
|
||||
|
||||
# 返回更新后的 current_node
|
||||
yield {"type": "_update_state", "current_node": new_current_node}
|
||||
|
||||
async def _handle_updates_chunk(
|
||||
self, chunk: Dict[str, Any], tool_calls_in_progress: Dict[str, Any], actual_model_used: str
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""处理 updates 类型的 chunk"""
|
||||
updates_data = chunk["data"]
|
||||
new_actual_model = actual_model_used
|
||||
|
||||
debug(f"[Stream] updates 数据: {list(updates_data.keys()) if isinstance(updates_data, dict) else type(updates_data)}")
|
||||
|
||||
# 特别检查 final_result 和 current_model
|
||||
if isinstance(updates_data, dict):
|
||||
if "final_result" in updates_data:
|
||||
debug(f"[Stream] 收到 final_result: {str(updates_data['final_result'])[:100]}...")
|
||||
if "current_model" in updates_data:
|
||||
new_actual_model = updates_data["current_model"]
|
||||
info(f"[Stream] 实际使用模型: {new_actual_model}")
|
||||
|
||||
serialized_data = self._serialize_value(updates_data)
|
||||
|
||||
# 检查是否有人工审核请求
|
||||
if "review_pending" in serialized_data and serialized_data["review_pending"]:
|
||||
review_id = serialized_data.get("review_id", "")
|
||||
content_to_review = serialized_data.get("content_to_review", "")
|
||||
yield {
|
||||
"type": "human_review_request",
|
||||
"review_id": review_id,
|
||||
"content": content_to_review
|
||||
}
|
||||
|
||||
# 检查是否有工具结果
|
||||
if "messages" in serialized_data:
|
||||
for msg in serialized_data["messages"]:
|
||||
# 检测工具结果消息
|
||||
if msg.get("role") == "tool":
|
||||
tool_call_id = msg.get("tool_call_id", "")
|
||||
tool_name = msg.get("name", "")
|
||||
tool_result = msg.get("content", "")
|
||||
|
||||
if tool_call_id and tool_call_id in tool_calls_in_progress:
|
||||
yield {
|
||||
"type": "tool_call_end",
|
||||
"tool": tool_name,
|
||||
"id": tool_call_id,
|
||||
"result": tool_result
|
||||
}
|
||||
del tool_calls_in_progress[tool_call_id]
|
||||
|
||||
yield {
|
||||
"type": "state_update",
|
||||
"data": serialized_data
|
||||
}
|
||||
|
||||
# 返回更新后的模型
|
||||
yield {"type": "_update_state", "actual_model_used": new_actual_model}
|
||||
|
||||
async def _handle_custom_chunk(self, chunk: Dict[str, Any]) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""处理 custom 类型的 chunk"""
|
||||
custom_data = chunk["data"]
|
||||
|
||||
# 处理我们从 react_reason_node 发送的自定义推理事件
|
||||
if isinstance(custom_data, dict):
|
||||
# 检查是否是我们的推理事件
|
||||
if "action" in custom_data and "reasoning" in custom_data:
|
||||
yield {
|
||||
"type": "react_reasoning",
|
||||
"step": custom_data.get("step", 1),
|
||||
"action": custom_data.get("action", "unknown"),
|
||||
"confidence": custom_data.get("confidence", 0),
|
||||
"reasoning": custom_data.get("reasoning", "")
|
||||
}
|
||||
else:
|
||||
# 处理其他自定义事件
|
||||
serialized_data = self._serialize_value(custom_data)
|
||||
yield {
|
||||
"type": "custom",
|
||||
"data": serialized_data
|
||||
}
|
||||
else:
|
||||
# 处理其他自定义事件
|
||||
serialized_data = self._serialize_value(custom_data)
|
||||
yield {
|
||||
"type": "custom",
|
||||
"data": serialized_data
|
||||
}
|
||||
|
||||
async def process_message_stream(
|
||||
self, message: str, thread_id: str, model: str = "", user_id: str = "default_user"
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""流式处理消息,返回异步生成器"""
|
||||
# 解析模型名称
|
||||
resolved_model = self._resolve_model(model)
|
||||
|
||||
# 构建调用参数
|
||||
config, input_state = self._build_invocation(message, thread_id, resolved_model, user_id)
|
||||
|
||||
# ========== 意图识别(保留用于日志和后续路由)==========
|
||||
intent_result = await self.intent_classifier.classify(message)
|
||||
info(f"🧠 意图识别: {intent_result.intent_type} (置信度: {intent_result.confidence:.2f})")
|
||||
info(f"📝 推理: {intent_result.reasoning}")
|
||||
|
||||
# 注入意图到状态(让 hybrid_router 可以利用)
|
||||
input_state["intent_type"] = intent_result.intent_type.value
|
||||
input_state["intent_confidence"] = intent_result.confidence
|
||||
|
||||
# 发送意图分类事件
|
||||
yield {
|
||||
"type": "intent_classified",
|
||||
@@ -154,25 +325,26 @@ class AIAgentService:
|
||||
"reasoning": intent_result.reasoning
|
||||
}
|
||||
|
||||
# 发送路径决策事件(现在都是 react_loop)
|
||||
# 发送路径决策事件(目前硬编码,但状态中有意图信息供后续使用)
|
||||
yield {
|
||||
"type": "path_decision",
|
||||
"path": "react_loop",
|
||||
"intent": intent_result.intent_type.value
|
||||
}
|
||||
# ========================================
|
||||
# =============================================
|
||||
|
||||
# ========== React 循环路径 ==========
|
||||
info(f"🚀 开始执行 React 图,模型: {model_name}")
|
||||
info(f"🚀 开始执行单图,指定模型: {resolved_model}")
|
||||
current_node = None
|
||||
tool_calls_in_progress = {}
|
||||
tool_calls_in_progress: Dict[str, Any] = {}
|
||||
actual_model_used = resolved_model
|
||||
chunk_count = 0
|
||||
full_message_content = ""
|
||||
|
||||
try:
|
||||
info(f"📡 开始调用 graph.astream()...")
|
||||
chunk_count = 0
|
||||
full_message_content = "" # 收集完整消息内容
|
||||
|
||||
async for chunk in graph.astream(
|
||||
async for chunk in self.graph.astream(
|
||||
input_state,
|
||||
config=config,
|
||||
stream_mode=["messages", "updates", "custom"],
|
||||
@@ -181,156 +353,58 @@ class AIAgentService:
|
||||
):
|
||||
chunk_count += 1
|
||||
chunk_type = chunk["type"]
|
||||
processed_event = {}
|
||||
|
||||
if chunk_type == "messages":
|
||||
message_chunk, metadata = chunk["data"]
|
||||
node_name = metadata.get("langgraph_node", "unknown")
|
||||
|
||||
# 检测节点变化,发送节点开始事件
|
||||
if node_name != current_node:
|
||||
if current_node:
|
||||
yield {
|
||||
"type": "node_end",
|
||||
"node": current_node
|
||||
}
|
||||
yield {
|
||||
"type": "node_start",
|
||||
"node": node_name
|
||||
}
|
||||
current_node = node_name
|
||||
|
||||
# 处理消息内容
|
||||
token_content = getattr(message_chunk, 'content', str(message_chunk))
|
||||
reasoning_token = ""
|
||||
if hasattr(message_chunk, 'additional_kwargs'):
|
||||
reasoning_token = message_chunk.additional_kwargs.get("reasoning_content", "")
|
||||
|
||||
# 处理思考过程
|
||||
if reasoning_token:
|
||||
processed_event = {
|
||||
"type": "llm_token",
|
||||
"node": node_name,
|
||||
"reasoning_token": reasoning_token
|
||||
}
|
||||
# 处理工具调用
|
||||
elif hasattr(message_chunk, 'tool_calls') and message_chunk.tool_calls:
|
||||
for tool_call in message_chunk.tool_calls:
|
||||
tool_call_id = tool_call.get("id", "")
|
||||
tool_name = tool_call.get("name", "")
|
||||
tool_args = tool_call.get("args", {})
|
||||
|
||||
# 记录工具调用开始
|
||||
if tool_call_id not in tool_calls_in_progress:
|
||||
tool_calls_in_progress[tool_call_id] = {
|
||||
"name": tool_name,
|
||||
"args": tool_args
|
||||
}
|
||||
yield {
|
||||
"type": "tool_call_start",
|
||||
"tool": tool_name,
|
||||
"args": tool_args,
|
||||
"id": tool_call_id
|
||||
}
|
||||
# 处理普通 token - 只收集,不打印单个 token
|
||||
elif token_content:
|
||||
processed_event = {
|
||||
"type": "llm_token",
|
||||
"node": node_name,
|
||||
"token": token_content,
|
||||
"reasoning_token": reasoning_token
|
||||
}
|
||||
if node_name == "llm_call":
|
||||
full_message_content += token_content
|
||||
async for event in self._handle_message_chunk(
|
||||
chunk, current_node, tool_calls_in_progress
|
||||
):
|
||||
if event.get("type") == "_update_state":
|
||||
current_node = event.get("current_node", current_node)
|
||||
else:
|
||||
# 如果是 llm_call 节点的 token,收集完整消息
|
||||
if (
|
||||
event.get("type") == "llm_token"
|
||||
and event.get("node") == "llm_call"
|
||||
and "token" in event
|
||||
):
|
||||
full_message_content += event["token"]
|
||||
yield event
|
||||
|
||||
elif chunk_type == "updates":
|
||||
updates_data = chunk["data"]
|
||||
info(f"[Stream] updates 数据: {list(updates_data.keys()) if isinstance(updates_data, dict) else type(updates_data)}")
|
||||
# 特别检查 final_result
|
||||
if isinstance(updates_data, dict) and "final_result" in updates_data:
|
||||
info(f"[Stream] 收到 final_result: {str(updates_data['final_result'])[:100]}...")
|
||||
serialized_data = self._serialize_value(updates_data)
|
||||
if "review_pending" in serialized_data and serialized_data["review_pending"]:
|
||||
review_id = serialized_data.get("review_id", "")
|
||||
content_to_review = serialized_data.get("content_to_review", "")
|
||||
yield {
|
||||
"type": "human_review_request",
|
||||
"review_id": review_id,
|
||||
"content": content_to_review
|
||||
}
|
||||
|
||||
# 检查是否有工具结果
|
||||
if "messages" in serialized_data:
|
||||
for msg in serialized_data["messages"]:
|
||||
# 检测工具结果消息
|
||||
if msg.get("role") == "tool":
|
||||
tool_call_id = msg.get("tool_call_id", "")
|
||||
tool_name = msg.get("name", "")
|
||||
tool_output = msg.get("content", "")
|
||||
|
||||
if tool_call_id in tool_calls_in_progress:
|
||||
yield {
|
||||
"type": "tool_call_end",
|
||||
"tool": tool_name,
|
||||
"id": tool_call_id,
|
||||
"result": tool_output
|
||||
}
|
||||
del tool_calls_in_progress[tool_call_id]
|
||||
|
||||
processed_event = {
|
||||
"type": "state_update",
|
||||
"data": serialized_data
|
||||
}
|
||||
async for event in self._handle_updates_chunk(
|
||||
chunk, tool_calls_in_progress, actual_model_used
|
||||
):
|
||||
if event.get("type") == "_update_state":
|
||||
actual_model_used = event.get("actual_model_used", actual_model_used)
|
||||
else:
|
||||
yield event
|
||||
|
||||
elif chunk_type == "custom":
|
||||
custom_data = chunk["data"]
|
||||
|
||||
# 处理我们从 react_reason_node 发送的自定义推理事件
|
||||
if isinstance(custom_data, dict):
|
||||
# 检查是否是我们的推理事件
|
||||
if "action" in custom_data and "reasoning" in custom_data:
|
||||
yield {
|
||||
"type": "react_reasoning",
|
||||
"step": custom_data.get("step", 1),
|
||||
"action": custom_data.get("action", "unknown"),
|
||||
"confidence": custom_data.get("confidence", 0),
|
||||
"reasoning": custom_data.get("reasoning", "")
|
||||
}
|
||||
else:
|
||||
# 处理其他自定义事件
|
||||
serialized_data = self._serialize_value(custom_data)
|
||||
processed_event = {
|
||||
"type": "custom",
|
||||
"data": serialized_data
|
||||
}
|
||||
else:
|
||||
# 处理其他自定义事件
|
||||
serialized_data = self._serialize_value(custom_data)
|
||||
processed_event = {
|
||||
"type": "custom",
|
||||
"data": serialized_data
|
||||
}
|
||||
|
||||
if processed_event:
|
||||
yield processed_event
|
||||
async for event in self._handle_custom_chunk(chunk):
|
||||
yield event
|
||||
|
||||
# 完整消息集合完成后,一次性打印
|
||||
info(f"✅ graph.astream() 完成,共 {chunk_count} 个chunks")
|
||||
info(f"✅ graph.astream() 完成,共 {chunk_count} 个 chunks")
|
||||
if full_message_content:
|
||||
info(f"📄 完整消息内容: {repr(full_message_content)}")
|
||||
info(f"🤖 实际使用模型: {actual_model_used}")
|
||||
|
||||
except Exception as e:
|
||||
error(f"❌ 执行 React 图时出错: {e}")
|
||||
error(f"❌ 执行单图时出错: {e}")
|
||||
import traceback
|
||||
error(f"📋 堆栈: {traceback.format_exc()}")
|
||||
raise
|
||||
|
||||
# 发送结束事件
|
||||
if current_node:
|
||||
yield {
|
||||
"type": "node_end",
|
||||
"node": current_node
|
||||
"type": "error",
|
||||
"message": str(e)
|
||||
}
|
||||
finally:
|
||||
# 无论成功或失败,都发送结束事件,保证前端平稳关闭
|
||||
if current_node:
|
||||
yield {
|
||||
"type": "node_end",
|
||||
"node": current_node
|
||||
}
|
||||
yield {
|
||||
"type": "done",
|
||||
"model_used": actual_model_used
|
||||
}
|
||||
yield {
|
||||
"type": "done"
|
||||
}
|
||||
@@ -22,9 +22,8 @@ def create_system_prompt(tools: list = None) -> ChatPromptTemplate:
|
||||
"3. 📇 通讯录子系统 - 查询联系人、添加联系人、管理通讯录\n"
|
||||
"4. 🔍 RAG检索 - 从知识库中检索相关信息回答问题\n\n"
|
||||
"【用户背景信息】\n"
|
||||
"以下是对当前用户的已知信息和长期记忆,你必须优先采纳并在回答中体现:\n"
|
||||
"以下是对当前用户的已知信息和长期记忆,你必须优先采纳:\n"
|
||||
"{memory_context}\n"
|
||||
"若包含姓名、偏好等个人信息,请自然融入回应(例如称呼名字、提及偏好)。\n\n"
|
||||
"【可用工具与使用规则】\n"
|
||||
f"{tools_section}\n"
|
||||
"工具调用时请直接返回所需参数,无需额外说明。\n\n"
|
||||
|
||||
@@ -127,6 +127,13 @@ BACKEND_PORT = _get_int("BACKEND_PORT")
|
||||
MEMORY_SUMMARIZE_INTERVAL = _get_int("MEMORY_SUMMARIZE_INTERVAL")
|
||||
|
||||
|
||||
# ========== Tavily 搜索配置 ==========
|
||||
# Tavily API:https://app.tavily.com
|
||||
# 免费额度:1000次/天
|
||||
TAVILY_API_KEY = _get_str("TAVILY_API_KEY")
|
||||
TAVILY_MAX_RESULTS = _get_int("TAVILY_MAX_RESULTS") or 5
|
||||
|
||||
|
||||
# ========== Graph 执行追踪配置 ==========
|
||||
# 是否启用 Graph 流转追踪(通过环境变量控制)
|
||||
ENABLE_GRAPH_TRACE = _get_bool("ENABLE_GRAPH_TRACE")
|
||||
|
||||
@@ -33,18 +33,29 @@ class WebSearchTool:
|
||||
|
||||
def search(self, query: str, max_results: Optional[int] = None) -> List[SearchResult]:
|
||||
"""
|
||||
使用多种方式搜索
|
||||
|
||||
使用多种方式搜索,按优先级尝试
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
max_results: 返回结果数量,默认使用初始化时的设置
|
||||
|
||||
|
||||
Returns:
|
||||
搜索结果列表
|
||||
"""
|
||||
num_results = max_results or self.max_results
|
||||
|
||||
# 方式 1: 尝试用 ddgs 包
|
||||
# 方式 1: Tavily (需要 API Key,质量最高)
|
||||
try:
|
||||
return self._search_tavily(query, num_results)
|
||||
except ImportError:
|
||||
print("[WebSearch] tavily 未安装,尝试其他搜索方式")
|
||||
except Exception as e:
|
||||
if "API_KEY" in str(e) or "未配置" in str(e):
|
||||
print(f"[WebSearch] Tavily API Key 未配置: {e}")
|
||||
else:
|
||||
print(f"[WebSearch] Tavily 搜索失败: {e}")
|
||||
|
||||
# 方式 2: 尝试用 ddgs 包
|
||||
try:
|
||||
from ddgs import DDGS
|
||||
print(f"[WebSearch] 使用 ddgs 搜索: {query}")
|
||||
@@ -65,29 +76,7 @@ class WebSearchTool:
|
||||
print("[WebSearch] ddgs 未安装,尝试 duckduckgo-search")
|
||||
except Exception as e:
|
||||
print(f"[WebSearch] ddgs 搜索失败: {e}")
|
||||
|
||||
# 方式 2: 尝试用旧的 duckduckgo-search 包
|
||||
try:
|
||||
from duckduckgo_search import DDGS
|
||||
print(f"[WebSearch] 使用 duckduckgo-search 搜索: {query}")
|
||||
with DDGS() as ddgs:
|
||||
results = list(ddgs.text(query, max_results=num_results))
|
||||
if results:
|
||||
search_results = []
|
||||
for r in results:
|
||||
search_results.append(SearchResult(
|
||||
title=r.get("title", ""),
|
||||
url=r.get("href", ""),
|
||||
snippet=r.get("body", ""),
|
||||
source="DuckDuckGo"
|
||||
))
|
||||
print(f"[WebSearch] duckduckgo-search 返回 {len(search_results)} 条结果")
|
||||
return search_results
|
||||
except ImportError:
|
||||
print("[WebSearch] duckduckgo-search 未安装")
|
||||
except Exception as e:
|
||||
print(f"[WebSearch] duckduckgo-search 搜索失败: {e}")
|
||||
|
||||
|
||||
# 方式 3: 尝试用简单 HTTP 请求
|
||||
try:
|
||||
return self._search_http(query, num_results)
|
||||
@@ -97,6 +86,34 @@ class WebSearchTool:
|
||||
# 方式 4: 返回模拟数据作为最后兜底
|
||||
return self._search_mock(query, num_results)
|
||||
|
||||
def _search_tavily(self, query: str, max_results: int) -> List[SearchResult]:
|
||||
"""使用 Tavily API 搜索"""
|
||||
from tavily import TavilyClient
|
||||
from app.config import TAVILY_API_KEY, TAVILY_MAX_RESULTS
|
||||
|
||||
if not TAVILY_API_KEY:
|
||||
raise ValueError("TAVILY_API_KEY 未配置")
|
||||
|
||||
client = TavilyClient(api_key=TAVILY_API_KEY)
|
||||
response = client.search(
|
||||
query=query,
|
||||
max_results=min(max_results, TAVILY_MAX_RESULTS or 5),
|
||||
include_answer=True,
|
||||
include_raw_content=False
|
||||
)
|
||||
|
||||
results = []
|
||||
for item in response.get("results", []):
|
||||
results.append(SearchResult(
|
||||
title=item.get("title", ""),
|
||||
url=item.get("url", ""),
|
||||
snippet=item.get("content", ""),
|
||||
source="Tavily"
|
||||
))
|
||||
|
||||
print(f"[WebSearch] Tavily 返回 {len(results)} 条结果")
|
||||
return results
|
||||
|
||||
def _search_http(self, query: str, max_results: int) -> List[SearchResult]:
|
||||
"""用简单 HTTP 请求搜索(备用方案)- 尝试多个国内源"""
|
||||
print(f"[WebSearch] 尝试 HTTP 搜索")
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
"""
|
||||
LangGraph 核心组件重新导出
|
||||
统一导入入口,避免直接依赖 langgraph
|
||||
"""
|
||||
|
||||
from langgraph.graph import StateGraph, START, END, add_messages
|
||||
|
||||
__all__ = ["StateGraph", "START", "END", "add_messages"]
|
||||
229
backend/app/main_graph/main_graph_builder.py
Normal file
229
backend/app/main_graph/main_graph_builder.py
Normal file
@@ -0,0 +1,229 @@
|
||||
"""
|
||||
主图构建器 - 构建整合后的完整主图
|
||||
"""
|
||||
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
from typing import Dict, Any
|
||||
|
||||
from .state import MainGraphState
|
||||
from .nodes.reasoning import react_reason_node
|
||||
from .nodes.web_search import web_search_node
|
||||
from .nodes.error_handling import error_handling_node
|
||||
from .nodes.routing import init_state_node, route_by_reasoning, should_summarize
|
||||
from .nodes.hybrid_router import (
|
||||
hybrid_router_node,
|
||||
route_from_hybrid_decision,
|
||||
check_fast_path_success,
|
||||
)
|
||||
from .nodes.fast_paths import (
|
||||
fast_chitchat_node,
|
||||
fast_rag_node,
|
||||
fast_tool_node,
|
||||
)
|
||||
from .nodes.llm_call import create_dynamic_llm_call_node
|
||||
from .nodes.rag_nodes import rag_retrieve_node
|
||||
from .nodes.retrieve_memory import create_retrieve_memory_node
|
||||
from .nodes.memory_trigger import memory_trigger_node, set_mem0_client
|
||||
from .nodes.summarize import create_summarize_node
|
||||
from .nodes.finalize import finalize_node
|
||||
from ..subgraphs.contact import build_contact_subgraph
|
||||
from ..subgraphs.dictionary import build_dictionary_subgraph
|
||||
from ..subgraphs.news_analysis import build_news_analysis_subgraph
|
||||
from ..logger import info
|
||||
|
||||
from .subgraph_wrapper import create_subgraph_nodes
|
||||
|
||||
|
||||
# ========== 主图构建 ==========
|
||||
|
||||
def build_react_main_graph(
|
||||
chat_services: dict,
|
||||
tools=None,
|
||||
mem0_client=None,
|
||||
use_hybrid_router: bool = True
|
||||
) -> StateGraph:
|
||||
"""
|
||||
构建整合后的完整主图(支持混合路由 + 动态模型选择)
|
||||
|
||||
Args:
|
||||
chat_services: 模型名称 -> ChatModel 实例 的字典
|
||||
tools: 工具列表
|
||||
mem0_client: Mem0 客户端实例
|
||||
use_hybrid_router: 是否使用混合路由(快速路径 + React 循环)
|
||||
|
||||
Returns:
|
||||
StateGraph: 构建好的图
|
||||
"""
|
||||
# 创建图
|
||||
graph = StateGraph(MainGraphState)
|
||||
|
||||
# 设置全局 mem0_client
|
||||
if mem0_client:
|
||||
set_mem0_client(mem0_client)
|
||||
|
||||
# ========== 创建节点 ==========
|
||||
|
||||
# LLM 调用节点
|
||||
llm_node = create_dynamic_llm_call_node(chat_services, tools or [])
|
||||
|
||||
# 记忆节点
|
||||
retrieve_memory_node = None
|
||||
summarize_node = None
|
||||
if mem0_client:
|
||||
retrieve_memory_node = create_retrieve_memory_node(mem0_client)
|
||||
summarize_node = create_summarize_node(mem0_client)
|
||||
|
||||
# 子图节点
|
||||
contact_graph = build_contact_subgraph()
|
||||
dictionary_graph = build_dictionary_subgraph()
|
||||
news_analysis_graph = build_news_analysis_subgraph()
|
||||
subgraph_nodes = create_subgraph_nodes(
|
||||
contact_graph, dictionary_graph, news_analysis_graph
|
||||
)
|
||||
|
||||
# ========== 添加节点到图 ==========
|
||||
|
||||
# 阶段 1: 记忆检索
|
||||
if retrieve_memory_node:
|
||||
graph.add_node("retrieve_memory", retrieve_memory_node)
|
||||
graph.add_node("memory_trigger", memory_trigger_node)
|
||||
|
||||
# 阶段 2: 初始化
|
||||
graph.add_node("init_state", init_state_node)
|
||||
|
||||
# 阶段 3: 混合路由(可选)
|
||||
if use_hybrid_router:
|
||||
graph.add_node("hybrid_router", hybrid_router_node)
|
||||
graph.add_node("fast_chitchat", fast_chitchat_node)
|
||||
graph.add_node("fast_rag", fast_rag_node)
|
||||
graph.add_node("fast_tool", fast_tool_node)
|
||||
|
||||
# 阶段 4: React 循环推理(始终保留)
|
||||
graph.add_node("react_reason", react_reason_node)
|
||||
graph.add_node("rag_retrieve", rag_retrieve_node)
|
||||
graph.add_node("web_search", web_search_node)
|
||||
graph.add_node("handle_error", error_handling_node)
|
||||
|
||||
if llm_node is not None:
|
||||
graph.add_node("llm_call", llm_node)
|
||||
|
||||
# 子图节点
|
||||
for node_name, node_func in subgraph_nodes.items():
|
||||
graph.add_node(node_name, node_func)
|
||||
|
||||
# 阶段 5: 完成处理
|
||||
if summarize_node:
|
||||
graph.add_node("summarize", summarize_node)
|
||||
graph.add_node("finalize", finalize_node)
|
||||
|
||||
# ========== 添加边 ==========
|
||||
|
||||
# 阶段 1: 记忆检索
|
||||
_add_memory_edges(graph, retrieve_memory_node)
|
||||
|
||||
# 阶段 2: 初始化
|
||||
graph.add_edge("memory_trigger", "init_state")
|
||||
|
||||
# 阶段 3: 路由分支
|
||||
_add_routing_edges(graph, use_hybrid_router, llm_node)
|
||||
|
||||
# 阶段 4: React 循环边
|
||||
_add_react_loop_edges(graph, subgraph_nodes)
|
||||
|
||||
# 阶段 5: 完成阶段
|
||||
_add_finalize_edges(graph, llm_node, summarize_node)
|
||||
|
||||
info(f"✅ [图构建] 整合后的完整主图构建完成(混合路由: {use_hybrid_router})")
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
def _add_memory_edges(graph: StateGraph, retrieve_memory_node) -> None:
|
||||
"""添加记忆检索阶段的边"""
|
||||
if retrieve_memory_node:
|
||||
graph.add_edge(START, "retrieve_memory")
|
||||
graph.add_edge("retrieve_memory", "memory_trigger")
|
||||
else:
|
||||
graph.add_edge(START, "memory_trigger")
|
||||
|
||||
|
||||
def _add_routing_edges(graph: StateGraph, use_hybrid_router: bool, llm_node) -> None:
|
||||
"""添加路由阶段的边"""
|
||||
if use_hybrid_router:
|
||||
graph.add_edge("init_state", "hybrid_router")
|
||||
|
||||
# 混合路由条件分支
|
||||
graph.add_conditional_edges(
|
||||
"hybrid_router",
|
||||
route_from_hybrid_decision,
|
||||
{
|
||||
"fast_chitchat": "fast_chitchat",
|
||||
"fast_rag": "fast_rag",
|
||||
"fast_tool": "fast_tool",
|
||||
"react_loop": "react_reason"
|
||||
}
|
||||
)
|
||||
|
||||
# 快速路径的完成检查
|
||||
for fast_node in ["fast_chitchat", "fast_rag", "fast_tool"]:
|
||||
graph.add_conditional_edges(
|
||||
fast_node,
|
||||
check_fast_path_success,
|
||||
{
|
||||
"llm_call": "llm_call",
|
||||
"escalate": "react_reason"
|
||||
}
|
||||
)
|
||||
|
||||
info(f"✅ [图构建] 混合路由模式已启用")
|
||||
else:
|
||||
graph.add_edge("init_state", "react_reason")
|
||||
info(f"✅ [图构建] 纯 React 模式")
|
||||
|
||||
|
||||
def _add_react_loop_edges(graph: StateGraph, subgraph_nodes: Dict[str, Any]) -> None:
|
||||
"""添加 React 循环阶段的边"""
|
||||
subgraph_names = list(subgraph_nodes.keys())
|
||||
|
||||
# React 推理的条件分支
|
||||
graph.add_conditional_edges(
|
||||
"react_reason",
|
||||
route_by_reasoning,
|
||||
{
|
||||
"rag_retrieve": "rag_retrieve",
|
||||
"web_search": "web_search",
|
||||
**{name: name for name in subgraph_names},
|
||||
"handle_error": "handle_error",
|
||||
"llm_call": "llm_call"
|
||||
}
|
||||
)
|
||||
|
||||
# 循环边(回到 react_reason)
|
||||
loop_back_nodes = ["rag_retrieve", "web_search", "handle_error"] + subgraph_names
|
||||
for node_name in loop_back_nodes:
|
||||
graph.add_edge(node_name, "react_reason")
|
||||
|
||||
|
||||
def _add_finalize_edges(graph: StateGraph, llm_node, summarize_node) -> None:
|
||||
"""添加完成阶段的边"""
|
||||
if llm_node is not None:
|
||||
if summarize_node:
|
||||
graph.add_conditional_edges(
|
||||
"llm_call",
|
||||
should_summarize,
|
||||
{
|
||||
"summarize": "summarize",
|
||||
"finalize": "finalize"
|
||||
}
|
||||
)
|
||||
graph.add_edge("summarize", "finalize")
|
||||
else:
|
||||
graph.add_edge("llm_call", "finalize")
|
||||
|
||||
graph.add_edge("finalize", END)
|
||||
|
||||
|
||||
# ========== 导出 ==========
|
||||
__all__ = [
|
||||
"build_react_main_graph",
|
||||
]
|
||||
@@ -6,8 +6,8 @@
|
||||
from .reasoning import react_reason_node
|
||||
from .web_search import web_search_node
|
||||
from .error_handling import error_handling_node
|
||||
from .routing import init_state_node, route_by_reasoning
|
||||
from .llm_call import create_llm_call_node
|
||||
from .routing import init_state_node, route_by_reasoning, should_summarize
|
||||
from .llm_call import create_dynamic_llm_call_node
|
||||
from .rag_nodes import rag_retrieve_node, rag_re_retrieve_node
|
||||
|
||||
# 记忆节点
|
||||
@@ -38,7 +38,8 @@ __all__ = [
|
||||
"web_search_node",
|
||||
"error_handling_node",
|
||||
"route_by_reasoning",
|
||||
"create_llm_call_node",
|
||||
"should_summarize",
|
||||
"create_dynamic_llm_call_node",
|
||||
"rag_retrieve_node",
|
||||
"rag_re_retrieve_node",
|
||||
# 记忆节点
|
||||
|
||||
@@ -5,7 +5,7 @@ LLM 调用节点模块
|
||||
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
from langchain_core.language_models import BaseLLM
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
# 本地模块
|
||||
@@ -14,29 +14,34 @@ from app.agent.prompts import create_system_prompt
|
||||
from app.utils.logging import log_state_change
|
||||
from app.logger import debug, info, error
|
||||
|
||||
def create_llm_call_node(llm, tools: list):
|
||||
|
||||
def create_dynamic_llm_call_node(chat_services: Dict[str, BaseChatModel], tools: list):
|
||||
"""
|
||||
工厂函数:创建 LLM 调用节点
|
||||
|
||||
工厂函数:创建动态 LLM 调用节点(根据 state.current_model 选择模型)
|
||||
|
||||
Args:
|
||||
llm: LangChain LLM 实例
|
||||
chat_services: 模型名称 -> ChatModel 实例 的字典
|
||||
tools: 工具列表
|
||||
|
||||
|
||||
Returns:
|
||||
异步节点函数
|
||||
"""
|
||||
# 构建调用链
|
||||
# 预构建所有模型的 tools 绑定(避免每次调用都 bind)
|
||||
bound_models: Dict[str, Any] = {}
|
||||
for name, llm in chat_services.items():
|
||||
if tools:
|
||||
bound_models[name] = llm.bind_tools(tools)
|
||||
else:
|
||||
bound_models[name] = llm
|
||||
|
||||
# 预构建 prompt
|
||||
prompt = create_system_prompt(tools)
|
||||
llm_with_tools = llm.bind_tools(tools)
|
||||
|
||||
# 恢复带 RunnableLambda 的链,并在下方使用 astream 遍历
|
||||
chain = prompt | llm_with_tools
|
||||
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
|
||||
async def call_llm(state: MainGraphState, config: RunnableConfig) -> Dict[str, Any]:
|
||||
"""
|
||||
LLM 调用节点(异步方法)
|
||||
LLM 调用节点(动态选择模型)
|
||||
|
||||
Args:
|
||||
state: 当前对话状态
|
||||
@@ -46,7 +51,7 @@ def create_llm_call_node(llm, tools: list):
|
||||
更新后的状态字典
|
||||
"""
|
||||
log_state_change("llm_call", state, "进入")
|
||||
|
||||
|
||||
memory_context = getattr(state, "memory_context", "暂无用户信息")
|
||||
start_time = time.time()
|
||||
|
||||
@@ -62,9 +67,20 @@ def create_llm_call_node(llm, tools: list):
|
||||
"last_elapsed_time": elapsed_time,
|
||||
"turns_since_last_summary": getattr(state, 'turns_since_last_summary', 0) + 1,
|
||||
}
|
||||
|
||||
|
||||
# 动态选择模型
|
||||
model_name = getattr(state, "current_model", "")
|
||||
if not model_name or model_name not in bound_models:
|
||||
# 回退到第一个可用模型
|
||||
fallback_name = next(iter(bound_models.keys()))
|
||||
info(f"[llm_call] 模型 '{model_name}' 不可用,回退到 '{fallback_name}'")
|
||||
model_name = fallback_name
|
||||
|
||||
llm_with_tools = bound_models[model_name]
|
||||
info(f"[llm_call] 使用模型: {model_name}")
|
||||
|
||||
try:
|
||||
# 添加 RAG 上下文到消息
|
||||
# 添加上下文到消息
|
||||
messages_with_context = list(state.messages)
|
||||
if state.rag_context:
|
||||
from langchain_core.messages import SystemMessage
|
||||
@@ -77,9 +93,10 @@ def create_llm_call_node(llm, tools: list):
|
||||
break
|
||||
if not inserted:
|
||||
messages_with_context.insert(0, rag_system_msg)
|
||||
|
||||
|
||||
# 恢复为:手动进行 astream,并将所有的 chunk 拼接成最终的 response 返回。
|
||||
# LangGraph 会自动监听这期间产生的所有 token。
|
||||
chain = prompt | llm_with_tools
|
||||
chunks = []
|
||||
async for chunk in chain.astream(
|
||||
{
|
||||
@@ -89,7 +106,7 @@ def create_llm_call_node(llm, tools: list):
|
||||
config=config
|
||||
):
|
||||
chunks.append(chunk)
|
||||
|
||||
|
||||
# 将所有 chunk 合并成最终的 AIMessage
|
||||
if chunks:
|
||||
response = chunks[0]
|
||||
@@ -97,14 +114,14 @@ def create_llm_call_node(llm, tools: list):
|
||||
response = response + chunk
|
||||
else:
|
||||
response = AIMessage(content="")
|
||||
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
|
||||
# 提取 token 用量(兼容不同 LLM 提供商的元数据格式)
|
||||
token_usage = {}
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
|
||||
|
||||
# 尝试从 response_metadata 中提取
|
||||
if hasattr(response, 'response_metadata') and response.response_metadata:
|
||||
meta = response.response_metadata
|
||||
@@ -112,33 +129,33 @@ def create_llm_call_node(llm, tools: list):
|
||||
token_usage = meta['token_usage']
|
||||
elif 'usage' in meta:
|
||||
token_usage = meta['usage']
|
||||
|
||||
|
||||
# 尝试从 additional_kwargs 中提取
|
||||
if not token_usage and hasattr(response, 'additional_kwargs'):
|
||||
add_kwargs = response.additional_kwargs
|
||||
if 'llm_output' in add_kwargs and 'token_usage' in add_kwargs['llm_output']:
|
||||
token_usage = add_kwargs['llm_output']['token_usage']
|
||||
|
||||
|
||||
# 提取具体的 token 数值
|
||||
if token_usage:
|
||||
input_tokens = token_usage.get('prompt_tokens', token_usage.get('input_tokens', 0))
|
||||
output_tokens = token_usage.get('completion_tokens', token_usage.get('output_tokens', 0))
|
||||
|
||||
|
||||
# 打印 LLM 的完整输出
|
||||
debug("\n" + "="*80)
|
||||
debug("📥 [LLM输出] 大模型返回的完整响应:")
|
||||
debug(f"📥 [LLM输出] 模型: {model_name} 返回的完整响应:")
|
||||
debug(f" 消息类型: {response.type.upper()}")
|
||||
debug(f" 内容长度: {len(str(response.content))} 字符")
|
||||
debug("-"*80)
|
||||
debug(f"{response.content}")
|
||||
|
||||
|
||||
# 打印响应统计信息
|
||||
info(f"⏱️ [LLM统计] 调用耗时: {elapsed_time:.2f}秒")
|
||||
info(f"📊 [LLM统计] Token用量: 输入={input_tokens}, 输出={output_tokens}, 总计={input_tokens + output_tokens}")
|
||||
if token_usage:
|
||||
debug(f"📋 [LLM统计] 详细用量: {token_usage}")
|
||||
debug("="*80 + "\n")
|
||||
|
||||
|
||||
# 检查是否有工具调用
|
||||
has_tool_calls = hasattr(response, 'tool_calls') and len(response.tool_calls) > 0
|
||||
|
||||
@@ -151,21 +168,22 @@ def create_llm_call_node(llm, tools: list):
|
||||
"final_result": response.content,
|
||||
"success": True,
|
||||
"current_phase": "done",
|
||||
"has_tool_calls": has_tool_calls
|
||||
"has_tool_calls": has_tool_calls,
|
||||
"current_model": model_name # 记录实际使用的模型
|
||||
}
|
||||
|
||||
|
||||
log_state_change("llm_call", state, "离开")
|
||||
return result
|
||||
|
||||
|
||||
except Exception as e:
|
||||
elapsed_time = time.time() - start_time
|
||||
error(f"\n❌ [LLM错误] 调用失败 (耗时: {elapsed_time:.2f}秒)")
|
||||
error(f"\n❌ [LLM错误] 模型 {model_name} 调用失败 (耗时: {elapsed_time:.2f}秒)")
|
||||
error(f" 错误类型: {type(e).__name__}")
|
||||
error(f" 错误信息: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
error(f"📋 堆栈: {traceback.format_exc()}")
|
||||
debug("="*80 + "\n")
|
||||
|
||||
|
||||
# 返回一个友好的错误消息
|
||||
error_response = AIMessage(
|
||||
content="抱歉,模型暂时无法响应,可能是网络超时或服务繁忙,请稍后再试。"
|
||||
@@ -178,10 +196,11 @@ def create_llm_call_node(llm, tools: list):
|
||||
"turns_since_last_summary": getattr(state, 'turns_since_last_summary', 0) + 1,
|
||||
"final_result": "抱歉,模型暂时无法响应,可能是网络超时或服务繁忙,请稍后再试。",
|
||||
"success": False,
|
||||
"current_phase": "done"
|
||||
"current_phase": "done",
|
||||
"current_model": model_name
|
||||
}
|
||||
|
||||
|
||||
log_state_change("llm_call", state, "离开(异常)")
|
||||
return error_result
|
||||
|
||||
return call_llm
|
||||
|
||||
return call_llm
|
||||
|
||||
@@ -118,3 +118,21 @@ def route_by_reasoning(state: MainGraphState) -> str:
|
||||
|
||||
info(f"[条件路由] 动作={latest_action}, 目标={target}")
|
||||
return target
|
||||
|
||||
|
||||
# ========== 完成阶段条件路由函数 ==========
|
||||
|
||||
def should_summarize(state: MainGraphState) -> str:
|
||||
"""
|
||||
检查是否需要总结对话(对话足够长时)
|
||||
|
||||
Args:
|
||||
state: 当前图状态
|
||||
|
||||
Returns:
|
||||
"summarize" 或 "finalize"
|
||||
"""
|
||||
if state.turns_since_last_summary >= 5: # 每5轮对话总结一次
|
||||
return "summarize"
|
||||
else:
|
||||
return "finalize"
|
||||
|
||||
@@ -6,7 +6,7 @@ Main Graph State Definition - React Mode Enhanced
|
||||
from enum import Enum, auto
|
||||
from typing import Optional, Dict, Any, Annotated, Sequence, TypedDict, List
|
||||
from dataclasses import dataclass, field
|
||||
from app.main_graph.graph import add_messages
|
||||
from langgraph.graph import add_messages
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
|
||||
@@ -62,6 +62,7 @@ class MainGraphState:
|
||||
# ========== 主图控制字段 ==========
|
||||
user_query: str = ""
|
||||
current_action: CurrentAction = CurrentAction.NONE
|
||||
current_model: str = "" # 新增:本次请求使用的模型
|
||||
intent_confidence: float = 0.0
|
||||
|
||||
# ========== React 推理专用字段 ==========
|
||||
|
||||
159
backend/app/main_graph/subgraph_wrapper.py
Normal file
159
backend/app/main_graph/subgraph_wrapper.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""
|
||||
子图包装器 - 为子图添加错误处理和事件追踪
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from .state import MainGraphState, ErrorRecord, ErrorSeverity
|
||||
from ..logger import info
|
||||
|
||||
|
||||
def wrap_subgraph_for_error_handling(subgraph, name: str):
|
||||
"""
|
||||
包装子图,使其错误能传递给主图
|
||||
|
||||
Args:
|
||||
subgraph: 编译好的子图
|
||||
name: 子图名称(用于错误标识)
|
||||
|
||||
Returns: 包装后的节点函数
|
||||
"""
|
||||
async def wrapped_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState:
|
||||
# 发送子图开始事件
|
||||
if config:
|
||||
try:
|
||||
from langchain_core.callbacks.manager import adispatch_custom_event
|
||||
callbacks = config.get("callbacks")
|
||||
if callbacks:
|
||||
await adispatch_custom_event(
|
||||
"react_reasoning",
|
||||
{
|
||||
"step": state.reasoning_step,
|
||||
"action": f"{name}_subgraph_start",
|
||||
"confidence": 1.0,
|
||||
"reasoning": f"开始执行 {name} 子图..."
|
||||
},
|
||||
callbacks=callbacks
|
||||
)
|
||||
except Exception as e:
|
||||
info(f"[{name}_subgraph] 无法发送开始事件: {e}")
|
||||
|
||||
try:
|
||||
# 调用子图
|
||||
result = subgraph.invoke(state)
|
||||
|
||||
# 更新主图状态
|
||||
subgraph_result = None
|
||||
if name == "contact":
|
||||
state.contact_result = result
|
||||
subgraph_result = result.get("final_result", "")
|
||||
elif name == "dictionary":
|
||||
state.dictionary_result = result
|
||||
subgraph_result = result.get("final_result", "")
|
||||
elif name == "news_analysis":
|
||||
state.news_result = result
|
||||
subgraph_result = result.get("final_result", "")
|
||||
|
||||
# 设置最终结果
|
||||
if subgraph_result:
|
||||
state.final_result = subgraph_result
|
||||
else:
|
||||
state.final_result = "子图执行完成"
|
||||
|
||||
# 标记成功
|
||||
state.success = True
|
||||
state.current_phase = "done"
|
||||
state.reasoning_history.append({
|
||||
"step": state.reasoning_step,
|
||||
"action": "subgraph_completed",
|
||||
"confidence": 1.0,
|
||||
"reasoning": f"{name}子图执行完成",
|
||||
"timestamp": datetime.now().isoformat()
|
||||
})
|
||||
|
||||
# 发送子图完成事件
|
||||
if config:
|
||||
try:
|
||||
from langchain_core.callbacks.manager import adispatch_custom_event
|
||||
callbacks = config.get("callbacks")
|
||||
if callbacks:
|
||||
await adispatch_custom_event(
|
||||
"react_reasoning",
|
||||
{
|
||||
"step": state.reasoning_step,
|
||||
"action": f"{name}_subgraph_complete",
|
||||
"confidence": 1.0,
|
||||
"reasoning": f"{name} 子图执行完成"
|
||||
},
|
||||
callbacks=callbacks
|
||||
)
|
||||
except Exception as e:
|
||||
info(f"[{name}_subgraph] 无法发送完成事件: {e}")
|
||||
|
||||
return state
|
||||
|
||||
except Exception as e:
|
||||
# 捕获子图错误,传递给主图
|
||||
error_record = ErrorRecord(
|
||||
error_type=f"{name}SubgraphError",
|
||||
error_message=str(e),
|
||||
severity=ErrorSeverity.WARNING,
|
||||
source=f"{name}_subgraph",
|
||||
timestamp=datetime.now().isoformat(),
|
||||
retry_count=0,
|
||||
max_retries=1,
|
||||
context={"user_query": state.user_query}
|
||||
)
|
||||
state.errors.append(error_record)
|
||||
state.current_error = error_record
|
||||
state.current_phase = "error_handling"
|
||||
state.success = False
|
||||
|
||||
# 发送子图错误事件
|
||||
if config:
|
||||
try:
|
||||
from langchain_core.callbacks.manager import adispatch_custom_event
|
||||
callbacks = config.get("callbacks")
|
||||
if callbacks:
|
||||
await adispatch_custom_event(
|
||||
"react_reasoning",
|
||||
{
|
||||
"step": state.reasoning_step,
|
||||
"action": f"{name}_subgraph_error",
|
||||
"confidence": 1.0,
|
||||
"reasoning": f"{name} 子图执行失败: {str(e)}"
|
||||
},
|
||||
callbacks=callbacks
|
||||
)
|
||||
except Exception as e:
|
||||
info(f"[{name}_subgraph] 无法发送错误事件: {e}")
|
||||
|
||||
return state
|
||||
|
||||
return wrapped_node
|
||||
|
||||
|
||||
def create_subgraph_nodes(contact_graph, dictionary_graph, news_analysis_graph) -> Dict[str, Any]:
|
||||
"""
|
||||
创建所有子图节点的字典
|
||||
|
||||
Args:
|
||||
contact_graph: 联系人子图
|
||||
dictionary_graph: 词典子图
|
||||
news_analysis_graph: 新闻分析子图
|
||||
|
||||
Returns:
|
||||
子图节点字典 {name: wrapped_node}
|
||||
"""
|
||||
return {
|
||||
"contact_subgraph": wrap_subgraph_for_error_handling(
|
||||
contact_graph.compile(), "contact"
|
||||
),
|
||||
"dictionary_subgraph": wrap_subgraph_for_error_handling(
|
||||
dictionary_graph.compile(), "dictionary"
|
||||
),
|
||||
"news_analysis_subgraph": wrap_subgraph_for_error_handling(
|
||||
news_analysis_graph.compile(), "news_analysis"
|
||||
),
|
||||
}
|
||||
@@ -1 +0,0 @@
|
||||
"""主图工具函数"""
|
||||
@@ -1,371 +0,0 @@
|
||||
"""
|
||||
整合后的完整主图构建器 - 所有节点都直接操作 MainGraphState
|
||||
"""
|
||||
|
||||
from ..graph import StateGraph, START, END
|
||||
from typing import Dict, Any, Optional
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from ..state import MainGraphState
|
||||
from ..nodes.reasoning import react_reason_node
|
||||
from ..nodes.web_search import web_search_node
|
||||
from ..nodes.error_handling import error_handling_node
|
||||
from ..nodes.routing import init_state_node, route_by_reasoning
|
||||
from ..nodes.hybrid_router import (
|
||||
hybrid_router_node,
|
||||
route_from_hybrid_decision,
|
||||
check_fast_path_success,
|
||||
)
|
||||
from ..nodes.fast_paths import (
|
||||
fast_chitchat_node,
|
||||
fast_rag_node,
|
||||
fast_tool_node,
|
||||
)
|
||||
from ..nodes.llm_call import create_llm_call_node
|
||||
from ..nodes.rag_nodes import rag_retrieve_node
|
||||
from ..nodes.retrieve_memory import create_retrieve_memory_node
|
||||
from ..nodes.memory_trigger import memory_trigger_node, set_mem0_client
|
||||
from ..nodes.summarize import create_summarize_node
|
||||
from ..nodes.finalize import finalize_node
|
||||
from ...subgraphs.contact import build_contact_subgraph
|
||||
from ...subgraphs.dictionary import build_dictionary_subgraph
|
||||
from ...subgraphs.news_analysis import build_news_analysis_subgraph
|
||||
from ...memory.mem0_client import Mem0Client
|
||||
from ...logger import info, debug
|
||||
|
||||
|
||||
# ========== 检查是否需要总结 ==========
|
||||
def should_summarize(state: MainGraphState) -> str:
|
||||
"""
|
||||
检查是否需要总结对话(对话足够长时)
|
||||
|
||||
Args:
|
||||
state: 当前图状态
|
||||
|
||||
Returns:
|
||||
"summarize" 或 "finalize"
|
||||
"""
|
||||
if state.turns_since_last_summary >= 5: # 每5轮对话总结一次
|
||||
return "summarize"
|
||||
else:
|
||||
return "finalize"
|
||||
|
||||
|
||||
# ========== 子图包装器(处理子图错误传递)==========
|
||||
def wrap_subgraph_for_error_handling(subgraph, name: str):
|
||||
"""
|
||||
包装子图,使其错误能传递给主图
|
||||
|
||||
Args:
|
||||
subgraph: 编译好的子图
|
||||
name: 子图名称(用于错误标识)
|
||||
|
||||
Returns: 包装后的节点函数
|
||||
"""
|
||||
async def wrapped_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState:
|
||||
# 发送子图开始事件
|
||||
if config:
|
||||
try:
|
||||
from langchain_core.callbacks.manager import adispatch_custom_event
|
||||
callbacks = config.get("callbacks")
|
||||
if callbacks:
|
||||
await adispatch_custom_event(
|
||||
"react_reasoning",
|
||||
{
|
||||
"step": state.reasoning_step,
|
||||
"action": f"{name}_subgraph_start",
|
||||
"confidence": 1.0,
|
||||
"reasoning": f"开始执行 {name} 子图..."
|
||||
},
|
||||
callbacks=callbacks
|
||||
)
|
||||
except Exception as e:
|
||||
info(f"[{name}_subgraph] 无法发送开始事件: {e}")
|
||||
|
||||
try:
|
||||
# 调用子图
|
||||
result = subgraph.invoke(state)
|
||||
|
||||
# 更新主图状态
|
||||
subgraph_result = None
|
||||
if name == "contact":
|
||||
state.contact_result = result
|
||||
subgraph_result = result.get("final_result", "")
|
||||
elif name == "dictionary":
|
||||
state.dictionary_result = result
|
||||
subgraph_result = result.get("final_result", "")
|
||||
elif name == "news_analysis":
|
||||
state.news_result = result
|
||||
subgraph_result = result.get("final_result", "")
|
||||
|
||||
# 关键:设置最终结果
|
||||
if subgraph_result:
|
||||
state.final_result = subgraph_result
|
||||
else:
|
||||
state.final_result = "子图执行完成"
|
||||
|
||||
# 标记成功
|
||||
state.success = True
|
||||
state.current_phase = "done"
|
||||
# 标记不再需要推理,避免循环
|
||||
state.reasoning_history.append({
|
||||
"step": state.reasoning_step,
|
||||
"action": "subgraph_completed",
|
||||
"confidence": 1.0,
|
||||
"reasoning": f"{name}子图执行完成",
|
||||
"timestamp": datetime.now().isoformat()
|
||||
})
|
||||
|
||||
# 发送子图完成事件
|
||||
if config:
|
||||
try:
|
||||
from langchain_core.callbacks.manager import adispatch_custom_event
|
||||
callbacks = config.get("callbacks")
|
||||
if callbacks:
|
||||
await adispatch_custom_event(
|
||||
"react_reasoning",
|
||||
{
|
||||
"step": state.reasoning_step,
|
||||
"action": f"{name}_subgraph_complete",
|
||||
"confidence": 1.0,
|
||||
"reasoning": f"{name} 子图执行完成"
|
||||
},
|
||||
callbacks=callbacks
|
||||
)
|
||||
except Exception as e:
|
||||
info(f"[{name}_subgraph] 无法发送完成事件: {e}")
|
||||
|
||||
return state
|
||||
|
||||
except Exception as e:
|
||||
# 捕获子图错误,传递给主图
|
||||
from ..state import ErrorRecord, ErrorSeverity
|
||||
from datetime import datetime
|
||||
|
||||
error_record = ErrorRecord(
|
||||
error_type=f"{name}SubgraphError",
|
||||
error_message=str(e),
|
||||
severity=ErrorSeverity.WARNING,
|
||||
source=f"{name}_subgraph",
|
||||
timestamp=datetime.now().isoformat(),
|
||||
retry_count=0,
|
||||
max_retries=1,
|
||||
context={"user_query": state.user_query}
|
||||
)
|
||||
state.errors.append(error_record)
|
||||
state.current_error = error_record
|
||||
state.current_phase = "error_handling"
|
||||
state.success = False
|
||||
|
||||
# 发送子图错误事件
|
||||
if config:
|
||||
try:
|
||||
from langchain_core.callbacks.manager import adispatch_custom_event
|
||||
callbacks = config.get("callbacks")
|
||||
if callbacks:
|
||||
await adispatch_custom_event(
|
||||
"react_reasoning",
|
||||
{
|
||||
"step": state.reasoning_step,
|
||||
"action": f"{name}_subgraph_error",
|
||||
"confidence": 1.0,
|
||||
"reasoning": f"{name} 子图执行失败: {str(e)}"
|
||||
},
|
||||
callbacks=callbacks
|
||||
)
|
||||
except Exception as e:
|
||||
info(f"[{name}_subgraph] 无法发送错误事件: {e}")
|
||||
|
||||
return state
|
||||
|
||||
return wrapped_node
|
||||
|
||||
# ========== 主图构建 ==========
|
||||
|
||||
def build_react_main_graph(llm=None, tools=None, mem0_client=None, use_hybrid_router: bool = True) -> StateGraph:
|
||||
"""
|
||||
构建整合后的完整主图(支持混合路由)
|
||||
|
||||
Args:
|
||||
llm: LangChain ChatModel 实例
|
||||
tools: 工具列表
|
||||
mem0_client: Mem0 客户端实例
|
||||
use_hybrid_router: 是否使用混合路由(快速路径 + React 循环)
|
||||
|
||||
Returns:
|
||||
StateGraph: 构建好的图
|
||||
"""
|
||||
# 创建图
|
||||
graph = StateGraph(MainGraphState)
|
||||
|
||||
# 设置全局 mem0_client
|
||||
if mem0_client:
|
||||
set_mem0_client(mem0_client)
|
||||
|
||||
# 创建节点
|
||||
llm_node = None
|
||||
if llm is not None:
|
||||
llm_node = create_llm_call_node(llm, tools or [])
|
||||
|
||||
retrieve_memory_node = None
|
||||
summarize_node = None
|
||||
if mem0_client:
|
||||
retrieve_memory_node = create_retrieve_memory_node(mem0_client)
|
||||
summarize_node = create_summarize_node(mem0_client)
|
||||
|
||||
# ========== 添加节点 ==========
|
||||
|
||||
# 第一阶段:记忆检索
|
||||
if retrieve_memory_node:
|
||||
graph.add_node("retrieve_memory", retrieve_memory_node)
|
||||
graph.add_node("memory_trigger", memory_trigger_node)
|
||||
|
||||
# 第二阶段:初始化
|
||||
graph.add_node("init_state", init_state_node)
|
||||
|
||||
# ========== 混合路由节点(如果启用) ==========
|
||||
if use_hybrid_router:
|
||||
graph.add_node("hybrid_router", hybrid_router_node)
|
||||
graph.add_node("fast_chitchat", fast_chitchat_node)
|
||||
graph.add_node("fast_rag", fast_rag_node)
|
||||
graph.add_node("fast_tool", fast_tool_node)
|
||||
|
||||
# 第三阶段:React 循环推理(始终保留)
|
||||
graph.add_node("react_reason", react_reason_node)
|
||||
graph.add_node("rag_retrieve", rag_retrieve_node)
|
||||
graph.add_node("web_search", web_search_node)
|
||||
graph.add_node("handle_error", error_handling_node)
|
||||
|
||||
if llm_node is not None:
|
||||
graph.add_node("llm_call", llm_node)
|
||||
|
||||
# 子图节点
|
||||
contact_graph = build_contact_subgraph()
|
||||
dictionary_graph = build_dictionary_subgraph()
|
||||
news_analysis_graph = build_news_analysis_subgraph()
|
||||
|
||||
graph.add_node(
|
||||
"contact_subgraph",
|
||||
wrap_subgraph_for_error_handling(contact_graph.compile(), "contact")
|
||||
)
|
||||
graph.add_node(
|
||||
"dictionary_subgraph",
|
||||
wrap_subgraph_for_error_handling(dictionary_graph.compile(), "dictionary")
|
||||
)
|
||||
graph.add_node(
|
||||
"news_analysis_subgraph",
|
||||
wrap_subgraph_for_error_handling(news_analysis_graph.compile(), "news_analysis")
|
||||
)
|
||||
|
||||
# 第四阶段:完成处理
|
||||
if summarize_node:
|
||||
graph.add_node("summarize", summarize_node)
|
||||
graph.add_node("finalize", finalize_node)
|
||||
|
||||
# ========== 添加边 ==========
|
||||
|
||||
# 第一阶段:记忆检索
|
||||
if retrieve_memory_node:
|
||||
graph.add_edge(START, "retrieve_memory")
|
||||
graph.add_edge("retrieve_memory", "memory_trigger")
|
||||
else:
|
||||
graph.add_edge(START, "memory_trigger")
|
||||
|
||||
# 进入初始化
|
||||
graph.add_edge("memory_trigger", "init_state")
|
||||
|
||||
# ========== 混合路由分支(如果启用) ==========
|
||||
if use_hybrid_router:
|
||||
graph.add_edge("init_state", "hybrid_router")
|
||||
|
||||
# 从 hybrid_router 条件分支
|
||||
graph.add_conditional_edges(
|
||||
"hybrid_router",
|
||||
route_from_hybrid_decision,
|
||||
{
|
||||
"fast_chitchat": "fast_chitchat",
|
||||
"fast_rag": "fast_rag",
|
||||
"fast_tool": "fast_tool",
|
||||
"react_loop": "react_reason"
|
||||
}
|
||||
)
|
||||
|
||||
# 快速路径的完成检查
|
||||
for fast_node in ["fast_chitchat", "fast_rag", "fast_tool"]:
|
||||
graph.add_conditional_edges(
|
||||
fast_node,
|
||||
check_fast_path_success,
|
||||
{
|
||||
"llm_call": "llm_call",
|
||||
"escalate": "react_reason"
|
||||
}
|
||||
)
|
||||
|
||||
info(f"✅ [图构建] 混合路由模式已启用")
|
||||
else:
|
||||
# 无混合路由,直接到 react_reason
|
||||
graph.add_edge("init_state", "react_reason")
|
||||
info(f"✅ [图构建] 纯 React 模式")
|
||||
|
||||
# ========== React 循环边(始终保留) ==========
|
||||
graph.add_conditional_edges(
|
||||
"react_reason",
|
||||
route_by_reasoning,
|
||||
{
|
||||
"rag_retrieve": "rag_retrieve",
|
||||
"web_search": "web_search",
|
||||
"contact_subgraph": "contact_subgraph",
|
||||
"dictionary_subgraph": "dictionary_subgraph",
|
||||
"news_analysis_subgraph": "news_analysis_subgraph",
|
||||
"handle_error": "handle_error",
|
||||
"llm_call": "llm_call"
|
||||
}
|
||||
)
|
||||
|
||||
# 循环边(rag、web_search、子图、error都回到 reason)
|
||||
graph.add_edge("rag_retrieve", "react_reason")
|
||||
graph.add_edge("web_search", "react_reason")
|
||||
graph.add_edge("contact_subgraph", "react_reason")
|
||||
graph.add_edge("dictionary_subgraph", "react_reason")
|
||||
graph.add_edge("news_analysis_subgraph", "react_reason")
|
||||
graph.add_edge("handle_error", "react_reason")
|
||||
|
||||
# ========== 最终完成阶段 ==========
|
||||
if llm_node is not None:
|
||||
if summarize_node:
|
||||
# 检查是否需要总结
|
||||
graph.add_conditional_edges(
|
||||
"llm_call",
|
||||
should_summarize,
|
||||
{
|
||||
"summarize": "summarize",
|
||||
"finalize": "finalize"
|
||||
}
|
||||
)
|
||||
graph.add_edge("summarize", "finalize")
|
||||
else:
|
||||
# 没有 summarize 节点,直接 finalize
|
||||
graph.add_edge("llm_call", "finalize")
|
||||
|
||||
# 完成
|
||||
graph.add_edge("finalize", END)
|
||||
|
||||
info(f"✅ [图构建] 整合后的完整主图构建完成(混合路由: {use_hybrid_router})")
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
# ========== 兼容性:保留旧的函数名 ==========
|
||||
def build_main_graph() -> StateGraph:
|
||||
"""
|
||||
兼容性函数:旧代码调用 build_main_graph() 时返回 React 版本
|
||||
"""
|
||||
return build_react_main_graph()
|
||||
|
||||
|
||||
# ========== 导出 ==========
|
||||
__all__ = [
|
||||
"build_react_main_graph",
|
||||
"build_main_graph",
|
||||
"wrap_subgraph_for_error_handling"
|
||||
]
|
||||
@@ -6,11 +6,17 @@
|
||||
|
||||
from .embedding_services import get_embedding_service
|
||||
from .rerank_services import get_rerank_service, BaseRerankService
|
||||
from .chat_services import get_small_llm_service
|
||||
from .chat_services import (
|
||||
get_small_llm_service,
|
||||
get_cached_chat_services,
|
||||
get_all_chat_services
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"get_embedding_service",
|
||||
"get_rerank_service",
|
||||
"get_small_llm_service",
|
||||
"get_cached_chat_services",
|
||||
"get_all_chat_services",
|
||||
"BaseRerankService"
|
||||
]
|
||||
|
||||
@@ -33,6 +33,21 @@ from app.config import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 缓存已初始化的模型字典
|
||||
_cached_services: Dict[str, BaseChatModel] | None = None
|
||||
|
||||
|
||||
def _check_http_service_available(base_url: str, api_key: str = "", timeout: float = 2.0) -> bool:
|
||||
"""通过探测 /models 端点检查 HTTP API 是否可用(内部工具函数)"""
|
||||
try:
|
||||
import httpx
|
||||
client = httpx.Client(base_url=base_url.rstrip('/'), timeout=timeout)
|
||||
headers = {"Authorization": f"Bearer {api_key}"} if api_key else {}
|
||||
resp = client.get("/models", headers=headers)
|
||||
return resp.status_code == 200
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
class LocalVLLMChatProvider(BaseServiceProvider[BaseChatModel]):
|
||||
"""
|
||||
@@ -54,46 +69,8 @@ class LocalVLLMChatProvider(BaseServiceProvider[BaseChatModel]):
|
||||
logger.warning("VLLM_BASE_URL 未配置")
|
||||
return False
|
||||
|
||||
try:
|
||||
# 先测试主机名能否解析
|
||||
import httpx
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed_url = urlparse(VLLM_BASE_URL)
|
||||
host = parsed_url.hostname
|
||||
port = parsed_url.port or (80 if parsed_url.scheme == 'http' else 443)
|
||||
|
||||
# 测试能否建立 TCP 连接(快速失败)
|
||||
import socket
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.settimeout(2.0)
|
||||
try:
|
||||
sock.connect((host, port))
|
||||
sock.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"本地 VLLM 服务无法连接: {host}:{port} - {e}")
|
||||
return False
|
||||
|
||||
# 再尝试调用简单的 API(比如 models 接口)
|
||||
client = httpx.Client(base_url=VLLM_BASE_URL.rstrip('/'), timeout=5.0)
|
||||
headers = {}
|
||||
if LLM_API_KEY:
|
||||
headers["Authorization"] = f"Bearer {LLM_API_KEY}"
|
||||
|
||||
try:
|
||||
response = client.get("/models", headers=headers)
|
||||
if response.status_code == 200:
|
||||
logger.info(f"本地 VLLM 服务可用: {self._model}")
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 如果 /v1/models 失败,也认为服务不可用
|
||||
logger.warning(f"本地 VLLM 服务响应异常")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(f"本地 VLLM 服务不可用: {e}")
|
||||
return False
|
||||
# 使用统一的 HTTP 探测方法
|
||||
return _check_http_service_available(VLLM_BASE_URL, LLM_API_KEY, timeout=2.0)
|
||||
|
||||
def get_service(self) -> BaseChatModel:
|
||||
"""
|
||||
@@ -238,45 +215,8 @@ class LocalSmallModelProvider(BaseServiceProvider[BaseChatModel]):
|
||||
logger.warning("SMALL_VLLM_BASE_URL 未配置,本地小模型不可用")
|
||||
return False
|
||||
|
||||
try:
|
||||
# 先测试主机名能否解析
|
||||
import httpx
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed_url = urlparse(self._base_url)
|
||||
host = parsed_url.hostname
|
||||
port = parsed_url.port or (80 if parsed_url.scheme == 'http' else 443)
|
||||
|
||||
# 测试能否建立 TCP 连接(快速失败)
|
||||
import socket
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.settimeout(2.0)
|
||||
try:
|
||||
sock.connect((host, port))
|
||||
sock.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"本地小模型服务无法连接: {host}:{port} - {e}")
|
||||
return False
|
||||
|
||||
# 再尝试调用简单的 API
|
||||
client = httpx.Client(base_url=self._base_url.rstrip('/'), timeout=5.0)
|
||||
headers = {}
|
||||
if self._api_key:
|
||||
headers["Authorization"] = f"Bearer {self._api_key}"
|
||||
|
||||
try:
|
||||
response = client.get("/models", headers=headers)
|
||||
if response.status_code == 200:
|
||||
logger.info(f"本地小模型服务可用: {self._model}")
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.warning(f"本地小模型服务响应异常")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(f"本地小模型服务不可用: {e}")
|
||||
return False
|
||||
# 使用统一的 HTTP 探测方法
|
||||
return _check_http_service_available(self._base_url, self._api_key, timeout=2.0)
|
||||
|
||||
def get_service(self) -> BaseChatModel:
|
||||
"""获取本地小模型服务"""
|
||||
@@ -358,25 +298,18 @@ def get_chat_service() -> BaseChatModel:
|
||||
return chain.get_available_service()
|
||||
|
||||
|
||||
def get_all_chat_services() -> Dict[str, BaseChatModel]:
|
||||
"""
|
||||
获取所有可用的生成式大模型服务(用于多模型切换)
|
||||
|
||||
Returns:
|
||||
Dict[str, BaseChatModel]: 模型名称 -> ChatModel 实例 的字典
|
||||
"""
|
||||
def _init_chat_services() -> Dict[str, BaseChatModel]:
|
||||
"""实际初始化所有可用模型(仅在首次调用)"""
|
||||
services = {}
|
||||
|
||||
for name, provider_factory in CHAT_PROVIDERS.items():
|
||||
try:
|
||||
provider = provider_factory()
|
||||
if provider.is_available():
|
||||
logger.info(f"模型 '{name}' 可用")
|
||||
services[name] = provider.get_service()
|
||||
else:
|
||||
logger.warning(f"模型 '{name}' 不可用,跳过")
|
||||
logger.info(f"已加载模型: {name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"初始化模型 '{name}' 失败: {e}")
|
||||
logger.warning(f"模型 {name} 初始化失败: {e}")
|
||||
|
||||
if not services:
|
||||
raise RuntimeError(f"没有可用的生成式大模型,尝试了: {list(CHAT_PROVIDERS.keys())}")
|
||||
@@ -384,6 +317,25 @@ def get_all_chat_services() -> Dict[str, BaseChatModel]:
|
||||
return services
|
||||
|
||||
|
||||
def get_cached_chat_services() -> Dict[str, BaseChatModel]:
|
||||
"""获取缓存的可用模型字典(用于单图动态注入)"""
|
||||
global _cached_services
|
||||
if _cached_services is None:
|
||||
_cached_services = _init_chat_services()
|
||||
return _cached_services
|
||||
|
||||
|
||||
def get_all_chat_services() -> Dict[str, BaseChatModel]:
|
||||
"""
|
||||
获取所有可用的生成式大模型服务(用于多模型切换,保留兼容性)
|
||||
新代码请使用 get_cached_chat_services() 获取缓存版本
|
||||
|
||||
Returns:
|
||||
Dict[str, BaseChatModel]: 模型名称 -> ChatModel 实例 的字典
|
||||
"""
|
||||
return get_cached_chat_services()
|
||||
|
||||
|
||||
def get_small_llm_service() -> BaseChatModel:
|
||||
"""
|
||||
获取轻量级大模型服务(用于查询改写、意图分类等简单任务)
|
||||
|
||||
@@ -4,7 +4,7 @@ Contact Subgraph Builder
|
||||
支持 API 注入的工厂模式
|
||||
"""
|
||||
|
||||
from app.main_graph.graph import StateGraph, START, END
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
|
||||
from .state import ContactState
|
||||
from .nodes import create_contact_nodes
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
Dictionary Subgraph Builder - Complete
|
||||
"""
|
||||
|
||||
from app.main_graph.graph import StateGraph, START, END
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
|
||||
from .state import DictionaryState
|
||||
from .nodes import (
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
News Analysis Subgraph Builder
|
||||
"""
|
||||
|
||||
from app.main_graph.graph import StateGraph, START, END
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
|
||||
from .state import NewsAnalysisState
|
||||
from .nodes import (
|
||||
|
||||
@@ -42,6 +42,7 @@ PyYAML>=6.0.3
|
||||
numpy>=1.26.2
|
||||
pyjwt>=2.8.0
|
||||
ddgs>=6.0.0 # 免费联网搜索(原 duckduckgo-search 已重命名)
|
||||
tavily-python>=0.5.0 # Tavily 搜索 API(需要 API Key,质量更高)
|
||||
matplotlib>=3.9.0 # 可视化图表
|
||||
|
||||
# Document Processing
|
||||
|
||||
@@ -66,6 +66,12 @@ services:
|
||||
- MEMORY_SUMMARIZE_INTERVAL=${MEMORY_SUMMARIZE_INTERVAL:-10}
|
||||
- FASTEMBED_CACHE_PATH=${FASTEMBED_CACHE_PATH:-/app/fastembed_cache}
|
||||
|
||||
# =========================================================================
|
||||
# Tavily 搜索配置(可选,有 API Key 时优先使用)
|
||||
# =========================================================================
|
||||
- TAVILY_API_KEY=${TAVILY_API_KEY:-}
|
||||
- TAVILY_MAX_RESULTS=${TAVILY_MAX_RESULTS:-5}
|
||||
|
||||
# =========================================================================
|
||||
# 前端通信地址(Docker 内部网络)
|
||||
# =========================================================================
|
||||
|
||||
@@ -11,6 +11,5 @@ sys.path.insert(0, str(backend_path))
|
||||
load_dotenv(project_root / ".env")
|
||||
|
||||
if __name__ == "__main__":
|
||||
from tools.test.test_graph_branches import main
|
||||
import asyncio
|
||||
asyncio.run(main())
|
||||
from tools.test import test_tavily_search
|
||||
test_tavily_search.main()
|
||||
|
||||
149
tools/test/test_tavily_search.py
Normal file
149
tools/test/test_tavily_search.py
Normal file
@@ -0,0 +1,149 @@
|
||||
#!/usr/bin/env python3
|
||||
"""测试 Tavily 搜索功能 - 直接调用 API"""
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 路径设置
|
||||
project_root = Path(__file__).resolve().parent.parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
load_dotenv(project_root / ".env")
|
||||
|
||||
import os
|
||||
|
||||
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
|
||||
TAVILY_MAX_RESULTS = int(os.getenv("TAVILY_MAX_RESULTS") or "5")
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchResult:
|
||||
"""搜索结果数据类"""
|
||||
title: str
|
||||
url: str
|
||||
snippet: str
|
||||
source: str = "DuckDuckGo"
|
||||
timestamp: datetime = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.timestamp is None:
|
||||
self.timestamp = datetime.now()
|
||||
|
||||
|
||||
def test_tavily_api_key():
|
||||
"""测试 API Key 配置"""
|
||||
print("=" * 60)
|
||||
print("测试 1: 检查 Tavily API Key")
|
||||
print("=" * 60)
|
||||
|
||||
if TAVILY_API_KEY:
|
||||
print(f"✓ TAVILY_API_KEY 已配置: {TAVILY_API_KEY[:15]}...")
|
||||
else:
|
||||
print("✗ TAVILY_API_KEY 未配置")
|
||||
print()
|
||||
|
||||
|
||||
def test_tavily_search_direct():
|
||||
"""直接测试 Tavily API"""
|
||||
print("=" * 60)
|
||||
print("测试 2: 直接调用 Tavily API")
|
||||
print("=" * 60)
|
||||
|
||||
if not TAVILY_API_KEY:
|
||||
print("✗ 未配置 API Key,跳过测试")
|
||||
return
|
||||
|
||||
from tavily import TavilyClient
|
||||
|
||||
client = TavilyClient(api_key=TAVILY_API_KEY)
|
||||
|
||||
test_queries = [
|
||||
"Python 编程语言最新版本",
|
||||
"LangGraph AI 框架",
|
||||
]
|
||||
|
||||
for query in test_queries:
|
||||
print(f"\n搜索: {query}")
|
||||
print("-" * 40)
|
||||
try:
|
||||
response = client.search(
|
||||
query=query,
|
||||
max_results=3,
|
||||
include_answer=True,
|
||||
include_raw_content=False
|
||||
)
|
||||
|
||||
print(f"✓ 搜索成功")
|
||||
print(f" - 结果数量: {len(response.get('results', []))}")
|
||||
|
||||
# 打印结果
|
||||
for i, item in enumerate(response.get("results", []), 1):
|
||||
print(f"\n [{i}] {item.get('title', '')}")
|
||||
print(f" URL: {item.get('url', '')}")
|
||||
print(f" 摘要: {item.get('content', '')[:100]}...")
|
||||
|
||||
# 如果有 answer
|
||||
if response.get("answer"):
|
||||
print(f"\n 🤖 AI 摘要: {response['answer'][:200]}...")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ 搜索失败: {e}")
|
||||
print()
|
||||
|
||||
|
||||
def test_web_search_integration():
|
||||
"""测试 web_search 模块集成"""
|
||||
print("=" * 60)
|
||||
print("测试 3: 测试 web_search 模块集成")
|
||||
print("=" * 60)
|
||||
|
||||
# 直接导入 web_search 模块(避免循环依赖)
|
||||
web_search_path = project_root / "backend" / "app" / "core" / "web_search.py"
|
||||
if not web_search_path.exists():
|
||||
print(f"✗ web_search.py 不存在于 {web_search_path}")
|
||||
return
|
||||
|
||||
print(f"✓ 找到 web_search.py: {web_search_path}")
|
||||
|
||||
# 使用 exec 动态加载模块
|
||||
import importlib.util
|
||||
spec = importlib.util.spec_from_file_location("web_search_module", web_search_path)
|
||||
web_search_module = importlib.util.module_from_spec(spec)
|
||||
|
||||
try:
|
||||
spec.loader.exec_module(web_search_module)
|
||||
print("✓ web_search 模块加载成功")
|
||||
except Exception as e:
|
||||
print(f"✗ 模块加载失败: {e}")
|
||||
return
|
||||
|
||||
# 测试搜索
|
||||
print("\n执行搜索测试:")
|
||||
try:
|
||||
result = web_search_module.web_search("今天天气怎么样", max_results=3)
|
||||
print(f"✓ 搜索成功,返回 {len(result)} 字符")
|
||||
print("-" * 40)
|
||||
print(result[:800] + "..." if len(result) > 800 else result)
|
||||
except Exception as e:
|
||||
print(f"✗ 搜索失败: {e}")
|
||||
print()
|
||||
|
||||
|
||||
def main():
|
||||
print("\n" + "=" * 60)
|
||||
print("🚀 Tavily 搜索功能测试")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
test_tavily_api_key()
|
||||
test_tavily_search_direct()
|
||||
test_web_search_integration()
|
||||
|
||||
print("=" * 60)
|
||||
print("✅ 测试完成")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -6,28 +6,19 @@ LangGraph 图结构可视化脚本
|
||||
"""
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 确定项目根目录(Agent1 目录)
|
||||
# 当前文件位置:tools/visualize_graph.py
|
||||
# 向上 1 级到 Agent1
|
||||
# 路径设置
|
||||
PROJECT_ROOT = Path(__file__).parent.parent
|
||||
BACKEND_DIR = PROJECT_ROOT / "backend"
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
# 关键:把 backend 目录加入 sys.path,这样才能找到 rag_core
|
||||
# 注意:这只对直接运行脚本有效,对 -m 方式无效(因为 -m 方式在脚本运行前就导入了)
|
||||
if str(BACKEND_DIR) not in sys.path:
|
||||
sys.path.insert(0, str(BACKEND_DIR))
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv(PROJECT_ROOT / ".env")
|
||||
|
||||
import asyncio
|
||||
from backend.app.agent.agent_service import AIAgentService
|
||||
from backend.app.config import DB_URI
|
||||
from backend.app.main_graph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
import asyncio
|
||||
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
|
||||
async def visualize_graph():
|
||||
"""可视化 LangGraph 结构"""
|
||||
@@ -37,6 +28,8 @@ async def visualize_graph():
|
||||
print(f"项目根目录: {PROJECT_ROOT}")
|
||||
print(f"Backend 目录: {BACKEND_DIR}")
|
||||
|
||||
|
||||
|
||||
async with AsyncPostgresSaver.from_conn_string(DB_URI) as checkpointer:
|
||||
await checkpointer.setup()
|
||||
|
||||
@@ -45,37 +38,52 @@ async def visualize_graph():
|
||||
agent_service = AIAgentService(checkpointer)
|
||||
await agent_service.initialize()
|
||||
|
||||
for model_name, graph in agent_service.graphs.items():
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f" 模型: {model_name}")
|
||||
print(f"{'=' * 80}")
|
||||
# 获取图(单图方案)
|
||||
graph = agent_service.graph
|
||||
print("\n✅ Agent 服务初始化完成")
|
||||
|
||||
# 获取图结构
|
||||
graph_structure = graph.get_graph()
|
||||
# 获取图结构
|
||||
graph_structure = graph.get_graph()
|
||||
|
||||
# 1. 直接打印节点和边
|
||||
print("\n[1] 节点列表:")
|
||||
print("-" * 80)
|
||||
for node_id, node in graph_structure.nodes.items():
|
||||
print(f" - {node_id}: {node.name}")
|
||||
# 1. 直接打印节点和边
|
||||
print("\n" + "=" * 80)
|
||||
print("[1] 节点列表")
|
||||
print("=" * 80)
|
||||
for node_id, node in graph_structure.nodes.items():
|
||||
print(f" 📦 {node_id}: {node.name}")
|
||||
|
||||
print("\n[2] 边列表:")
|
||||
print("-" * 80)
|
||||
for edge in graph_structure.edges:
|
||||
print(f" {edge.source} --> {edge.target}")
|
||||
print("\n" + "=" * 80)
|
||||
print("[2] 边列表")
|
||||
print("=" * 80)
|
||||
for edge in graph_structure.edges:
|
||||
print(f" {edge.source} --> {edge.target}")
|
||||
|
||||
# 3. ASCII 字符画(需要 grandalf)
|
||||
print("\n[3] ASCII 字符画:")
|
||||
print("-" * 80)
|
||||
try:
|
||||
print(graph_structure.draw_ascii())
|
||||
except Exception as e:
|
||||
print(f"⚠️ ASCII 绘制失败: {e}")
|
||||
# 2. ASCII 字符画
|
||||
print("\n" + "=" * 80)
|
||||
print("[3] ASCII 字符画")
|
||||
print("=" * 80)
|
||||
try:
|
||||
ascii_graph = graph_structure.draw_ascii()
|
||||
print(ascii_graph)
|
||||
except Exception as e:
|
||||
print(f"⚠️ ASCII 绘制失败: {e}")
|
||||
|
||||
# 4. Mermaid 源码
|
||||
print("\n[4] Mermaid 源码 (可复制到 https://mermaid.live/):")
|
||||
print("-" * 80)
|
||||
print(graph_structure.draw_mermaid())
|
||||
# 3. Mermaid 源码
|
||||
print("\n" + "=" * 80)
|
||||
print("[4] Mermaid 源码 (复制到 https://mermaid.live/)")
|
||||
print("=" * 80)
|
||||
try:
|
||||
mermaid_code = graph_structure.draw_mermaid()
|
||||
print(mermaid_code)
|
||||
except Exception as e:
|
||||
print(f"⚠️ Mermaid 生成失败: {e}")
|
||||
|
||||
# 4. 节点统计
|
||||
print("\n" + "=" * 80)
|
||||
print("[5] 图统计")
|
||||
print("=" * 80)
|
||||
print(f" 节点数量: {len(graph_structure.nodes)}")
|
||||
print(f" 边数量: {len(graph_structure.edges)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user