diff --git a/.env.docker b/.env.docker index 070f717..3bfb0e4 100644 --- a/.env.docker +++ b/.env.docker @@ -60,6 +60,13 @@ MEMORY_SUMMARIZE_INTERVAL=10 # 是否启用 Graph 执行追踪(调试用) ENABLE_GRAPH_TRACE=true +# ----------------------------------------------------------------------------- +# Tavily 搜索配置 +# 免费额度:1000次/天,官网:https://app.tavily.com +# ----------------------------------------------------------------------------- +TAVILY_API_KEY=你的Tavily_API密钥 +TAVILY_MAX_RESULTS=5 + # ----------------------------------------------------------------------------- # 稀疏模型配置 # ----------------------------------------------------------------------------- diff --git a/README.md b/README.md index 4b35643..022c23b 100644 --- a/README.md +++ b/README.md @@ -7,15 +7,15 @@ ## 📑 目录导航 - [核心功能](#-核心功能) - 面向用户的功能和技术特性 +- [使用指南](#-使用指南) - 基础对话、工具调用、多模型切换 - [技术架构](#️-技术架构) - 技术栈、系统架构图、工作流流程图 +- [模型服务使用情况](#55-模型服务使用情况详解) - 模型选型、Token估算、成本分析 - [核心算法与实现原理](#-核心算法与实现原理) - LangGraph 工作流、多模型路由、SSE 流式响应 - [快速开始](#-快速开始) - Docker 和本地部署指南 -- [使用指南](#-使用指南) - 基础对话、工具调用、多模型切换 - [开发指南](#-开发指南) - 添加工具、添加模型、Docker 部署 - [实现指南与最佳实践](#️-实现指南与最佳实践) - 性能优化、扩展开发、部署实践 - [环境配置](#️-环境配置) - 配置文件、环境变量 - [故障排查](#-故障排查) - 常见问题 - --- ## 🎯 核心功能 @@ -50,6 +50,48 @@ --- + +--- + +## 📖 使用指南 + +### 基础对话 + +直接在聊天框输入问题即可: + +``` +你好,请介绍一下自己 +帮我写一个 Python 脚本 +``` + +### 主要功能 + +| 功能 | 说明 | 示例提问 | +|------|------|---------| +| 🧠 混合路由智能分流 | 自动判断任务类型,选择最佳路径 | 自然对话即可 | +| ⚡ 快速路径 | 闲聊、RAG查询、工具调用可走快速路径 | "你好"、"什么是 RAG" | +| 🔄 React 推理循环 | 复杂任务走完整的思考-行动-观察循环 | "帮我分析一下这个文档" | +| 🌐 联网搜索 | 免费 DuckDuckGo 搜索 | "今天北京天气怎么样?" | +| 📚 RAG 知识库检索 | 检索本地知识库 | "如何配置系统?" | +| 📇 通讯录管理 | 联系人 CRUD、邮件处理 | "帮我查看一下张三的联系方式" | +| 📖 智能词典 | 翻译、生词本、专业术语提取 | "帮我翻译这句话" | +| 📰 资讯分析 | 资讯获取、内容分析 | "帮我分析一下这篇新闻" | +| 📊 可视化图表 | 支持 Mermaid 图表生成 | "帮我画一个流程图" | + +### 多模型切换 + +1. 在左侧边栏选择模型: + - **智谱 GLM-4**:在线服务,速度快 + - **DeepSeek V3**:深度推理模型 + - **OpenAI GPT-4o-mini**:通用对话模型 + - **本地 Qwen3.5-9B**:本地部署,隐私性好 + +2. 可随时切换,甚至在同一会话中 + +3. 点击 "🔄 新会话" 清空当前对话 + +--- + ## 🏗️ 技术架构 ### 1. 技术栈总览 @@ -385,6 +427,126 @@ flowchart TB --- + +--- + +### 5.5. 模型服务使用情况详解 + +#### 5.5.1. 模型服务架构总览 + +本项目采用**分层模型策略**,根据任务复杂度选择不同能力和成本的模型: + +| 模型类型 | 用途 | 主要来源 | 成本考量 | +|---------|------|---------|---------| +| 小模型 (Small LLM) | 意图分类、路由决策、查询改写 | 本地模型 / DeepSeek小模型 | 低成本,高频率 | +| 大模型 (Main LLM) | 对话生成、推理、工具调用 | 智谱 / DeepSeek / 本地 | 高能力,低频率 | +| Embedding模型 | 文本向量化、语义检索 | 本地 llama.cpp / 智谱API | 批量处理 | +| Rerank模型 | 检索结果重排序 | 硅基流动 / 智谱API | 精准排序 | +| Sparse模型 | BM25稀疏检索 | FastEmbed本地 | 关键词匹配 | + +#### 5.5.2. 小模型使用场景及Token估算 + +**小模型**主要用于高频率、低复杂度的任务: + +| 使用场景 | 位置文件 | 用途描述 | 单次Token估算 | 调用频率 | +|---------|---------|---------|-------------|---------| +| 意图分类 (1) | `app/core/intent_classifier.py` | 判断用户意图类型 | ~300输入 + ~50输出 | 每轮对话1次 | +| 意图分类 (2) | `app/main_graph/nodes/hybrid_router.py` | 混合路由决策 | ~200输入 + ~50输出 | 每轮对话1次 | +| 闲聊回复 | `app/main_graph/nodes/fast_paths.py` | 快速回复问候语 | ~50输入 + ~30输出 | 按需调用 | + +**Token估算说明**: +- 单次意图分类:总计 ~350-600 tokens +- 小模型成本通常是大模型的 1/10 - 1/100 +- 每日1000次对话,小模型仅消耗 ~350k-600k tokens + +#### 5.5.3. 大模型使用场景及Token估算 + +**大模型**用于核心对话生成和复杂推理: + +| 使用场景 | 位置文件 | 用途描述 | 单次Token估算 | 调用频率 | +|---------|---------|---------|-------------|---------| +| RAG查询改写 | `app/main_graph/utils/rag_initializer.py` | 生成多角度查询 | ~100输入 + ~150输出 | RAG调用时 | +| 主对话生成 | `app/main_graph/nodes/llm_call.py` | 用户查询响应 | ~500-2000输入 + ~200-1000输出 | 每轮对话1次 | +| React推理 | `app/main_graph/nodes/reasoning.py` | 任务分解与规划 | ~300-1000输入 + ~100-500输出 | 复杂任务多次 | +| 记忆摘要 | `app/memory/mem0_client.py` | 长期记忆压缩 | ~500-2000输入 + ~200-500输出 | 每N轮对话1次 | + +**Token估算说明**: +- 普通对话:总计 ~1000-3000 tokens +- RAG查询:额外 ~250 tokens +- 复杂多步推理:可能额外增加 500-3000 tokens +- 每日1000次对话,大模型预计消耗 1M-3M tokens + +#### 5.5.4. Embedding模型使用场景 + +**Embedding模型**用于语义检索和向量存储: + +| 使用场景 | 位置文件 | 用途描述 | 估算 | +|---------|---------|---------|------| +| RAG文档索引 | `rag_indexer/index_builder.py` | 文档分片向量化 | 每个文档片段1次 | +| 在线检索 | `app/rag/retriever.py` | 查询向量化 + 相似度检索 | 每次检索1次 | +| 记忆向量化 | `app/memory/mem0_client.py` | 记忆内容向量化存储 | 每次记忆更新1次 | + +**Embedding说明**: +- 向量维度:1024 (Qwen3-Embedding-0.6B) 或 2048 (智谱 embedding-3) +- 批量处理:建议使用 batch_size=10-20 提高效率 +- 本地优先:优先使用 llama.cpp 服务,降低API调用成本 + +#### 5.5.5. Rerank模型使用场景 + +**Rerank模型**用于检索结果精细化排序: + +| 使用场景 | 位置文件 | 用途描述 | 估算 | +|---------|---------|---------|------| +| RAG结果重排 | `app/rag/rerank.py` | 提升检索相关性 | 每次检索调用 | +| 混合检索重排 | `app/rag/retriever.py` | 稀疏+稠密结果融合排序 | 每次检索调用 | + +**Rerank说明**: +- 通常在 RRF 融合后使用,进一步提升精准度 +- 重排数量建议:rerank_top_n=3-10 +- 成本权衡:rerank 会增加额外调用成本,但精度提升明显 + +#### 5.5.6. 模型服务选型参考对比 + +为方便不同部署场景选择,提供以下模型选型参考: + +| 维度 | 本地优先方案 | 云端优先方案 | 混合方案 | +|------|------------|------------|---------| +| **小模型** | Qwen3.5-9B (本地) | DeepSeek-Chat (API) | 本地+DeepSeek降级 | +| **大模型** | Qwen3.5-9B (本地) | 智谱 GLM-4 / DeepSeek | 本地+云端降级链 | +| **Embedding** | Qwen3-Embedding-0.6B (本地llama.cpp) | 智谱 embedding-2 | 本地优先,智谱降级 | +| **Rerank** | (可选本地) | 硅基流动 bge-reranker-v2-m3 | 硅基流动API | +| **Sparse** | FastEmbed BM25 (本地) | FastEmbed BM25 (本地) | 本地 | + +**成本参考对比**(每1M tokens,仅作示例): + +| 模型 | 输入成本 | 输出成本 | 适用场景 | +|------|---------|---------|---------| +| **本地模型** | ~0元 | ~0元 | 有GPU机器,隐私敏感 | +| **DeepSeek-Chat** | ~¥0.5 | ~¥1.0 | 通用推理,成本适中 | +| **智谱 GLM-4** | ~¥1.0 | ~¥2.0 | 高质量对话 | +| **智谱 embedding-2** | ~¥0.2 | - | 向量嵌入 | +| **硅基流动 Rerank** | ~¥0.3 | - | 精准重排 | + +**部署建议**: +- **个人/测试**:全云端方案,快速上手 +- **小团队**:小模型本地,大模型云端降级 +- **企业/隐私敏感**:全本地部署,或使用私有API +- **生产环境**:核心能力本地+云端降级链,保证高可用 + +#### 5.5.7. 模型服务降级链路设计 + +本项目所有模型服务都设计了**自动降级链路**,保证服务高可用: + +| 服务类型 | 主服务 | 降级服务 | +|---------|--------|---------| +| 对话生成 | 本地模型 → 智谱GLM-4 → DeepSeek | +| Embedding | 本地llama.cpp → 智谱embedding-2 | +| Rerank | 硅基流动 → 智谱rerank-2 | + +降级逻辑实现在 `app/model_services/base.py: FallbackServiceChain`,对上层业务透明。 + +--- + ### 6. 模型服务层 #### 6.1 多模型降级链 @@ -1582,45 +1744,6 @@ streamlit run frontend/src/frontend_main.py --- -## 📖 使用指南 - -### 基础对话 - -直接在聊天框输入问题即可: - -``` -你好,请介绍一下自己 -帮我写一个 Python 脚本 -``` - -### 主要功能 - -| 功能 | 说明 | 示例提问 | -|------|------|---------| -| 🧠 混合路由智能分流 | 自动判断任务类型,选择最佳路径 | 自然对话即可 | -| ⚡ 快速路径 | 闲聊、RAG查询、工具调用可走快速路径 | "你好"、"什么是 RAG" | -| 🔄 React 推理循环 | 复杂任务走完整的思考-行动-观察循环 | "帮我分析一下这个文档" | -| 🌐 联网搜索 | 免费 DuckDuckGo 搜索 | "今天北京天气怎么样?" | -| 📚 RAG 知识库检索 | 检索本地知识库 | "如何配置系统?" | -| 📇 通讯录管理 | 联系人 CRUD、邮件处理 | "帮我查看一下张三的联系方式" | -| 📖 智能词典 | 翻译、生词本、专业术语提取 | "帮我翻译这句话" | -| 📰 资讯分析 | 资讯获取、内容分析 | "帮我分析一下这篇新闻" | -| 📊 可视化图表 | 支持 Mermaid 图表生成 | "帮我画一个流程图" | - -### 多模型切换 - -1. 在左侧边栏选择模型: - - **智谱 GLM-4**:在线服务,速度快 - - **DeepSeek V3**:深度推理模型 - - **OpenAI GPT-4o-mini**:通用对话模型 - - **本地 Qwen3.5-9B**:本地部署,隐私性好 - -2. 可随时切换,甚至在同一会话中 - -3. 点击 "🔄 新会话" 清空当前对话 - ---- - ## 🔧 开发指南 ### 添加新工具 diff --git a/backend/app/agent/agent_service.py b/backend/app/agent/agent_service.py index aa8a8c8..f3d7a79 100644 --- a/backend/app/agent/agent_service.py +++ b/backend/app/agent/agent_service.py @@ -1,25 +1,28 @@ """ -AI Agent 服务类 - 支持多模型动态切换 +AI Agent 服务类 - 单图方案 + 动态模型选择 接收外部传入的 checkpointer,不负责管理连接生命周期 """ import json import asyncio +from typing import AsyncGenerator, Dict, Any, Optional, Tuple # 本地模块 -from ..main_graph.utils.main_graph_builder import build_react_main_graph +from ..model_services import get_cached_chat_services +from ..main_graph.main_graph_builder import build_react_main_graph from ..main_graph.tools.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME from ..main_graph.config import set_stream_writer from ..main_graph.utils.rag_initializer import init_rag_tool from ..core.intent_classifier import get_intent_classifier -from ..logger import info, warning, error +from ..logger import debug, info, warning, error from ..main_graph.state import MainGraphState, CurrentAction class AIAgentService: def __init__(self, checkpointer): self.checkpointer = checkpointer - self.graphs = {} + self.graph = None # 只有一张图 + self.chat_services = None # 缓存的模型字典 self.tools = AVAILABLE_TOOLS.copy() self.tools_by_name = TOOLS_BY_NAME.copy() # 添加:意图分类器 @@ -40,64 +43,94 @@ class AIAgentService: self.tools.append(rag_tool) self.tools_by_name[rag_tool.name] = rag_tool self.rag_tool = rag_tool # 保存到实例变量,供 config 注入 - - # 2. 构建各模型的 Graph(使用新版 React 模式) - for name, llm in chat_services.items(): - try: - info(f"🔄 初始化模型 '{name}'...") - graph = build_react_main_graph( - llm=llm, - tools=self.tools, - mem0_client=self.mem0_client - ).compile(checkpointer=self.checkpointer) - self.graphs[name] = graph - info(f"✅ 模型 '{name}' 初始化成功") - except Exception as e: - warning(f"⚠️ 模型 '{name}' 初始化失败: {e}") - - if not self.graphs: - raise RuntimeError("没有可用的模型") + + # 2. 获取缓存的模型字典 + self.chat_services = get_cached_chat_services() + info(f"✅ 加载了 {len(self.chat_services)} 个可用模型: {list(self.chat_services.keys())}") + + # 3. 只构建一次图(传入 chat_services 字典) + info(f"🔄 构建单图...") + graph_builder = build_react_main_graph( + chat_services=self.chat_services, + tools=self.tools, + mem0_client=self.mem0_client + ) + self.graph = graph_builder.compile(checkpointer=self.checkpointer) + info(f"✅ 单图初始化完成") + return self - async def process_message(self, message: str, thread_id: str, model: str = "local", user_id: str = "default_user") -> dict: - """处理用户消息,返回包含回复、token统计和耗时的字典""" - if model not in self.graphs: - # 回退到第一个可用模型 - available = list(self.graphs.keys()) - if not available: - raise RuntimeError("没有可用的模型") - model = available[0] - warning(f"模型 '{model}' 不可用,已回退到 '{model}'") + def _resolve_model(self, model: str) -> str: + """ + 解析并验证模型名称,不可用时回退到第一个可用模型 + + Args: + model: 目标模型名称 + + Returns: + 实际使用的模型名称 + """ + if not model or model not in self.chat_services: + fallback = next(iter(self.chat_services.keys())) + warning(f"模型 '{model}' 不可用,回退到 '{fallback}'") + return fallback + return model - graph = self.graphs[model] + def _build_invocation( + self, message: str, thread_id: str, model: str, user_id: str + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """ + 构建图调用所需的 config 和 input_state + + Args: + message: 用户消息 + thread_id: 会话 ID + model: 模型名称 + user_id: 用户 ID + + Returns: + (config, input_state) 元组 + """ config = { "configurable": { "thread_id": thread_id, - "rag_tool": getattr(self, "rag_tool", None), # 注入 RAG 工具 + "rag_tool": getattr(self, "rag_tool", None), }, "metadata": {"user_id": user_id} } - # 新版状态输入:传入完整的 MainGraphState,关键是 user_query - from ..main_graph.state import MainGraphState, CurrentAction input_state = { "user_query": message, "messages": [{"role": "user", "content": message}], "user_id": user_id, + "current_model": model, "current_action": CurrentAction.NONE } + return config, input_state - result = await graph.ainvoke(input_state, config=config) + async def process_message( + self, message: str, thread_id: str, model: str = "", user_id: str = "default_user" + ) -> dict: + """处理用户消息,返回包含回复、token统计和耗时的字典""" + # 解析模型名称 + resolved_model = self._resolve_model(model) + + # 构建调用参数 + config, input_state = self._build_invocation(message, thread_id, resolved_model, user_id) + + result = await self.graph.ainvoke(input_state, config=config) reply = result.get("final_result", "") if not reply and result.get("messages"): reply = result["messages"][-1].content - token_usage = result.get("debug_info", {}).get("token_usage", {}) - elapsed_time = result.get("debug_info", {}).get("elapsed_time", 0.0) + token_usage = result.get("last_token_usage", {}) + elapsed_time = result.get("last_elapsed_time", 0.0) + actual_model = result.get("current_model", resolved_model) return { "reply": reply, "token_usage": token_usage, - "elapsed_time": elapsed_time + "elapsed_time": elapsed_time, + "model_used": actual_model } def _serialize_value(self, value): @@ -121,31 +154,169 @@ class AIAgentService: except (TypeError, ValueError): return str(value) - async def process_message_stream(self, message: str, thread_id: str, model_name: str, user_id: str = "default_user"): - """流式处理消息,返回异步生成器(全部走 React 模式)""" - graph = self.graphs.get(model_name) - if not graph: - raise ValueError(f"模型 '{model_name}' 未找到或未初始化") + async def _handle_message_chunk( + self, chunk: Dict[str, Any], current_node: Optional[str], tool_calls_in_progress: Dict[str, Any] + ) -> AsyncGenerator[Dict[str, Any], None]: + """处理 messages 类型的 chunk""" + message_chunk, metadata = chunk["data"] + node_name = metadata.get("langgraph_node", "unknown") + new_current_node = current_node - config = { - "configurable": { - "thread_id": thread_id, - "rag_tool": getattr(self, "rag_tool", None), # 注入 RAG 工具 - }, - "metadata": {"user_id": user_id} - } - input_state = { - "user_query": message, - "messages": [{"role": "user", "content": message}], - "user_id": user_id, - "current_action": CurrentAction.NONE - } + # 检测节点变化,发送节点开始事件 + if node_name != current_node: + if current_node: + yield {"type": "node_end", "node": current_node} + yield {"type": "node_start", "node": node_name} + new_current_node = node_name - # ========== 意图识别(保留用于日志)========== + # 处理消息内容 + token_content = getattr(message_chunk, 'content', str(message_chunk)) + reasoning_token = "" + if hasattr(message_chunk, 'additional_kwargs'): + reasoning_token = message_chunk.additional_kwargs.get("reasoning_content", "") + + # 处理思考过程 + if reasoning_token: + yield { + "type": "llm_token", + "node": node_name, + "reasoning_token": reasoning_token + } + # 处理工具调用 + elif hasattr(message_chunk, 'tool_calls') and message_chunk.tool_calls: + for tool_call in message_chunk.tool_calls: + tool_call_id = tool_call.get("id", "") + tool_name = tool_call.get("name", "") + tool_args = tool_call.get("args", {}) + + # 记录工具调用开始,避免重复 + if tool_call_id and tool_call_id not in tool_calls_in_progress: + tool_calls_in_progress[tool_call_id] = { + "name": tool_name, + "args": tool_args + } + yield { + "type": "tool_call_start", + "tool": tool_name, + "args": tool_args, + "id": tool_call_id + } + # 处理普通 token + elif token_content: + yield { + "type": "llm_token", + "node": node_name, + "token": token_content, + "reasoning_token": reasoning_token + } + + # 返回更新后的 current_node + yield {"type": "_update_state", "current_node": new_current_node} + + async def _handle_updates_chunk( + self, chunk: Dict[str, Any], tool_calls_in_progress: Dict[str, Any], actual_model_used: str + ) -> AsyncGenerator[Dict[str, Any], None]: + """处理 updates 类型的 chunk""" + updates_data = chunk["data"] + new_actual_model = actual_model_used + + debug(f"[Stream] updates 数据: {list(updates_data.keys()) if isinstance(updates_data, dict) else type(updates_data)}") + + # 特别检查 final_result 和 current_model + if isinstance(updates_data, dict): + if "final_result" in updates_data: + debug(f"[Stream] 收到 final_result: {str(updates_data['final_result'])[:100]}...") + if "current_model" in updates_data: + new_actual_model = updates_data["current_model"] + info(f"[Stream] 实际使用模型: {new_actual_model}") + + serialized_data = self._serialize_value(updates_data) + + # 检查是否有人工审核请求 + if "review_pending" in serialized_data and serialized_data["review_pending"]: + review_id = serialized_data.get("review_id", "") + content_to_review = serialized_data.get("content_to_review", "") + yield { + "type": "human_review_request", + "review_id": review_id, + "content": content_to_review + } + + # 检查是否有工具结果 + if "messages" in serialized_data: + for msg in serialized_data["messages"]: + # 检测工具结果消息 + if msg.get("role") == "tool": + tool_call_id = msg.get("tool_call_id", "") + tool_name = msg.get("name", "") + tool_result = msg.get("content", "") + + if tool_call_id and tool_call_id in tool_calls_in_progress: + yield { + "type": "tool_call_end", + "tool": tool_name, + "id": tool_call_id, + "result": tool_result + } + del tool_calls_in_progress[tool_call_id] + + yield { + "type": "state_update", + "data": serialized_data + } + + # 返回更新后的模型 + yield {"type": "_update_state", "actual_model_used": new_actual_model} + + async def _handle_custom_chunk(self, chunk: Dict[str, Any]) -> AsyncGenerator[Dict[str, Any], None]: + """处理 custom 类型的 chunk""" + custom_data = chunk["data"] + + # 处理我们从 react_reason_node 发送的自定义推理事件 + if isinstance(custom_data, dict): + # 检查是否是我们的推理事件 + if "action" in custom_data and "reasoning" in custom_data: + yield { + "type": "react_reasoning", + "step": custom_data.get("step", 1), + "action": custom_data.get("action", "unknown"), + "confidence": custom_data.get("confidence", 0), + "reasoning": custom_data.get("reasoning", "") + } + else: + # 处理其他自定义事件 + serialized_data = self._serialize_value(custom_data) + yield { + "type": "custom", + "data": serialized_data + } + else: + # 处理其他自定义事件 + serialized_data = self._serialize_value(custom_data) + yield { + "type": "custom", + "data": serialized_data + } + + async def process_message_stream( + self, message: str, thread_id: str, model: str = "", user_id: str = "default_user" + ) -> AsyncGenerator[Dict[str, Any], None]: + """流式处理消息,返回异步生成器""" + # 解析模型名称 + resolved_model = self._resolve_model(model) + + # 构建调用参数 + config, input_state = self._build_invocation(message, thread_id, resolved_model, user_id) + + # ========== 意图识别(保留用于日志和后续路由)========== intent_result = await self.intent_classifier.classify(message) info(f"🧠 意图识别: {intent_result.intent_type} (置信度: {intent_result.confidence:.2f})") info(f"📝 推理: {intent_result.reasoning}") + # 注入意图到状态(让 hybrid_router 可以利用) + input_state["intent_type"] = intent_result.intent_type.value + input_state["intent_confidence"] = intent_result.confidence + # 发送意图分类事件 yield { "type": "intent_classified", @@ -154,25 +325,26 @@ class AIAgentService: "reasoning": intent_result.reasoning } - # 发送路径决策事件(现在都是 react_loop) + # 发送路径决策事件(目前硬编码,但状态中有意图信息供后续使用) yield { "type": "path_decision", "path": "react_loop", "intent": intent_result.intent_type.value } - # ======================================== + # ============================================= # ========== React 循环路径 ========== - info(f"🚀 开始执行 React 图,模型: {model_name}") + info(f"🚀 开始执行单图,指定模型: {resolved_model}") current_node = None - tool_calls_in_progress = {} + tool_calls_in_progress: Dict[str, Any] = {} + actual_model_used = resolved_model + chunk_count = 0 + full_message_content = "" try: info(f"📡 开始调用 graph.astream()...") - chunk_count = 0 - full_message_content = "" # 收集完整消息内容 - async for chunk in graph.astream( + async for chunk in self.graph.astream( input_state, config=config, stream_mode=["messages", "updates", "custom"], @@ -181,156 +353,58 @@ class AIAgentService: ): chunk_count += 1 chunk_type = chunk["type"] - processed_event = {} if chunk_type == "messages": - message_chunk, metadata = chunk["data"] - node_name = metadata.get("langgraph_node", "unknown") - - # 检测节点变化,发送节点开始事件 - if node_name != current_node: - if current_node: - yield { - "type": "node_end", - "node": current_node - } - yield { - "type": "node_start", - "node": node_name - } - current_node = node_name - - # 处理消息内容 - token_content = getattr(message_chunk, 'content', str(message_chunk)) - reasoning_token = "" - if hasattr(message_chunk, 'additional_kwargs'): - reasoning_token = message_chunk.additional_kwargs.get("reasoning_content", "") - - # 处理思考过程 - if reasoning_token: - processed_event = { - "type": "llm_token", - "node": node_name, - "reasoning_token": reasoning_token - } - # 处理工具调用 - elif hasattr(message_chunk, 'tool_calls') and message_chunk.tool_calls: - for tool_call in message_chunk.tool_calls: - tool_call_id = tool_call.get("id", "") - tool_name = tool_call.get("name", "") - tool_args = tool_call.get("args", {}) - - # 记录工具调用开始 - if tool_call_id not in tool_calls_in_progress: - tool_calls_in_progress[tool_call_id] = { - "name": tool_name, - "args": tool_args - } - yield { - "type": "tool_call_start", - "tool": tool_name, - "args": tool_args, - "id": tool_call_id - } - # 处理普通 token - 只收集,不打印单个 token - elif token_content: - processed_event = { - "type": "llm_token", - "node": node_name, - "token": token_content, - "reasoning_token": reasoning_token - } - if node_name == "llm_call": - full_message_content += token_content + async for event in self._handle_message_chunk( + chunk, current_node, tool_calls_in_progress + ): + if event.get("type") == "_update_state": + current_node = event.get("current_node", current_node) + else: + # 如果是 llm_call 节点的 token,收集完整消息 + if ( + event.get("type") == "llm_token" + and event.get("node") == "llm_call" + and "token" in event + ): + full_message_content += event["token"] + yield event elif chunk_type == "updates": - updates_data = chunk["data"] - info(f"[Stream] updates 数据: {list(updates_data.keys()) if isinstance(updates_data, dict) else type(updates_data)}") - # 特别检查 final_result - if isinstance(updates_data, dict) and "final_result" in updates_data: - info(f"[Stream] 收到 final_result: {str(updates_data['final_result'])[:100]}...") - serialized_data = self._serialize_value(updates_data) - if "review_pending" in serialized_data and serialized_data["review_pending"]: - review_id = serialized_data.get("review_id", "") - content_to_review = serialized_data.get("content_to_review", "") - yield { - "type": "human_review_request", - "review_id": review_id, - "content": content_to_review - } - - # 检查是否有工具结果 - if "messages" in serialized_data: - for msg in serialized_data["messages"]: - # 检测工具结果消息 - if msg.get("role") == "tool": - tool_call_id = msg.get("tool_call_id", "") - tool_name = msg.get("name", "") - tool_output = msg.get("content", "") - - if tool_call_id in tool_calls_in_progress: - yield { - "type": "tool_call_end", - "tool": tool_name, - "id": tool_call_id, - "result": tool_output - } - del tool_calls_in_progress[tool_call_id] - - processed_event = { - "type": "state_update", - "data": serialized_data - } + async for event in self._handle_updates_chunk( + chunk, tool_calls_in_progress, actual_model_used + ): + if event.get("type") == "_update_state": + actual_model_used = event.get("actual_model_used", actual_model_used) + else: + yield event elif chunk_type == "custom": - custom_data = chunk["data"] - - # 处理我们从 react_reason_node 发送的自定义推理事件 - if isinstance(custom_data, dict): - # 检查是否是我们的推理事件 - if "action" in custom_data and "reasoning" in custom_data: - yield { - "type": "react_reasoning", - "step": custom_data.get("step", 1), - "action": custom_data.get("action", "unknown"), - "confidence": custom_data.get("confidence", 0), - "reasoning": custom_data.get("reasoning", "") - } - else: - # 处理其他自定义事件 - serialized_data = self._serialize_value(custom_data) - processed_event = { - "type": "custom", - "data": serialized_data - } - else: - # 处理其他自定义事件 - serialized_data = self._serialize_value(custom_data) - processed_event = { - "type": "custom", - "data": serialized_data - } - - if processed_event: - yield processed_event + async for event in self._handle_custom_chunk(chunk): + yield event # 完整消息集合完成后,一次性打印 - info(f"✅ graph.astream() 完成,共 {chunk_count} 个chunks") + info(f"✅ graph.astream() 完成,共 {chunk_count} 个 chunks") if full_message_content: info(f"📄 完整消息内容: {repr(full_message_content)}") + info(f"🤖 实际使用模型: {actual_model_used}") except Exception as e: - error(f"❌ 执行 React 图时出错: {e}") + error(f"❌ 执行单图时出错: {e}") import traceback error(f"📋 堆栈: {traceback.format_exc()}") - raise - - # 发送结束事件 - if current_node: yield { - "type": "node_end", - "node": current_node + "type": "error", + "message": str(e) + } + finally: + # 无论成功或失败,都发送结束事件,保证前端平稳关闭 + if current_node: + yield { + "type": "node_end", + "node": current_node + } + yield { + "type": "done", + "model_used": actual_model_used } - yield { - "type": "done" - } \ No newline at end of file diff --git a/backend/app/agent/prompts.py b/backend/app/agent/prompts.py index 848b8c6..8fc0e98 100644 --- a/backend/app/agent/prompts.py +++ b/backend/app/agent/prompts.py @@ -22,9 +22,8 @@ def create_system_prompt(tools: list = None) -> ChatPromptTemplate: "3. 📇 通讯录子系统 - 查询联系人、添加联系人、管理通讯录\n" "4. 🔍 RAG检索 - 从知识库中检索相关信息回答问题\n\n" "【用户背景信息】\n" - "以下是对当前用户的已知信息和长期记忆,你必须优先采纳并在回答中体现:\n" + "以下是对当前用户的已知信息和长期记忆,你必须优先采纳:\n" "{memory_context}\n" - "若包含姓名、偏好等个人信息,请自然融入回应(例如称呼名字、提及偏好)。\n\n" "【可用工具与使用规则】\n" f"{tools_section}\n" "工具调用时请直接返回所需参数,无需额外说明。\n\n" diff --git a/backend/app/config.py b/backend/app/config.py index 9b46560..86a48cf 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -127,6 +127,13 @@ BACKEND_PORT = _get_int("BACKEND_PORT") MEMORY_SUMMARIZE_INTERVAL = _get_int("MEMORY_SUMMARIZE_INTERVAL") +# ========== Tavily 搜索配置 ========== +# Tavily API:https://app.tavily.com +# 免费额度:1000次/天 +TAVILY_API_KEY = _get_str("TAVILY_API_KEY") +TAVILY_MAX_RESULTS = _get_int("TAVILY_MAX_RESULTS") or 5 + + # ========== Graph 执行追踪配置 ========== # 是否启用 Graph 流转追踪(通过环境变量控制) ENABLE_GRAPH_TRACE = _get_bool("ENABLE_GRAPH_TRACE") diff --git a/backend/app/core/web_search.py b/backend/app/core/web_search.py index d7d0603..db8fac2 100644 --- a/backend/app/core/web_search.py +++ b/backend/app/core/web_search.py @@ -33,18 +33,29 @@ class WebSearchTool: def search(self, query: str, max_results: Optional[int] = None) -> List[SearchResult]: """ - 使用多种方式搜索 - + 使用多种方式搜索,按优先级尝试 + Args: query: 搜索关键词 max_results: 返回结果数量,默认使用初始化时的设置 - + Returns: 搜索结果列表 """ num_results = max_results or self.max_results - # 方式 1: 尝试用 ddgs 包 + # 方式 1: Tavily (需要 API Key,质量最高) + try: + return self._search_tavily(query, num_results) + except ImportError: + print("[WebSearch] tavily 未安装,尝试其他搜索方式") + except Exception as e: + if "API_KEY" in str(e) or "未配置" in str(e): + print(f"[WebSearch] Tavily API Key 未配置: {e}") + else: + print(f"[WebSearch] Tavily 搜索失败: {e}") + + # 方式 2: 尝试用 ddgs 包 try: from ddgs import DDGS print(f"[WebSearch] 使用 ddgs 搜索: {query}") @@ -65,29 +76,7 @@ class WebSearchTool: print("[WebSearch] ddgs 未安装,尝试 duckduckgo-search") except Exception as e: print(f"[WebSearch] ddgs 搜索失败: {e}") - - # 方式 2: 尝试用旧的 duckduckgo-search 包 - try: - from duckduckgo_search import DDGS - print(f"[WebSearch] 使用 duckduckgo-search 搜索: {query}") - with DDGS() as ddgs: - results = list(ddgs.text(query, max_results=num_results)) - if results: - search_results = [] - for r in results: - search_results.append(SearchResult( - title=r.get("title", ""), - url=r.get("href", ""), - snippet=r.get("body", ""), - source="DuckDuckGo" - )) - print(f"[WebSearch] duckduckgo-search 返回 {len(search_results)} 条结果") - return search_results - except ImportError: - print("[WebSearch] duckduckgo-search 未安装") - except Exception as e: - print(f"[WebSearch] duckduckgo-search 搜索失败: {e}") - + # 方式 3: 尝试用简单 HTTP 请求 try: return self._search_http(query, num_results) @@ -97,6 +86,34 @@ class WebSearchTool: # 方式 4: 返回模拟数据作为最后兜底 return self._search_mock(query, num_results) + def _search_tavily(self, query: str, max_results: int) -> List[SearchResult]: + """使用 Tavily API 搜索""" + from tavily import TavilyClient + from app.config import TAVILY_API_KEY, TAVILY_MAX_RESULTS + + if not TAVILY_API_KEY: + raise ValueError("TAVILY_API_KEY 未配置") + + client = TavilyClient(api_key=TAVILY_API_KEY) + response = client.search( + query=query, + max_results=min(max_results, TAVILY_MAX_RESULTS or 5), + include_answer=True, + include_raw_content=False + ) + + results = [] + for item in response.get("results", []): + results.append(SearchResult( + title=item.get("title", ""), + url=item.get("url", ""), + snippet=item.get("content", ""), + source="Tavily" + )) + + print(f"[WebSearch] Tavily 返回 {len(results)} 条结果") + return results + def _search_http(self, query: str, max_results: int) -> List[SearchResult]: """用简单 HTTP 请求搜索(备用方案)- 尝试多个国内源""" print(f"[WebSearch] 尝试 HTTP 搜索") diff --git a/backend/app/main_graph/graph.py b/backend/app/main_graph/graph.py deleted file mode 100644 index e46b75a..0000000 --- a/backend/app/main_graph/graph.py +++ /dev/null @@ -1,8 +0,0 @@ -""" -LangGraph 核心组件重新导出 -统一导入入口,避免直接依赖 langgraph -""" - -from langgraph.graph import StateGraph, START, END, add_messages - -__all__ = ["StateGraph", "START", "END", "add_messages"] diff --git a/backend/app/main_graph/main_graph_builder.py b/backend/app/main_graph/main_graph_builder.py new file mode 100644 index 0000000..bde6d0f --- /dev/null +++ b/backend/app/main_graph/main_graph_builder.py @@ -0,0 +1,229 @@ +""" +主图构建器 - 构建整合后的完整主图 +""" + +from langgraph.graph import StateGraph, START, END +from typing import Dict, Any + +from .state import MainGraphState +from .nodes.reasoning import react_reason_node +from .nodes.web_search import web_search_node +from .nodes.error_handling import error_handling_node +from .nodes.routing import init_state_node, route_by_reasoning, should_summarize +from .nodes.hybrid_router import ( + hybrid_router_node, + route_from_hybrid_decision, + check_fast_path_success, +) +from .nodes.fast_paths import ( + fast_chitchat_node, + fast_rag_node, + fast_tool_node, +) +from .nodes.llm_call import create_dynamic_llm_call_node +from .nodes.rag_nodes import rag_retrieve_node +from .nodes.retrieve_memory import create_retrieve_memory_node +from .nodes.memory_trigger import memory_trigger_node, set_mem0_client +from .nodes.summarize import create_summarize_node +from .nodes.finalize import finalize_node +from ..subgraphs.contact import build_contact_subgraph +from ..subgraphs.dictionary import build_dictionary_subgraph +from ..subgraphs.news_analysis import build_news_analysis_subgraph +from ..logger import info + +from .subgraph_wrapper import create_subgraph_nodes + + +# ========== 主图构建 ========== + +def build_react_main_graph( + chat_services: dict, + tools=None, + mem0_client=None, + use_hybrid_router: bool = True +) -> StateGraph: + """ + 构建整合后的完整主图(支持混合路由 + 动态模型选择) + + Args: + chat_services: 模型名称 -> ChatModel 实例 的字典 + tools: 工具列表 + mem0_client: Mem0 客户端实例 + use_hybrid_router: 是否使用混合路由(快速路径 + React 循环) + + Returns: + StateGraph: 构建好的图 + """ + # 创建图 + graph = StateGraph(MainGraphState) + + # 设置全局 mem0_client + if mem0_client: + set_mem0_client(mem0_client) + + # ========== 创建节点 ========== + + # LLM 调用节点 + llm_node = create_dynamic_llm_call_node(chat_services, tools or []) + + # 记忆节点 + retrieve_memory_node = None + summarize_node = None + if mem0_client: + retrieve_memory_node = create_retrieve_memory_node(mem0_client) + summarize_node = create_summarize_node(mem0_client) + + # 子图节点 + contact_graph = build_contact_subgraph() + dictionary_graph = build_dictionary_subgraph() + news_analysis_graph = build_news_analysis_subgraph() + subgraph_nodes = create_subgraph_nodes( + contact_graph, dictionary_graph, news_analysis_graph + ) + + # ========== 添加节点到图 ========== + + # 阶段 1: 记忆检索 + if retrieve_memory_node: + graph.add_node("retrieve_memory", retrieve_memory_node) + graph.add_node("memory_trigger", memory_trigger_node) + + # 阶段 2: 初始化 + graph.add_node("init_state", init_state_node) + + # 阶段 3: 混合路由(可选) + if use_hybrid_router: + graph.add_node("hybrid_router", hybrid_router_node) + graph.add_node("fast_chitchat", fast_chitchat_node) + graph.add_node("fast_rag", fast_rag_node) + graph.add_node("fast_tool", fast_tool_node) + + # 阶段 4: React 循环推理(始终保留) + graph.add_node("react_reason", react_reason_node) + graph.add_node("rag_retrieve", rag_retrieve_node) + graph.add_node("web_search", web_search_node) + graph.add_node("handle_error", error_handling_node) + + if llm_node is not None: + graph.add_node("llm_call", llm_node) + + # 子图节点 + for node_name, node_func in subgraph_nodes.items(): + graph.add_node(node_name, node_func) + + # 阶段 5: 完成处理 + if summarize_node: + graph.add_node("summarize", summarize_node) + graph.add_node("finalize", finalize_node) + + # ========== 添加边 ========== + + # 阶段 1: 记忆检索 + _add_memory_edges(graph, retrieve_memory_node) + + # 阶段 2: 初始化 + graph.add_edge("memory_trigger", "init_state") + + # 阶段 3: 路由分支 + _add_routing_edges(graph, use_hybrid_router, llm_node) + + # 阶段 4: React 循环边 + _add_react_loop_edges(graph, subgraph_nodes) + + # 阶段 5: 完成阶段 + _add_finalize_edges(graph, llm_node, summarize_node) + + info(f"✅ [图构建] 整合后的完整主图构建完成(混合路由: {use_hybrid_router})") + + return graph + + +def _add_memory_edges(graph: StateGraph, retrieve_memory_node) -> None: + """添加记忆检索阶段的边""" + if retrieve_memory_node: + graph.add_edge(START, "retrieve_memory") + graph.add_edge("retrieve_memory", "memory_trigger") + else: + graph.add_edge(START, "memory_trigger") + + +def _add_routing_edges(graph: StateGraph, use_hybrid_router: bool, llm_node) -> None: + """添加路由阶段的边""" + if use_hybrid_router: + graph.add_edge("init_state", "hybrid_router") + + # 混合路由条件分支 + graph.add_conditional_edges( + "hybrid_router", + route_from_hybrid_decision, + { + "fast_chitchat": "fast_chitchat", + "fast_rag": "fast_rag", + "fast_tool": "fast_tool", + "react_loop": "react_reason" + } + ) + + # 快速路径的完成检查 + for fast_node in ["fast_chitchat", "fast_rag", "fast_tool"]: + graph.add_conditional_edges( + fast_node, + check_fast_path_success, + { + "llm_call": "llm_call", + "escalate": "react_reason" + } + ) + + info(f"✅ [图构建] 混合路由模式已启用") + else: + graph.add_edge("init_state", "react_reason") + info(f"✅ [图构建] 纯 React 模式") + + +def _add_react_loop_edges(graph: StateGraph, subgraph_nodes: Dict[str, Any]) -> None: + """添加 React 循环阶段的边""" + subgraph_names = list(subgraph_nodes.keys()) + + # React 推理的条件分支 + graph.add_conditional_edges( + "react_reason", + route_by_reasoning, + { + "rag_retrieve": "rag_retrieve", + "web_search": "web_search", + **{name: name for name in subgraph_names}, + "handle_error": "handle_error", + "llm_call": "llm_call" + } + ) + + # 循环边(回到 react_reason) + loop_back_nodes = ["rag_retrieve", "web_search", "handle_error"] + subgraph_names + for node_name in loop_back_nodes: + graph.add_edge(node_name, "react_reason") + + +def _add_finalize_edges(graph: StateGraph, llm_node, summarize_node) -> None: + """添加完成阶段的边""" + if llm_node is not None: + if summarize_node: + graph.add_conditional_edges( + "llm_call", + should_summarize, + { + "summarize": "summarize", + "finalize": "finalize" + } + ) + graph.add_edge("summarize", "finalize") + else: + graph.add_edge("llm_call", "finalize") + + graph.add_edge("finalize", END) + + +# ========== 导出 ========== +__all__ = [ + "build_react_main_graph", +] diff --git a/backend/app/main_graph/nodes/__init__.py b/backend/app/main_graph/nodes/__init__.py index a3cb165..bd00186 100644 --- a/backend/app/main_graph/nodes/__init__.py +++ b/backend/app/main_graph/nodes/__init__.py @@ -6,8 +6,8 @@ from .reasoning import react_reason_node from .web_search import web_search_node from .error_handling import error_handling_node -from .routing import init_state_node, route_by_reasoning -from .llm_call import create_llm_call_node +from .routing import init_state_node, route_by_reasoning, should_summarize +from .llm_call import create_dynamic_llm_call_node from .rag_nodes import rag_retrieve_node, rag_re_retrieve_node # 记忆节点 @@ -38,7 +38,8 @@ __all__ = [ "web_search_node", "error_handling_node", "route_by_reasoning", - "create_llm_call_node", + "should_summarize", + "create_dynamic_llm_call_node", "rag_retrieve_node", "rag_re_retrieve_node", # 记忆节点 diff --git a/backend/app/main_graph/nodes/llm_call.py b/backend/app/main_graph/nodes/llm_call.py index 24bca23..40c278f 100644 --- a/backend/app/main_graph/nodes/llm_call.py +++ b/backend/app/main_graph/nodes/llm_call.py @@ -5,7 +5,7 @@ LLM 调用节点模块 import time from typing import Any, Dict -from langchain_core.language_models import BaseLLM +from langchain_core.language_models import BaseChatModel from langchain_core.messages import AIMessage # 本地模块 @@ -14,29 +14,34 @@ from app.agent.prompts import create_system_prompt from app.utils.logging import log_state_change from app.logger import debug, info, error -def create_llm_call_node(llm, tools: list): + +def create_dynamic_llm_call_node(chat_services: Dict[str, BaseChatModel], tools: list): """ - 工厂函数:创建 LLM 调用节点 - + 工厂函数:创建动态 LLM 调用节点(根据 state.current_model 选择模型) + Args: - llm: LangChain LLM 实例 + chat_services: 模型名称 -> ChatModel 实例 的字典 tools: 工具列表 - + Returns: 异步节点函数 """ - # 构建调用链 + # 预构建所有模型的 tools 绑定(避免每次调用都 bind) + bound_models: Dict[str, Any] = {} + for name, llm in chat_services.items(): + if tools: + bound_models[name] = llm.bind_tools(tools) + else: + bound_models[name] = llm + + # 预构建 prompt prompt = create_system_prompt(tools) - llm_with_tools = llm.bind_tools(tools) - - # 恢复带 RunnableLambda 的链,并在下方使用 astream 遍历 - chain = prompt | llm_with_tools - + from langchain_core.runnables.config import RunnableConfig - + async def call_llm(state: MainGraphState, config: RunnableConfig) -> Dict[str, Any]: """ - LLM 调用节点(异步方法) + LLM 调用节点(动态选择模型) Args: state: 当前对话状态 @@ -46,7 +51,7 @@ def create_llm_call_node(llm, tools: list): 更新后的状态字典 """ log_state_change("llm_call", state, "进入") - + memory_context = getattr(state, "memory_context", "暂无用户信息") start_time = time.time() @@ -62,9 +67,20 @@ def create_llm_call_node(llm, tools: list): "last_elapsed_time": elapsed_time, "turns_since_last_summary": getattr(state, 'turns_since_last_summary', 0) + 1, } - + + # 动态选择模型 + model_name = getattr(state, "current_model", "") + if not model_name or model_name not in bound_models: + # 回退到第一个可用模型 + fallback_name = next(iter(bound_models.keys())) + info(f"[llm_call] 模型 '{model_name}' 不可用,回退到 '{fallback_name}'") + model_name = fallback_name + + llm_with_tools = bound_models[model_name] + info(f"[llm_call] 使用模型: {model_name}") + try: - # 添加 RAG 上下文到消息 + # 添加上下文到消息 messages_with_context = list(state.messages) if state.rag_context: from langchain_core.messages import SystemMessage @@ -77,9 +93,10 @@ def create_llm_call_node(llm, tools: list): break if not inserted: messages_with_context.insert(0, rag_system_msg) - + # 恢复为:手动进行 astream,并将所有的 chunk 拼接成最终的 response 返回。 # LangGraph 会自动监听这期间产生的所有 token。 + chain = prompt | llm_with_tools chunks = [] async for chunk in chain.astream( { @@ -89,7 +106,7 @@ def create_llm_call_node(llm, tools: list): config=config ): chunks.append(chunk) - + # 将所有 chunk 合并成最终的 AIMessage if chunks: response = chunks[0] @@ -97,14 +114,14 @@ def create_llm_call_node(llm, tools: list): response = response + chunk else: response = AIMessage(content="") - + elapsed_time = time.time() - start_time - + # 提取 token 用量(兼容不同 LLM 提供商的元数据格式) token_usage = {} input_tokens = 0 output_tokens = 0 - + # 尝试从 response_metadata 中提取 if hasattr(response, 'response_metadata') and response.response_metadata: meta = response.response_metadata @@ -112,33 +129,33 @@ def create_llm_call_node(llm, tools: list): token_usage = meta['token_usage'] elif 'usage' in meta: token_usage = meta['usage'] - + # 尝试从 additional_kwargs 中提取 if not token_usage and hasattr(response, 'additional_kwargs'): add_kwargs = response.additional_kwargs if 'llm_output' in add_kwargs and 'token_usage' in add_kwargs['llm_output']: token_usage = add_kwargs['llm_output']['token_usage'] - + # 提取具体的 token 数值 if token_usage: input_tokens = token_usage.get('prompt_tokens', token_usage.get('input_tokens', 0)) output_tokens = token_usage.get('completion_tokens', token_usage.get('output_tokens', 0)) - + # 打印 LLM 的完整输出 debug("\n" + "="*80) - debug("📥 [LLM输出] 大模型返回的完整响应:") + debug(f"📥 [LLM输出] 模型: {model_name} 返回的完整响应:") debug(f" 消息类型: {response.type.upper()}") debug(f" 内容长度: {len(str(response.content))} 字符") debug("-"*80) debug(f"{response.content}") - + # 打印响应统计信息 info(f"⏱️ [LLM统计] 调用耗时: {elapsed_time:.2f}秒") info(f"📊 [LLM统计] Token用量: 输入={input_tokens}, 输出={output_tokens}, 总计={input_tokens + output_tokens}") if token_usage: debug(f"📋 [LLM统计] 详细用量: {token_usage}") debug("="*80 + "\n") - + # 检查是否有工具调用 has_tool_calls = hasattr(response, 'tool_calls') and len(response.tool_calls) > 0 @@ -151,21 +168,22 @@ def create_llm_call_node(llm, tools: list): "final_result": response.content, "success": True, "current_phase": "done", - "has_tool_calls": has_tool_calls + "has_tool_calls": has_tool_calls, + "current_model": model_name # 记录实际使用的模型 } - + log_state_change("llm_call", state, "离开") return result - + except Exception as e: elapsed_time = time.time() - start_time - error(f"\n❌ [LLM错误] 调用失败 (耗时: {elapsed_time:.2f}秒)") + error(f"\n❌ [LLM错误] 模型 {model_name} 调用失败 (耗时: {elapsed_time:.2f}秒)") error(f" 错误类型: {type(e).__name__}") error(f" 错误信息: {str(e)}") import traceback - traceback.print_exc() + error(f"📋 堆栈: {traceback.format_exc()}") debug("="*80 + "\n") - + # 返回一个友好的错误消息 error_response = AIMessage( content="抱歉,模型暂时无法响应,可能是网络超时或服务繁忙,请稍后再试。" @@ -178,10 +196,11 @@ def create_llm_call_node(llm, tools: list): "turns_since_last_summary": getattr(state, 'turns_since_last_summary', 0) + 1, "final_result": "抱歉,模型暂时无法响应,可能是网络超时或服务繁忙,请稍后再试。", "success": False, - "current_phase": "done" + "current_phase": "done", + "current_model": model_name } - + log_state_change("llm_call", state, "离开(异常)") return error_result - - return call_llm \ No newline at end of file + + return call_llm diff --git a/backend/app/main_graph/nodes/routing.py b/backend/app/main_graph/nodes/routing.py index 1f43bd7..24985c4 100644 --- a/backend/app/main_graph/nodes/routing.py +++ b/backend/app/main_graph/nodes/routing.py @@ -118,3 +118,21 @@ def route_by_reasoning(state: MainGraphState) -> str: info(f"[条件路由] 动作={latest_action}, 目标={target}") return target + + +# ========== 完成阶段条件路由函数 ========== + +def should_summarize(state: MainGraphState) -> str: + """ + 检查是否需要总结对话(对话足够长时) + + Args: + state: 当前图状态 + + Returns: + "summarize" 或 "finalize" + """ + if state.turns_since_last_summary >= 5: # 每5轮对话总结一次 + return "summarize" + else: + return "finalize" diff --git a/backend/app/main_graph/state.py b/backend/app/main_graph/state.py index 5599a17..92267df 100644 --- a/backend/app/main_graph/state.py +++ b/backend/app/main_graph/state.py @@ -6,7 +6,7 @@ Main Graph State Definition - React Mode Enhanced from enum import Enum, auto from typing import Optional, Dict, Any, Annotated, Sequence, TypedDict, List from dataclasses import dataclass, field -from app.main_graph.graph import add_messages +from langgraph.graph import add_messages from langchain_core.messages import BaseMessage @@ -62,6 +62,7 @@ class MainGraphState: # ========== 主图控制字段 ========== user_query: str = "" current_action: CurrentAction = CurrentAction.NONE + current_model: str = "" # 新增:本次请求使用的模型 intent_confidence: float = 0.0 # ========== React 推理专用字段 ========== diff --git a/backend/app/main_graph/subgraph_wrapper.py b/backend/app/main_graph/subgraph_wrapper.py new file mode 100644 index 0000000..0be9b4c --- /dev/null +++ b/backend/app/main_graph/subgraph_wrapper.py @@ -0,0 +1,159 @@ +""" +子图包装器 - 为子图添加错误处理和事件追踪 +""" + +from typing import Dict, Any, Optional +from datetime import datetime + +from .state import MainGraphState, ErrorRecord, ErrorSeverity +from ..logger import info + + +def wrap_subgraph_for_error_handling(subgraph, name: str): + """ + 包装子图,使其错误能传递给主图 + + Args: + subgraph: 编译好的子图 + name: 子图名称(用于错误标识) + + Returns: 包装后的节点函数 + """ + async def wrapped_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState: + # 发送子图开始事件 + if config: + try: + from langchain_core.callbacks.manager import adispatch_custom_event + callbacks = config.get("callbacks") + if callbacks: + await adispatch_custom_event( + "react_reasoning", + { + "step": state.reasoning_step, + "action": f"{name}_subgraph_start", + "confidence": 1.0, + "reasoning": f"开始执行 {name} 子图..." + }, + callbacks=callbacks + ) + except Exception as e: + info(f"[{name}_subgraph] 无法发送开始事件: {e}") + + try: + # 调用子图 + result = subgraph.invoke(state) + + # 更新主图状态 + subgraph_result = None + if name == "contact": + state.contact_result = result + subgraph_result = result.get("final_result", "") + elif name == "dictionary": + state.dictionary_result = result + subgraph_result = result.get("final_result", "") + elif name == "news_analysis": + state.news_result = result + subgraph_result = result.get("final_result", "") + + # 设置最终结果 + if subgraph_result: + state.final_result = subgraph_result + else: + state.final_result = "子图执行完成" + + # 标记成功 + state.success = True + state.current_phase = "done" + state.reasoning_history.append({ + "step": state.reasoning_step, + "action": "subgraph_completed", + "confidence": 1.0, + "reasoning": f"{name}子图执行完成", + "timestamp": datetime.now().isoformat() + }) + + # 发送子图完成事件 + if config: + try: + from langchain_core.callbacks.manager import adispatch_custom_event + callbacks = config.get("callbacks") + if callbacks: + await adispatch_custom_event( + "react_reasoning", + { + "step": state.reasoning_step, + "action": f"{name}_subgraph_complete", + "confidence": 1.0, + "reasoning": f"{name} 子图执行完成" + }, + callbacks=callbacks + ) + except Exception as e: + info(f"[{name}_subgraph] 无法发送完成事件: {e}") + + return state + + except Exception as e: + # 捕获子图错误,传递给主图 + error_record = ErrorRecord( + error_type=f"{name}SubgraphError", + error_message=str(e), + severity=ErrorSeverity.WARNING, + source=f"{name}_subgraph", + timestamp=datetime.now().isoformat(), + retry_count=0, + max_retries=1, + context={"user_query": state.user_query} + ) + state.errors.append(error_record) + state.current_error = error_record + state.current_phase = "error_handling" + state.success = False + + # 发送子图错误事件 + if config: + try: + from langchain_core.callbacks.manager import adispatch_custom_event + callbacks = config.get("callbacks") + if callbacks: + await adispatch_custom_event( + "react_reasoning", + { + "step": state.reasoning_step, + "action": f"{name}_subgraph_error", + "confidence": 1.0, + "reasoning": f"{name} 子图执行失败: {str(e)}" + }, + callbacks=callbacks + ) + except Exception as e: + info(f"[{name}_subgraph] 无法发送错误事件: {e}") + + return state + + return wrapped_node + + +def create_subgraph_nodes(contact_graph, dictionary_graph, news_analysis_graph) -> Dict[str, Any]: + """ + 创建所有子图节点的字典 + + Args: + contact_graph: 联系人子图 + dictionary_graph: 词典子图 + news_analysis_graph: 新闻分析子图 + + Returns: + 子图节点字典 {name: wrapped_node} + """ + return { + "contact_subgraph": wrap_subgraph_for_error_handling( + contact_graph.compile(), "contact" + ), + "dictionary_subgraph": wrap_subgraph_for_error_handling( + dictionary_graph.compile(), "dictionary" + ), + "news_analysis_subgraph": wrap_subgraph_for_error_handling( + news_analysis_graph.compile(), "news_analysis" + ), + } diff --git a/backend/app/main_graph/utils/__init__.py b/backend/app/main_graph/utils/__init__.py deleted file mode 100644 index a03be40..0000000 --- a/backend/app/main_graph/utils/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""主图工具函数""" diff --git a/backend/app/main_graph/utils/main_graph_builder.py b/backend/app/main_graph/utils/main_graph_builder.py deleted file mode 100644 index cae948f..0000000 --- a/backend/app/main_graph/utils/main_graph_builder.py +++ /dev/null @@ -1,371 +0,0 @@ -""" -整合后的完整主图构建器 - 所有节点都直接操作 MainGraphState -""" - -from ..graph import StateGraph, START, END -from typing import Dict, Any, Optional -from langchain_core.runnables.config import RunnableConfig - -from ..state import MainGraphState -from ..nodes.reasoning import react_reason_node -from ..nodes.web_search import web_search_node -from ..nodes.error_handling import error_handling_node -from ..nodes.routing import init_state_node, route_by_reasoning -from ..nodes.hybrid_router import ( - hybrid_router_node, - route_from_hybrid_decision, - check_fast_path_success, -) -from ..nodes.fast_paths import ( - fast_chitchat_node, - fast_rag_node, - fast_tool_node, -) -from ..nodes.llm_call import create_llm_call_node -from ..nodes.rag_nodes import rag_retrieve_node -from ..nodes.retrieve_memory import create_retrieve_memory_node -from ..nodes.memory_trigger import memory_trigger_node, set_mem0_client -from ..nodes.summarize import create_summarize_node -from ..nodes.finalize import finalize_node -from ...subgraphs.contact import build_contact_subgraph -from ...subgraphs.dictionary import build_dictionary_subgraph -from ...subgraphs.news_analysis import build_news_analysis_subgraph -from ...memory.mem0_client import Mem0Client -from ...logger import info, debug - - -# ========== 检查是否需要总结 ========== -def should_summarize(state: MainGraphState) -> str: - """ - 检查是否需要总结对话(对话足够长时) - - Args: - state: 当前图状态 - - Returns: - "summarize" 或 "finalize" - """ - if state.turns_since_last_summary >= 5: # 每5轮对话总结一次 - return "summarize" - else: - return "finalize" - - -# ========== 子图包装器(处理子图错误传递)========== -def wrap_subgraph_for_error_handling(subgraph, name: str): - """ - 包装子图,使其错误能传递给主图 - - Args: - subgraph: 编译好的子图 - name: 子图名称(用于错误标识) - - Returns: 包装后的节点函数 - """ - async def wrapped_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState: - # 发送子图开始事件 - if config: - try: - from langchain_core.callbacks.manager import adispatch_custom_event - callbacks = config.get("callbacks") - if callbacks: - await adispatch_custom_event( - "react_reasoning", - { - "step": state.reasoning_step, - "action": f"{name}_subgraph_start", - "confidence": 1.0, - "reasoning": f"开始执行 {name} 子图..." - }, - callbacks=callbacks - ) - except Exception as e: - info(f"[{name}_subgraph] 无法发送开始事件: {e}") - - try: - # 调用子图 - result = subgraph.invoke(state) - - # 更新主图状态 - subgraph_result = None - if name == "contact": - state.contact_result = result - subgraph_result = result.get("final_result", "") - elif name == "dictionary": - state.dictionary_result = result - subgraph_result = result.get("final_result", "") - elif name == "news_analysis": - state.news_result = result - subgraph_result = result.get("final_result", "") - - # 关键:设置最终结果 - if subgraph_result: - state.final_result = subgraph_result - else: - state.final_result = "子图执行完成" - - # 标记成功 - state.success = True - state.current_phase = "done" - # 标记不再需要推理,避免循环 - state.reasoning_history.append({ - "step": state.reasoning_step, - "action": "subgraph_completed", - "confidence": 1.0, - "reasoning": f"{name}子图执行完成", - "timestamp": datetime.now().isoformat() - }) - - # 发送子图完成事件 - if config: - try: - from langchain_core.callbacks.manager import adispatch_custom_event - callbacks = config.get("callbacks") - if callbacks: - await adispatch_custom_event( - "react_reasoning", - { - "step": state.reasoning_step, - "action": f"{name}_subgraph_complete", - "confidence": 1.0, - "reasoning": f"{name} 子图执行完成" - }, - callbacks=callbacks - ) - except Exception as e: - info(f"[{name}_subgraph] 无法发送完成事件: {e}") - - return state - - except Exception as e: - # 捕获子图错误,传递给主图 - from ..state import ErrorRecord, ErrorSeverity - from datetime import datetime - - error_record = ErrorRecord( - error_type=f"{name}SubgraphError", - error_message=str(e), - severity=ErrorSeverity.WARNING, - source=f"{name}_subgraph", - timestamp=datetime.now().isoformat(), - retry_count=0, - max_retries=1, - context={"user_query": state.user_query} - ) - state.errors.append(error_record) - state.current_error = error_record - state.current_phase = "error_handling" - state.success = False - - # 发送子图错误事件 - if config: - try: - from langchain_core.callbacks.manager import adispatch_custom_event - callbacks = config.get("callbacks") - if callbacks: - await adispatch_custom_event( - "react_reasoning", - { - "step": state.reasoning_step, - "action": f"{name}_subgraph_error", - "confidence": 1.0, - "reasoning": f"{name} 子图执行失败: {str(e)}" - }, - callbacks=callbacks - ) - except Exception as e: - info(f"[{name}_subgraph] 无法发送错误事件: {e}") - - return state - - return wrapped_node - -# ========== 主图构建 ========== - -def build_react_main_graph(llm=None, tools=None, mem0_client=None, use_hybrid_router: bool = True) -> StateGraph: - """ - 构建整合后的完整主图(支持混合路由) - - Args: - llm: LangChain ChatModel 实例 - tools: 工具列表 - mem0_client: Mem0 客户端实例 - use_hybrid_router: 是否使用混合路由(快速路径 + React 循环) - - Returns: - StateGraph: 构建好的图 - """ - # 创建图 - graph = StateGraph(MainGraphState) - - # 设置全局 mem0_client - if mem0_client: - set_mem0_client(mem0_client) - - # 创建节点 - llm_node = None - if llm is not None: - llm_node = create_llm_call_node(llm, tools or []) - - retrieve_memory_node = None - summarize_node = None - if mem0_client: - retrieve_memory_node = create_retrieve_memory_node(mem0_client) - summarize_node = create_summarize_node(mem0_client) - - # ========== 添加节点 ========== - - # 第一阶段:记忆检索 - if retrieve_memory_node: - graph.add_node("retrieve_memory", retrieve_memory_node) - graph.add_node("memory_trigger", memory_trigger_node) - - # 第二阶段:初始化 - graph.add_node("init_state", init_state_node) - - # ========== 混合路由节点(如果启用) ========== - if use_hybrid_router: - graph.add_node("hybrid_router", hybrid_router_node) - graph.add_node("fast_chitchat", fast_chitchat_node) - graph.add_node("fast_rag", fast_rag_node) - graph.add_node("fast_tool", fast_tool_node) - - # 第三阶段:React 循环推理(始终保留) - graph.add_node("react_reason", react_reason_node) - graph.add_node("rag_retrieve", rag_retrieve_node) - graph.add_node("web_search", web_search_node) - graph.add_node("handle_error", error_handling_node) - - if llm_node is not None: - graph.add_node("llm_call", llm_node) - - # 子图节点 - contact_graph = build_contact_subgraph() - dictionary_graph = build_dictionary_subgraph() - news_analysis_graph = build_news_analysis_subgraph() - - graph.add_node( - "contact_subgraph", - wrap_subgraph_for_error_handling(contact_graph.compile(), "contact") - ) - graph.add_node( - "dictionary_subgraph", - wrap_subgraph_for_error_handling(dictionary_graph.compile(), "dictionary") - ) - graph.add_node( - "news_analysis_subgraph", - wrap_subgraph_for_error_handling(news_analysis_graph.compile(), "news_analysis") - ) - - # 第四阶段:完成处理 - if summarize_node: - graph.add_node("summarize", summarize_node) - graph.add_node("finalize", finalize_node) - - # ========== 添加边 ========== - - # 第一阶段:记忆检索 - if retrieve_memory_node: - graph.add_edge(START, "retrieve_memory") - graph.add_edge("retrieve_memory", "memory_trigger") - else: - graph.add_edge(START, "memory_trigger") - - # 进入初始化 - graph.add_edge("memory_trigger", "init_state") - - # ========== 混合路由分支(如果启用) ========== - if use_hybrid_router: - graph.add_edge("init_state", "hybrid_router") - - # 从 hybrid_router 条件分支 - graph.add_conditional_edges( - "hybrid_router", - route_from_hybrid_decision, - { - "fast_chitchat": "fast_chitchat", - "fast_rag": "fast_rag", - "fast_tool": "fast_tool", - "react_loop": "react_reason" - } - ) - - # 快速路径的完成检查 - for fast_node in ["fast_chitchat", "fast_rag", "fast_tool"]: - graph.add_conditional_edges( - fast_node, - check_fast_path_success, - { - "llm_call": "llm_call", - "escalate": "react_reason" - } - ) - - info(f"✅ [图构建] 混合路由模式已启用") - else: - # 无混合路由,直接到 react_reason - graph.add_edge("init_state", "react_reason") - info(f"✅ [图构建] 纯 React 模式") - - # ========== React 循环边(始终保留) ========== - graph.add_conditional_edges( - "react_reason", - route_by_reasoning, - { - "rag_retrieve": "rag_retrieve", - "web_search": "web_search", - "contact_subgraph": "contact_subgraph", - "dictionary_subgraph": "dictionary_subgraph", - "news_analysis_subgraph": "news_analysis_subgraph", - "handle_error": "handle_error", - "llm_call": "llm_call" - } - ) - - # 循环边(rag、web_search、子图、error都回到 reason) - graph.add_edge("rag_retrieve", "react_reason") - graph.add_edge("web_search", "react_reason") - graph.add_edge("contact_subgraph", "react_reason") - graph.add_edge("dictionary_subgraph", "react_reason") - graph.add_edge("news_analysis_subgraph", "react_reason") - graph.add_edge("handle_error", "react_reason") - - # ========== 最终完成阶段 ========== - if llm_node is not None: - if summarize_node: - # 检查是否需要总结 - graph.add_conditional_edges( - "llm_call", - should_summarize, - { - "summarize": "summarize", - "finalize": "finalize" - } - ) - graph.add_edge("summarize", "finalize") - else: - # 没有 summarize 节点,直接 finalize - graph.add_edge("llm_call", "finalize") - - # 完成 - graph.add_edge("finalize", END) - - info(f"✅ [图构建] 整合后的完整主图构建完成(混合路由: {use_hybrid_router})") - - return graph - - -# ========== 兼容性:保留旧的函数名 ========== -def build_main_graph() -> StateGraph: - """ - 兼容性函数:旧代码调用 build_main_graph() 时返回 React 版本 - """ - return build_react_main_graph() - - -# ========== 导出 ========== -__all__ = [ - "build_react_main_graph", - "build_main_graph", - "wrap_subgraph_for_error_handling" -] \ No newline at end of file diff --git a/backend/app/model_services/__init__.py b/backend/app/model_services/__init__.py index 5d7f173..9ee51e5 100644 --- a/backend/app/model_services/__init__.py +++ b/backend/app/model_services/__init__.py @@ -6,11 +6,17 @@ from .embedding_services import get_embedding_service from .rerank_services import get_rerank_service, BaseRerankService -from .chat_services import get_small_llm_service +from .chat_services import ( + get_small_llm_service, + get_cached_chat_services, + get_all_chat_services +) __all__ = [ "get_embedding_service", "get_rerank_service", "get_small_llm_service", + "get_cached_chat_services", + "get_all_chat_services", "BaseRerankService" ] diff --git a/backend/app/model_services/chat_services.py b/backend/app/model_services/chat_services.py index f75dd7c..581a2c3 100644 --- a/backend/app/model_services/chat_services.py +++ b/backend/app/model_services/chat_services.py @@ -33,6 +33,21 @@ from app.config import ( logger = logging.getLogger(__name__) +# 缓存已初始化的模型字典 +_cached_services: Dict[str, BaseChatModel] | None = None + + +def _check_http_service_available(base_url: str, api_key: str = "", timeout: float = 2.0) -> bool: + """通过探测 /models 端点检查 HTTP API 是否可用(内部工具函数)""" + try: + import httpx + client = httpx.Client(base_url=base_url.rstrip('/'), timeout=timeout) + headers = {"Authorization": f"Bearer {api_key}"} if api_key else {} + resp = client.get("/models", headers=headers) + return resp.status_code == 200 + except Exception: + return False + class LocalVLLMChatProvider(BaseServiceProvider[BaseChatModel]): """ @@ -54,46 +69,8 @@ class LocalVLLMChatProvider(BaseServiceProvider[BaseChatModel]): logger.warning("VLLM_BASE_URL 未配置") return False - try: - # 先测试主机名能否解析 - import httpx - from urllib.parse import urlparse - - parsed_url = urlparse(VLLM_BASE_URL) - host = parsed_url.hostname - port = parsed_url.port or (80 if parsed_url.scheme == 'http' else 443) - - # 测试能否建立 TCP 连接(快速失败) - import socket - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(2.0) - try: - sock.connect((host, port)) - sock.close() - except Exception as e: - logger.warning(f"本地 VLLM 服务无法连接: {host}:{port} - {e}") - return False - - # 再尝试调用简单的 API(比如 models 接口) - client = httpx.Client(base_url=VLLM_BASE_URL.rstrip('/'), timeout=5.0) - headers = {} - if LLM_API_KEY: - headers["Authorization"] = f"Bearer {LLM_API_KEY}" - - try: - response = client.get("/models", headers=headers) - if response.status_code == 200: - logger.info(f"本地 VLLM 服务可用: {self._model}") - return True - except Exception: - pass - - # 如果 /v1/models 失败,也认为服务不可用 - logger.warning(f"本地 VLLM 服务响应异常") - return False - except Exception as e: - logger.warning(f"本地 VLLM 服务不可用: {e}") - return False + # 使用统一的 HTTP 探测方法 + return _check_http_service_available(VLLM_BASE_URL, LLM_API_KEY, timeout=2.0) def get_service(self) -> BaseChatModel: """ @@ -238,45 +215,8 @@ class LocalSmallModelProvider(BaseServiceProvider[BaseChatModel]): logger.warning("SMALL_VLLM_BASE_URL 未配置,本地小模型不可用") return False - try: - # 先测试主机名能否解析 - import httpx - from urllib.parse import urlparse - - parsed_url = urlparse(self._base_url) - host = parsed_url.hostname - port = parsed_url.port or (80 if parsed_url.scheme == 'http' else 443) - - # 测试能否建立 TCP 连接(快速失败) - import socket - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(2.0) - try: - sock.connect((host, port)) - sock.close() - except Exception as e: - logger.warning(f"本地小模型服务无法连接: {host}:{port} - {e}") - return False - - # 再尝试调用简单的 API - client = httpx.Client(base_url=self._base_url.rstrip('/'), timeout=5.0) - headers = {} - if self._api_key: - headers["Authorization"] = f"Bearer {self._api_key}" - - try: - response = client.get("/models", headers=headers) - if response.status_code == 200: - logger.info(f"本地小模型服务可用: {self._model}") - return True - except Exception: - pass - - logger.warning(f"本地小模型服务响应异常") - return False - except Exception as e: - logger.warning(f"本地小模型服务不可用: {e}") - return False + # 使用统一的 HTTP 探测方法 + return _check_http_service_available(self._base_url, self._api_key, timeout=2.0) def get_service(self) -> BaseChatModel: """获取本地小模型服务""" @@ -358,25 +298,18 @@ def get_chat_service() -> BaseChatModel: return chain.get_available_service() -def get_all_chat_services() -> Dict[str, BaseChatModel]: - """ - 获取所有可用的生成式大模型服务(用于多模型切换) - - Returns: - Dict[str, BaseChatModel]: 模型名称 -> ChatModel 实例 的字典 - """ +def _init_chat_services() -> Dict[str, BaseChatModel]: + """实际初始化所有可用模型(仅在首次调用)""" services = {} for name, provider_factory in CHAT_PROVIDERS.items(): try: provider = provider_factory() if provider.is_available(): - logger.info(f"模型 '{name}' 可用") services[name] = provider.get_service() - else: - logger.warning(f"模型 '{name}' 不可用,跳过") + logger.info(f"已加载模型: {name}") except Exception as e: - logger.warning(f"初始化模型 '{name}' 失败: {e}") + logger.warning(f"模型 {name} 初始化失败: {e}") if not services: raise RuntimeError(f"没有可用的生成式大模型,尝试了: {list(CHAT_PROVIDERS.keys())}") @@ -384,6 +317,25 @@ def get_all_chat_services() -> Dict[str, BaseChatModel]: return services +def get_cached_chat_services() -> Dict[str, BaseChatModel]: + """获取缓存的可用模型字典(用于单图动态注入)""" + global _cached_services + if _cached_services is None: + _cached_services = _init_chat_services() + return _cached_services + + +def get_all_chat_services() -> Dict[str, BaseChatModel]: + """ + 获取所有可用的生成式大模型服务(用于多模型切换,保留兼容性) + 新代码请使用 get_cached_chat_services() 获取缓存版本 + + Returns: + Dict[str, BaseChatModel]: 模型名称 -> ChatModel 实例 的字典 + """ + return get_cached_chat_services() + + def get_small_llm_service() -> BaseChatModel: """ 获取轻量级大模型服务(用于查询改写、意图分类等简单任务) diff --git a/backend/app/subgraphs/contact/graph.py b/backend/app/subgraphs/contact/graph.py index 4026179..0c00360 100644 --- a/backend/app/subgraphs/contact/graph.py +++ b/backend/app/subgraphs/contact/graph.py @@ -4,7 +4,7 @@ Contact Subgraph Builder 支持 API 注入的工厂模式 """ -from app.main_graph.graph import StateGraph, START, END +from langgraph.graph import StateGraph, START, END from .state import ContactState from .nodes import create_contact_nodes diff --git a/backend/app/subgraphs/dictionary/graph.py b/backend/app/subgraphs/dictionary/graph.py index bc65340..6f3ce31 100644 --- a/backend/app/subgraphs/dictionary/graph.py +++ b/backend/app/subgraphs/dictionary/graph.py @@ -3,7 +3,7 @@ Dictionary Subgraph Builder - Complete """ -from app.main_graph.graph import StateGraph, START, END +from langgraph.graph import StateGraph, START, END from .state import DictionaryState from .nodes import ( diff --git a/backend/app/subgraphs/news_analysis/graph.py b/backend/app/subgraphs/news_analysis/graph.py index 1aa07c8..9dfbf90 100644 --- a/backend/app/subgraphs/news_analysis/graph.py +++ b/backend/app/subgraphs/news_analysis/graph.py @@ -3,7 +3,7 @@ News Analysis Subgraph Builder """ -from app.main_graph.graph import StateGraph, START, END +from langgraph.graph import StateGraph, START, END from .state import NewsAnalysisState from .nodes import ( diff --git a/backend/requirements.txt b/backend/requirements.txt index fef6bd2..32e0bb3 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -42,6 +42,7 @@ PyYAML>=6.0.3 numpy>=1.26.2 pyjwt>=2.8.0 ddgs>=6.0.0 # 免费联网搜索(原 duckduckgo-search 已重命名) +tavily-python>=0.5.0 # Tavily 搜索 API(需要 API Key,质量更高) matplotlib>=3.9.0 # 可视化图表 # Document Processing diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index 30b8d65..9267a28 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -66,6 +66,12 @@ services: - MEMORY_SUMMARIZE_INTERVAL=${MEMORY_SUMMARIZE_INTERVAL:-10} - FASTEMBED_CACHE_PATH=${FASTEMBED_CACHE_PATH:-/app/fastembed_cache} + # ========================================================================= + # Tavily 搜索配置(可选,有 API Key 时优先使用) + # ========================================================================= + - TAVILY_API_KEY=${TAVILY_API_KEY:-} + - TAVILY_MAX_RESULTS=${TAVILY_MAX_RESULTS:-5} + # ========================================================================= # 前端通信地址(Docker 内部网络) # ========================================================================= diff --git a/tools/run.py b/tools/run.py index 446ee55..d6a6ae9 100644 --- a/tools/run.py +++ b/tools/run.py @@ -11,6 +11,5 @@ sys.path.insert(0, str(backend_path)) load_dotenv(project_root / ".env") if __name__ == "__main__": - from tools.test.test_graph_branches import main - import asyncio - asyncio.run(main()) + from tools.test import test_tavily_search + test_tavily_search.main() diff --git a/tools/test/test_tavily_search.py b/tools/test/test_tavily_search.py new file mode 100644 index 0000000..1abbd2a --- /dev/null +++ b/tools/test/test_tavily_search.py @@ -0,0 +1,149 @@ +#!/usr/bin/env python3 +"""测试 Tavily 搜索功能 - 直接调用 API""" +import sys +from pathlib import Path +from dataclasses import dataclass +from datetime import datetime +from typing import List, Optional +from dotenv import load_dotenv + +# 路径设置 +project_root = Path(__file__).resolve().parent.parent.parent +sys.path.insert(0, str(project_root)) +load_dotenv(project_root / ".env") + +import os + +TAVILY_API_KEY = os.getenv("TAVILY_API_KEY") +TAVILY_MAX_RESULTS = int(os.getenv("TAVILY_MAX_RESULTS") or "5") + + +@dataclass +class SearchResult: + """搜索结果数据类""" + title: str + url: str + snippet: str + source: str = "DuckDuckGo" + timestamp: datetime = None + + def __post_init__(self): + if self.timestamp is None: + self.timestamp = datetime.now() + + +def test_tavily_api_key(): + """测试 API Key 配置""" + print("=" * 60) + print("测试 1: 检查 Tavily API Key") + print("=" * 60) + + if TAVILY_API_KEY: + print(f"✓ TAVILY_API_KEY 已配置: {TAVILY_API_KEY[:15]}...") + else: + print("✗ TAVILY_API_KEY 未配置") + print() + + +def test_tavily_search_direct(): + """直接测试 Tavily API""" + print("=" * 60) + print("测试 2: 直接调用 Tavily API") + print("=" * 60) + + if not TAVILY_API_KEY: + print("✗ 未配置 API Key,跳过测试") + return + + from tavily import TavilyClient + + client = TavilyClient(api_key=TAVILY_API_KEY) + + test_queries = [ + "Python 编程语言最新版本", + "LangGraph AI 框架", + ] + + for query in test_queries: + print(f"\n搜索: {query}") + print("-" * 40) + try: + response = client.search( + query=query, + max_results=3, + include_answer=True, + include_raw_content=False + ) + + print(f"✓ 搜索成功") + print(f" - 结果数量: {len(response.get('results', []))}") + + # 打印结果 + for i, item in enumerate(response.get("results", []), 1): + print(f"\n [{i}] {item.get('title', '')}") + print(f" URL: {item.get('url', '')}") + print(f" 摘要: {item.get('content', '')[:100]}...") + + # 如果有 answer + if response.get("answer"): + print(f"\n 🤖 AI 摘要: {response['answer'][:200]}...") + + except Exception as e: + print(f"✗ 搜索失败: {e}") + print() + + +def test_web_search_integration(): + """测试 web_search 模块集成""" + print("=" * 60) + print("测试 3: 测试 web_search 模块集成") + print("=" * 60) + + # 直接导入 web_search 模块(避免循环依赖) + web_search_path = project_root / "backend" / "app" / "core" / "web_search.py" + if not web_search_path.exists(): + print(f"✗ web_search.py 不存在于 {web_search_path}") + return + + print(f"✓ 找到 web_search.py: {web_search_path}") + + # 使用 exec 动态加载模块 + import importlib.util + spec = importlib.util.spec_from_file_location("web_search_module", web_search_path) + web_search_module = importlib.util.module_from_spec(spec) + + try: + spec.loader.exec_module(web_search_module) + print("✓ web_search 模块加载成功") + except Exception as e: + print(f"✗ 模块加载失败: {e}") + return + + # 测试搜索 + print("\n执行搜索测试:") + try: + result = web_search_module.web_search("今天天气怎么样", max_results=3) + print(f"✓ 搜索成功,返回 {len(result)} 字符") + print("-" * 40) + print(result[:800] + "..." if len(result) > 800 else result) + except Exception as e: + print(f"✗ 搜索失败: {e}") + print() + + +def main(): + print("\n" + "=" * 60) + print("🚀 Tavily 搜索功能测试") + print("=" * 60 + "\n") + + test_tavily_api_key() + test_tavily_search_direct() + test_web_search_integration() + + print("=" * 60) + print("✅ 测试完成") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/tools/visualize_graph.py b/tools/visualize_graph.py index a510e8a..bb0557d 100644 --- a/tools/visualize_graph.py +++ b/tools/visualize_graph.py @@ -6,28 +6,19 @@ LangGraph 图结构可视化脚本 """ import sys from pathlib import Path -from dotenv import load_dotenv -# 确定项目根目录(Agent1 目录) -# 当前文件位置:tools/visualize_graph.py -# 向上 1 级到 Agent1 +# 路径设置 PROJECT_ROOT = Path(__file__).parent.parent BACKEND_DIR = PROJECT_ROOT / "backend" +sys.path.insert(0, str(PROJECT_ROOT)) -# 关键:把 backend 目录加入 sys.path,这样才能找到 rag_core -# 注意:这只对直接运行脚本有效,对 -m 方式无效(因为 -m 方式在脚本运行前就导入了) -if str(BACKEND_DIR) not in sys.path: - sys.path.insert(0, str(BACKEND_DIR)) -if str(PROJECT_ROOT) not in sys.path: - sys.path.insert(0, str(PROJECT_ROOT)) - +from dotenv import load_dotenv load_dotenv(PROJECT_ROOT / ".env") +import asyncio from backend.app.agent.agent_service import AIAgentService from backend.app.config import DB_URI -from backend.app.main_graph.checkpoint.postgres.aio import AsyncPostgresSaver -import asyncio - +from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver async def visualize_graph(): """可视化 LangGraph 结构""" @@ -37,6 +28,8 @@ async def visualize_graph(): print(f"项目根目录: {PROJECT_ROOT}") print(f"Backend 目录: {BACKEND_DIR}") + + async with AsyncPostgresSaver.from_conn_string(DB_URI) as checkpointer: await checkpointer.setup() @@ -45,37 +38,52 @@ async def visualize_graph(): agent_service = AIAgentService(checkpointer) await agent_service.initialize() - for model_name, graph in agent_service.graphs.items(): - print(f"\n{'=' * 80}") - print(f" 模型: {model_name}") - print(f"{'=' * 80}") + # 获取图(单图方案) + graph = agent_service.graph + print("\n✅ Agent 服务初始化完成") - # 获取图结构 - graph_structure = graph.get_graph() + # 获取图结构 + graph_structure = graph.get_graph() - # 1. 直接打印节点和边 - print("\n[1] 节点列表:") - print("-" * 80) - for node_id, node in graph_structure.nodes.items(): - print(f" - {node_id}: {node.name}") + # 1. 直接打印节点和边 + print("\n" + "=" * 80) + print("[1] 节点列表") + print("=" * 80) + for node_id, node in graph_structure.nodes.items(): + print(f" 📦 {node_id}: {node.name}") - print("\n[2] 边列表:") - print("-" * 80) - for edge in graph_structure.edges: - print(f" {edge.source} --> {edge.target}") + print("\n" + "=" * 80) + print("[2] 边列表") + print("=" * 80) + for edge in graph_structure.edges: + print(f" {edge.source} --> {edge.target}") - # 3. ASCII 字符画(需要 grandalf) - print("\n[3] ASCII 字符画:") - print("-" * 80) - try: - print(graph_structure.draw_ascii()) - except Exception as e: - print(f"⚠️ ASCII 绘制失败: {e}") + # 2. ASCII 字符画 + print("\n" + "=" * 80) + print("[3] ASCII 字符画") + print("=" * 80) + try: + ascii_graph = graph_structure.draw_ascii() + print(ascii_graph) + except Exception as e: + print(f"⚠️ ASCII 绘制失败: {e}") - # 4. Mermaid 源码 - print("\n[4] Mermaid 源码 (可复制到 https://mermaid.live/):") - print("-" * 80) - print(graph_structure.draw_mermaid()) + # 3. Mermaid 源码 + print("\n" + "=" * 80) + print("[4] Mermaid 源码 (复制到 https://mermaid.live/)") + print("=" * 80) + try: + mermaid_code = graph_structure.draw_mermaid() + print(mermaid_code) + except Exception as e: + print(f"⚠️ Mermaid 生成失败: {e}") + + # 4. 节点统计 + print("\n" + "=" * 80) + print("[5] 图统计") + print("=" * 80) + print(f" 节点数量: {len(graph_structure.nodes)}") + print(f" 边数量: {len(graph_structure.edges)}") if __name__ == "__main__":