diff --git a/backend/app/agent/agent_service.py b/backend/app/agent/agent_service.py index bedaedf..8bf07c4 100644 --- a/backend/app/agent/agent_service.py +++ b/backend/app/agent/agent_service.py @@ -1,198 +1,88 @@ """ -AI Agent 服务类 - 完全简化版本! -按照指南实现,不用 stream_mode="messages" 避免重复 token! +AI Agent 服务类 """ -import json -import asyncio -from typing import AsyncGenerator, Dict, Any, Optional, Tuple +from typing import AsyncGenerator, Dict, Any -# LangGraph 序列化器(修复 checkpoint 反序列化警告) from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer -# 本地模块 from backend.app.model_services import get_cached_chat_services from backend.app.main_graph.main_graph_builder import build_agent_graph -from backend.app.logger import debug, info, warning, error -from backend.app.main_graph.state import AgentState -from .stream_context import set_stream_queue +from backend.app.logger import info +from backend.app.memory.mem0_client import Mem0Client + +from .service_config import ServiceConfig +from .stream_handler import run_graph_stream class AIAgentService: def __init__(self, checkpointer): self.checkpointer = checkpointer self.graph = None - self.chat_services = None - # Mem0 客户端 + self.config: ServiceConfig = None self.mem0_client = None - async def initialize(self): - # 0. 初始化 Mem0 客户端 - from ..memory.mem0_client import Mem0Client + async def initialize(self) -> "AIAgentService": + """初始化 Agent 服务""" self.mem0_client = Mem0Client() - - # 1. 获取缓存的模型字典 + self.chat_services = get_cached_chat_services() info(f"✅ 加载了 {len(self.chat_services)} 个可用模型: {list(self.chat_services.keys())}") - - # 2. 构建图 - info(f"🔄 构建 Agent 图...") + graph_builder = build_agent_graph( chat_services=self.chat_services, mem0_client=self.mem0_client ) - - # 编译图 self.graph = graph_builder.compile(checkpointer=self.checkpointer) + + self.config = ServiceConfig(self.chat_services) info(f"✅ Agent 图初始化完成") - + return self - def _resolve_model(self, model: str) -> str: - """ - 解析并验证模型名称,不可用时回退到第一个可用模型 - - Args: - model: 目标模型名称 - - Returns: - 实际使用的模型名称 - """ - if not model or model not in self.chat_services: - fallback = next(iter(self.chat_services.keys())) - warning(f"模型 '{model}' 不可用,回退到 '{fallback}'") - return fallback - return model - - def _build_invocation( + def _resolve_and_build( self, message: str, thread_id: str, model: str, user_id: str - ) -> Tuple[Dict[str, Any], Dict[str, Any]]: - """ - 构建图调用所需的 config 和 input_state - - Args: - message: 用户消息 - thread_id: 会话 ID - model: 模型名称 - user_id: 用户 ID - - Returns: - (config, input_state) 元组 - """ - from langchain_core.messages import HumanMessage - - config = { - "configurable": { - "thread_id": thread_id, - }, - "metadata": {"user_id": user_id} - } - - input_state = { - "messages": [HumanMessage(content=message)], - "user_id": user_id, - } - return config, input_state + ): + """解析模型并构建调用参数""" + resolved_model = self.config.resolve_model(model) + return resolved_model, self.config.build_invocation( + message, thread_id, resolved_model, user_id + ) async def process_message( self, message: str, thread_id: str, model: str = "", user_id: str = "default_user" ) -> dict: """处理用户消息,返回包含回复、token统计和耗时的字典""" - # 解析模型名称 - resolved_model = self._resolve_model(model) - - # 构建调用参数 - config, input_state = self._build_invocation(message, thread_id, resolved_model, user_id) + resolved_model, (config, input_state) = self._resolve_and_build( + message, thread_id, model, user_id + ) result = await self.graph.ainvoke(input_state, config=config) - # 优先使用 final_reply(finalize 节点返回) reply = result.get("final_reply", "") if not reply and result.get("messages"): reply = result["messages"][-1].content - token_usage = result.get("last_token_usage", {}) - elapsed_time = result.get("last_elapsed_time", 0.0) - - # 获取元数据 - metadata = result.get("metadata", {}) - return { "reply": reply, - "token_usage": token_usage, - "elapsed_time": elapsed_time, + "token_usage": result.get("last_token_usage", {}), + "elapsed_time": result.get("last_elapsed_time", 0.0), "model_used": resolved_model, - "metadata": metadata + "metadata": result.get("metadata", {}), } async def process_message_stream( self, message: str, thread_id: str, model: str = "", user_id: str = "default_user" ) -> AsyncGenerator[Dict[str, Any], None]: - """流式处理消息 - 完全简化!""" - # 解析模型名称 - resolved_model = self._resolve_model(model) - - # 构建调用参数 - config, input_state = self._build_invocation(message, thread_id, resolved_model, user_id) + """流式处理消息""" + resolved_model, (config, input_state) = self._resolve_and_build( + message, thread_id, model, user_id + ) info(f"🚀 开始执行 Agent 图,指定模型: {resolved_model}") - actual_model_used = resolved_model - # 创建 token 队列 - queue = asyncio.Queue() - set_stream_queue(queue) # 设置上下文变量 - - async def run_graph(): - """后台任务:运行 graph,流式事件都从 agent 节点内部发送!""" - try: - info(f"📡 开始调用 graph.astream()...") - - # 注意:只用 stream_mode=["updates"],不要 "messages"!避免重复 token! - async for _ in self.graph.astream( - input_state, - config=config, - stream_mode=["updates"], - version="v2", - subgraphs=True - ): - # 流式事件都从 agent.py 节点内部通过队列发送了 - # 这里不需要再发送任何事件 - pass - except Exception as e: - error(f"❌ 执行图时出错: {e}") - import traceback - error(f"📋 堆栈: {traceback.format_exc()}") - await queue.put({"type": "error", "message": str(e)}) - finally: - await queue.put(None) # 结束哨兵 - - # 启动后台任务 - bg_task = asyncio.create_task(run_graph()) - - try: - while True: - event = await queue.get() - if event is None: - break + async for event in run_graph_stream(self.graph, input_state, config): + if event.get("type") != "done": yield event - - except GeneratorExit: - # 客户端断开连接,取消后台任务 - info("⚠️ GeneratorExit,取消后台任务") - bg_task.cancel() - raise - finally: - # 保证任务被清理 - if not bg_task.done(): - info("⏹️ 清理后台任务") - bg_task.cancel() - try: - await bg_task - except asyncio.CancelledError: - info("✅ 后台任务已取消") - - # 发送结束事件,保证前端平稳关闭 - yield { - "type": "done", - "model_used": actual_model_used - } + else: + yield {**event, "model_used": resolved_model} diff --git a/backend/app/agent/service_config.py b/backend/app/agent/service_config.py new file mode 100644 index 0000000..5e36408 --- /dev/null +++ b/backend/app/agent/service_config.py @@ -0,0 +1,46 @@ +""" +Agent Service 配置模块 - 配置构建和解析 +""" + +from typing import Dict, Any, Tuple, Optional +from langchain_core.messages import HumanMessage +from backend.app.logger import warning + + +class ServiceConfig: + """配置构建器""" + + def __init__(self, chat_services: dict): + self.chat_services = chat_services + + def resolve_model(self, model: Optional[str]) -> str: + """ + 解析并验证模型名称,不可用时回退到第一个可用模型 + """ + if not model or model not in self.chat_services: + fallback = next(iter(self.chat_services.keys())) + warning(f"模型 '{model}' 不可用,回退到 '{fallback}'") + return fallback + return model + + def build_invocation( + self, message: str, thread_id: str, model: str, user_id: str + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """ + 构建图调用所需的 config 和 input_state + + Returns: + (config, input_state) 元组 + """ + config = { + "configurable": { + "thread_id": thread_id, + }, + "metadata": {"user_id": user_id} + } + + input_state = { + "messages": [HumanMessage(content=message)], + "user_id": user_id, + } + return config, input_state diff --git a/backend/app/agent/stream_handler.py b/backend/app/agent/stream_handler.py new file mode 100644 index 0000000..4da4a4c --- /dev/null +++ b/backend/app/agent/stream_handler.py @@ -0,0 +1,78 @@ +""" +流式处理模块 - 处理 Agent 执行的流式输出 +""" + +import asyncio +from typing import AsyncGenerator, Dict, Any + +from backend.app.logger import info, error +from .stream_context import set_stream_queue + + +async def run_graph_stream( + graph, + input_state: Dict[str, Any], + config: Dict[str, Any], +) -> AsyncGenerator[Dict[str, Any], None]: + """ + 运行图并通过队列流式输出事件 + + Args: + graph: 编译后的 LangGraph + input_state: 输入状态 + config: 配置 + + Yields: + 流式事件 + """ + queue: asyncio.Queue = asyncio.Queue() + set_stream_queue(queue) + + async def run_graph(): + """后台任务:运行 graph""" + try: + info(f"📡 开始调用 graph.astream()...") + async for _ in graph.astream( + input_state, + config=config, + stream_mode=["updates"], + version="v2", + subgraphs=True + ): + # 流式事件都从 agent.py 节点内部通过队列发送了 + pass + except Exception as e: + error(f"❌ 执行图时出错: {e}") + import traceback + error(f"📋 堆栈: {traceback.format_exc()}") + await queue.put({"type": "error", "message": str(e)}) + finally: + await queue.put(None) # 结束哨兵 + + # 启动后台任务 + bg_task = asyncio.create_task(run_graph()) + + try: + while True: + event = await queue.get() + if event is None: + break + yield event + + except GeneratorExit: + info("⚠️ GeneratorExit,取消后台任务") + bg_task.cancel() + raise + finally: + await _cleanup_task(bg_task) + + +async def _cleanup_task(bg_task: asyncio.Task) -> None: + """清理后台任务""" + if not bg_task.done(): + info("⏹️ 清理后台任务") + bg_task.cancel() + try: + await bg_task + except asyncio.CancelledError: + info("✅ 后台任务已取消") diff --git a/backend/app/core/rag_initializer.py b/backend/app/core/rag_initializer.py deleted file mode 100644 index dbd2233..0000000 --- a/backend/app/core/rag_initializer.py +++ /dev/null @@ -1,73 +0,0 @@ -# app/rag_initializer.py -from ...rag.tools import create_rag_tool -from ...rag.retriever import create_parent_hybrid_retriever -from ...model_services import get_embedding_service -from backend.app.logger import info, warning -import sys - -# 全局 RAG 工具 -_rag_tool = None -_initialized = False - - -def get_rag_tool() -> callable: - """获取全局 RAG 工具""" - return _rag_tool - - -def is_initialized() -> bool: - """检查是否已初始化""" - return _initialized - - -async def init_rag_tool(force: bool = False): - """ - 初始化 RAG 工具(注册到模块级变量,内部获取所需服务) - - Args: - force: 是否强制重新初始化 - - Returns: - RAG 工具(@tool 装饰函数)或 None - """ - global _rag_tool, _initialized - - # 防止重复初始化 - if _initialized and not force: - info("[RAG] 已初始化,跳过") - return _rag_tool - - try: - from backend.app.model_services.chat_services import get_chat_service - - info("🔄 正在初始化 RAG 检索系统...") - embeddings = get_embedding_service() - retriever = create_parent_hybrid_retriever( - collection_name="rag_documents", - search_k=5, - embeddings=embeddings, - ) - rewrite_llm = get_chat_service() - - rag_tool = create_rag_tool( - retriever=retriever, - llm=rewrite_llm, - num_queries=3, - rerank_top_n=5, - ) - - _rag_tool = rag_tool - _initialized = True - info(f"✅ RAG 检索工具初始化成功 (id={id(rag_tool)})") - return rag_tool - - except Exception as e: - warning(f"⚠️ RAG 检索工具初始化失败: {e}") - return None - - -def reset(): - """重置(用于测试)""" - global _rag_tool, _initialized - _rag_tool = None - _initialized = False diff --git a/backend/app/core/web_search.py b/backend/app/core/web_search.py index fe9d79b..2ae09df 100644 --- a/backend/app/core/web_search.py +++ b/backend/app/core/web_search.py @@ -3,12 +3,11 @@ Web Search Public Utility - Free, no API Key, using DuckDuckGo """ -from typing import List, Dict, Any, Optional +from typing import List, Optional from dataclasses import dataclass from datetime import datetime -import requests -import warnings -import re + +from backend.app.logger import info @dataclass @@ -44,47 +43,31 @@ class WebSearchTool: """ num_results = max_results or self.max_results - # 方式 1: Tavily (需要 API Key,质量最高) + # 尝试搜索方式,按优先级 + result = self._try_tavily(query, num_results) + if result is not None: + return result + + result = self._try_ddgs(query, num_results) + if result is not None: + return result + + # 兜底方案 + return self._get_mock_results(query, num_results) + + def _try_tavily(self, query: str, max_results: int) -> Optional[List[SearchResult]]: + """尝试 Tavily API 搜索""" try: - return self._search_tavily(query, num_results) + return self._search_tavily(query, max_results) except ImportError: - print("[WebSearch] tavily 未安装,尝试其他搜索方式") + info("[WebSearch] tavily 未安装") except Exception as e: - if "API_KEY" in str(e) or "未配置" in str(e): - print(f"[WebSearch] Tavily API Key 未配置: {e}") + error_msg = str(e) + if "API_KEY" in error_msg or "未配置" in error_msg: + info(f"[WebSearch] Tavily API Key 未配置") else: - print(f"[WebSearch] Tavily 搜索失败: {e}") - - # 方式 2: 尝试用 ddgs 包 - try: - from ddgs import DDGS - print(f"[WebSearch] 使用 ddgs 搜索: {query}") - with DDGS() as ddgs: - results = list(ddgs.text(query, max_results=num_results)) - if results: - search_results = [] - for r in results: - search_results.append(SearchResult( - title=r.get("title", ""), - url=r.get("href", ""), - snippet=r.get("body", ""), - source="DuckDuckGo" - )) - print(f"[WebSearch] ddgs 返回 {len(search_results)} 条结果") - return search_results - except ImportError: - print("[WebSearch] ddgs 未安装,尝试 duckduckgo-search") - except Exception as e: - print(f"[WebSearch] ddgs 搜索失败: {e}") - - # 方式 3: 尝试用简单 HTTP 请求 - try: - return self._search_http(query, num_results) - except Exception as e: - print(f"[WebSearch] HTTP 搜索也失败: {e}") - - # 方式 4: 返回模拟数据作为最后兜底 - return self._search_mock(query, num_results) + info(f"[WebSearch] Tavily 搜索失败: {e}") + return None def _search_tavily(self, query: str, max_results: int) -> List[SearchResult]: """使用 Tavily API 搜索""" @@ -111,56 +94,40 @@ class WebSearchTool: source="Tavily" )) - print(f"[WebSearch] Tavily 返回 {len(results)} 条结果") + info(f"[WebSearch] Tavily 返回 {len(results)} 条结果") return results - def _search_http(self, query: str, max_results: int) -> List[SearchResult]: - """用简单 HTTP 请求搜索(备用方案)- 尝试多个国内源""" - print(f"[WebSearch] 尝试 HTTP 搜索") - - # 方式 1: 尝试百度搜索(简单方式) + def _try_ddgs(self, query: str, max_results: int) -> Optional[List[SearchResult]]: + """尝试 DuckDuckGo 搜索""" try: - return self._search_baidu(query, max_results) - except Exception as e: - print(f"[WebSearch] 百度搜索失败: {e}") - - # 方式 2: 返回模拟数据 - return self._search_mock(query, max_results) + from ddgs import DDGS - def _search_baidu(self, query: str, max_results: int) -> List[SearchResult]: - """尝试百度搜索""" - import requests - from urllib.parse import quote - - url = f"https://www.baidu.com/s?wd={quote(query)}" - headers = { - "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36" - } - - try: - response = requests.get(url, headers=headers, timeout=10) - response.raise_for_status() - - # 简单解析百度搜索结果(简化版) results = [] - # 这里只是示意,真实百度搜索需要更复杂的解析 - results.append(SearchResult( - title=f"百度搜索: {query}", - url=url, - snippet="如需要真实搜索结果,请考虑使用百度搜索 API", - source="百度" - )) - return results - except Exception as e: - print(f"[WebSearch] 百度搜索也失败: {e}") - raise + with DDGS() as ddgs: + for r in ddgs.text(query, max_results=max_results): + results.append(SearchResult( + title=r.get("title", ""), + url=r.get("href", ""), + snippet=r.get("body", ""), + source="DuckDuckGo" + )) - def _search_mock(self, query: str, max_results: Optional[int] = None) -> List[SearchResult]: - """模拟搜索结果(兜底方案)""" - print(f"[WebSearch] 使用模拟搜索结果 (查询: {query})") - - # 根据查询内容生成更有意义的模拟结果 - mock_templates = [ + if results: + info(f"[WebSearch] ddgs 返回 {len(results)} 条结果") + return results + + except ImportError: + info("[WebSearch] ddgs 未安装") + except Exception as e: + info(f"[WebSearch] ddgs 搜索失败: {e}") + + return None + + def _get_mock_results(self, query: str, max_results: Optional[int] = None) -> List[SearchResult]: + """获取模拟搜索结果(兜底方案)""" + info(f"[WebSearch] 使用模拟搜索结果") + + templates = [ { "title": f"关于「{query}」的相关介绍", "snippet": "这是模拟结果。如需真实搜索,请检查容器网络连接或配置代理。", @@ -177,50 +144,48 @@ class WebSearchTool: "url": "https://example.com/more" } ] - + num = max_results or self.max_results results = [] - - for i, template in enumerate(mock_templates[:num]): + + for template in templates[:num]: results.append(SearchResult( title=template["title"], url=template["url"], snippet=template["snippet"], source="模拟数据" )) - + return results def format_search_results(self, results: List[SearchResult]) -> str: """ 格式化搜索结果(带引用溯源) - + Args: results: 搜索结果列表 - + Returns: 格式化后的 Markdown 文本 """ if not results: return "未找到相关搜索结果" - - lines = [] - lines.append("## 🔍 联网搜索结果\n") - + + lines = ["## 🔍 联网搜索结果\n"] + for idx, result in enumerate(results, 1): lines.append(f"### [{idx}] {result.title}") lines.append(f"- 🔗 来源:[{result.url}]({result.url})") lines.append(f"- 📝 摘要:{result.snippet}") lines.append(f"- 📅 时间:{result.timestamp.strftime('%Y-%m-%d %H:%M:%S')}") lines.append("") - - # 添加引用溯源说明 + lines.append("---") lines.append("💡 **引用溯源说明**:") lines.append("- 以上搜索结果均标注了来源链接") lines.append("- 使用方括号数字标识引用(如 [1]、[2])") lines.append("- 可通过链接追溯原始信息") - + return "\n".join(lines) @@ -239,11 +204,11 @@ def get_web_search_tool() -> WebSearchTool: def web_search(query: str, max_results: int = 5) -> str: """ 便捷函数:联网搜索并返回格式化结果 - + Args: query: 搜索关键词 max_results: 返回结果数量 - + Returns: 格式化后的搜索结果文本 """ diff --git a/backend/app/rag/tools.py b/backend/app/rag/tools.py deleted file mode 100644 index d66e688..0000000 --- a/backend/app/rag/tools.py +++ /dev/null @@ -1,73 +0,0 @@ -""" -RAG 工具模块(完全异步) - -将检索功能封装为 LangChain Tool,供 Agent 调用。 -采用固定流水线:多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档。 - -默认使用混合检索(稠密+BM25稀疏)+ 父子文档模式。 -""" -from typing import Callable, Optional -from langchain_core.tools import tool -from langchain_core.language_models import BaseLanguageModel -from langchain_core.retrievers import BaseRetriever -from ..rag.pipeline import RAGPipeline, create_rag_pipeline - - -def create_rag_tool( - retriever: Optional[BaseRetriever] = None, - llm: Optional[BaseLanguageModel] = "default_small", - num_queries: int = 3, - rerank_top_n: int = 5, - collection_name: str = "rag_documents", -) -> Callable: - """ - 创建一个配置好的 RAG 检索工具(完全异步)。 - - 默认使用混合检索(稠密+BM25稀疏)+ 父子文档模式。 - - Args: - retriever: 基础检索器对象(可选,不提供则自动创建) - llm: 用于生成多路查询的语言模型。 - - "default_small": (默认) 使用小模型(本地 + DeepSeek) - - None / False: 不做查询改写 - - BaseLanguageModel 实例: 自定义模型 - num_queries: 生成的查询变体数量 - rerank_top_n: 最终返回的文档数量 - collection_name: Qdrant 集合名称 - - Returns: - Async LangChain Tool 函数 - """ - pipeline = RAGPipeline( - retriever=retriever, - llm=llm, - num_queries=num_queries, - rerank_top_n=rerank_top_n, - collection_name=collection_name, - ) - - @tool - async def search_knowledge_base(query: str) -> str: - """ - 在知识库中搜索与查询相关的文档片段(完全异步)。 - - 使用混合检索(稠密向量语义 + BM25 关键词)+ 父子文档模式, - 检索效果最优。 - - Args: - query: 用户提出的问题或查询字符串 - - Returns: - 格式化后的相关文档内容 - """ - try: - documents = await pipeline.aretrieve(query) - if not documents: - return f"在知识库 '{collection_name}' 中未找到与 '{query}' 相关的信息。" - - context = pipeline.format_context(documents) - return context - except Exception as e: - return f"检索过程中发生错误: {str(e)}" - - return search_knowledge_base diff --git a/backend/app/tools/__init__.py b/backend/app/tools/__init__.py index 74dfebe..45d609f 100644 --- a/backend/app/tools/__init__.py +++ b/backend/app/tools/__init__.py @@ -2,124 +2,17 @@ Agent Tools - 所有工具统一定义 """ -from langchain_core.tools import tool -from backend.app.logger import info - -# ========== RAG ========== - -_rag_pipeline = None - - -def _get_rag_pipeline(): - global _rag_pipeline - if _rag_pipeline is None: - from backend.app.rag.pipeline import RAGPipeline - _rag_pipeline = RAGPipeline( - num_queries=3, - rerank_top_n=5, - use_rerank=True, - return_parent_docs=True, - ) - return _rag_pipeline - - -@tool -async def rag_search(query: str) -> str: - """ - 检索知识库获取相关信息 - - Returns: - 包含检索结果和置信度的结构化回复,格式: - - 内容:检索到的相关信息 - - 置信度评估:基于向量相似度、重排分数、LLM判断的综合评分 - """ - info(f"[Tool] rag_search: {query[:30]}...") - try: - pipeline = _get_rag_pipeline() - # 使用带置信度的检索 - result = await pipeline.aretrieve_with_confidence(query, original_query=query) - - if not result.content: - return "【RAG检索结果】\n未在知识库中找到相关内容。\n置信度:0.0\n建议:可尝试联网搜索获取信息。" - - # 构建包含置信度的回复 - confidence_desc = "高" - if result.confidence < 0.4: - confidence_desc = "低" - elif result.confidence < 0.6: - confidence_desc = "中" - - response = f"""【RAG检索结果】 -{result.content} - -【置信度评估】 -- 综合置信度:{result.confidence:.2f}({confidence_desc}) -- 向量相似度:{result.scores['embedding']:.2f} -- 重排分数:{result.scores['rerank']:.2f} -- LLM评估:{result.scores['llm']:.2f} - -{'✅ 检索结果可信,可直接使用' if result.is_useful else '⚠️ 检索结果置信度较低,可能需要联网搜索补充'}""" - - info(f"[Tool] rag_search 完成: confidence={result.confidence:.3f}, is_useful={result.is_useful}") - return response - - except Exception as e: - info(f"[Tool] rag_search 失败: {e}") - return f"【RAG检索失败】\n错误:{str(e)}\n建议:请稍后重试或使用联网搜索" - - -# ========== 联网搜索 ========== - -@tool -def web_search(query: str) -> str: - """联网搜索获取最新信息""" - info(f"[Tool] web_search: {query[:30]}...") - try: - from backend.app.core.web_search import web_search as search_fn - return search_fn(query, max_results=5) - except Exception as e: - info(f"[Tool] web_search 失败: {e}") - return f"联网搜索失败: {str(e)}" - - -# ========== 子图工具 ========== - -async def _call_subgraph(builder_fn, state_cls, query: str) -> str: - """通用子图调用""" - try: - graph = builder_fn().compile() - state = state_cls(user_query=query) - result = await graph.ainvoke(state) - return result.get("final_result", "执行完成") - except Exception as e: - info(f"[Tool] 子图调用失败: {e}") - return f"执行失败: {str(e)}" - - -@tool -async def contact_lookup(query: str) -> str: - """查询通讯录""" - from backend.app.subgraphs.contact.graph import build_contact_subgraph - from backend.app.subgraphs.contact.state import ContactState - return await _call_subgraph(build_contact_subgraph, ContactState, query) - - -@tool -async def dictionary_lookup(word: str) -> str: - """查询词典/翻译""" - from backend.app.subgraphs.dictionary.graph import build_dictionary_subgraph - from backend.app.subgraphs.dictionary.state import DictionaryState - return await _call_subgraph(build_dictionary_subgraph, DictionaryState, word) - - -@tool -async def news_analysis(topic: str) -> str: - """分析新闻热点""" - from backend.app.subgraphs.news_analysis.graph import build_news_analysis_subgraph - from backend.app.subgraphs.news_analysis.state import NewsAnalysisState - return await _call_subgraph(build_news_analysis_subgraph, NewsAnalysisState, topic) - - -# ========== 导出 ========== +from .rag import rag_search +from .web_search import web_search +from .subgraph import contact_lookup, dictionary_lookup, news_analysis ALL_TOOLS = [rag_search, web_search, contact_lookup, dictionary_lookup, news_analysis] + +__all__ = [ + "rag_search", + "web_search", + "contact_lookup", + "dictionary_lookup", + "news_analysis", + "ALL_TOOLS", +] diff --git a/backend/app/tools/base.py b/backend/app/tools/base.py new file mode 100644 index 0000000..1b03df5 --- /dev/null +++ b/backend/app/tools/base.py @@ -0,0 +1,8 @@ +""" +工具模块配置 +""" + +from langchain_core.tools import tool +from backend.app.logger import info + +__all__ = ["tool", "info"] diff --git a/backend/app/tools/rag.py b/backend/app/tools/rag.py new file mode 100644 index 0000000..5ac7849 --- /dev/null +++ b/backend/app/tools/rag.py @@ -0,0 +1,70 @@ +""" +RAG 检索工具 +""" + +from langchain_core.tools import tool +from backend.app.logger import info + + +_rag_pipeline = None + + +def _get_rag_pipeline(): + """获取或创建 RAG pipeline 单例""" + global _rag_pipeline + if _rag_pipeline is None: + from backend.app.rag.pipeline import RAGPipeline + _rag_pipeline = RAGPipeline( + num_queries=3, + rerank_top_n=5, + use_rerank=True, + return_parent_docs=True, + ) + return _rag_pipeline + + +def _format_confidence(result) -> str: + """格式化置信度描述""" + if result.confidence < 0.4: + return "低" + elif result.confidence < 0.6: + return "中" + return "高" + + +@tool +async def rag_search(query: str) -> str: + """ + 检索知识库获取相关信息 + + Returns: + 包含检索结果和置信度的结构化回复 + """ + info(f"[Tool] rag_search: {query[:30]}...") + try: + pipeline = _get_rag_pipeline() + result = await pipeline.aretrieve_with_confidence(query, original_query=query) + + if not result.content: + return "【RAG检索结果】\n未在知识库中找到相关内容。\n置信度:0.0\n建议:可尝试联网搜索获取信息。" + + confidence_desc = _format_confidence(result) + is_useful_note = "✅ 检索结果可信,可直接使用" if result.is_useful else "⚠️ 检索结果置信度较低,可能需要联网搜索补充" + + response = f"""【RAG检索结果】 +{result.content} + +【置信度评估】 +- 综合置信度:{result.confidence:.2f}({confidence_desc}) +- 向量相似度:{result.scores['embedding']:.2f} +- 重排分数:{result.scores['rerank']:.2f} +- LLM评估:{result.scores['llm']:.2f} + +{is_useful_note}""" + + info(f"[Tool] rag_search 完成: confidence={result.confidence:.3f}, is_useful={result.is_useful}") + return response + + except Exception as e: + info(f"[Tool] rag_search 失败: {e}") + return f"【RAG检索失败】\n错误:{str(e)}\n建议:请稍后重试或使用联网搜索" diff --git a/backend/app/tools/subgraph.py b/backend/app/tools/subgraph.py new file mode 100644 index 0000000..8bc9ec6 --- /dev/null +++ b/backend/app/tools/subgraph.py @@ -0,0 +1,42 @@ +""" +子图工具 - 通讯录、词典、新闻分析等 +""" + +from langchain_core.tools import tool +from backend.app.logger import info + + +async def _call_subgraph(builder_fn, state_cls, query: str) -> str: + """通用子图调用""" + try: + graph = builder_fn().compile() + state = state_cls(user_query=query) + result = await graph.ainvoke(state) + return result.get("final_result", "执行完成") + except Exception as e: + info(f"[Tool] 子图调用失败: {e}") + return f"执行失败: {str(e)}" + + +@tool +async def contact_lookup(query: str) -> str: + """查询通讯录""" + from backend.app.subgraphs.contact.graph import build_contact_subgraph + from backend.app.subgraphs.contact.state import ContactState + return await _call_subgraph(build_contact_subgraph, ContactState, query) + + +@tool +async def dictionary_lookup(word: str) -> str: + """查询词典/翻译""" + from backend.app.subgraphs.dictionary.graph import build_dictionary_subgraph + from backend.app.subgraphs.dictionary.state import DictionaryState + return await _call_subgraph(build_dictionary_subgraph, DictionaryState, word) + + +@tool +async def news_analysis(topic: str) -> str: + """分析新闻热点""" + from backend.app.subgraphs.news_analysis.graph import build_news_analysis_subgraph + from backend.app.subgraphs.news_analysis.state import NewsAnalysisState + return await _call_subgraph(build_news_analysis_subgraph, NewsAnalysisState, topic) diff --git a/backend/app/tools/web_search.py b/backend/app/tools/web_search.py new file mode 100644 index 0000000..3866f86 --- /dev/null +++ b/backend/app/tools/web_search.py @@ -0,0 +1,18 @@ +""" +联网搜索工具 +""" + +from langchain_core.tools import tool +from backend.app.logger import info + + +@tool +def web_search(query: str) -> str: + """联网搜索获取最新信息""" + info(f"[Tool] web_search: {query[:30]}...") + try: + from backend.app.core.web_search import web_search as search_fn + return search_fn(query, max_results=5) + except Exception as e: + info(f"[Tool] web_search 失败: {e}") + return f"联网搜索失败: {str(e)}"