From 60afa86ded4b18bddb50129e292a233abbde111f Mon Sep 17 00:00:00 2001 From: root <953994191@qq.com> Date: Mon, 4 May 2026 02:01:22 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=AE=9E=E7=8E=B0=20BM25=20=E7=A8=80?= =?UTF-8?q?=E7=96=8F=20+=20=E7=A8=A0=E5=AF=86=E5=90=91=E9=87=8F=E6=B7=B7?= =?UTF-8?q?=E5=90=88=E6=A3=80=E7=B4=A2=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .env.docker | 134 +++--- .gitignore | 11 +- REACT_MODE_SUMMARY.md | 182 ------- backend/app/config.py | 12 +- backend/app/model_services/rerank_services.py | 97 +++- backend/app/rag/pipeline.py | 130 +++-- backend/app/rag/retriever.py | 449 +++++++++++++----- backend/app/rag/tools.py | 74 ++- backend/rag_core/__init__.py | 3 + backend/rag_core/retriever_factory.py | 58 ++- backend/rag_core/sparse_embedder.py | 34 ++ backend/rag_core/vector_store.py | 91 +++- backend/requirements.txt | 59 ++- docker/backend/Dockerfile | 6 + docker/docker-compose.yml | 2 + download_sparse_model.py | 73 --- rag_indexer/index_builder.py | 37 +- rag_indexer/requirements.txt | 34 -- requirement.txt | 47 -- requirements.txt | 6 + tools/download_bm25.py | 22 + {test => tools/test}/test_backend.py | 0 {test => tools/test}/test_dqrant.py | 0 {test => tools/test}/test_frontend.py | 0 {test => tools/test}/test_rag.py | 0 .../test}/test_rag_indexer_result.py | 0 26 files changed, 905 insertions(+), 656 deletions(-) delete mode 100644 REACT_MODE_SUMMARY.md create mode 100644 backend/rag_core/sparse_embedder.py delete mode 100644 download_sparse_model.py delete mode 100644 rag_indexer/requirements.txt delete mode 100644 requirement.txt create mode 100644 requirements.txt create mode 100644 tools/download_bm25.py rename {test => tools/test}/test_backend.py (100%) rename {test => tools/test}/test_dqrant.py (100%) rename {test => tools/test}/test_frontend.py (100%) rename {test => tools/test}/test_rag.py (100%) rename {test => tools/test}/test_rag_indexer_result.py (100%) diff --git a/.env.docker b/.env.docker index ebece03..070f717 100644 --- a/.env.docker +++ b/.env.docker @@ -1,86 +1,100 @@ # ============================================================================= -# Docker Compose 服务器部署配置模板 -# 用法: cp .env.docker .env 然后填入敏感密钥 +# Docker 部署环境配置文件 +# 用法: cp .env.docker .env 然后修改配置值用于Docker部署 # ============================================================================= # ----------------------------------------------------------------------------- -# AI 模型 API 密钥(⭐ 敏感配置 - 必须配置) -# 本地部署:在此文件中填入 -# CI/CD 部署:在仓库 Settings → Secrets 中配置 +# AI 模型 API 密钥(必需 - 请填入真实值) # ----------------------------------------------------------------------------- -ZHIPUAI_API_KEY=your_zhipuai_api_key_here # ⭐ 敏感密钥配置 -DEEPSEEK_API_KEY=your_deepseek_api_key_here # ⭐ 敏感密钥配置 -LLAMACPP_API_KEY=your_llamacpp_api_key_here # ⭐ 敏感密钥配置 +ZHIPUAI_API_KEY=你的智谱API密钥 +DEEPSEEK_API_KEY=你的深度求索API密钥 +LLAMACPP_API_KEY=huang1998 +SILICONFLOW_API_KEY=你的硅基流动API密钥(可选,本地服务故障时降级使用) # ----------------------------------------------------------------------------- -# PostgreSQL 数据库配置(分离配置,易于管理) +# llama.cpp 服务配置(Docker环境下使用host.docker.internal访问宿主服务) +# ----------------------------------------------------------------------------- +# 主 LLM 服务 (Gemma-4-E2B GGUF) - 宿主端口 18000 +VLLM_BASE_URL=http://host.docker.internal:18000/v1 + +# Embedding 服务 (Qwen3-Embedding-0.6B GGUF) - 宿主端口 18001 +LLAMACPP_EMBEDDING_URL=http://host.docker.internal:18001/v1 + +# Reranker 服务 (bge-reranker-v2-m3) - 宿主端口 18002 +LLAMACPP_RERANKER_URL=http://host.docker.internal:18002/v1 + +# ----------------------------------------------------------------------------- +# Qdrant 向量数据库配置(使用远程服务) +# ----------------------------------------------------------------------------- +QDRANT_URL=http://115.190.121.151:6333 +QDRANT_API_KEY=你的QdrantAPI密钥 +QDRANT_COLLECTION_NAME=mem0_user_memories + +# ----------------------------------------------------------------------------- +# PostgreSQL 数据库配置(使用远程服务) # ----------------------------------------------------------------------------- DB_HOST=115.190.121.151 DB_PORT=5432 DB_USER=postgres -DB_PASSWORD=your_db_password_here # ⭐ 敏感密钥配置 +DB_PASSWORD=你的PostgreSQL密码 DB_NAME=langgraph_db -# 完整连接字符串(也支持直接配置,优先使用分离配置) -DB_URI=postgresql://postgres:${DB_PASSWORD}@115.190.121.151:5432/langgraph_db?sslmode=disable +# 完整连接字符串(可选,优先使用分离配置) +DB_URI=postgresql://postgres:你的PostgreSQL密码@115.190.121.151:5432/langgraph_db?sslmode=disable # ----------------------------------------------------------------------------- -# Qdrant 向量数据库配置(URL + API密钥 配对) +# 后端服务配置 # ----------------------------------------------------------------------------- -QDRANT_URL=http://115.190.121.151:6333 -QDRANT_API_KEY=your_qdrant_api_key_here # ⭐ 敏感密钥配置 -QDRANT_COLLECTION_NAME=mem0_user_memories +BACKEND_PORT=8079 # ----------------------------------------------------------------------------- -# llama.cpp 服务配置(URL + API密钥 配对) +# 前端配置(Docker内部通信) # ----------------------------------------------------------------------------- -# 主 LLM 服务 (Gemma-4-E2B GGUF) - 端口 18000 (Docker host 映射) -VLLM_BASE_URL=http://host.docker.internal:18000/v1 - -# Embedding 服务 (Qwen3-Embedding-0.6B GGUF) - 端口 18001 -LLAMACPP_EMBEDDING_URL=http://host.docker.internal:18001/v1 -# LLAMACPP_API_KEY=your_llamacpp_api_key_here (已在上面配置) - -# Reranker 服务 (bge-reranker-v2-m3) - 端口 18002 -LLAMACPP_RERANKER_URL=http://host.docker.internal:18002/v1 - -# ----------------------------------------------------------------------------- -# RAG 索引构建配置(非敏感,可直接使用) -# ----------------------------------------------------------------------------- -RAG_COLLECTION_NAME=rag_documents -RAG_CHUNK_SIZE=500 -RAG_CHUNK_OVERLAP=50 -RAG_PARENT_CHUNK_SIZE=1000 -RAG_CHILD_CHUNK_SIZE=200 -RAG_PARENT_CHUNK_OVERLAP=100 -RAG_CHILD_CHUNK_OVERLAP=20 -RAG_STRATEGY=parent-child -RAG_STORAGE_TYPE=postgres - -# ----------------------------------------------------------------------------- -# 日志调试配置(部署时可灵活调整) -# ----------------------------------------------------------------------------- -# 日志级别:DEBUG, INFO, WARNING, ERROR, CRITICAL -# 生产环境推荐 WARNING,排查问题时改为 DEBUG -LOG_LEVEL=WARNING - -# 是否启用 DEBUG 模式 -# true: 输出详细调试信息,包含完整的工具调用、数据库查询等 -# false: 仅输出关键信息,适合生产环境 -DEBUG=false - -# 是否启用 Graph 流转追踪 -# true: 输出每个节点的输入输出状态,便于调试工作流 -# false: 关闭追踪,减少日志量 -ENABLE_GRAPH_TRACE=false +API_URL=http://backend:8079/chat # ----------------------------------------------------------------------------- # 应用行为配置 # ----------------------------------------------------------------------------- -BACKEND_PORT=8079 +# 记忆提取间隔:每 N 轮对话执行一次记忆提取 MEMORY_SUMMARIZE_INTERVAL=10 +# 是否启用 Graph 执行追踪(调试用) +ENABLE_GRAPH_TRACE=true + # ----------------------------------------------------------------------------- -# 前端配置 +# 稀疏模型配置 # ----------------------------------------------------------------------------- -# Docker Compose 内部网络,使用服务名 'backend' -API_URL=http://backend:8079/chat +FASTEMBED_CACHE_PATH=/app/fastembed_cache + +# ----------------------------------------------------------------------------- +# RAG 索引构建配置 +# ----------------------------------------------------------------------------- +# Qdrant 集合名称 +RAG_COLLECTION_NAME=rag_documents + +# 基础切分参数 +RAG_CHUNK_SIZE=500 +RAG_CHUNK_OVERLAP=50 + +# 父子块切分参数 +RAG_PARENT_CHUNK_SIZE=1000 +RAG_CHILD_CHUNK_SIZE=200 +RAG_PARENT_CHUNK_OVERLAP=100 +RAG_CHILD_CHUNK_OVERLAP=20 + +# 切分策略:basic(基础)、semantic(语义)、parent-child(父子块) +RAG_STRATEGY=parent-child + +# 存储类型:postgres(PostgreSQL)、local(本地文件) +RAG_STORAGE_TYPE=postgres + +# 文档加载器配置(可选) +# OCR 语言列表(逗号分隔) +RAG_OCR_LANGUAGES=chi_sim,eng +# 文档主语言列表(逗号分隔) +RAG_DOC_LANGUAGES=zh + +# ----------------------------------------------------------------------------- +# 日志配置 +# ----------------------------------------------------------------------------- +LOG_LEVEL=DEBUG +DEBUG=true \ No newline at end of file diff --git a/.gitignore b/.gitignore index ff42873..df920ad 100644 --- a/.gitignore +++ b/.gitignore @@ -21,7 +21,8 @@ !test/** !.gitea/ !.gitea/** -!download_sparse_model.py +!tools/ +!tools/** # 3. 放行必要的根目录文件 !.gitignore @@ -29,7 +30,7 @@ !QUICKSTART.md !REACT_MODE_SUMMARY.md !LICENSE -!requirement.txt +!requirements.txt !.env.docker # ========================================== @@ -41,12 +42,8 @@ __pycache__/ *.so .DS_Store -# 模型目录(不提交到 Git,在 Docker 构建时下载) -models/ - # 包含敏感信息的环境变量配置(绝对不能传) .env -.env.local # 日志 *.log @@ -54,4 +51,4 @@ app/*.log frontend/*.log # 测试和用户数据 -data/ +data/ \ No newline at end of file diff --git a/REACT_MODE_SUMMARY.md b/REACT_MODE_SUMMARY.md deleted file mode 100644 index 255dafe..0000000 --- a/REACT_MODE_SUMMARY.md +++ /dev/null @@ -1,182 +0,0 @@ -# React 模式架构总结 - ---- - -## ✅ 当前架构:混合路由 + React 循环 - -本项目采用 **两层混合架构**: - -``` -┌─────────────────────────────────────────────────────────────┐ -│ 第一层:前置混合路由(低延迟) │ -│ ├─ 规则快速分流(无 LLM) │ -│ ├─ 轻量级意图分类(smallLLM) │ -│ └─ 快速路径(fast_chitchat, fast_rag, fast_tool) │ -└───────────────────────┬─────────────────────────────────────┘ - ↓(自动升级:失败时) -┌─────────────────────────────────────────────────────────────┐ -│ 第二层:完整 React 循环(兜底,复杂任务处理) │ -│ └─ 推理 → 行动 → 观察(最多 40 步) │ -└─────────────────────────────────────────────────────────────┘ -``` - ---- - -## 🎯 第一层:前置混合路由(新) - -### 核心功能 - -| 功能 | 说明 | -|------|------| -| 规则快速分流 | 无 LLM,毫秒级响应,用于问候、感谢、子图关键词等 | -| 轻量级意图分类 | 使用 smallLLM,压缩到 4 类:chitchat, knowledge, tool, complex | -| 快速路径 | 三个快速处理节点:fast_chitchat, fast_rag, fast_tool | -| 自动升级 | 快速路径失败时,自动回到完整 React 循环 | -| SSE 事件增强 | intent_classified, path_decision, fast_path_*, escalation | - -### 快速流程图 - -``` -START - ↓ -init_state - ↓ -hybrid_router (前置路由) ←────────────┐ - ↓ │ - ├─ 规则分流 → fast_chitchat →────────┤ - │ ↓ │ - ├─ 模型分类 → fast_rag →────────────┤ - │ ↓ │ - ├─ fast_tool →────────┤ - │ ↓ │ - └─ react_loop →────────┤ - ↓ │ - 检查成功/升级? ──────────┘ - ↓ ↓ - finalize react_reason -``` - -### 关键文件 - -| 文件 | 说明 | -|------|------| -| `backend/app/main_graph/nodes/hybrid_router.py` | 混合路由完整实现 | -| `backend/app/model_services/chat_services.py` | get_chat_service() + get_small_llm_service() | -| `backend/app/main_graph/utils/main_graph_builder.py` | 集成混合路由到主图 | - -### 配置项 - -```python -# 构建图时可选择 -graph = build_react_main_graph(use_hybrid_router=True) # 启用混合路由(默认) -graph = build_react_main_graph(use_hybrid_router=False) # 禁用,纯 React 循环 -``` - ---- - -## 🎯 第二层:完整 React 循环(保留) - -### 核心特性 - -| 特性 | 说明 | -|------|------| -| 循环推理 | 每轮推理判断下一步,最多 40 步 | -| 结构化错误 | ErrorRecord + ErrorSeverity | -| 超时重试 | RAG 最多 2 次,子图最多 1 次 | -| 子图集成 | contact, dictionary, news_analysis | -| RAG 检索 | 支持重检索(re_retrieve) | - -### 流程图 - -``` -react_reason (推理) ←──────────────────┐ - ↓ │ -条件路由 │ - ├─→ rag_retrieve (带重试) →──────────┤ - ├─→ contact_subgraph →───────────────┤ - ├─→ dictionary_subgraph →────────────┤ - ├─→ news_analysis_subgraph →─────────┤ - ├─→ handle_error → (重试或降级) →────┤ - └─→ finalize - ↓ -END -``` - ---- - -## 📦 关键文件清单 - -| 文件 | 说明 | -|------|------| -| `backend/app/main_graph/utils/main_graph_builder.py` | 主图构建(支持混合路由开关) | -| `backend/app/main_graph/nodes/react_nodes.py` | React 循环节点 | -| `backend/app/main_graph/nodes/hybrid_router.py` | 混合路由节点(新) | -| `backend/app/main_graph/nodes/rag_nodes.py` | RAG 检索节点 | -| `backend/app/main_graph/utils/retry_utils.py` | 超时重试工具 | -| `backend/app/main_graph/state.py` | 主状态 | -| `backend/app/core/intent.py` | React 模式意图推理器 | -| `backend/app/model_services/chat_services.py` | 双模型服务(llm + smallLLM) | - ---- - -## 🛠️ 模型服务层 - -### 生成式大模型服务(Chat) - -| 函数 | 说明 | -|------|------| -| `get_chat_service()` | 获取大模型服务(用于复杂推理、生成) | -| `get_small_llm_service()` | 获取轻量级模型服务(用于简单意图分类、快速问答) | -| `get_all_chat_services()` | 获取所有可用的生成式大模型服务(用于多模型切换) | - -### 使用方法 - -```python -from app.model_services import get_chat_service, get_small_llm_service - -# 获取大模型服务(复杂任务) -llm = get_chat_service() -response = llm.invoke("什么是 LangGraph?") - -# 获取轻量级模型服务(简单任务) -small_llm = get_small_llm_service() -response = small_llm.invoke("分类用户意图:'你好'") -``` - -### 嵌入与重排模型服务 - -| 函数 | 说明 | -|------|------| -| `get_embedding_service()` | 获取嵌入模型服务(自动降级) | -| `get_rerank_service()` | 获取重排模型服务(自动降级) | - ---- - -## 🚀 快速使用 - -```python -from backend.app.main_graph.utils.main_graph_builder import build_react_main_graph - -# 构建图(默认启用混合路由) -graph = build_react_main_graph(use_hybrid_router=True) -compiled_graph = graph.compile() - -# 调用 -result = compiled_graph.invoke({"user_query": "你好", "user_id": "test"}) -print(result.final_result) -``` - ---- - -## 🎉 完整特性总结 - -✅ 双模型服务 (llm + smallLLM) -✅ 前置混合路由(规则快速分流 + 轻量级意图分类) -✅ 三个快速路径(fast_chitchat, fast_rag, fast_tool) -✅ 自动升级机制(快速路径失败 → 完整 React 循环) -✅ SSE 事件增强(intent_classified, path_decision, fast_path_*, escalation) -✅ 完整 React 循环(最多 40 步) -✅ 结构化错误处理 -✅ 超时和重试策略 -✅ 子图集成(contact, dictionary, news_analysis) -✅ 向后兼容(use_hybrid_router=True/False) diff --git a/backend/app/config.py b/backend/app/config.py index 3a43009..e70c829 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -37,8 +37,9 @@ def _get_bool(key: str) -> bool | None: # ========== 第三方 API 密钥 ========== -ZHIPUAI_API_KEY = _get_str("ZHIPUAI_API_KEY") -DEEPSEEK_API_KEY = _get_str("DEEPSEEK_API_KEY") +ZHIPUAI_API_KEY=_get_str("ZHIPUAI_API_KEY") +DEEPSEEK_API_KEY=_get_str("DEEPSEEK_API_KEY") +SILICONFLOW_API_KEY=_get_str("SILICONFLOW_API_KEY") # ========== 智谱 API 配置 ========== @@ -51,9 +52,16 @@ ZHIPU_RERANK_MODEL = _get_str("ZHIPU_RERANK_MODEL") or "rerank-2" ZHIPU_API_BASE = _get_str("ZHIPU_API_BASE") or "https://open.bigmodel.cn/api/paas/v4" +# ========== 硅基流动(SiliconFlow) API 配置 ========== +# 重排模型:BAAI/bge-reranker-v2-m3 +SILICONFLOW_RERANK_MODEL = _get_str("SILICONFLOW_RERANK_MODEL") or "BAAI/bge-reranker-v2-m3" +SILICONFLOW_API_BASE = _get_str("SILICONFLOW_API_BASE") or "https://api.siliconflow.cn/v1" + + # ========== 稀疏模型配置 ========== SPARSE_MODEL_PATH = _get_str("SPARSE_MODEL_PATH") or "./models/sparse" SPARSE_MODEL_NAME = _get_str("SPARSE_MODEL_NAME") or "Qdrant/bm25" +FASTEMBED_CACHE_PATH = _get_str("FASTEMBED_CACHE_PATH") or "./models/fastembed_cache" # ========== llama.cpp 服务配置(URL + API密钥 配对) ========== # 主 LLM 服务 diff --git a/backend/app/model_services/rerank_services.py b/backend/app/model_services/rerank_services.py index 4cc2052..1aabe00 100644 --- a/backend/app/model_services/rerank_services.py +++ b/backend/app/model_services/rerank_services.py @@ -3,11 +3,15 @@ 本模块提供统一的重排模型服务获取接口,支持自动降级: 1. 优先使用本地 llama.cpp 重排服务 -2. 本地服务不可用时,自动降级到智谱 API 重排服务 +2. 本地服务不可用时,自动降级到硅基流动(SiliconFlow) API 重排服务 +3. 硅基流动服务不可用时,自动降级到智谱 API 重排服务 +4. 所有API服务不可用时,自动降级到 LLM 评分重排服务 主要功能: - LocalLlamaCppRerankProvider:本地 llama.cpp 重排服务提供者 +- SiliconFlowRerankProvider:硅基流动 API 重排服务提供者 - ZhipuRerankProvider:智谱 API 重排服务提供者 +- LLMFallbackRerankProvider:LLM 评分降级重排服务提供者 - get_rerank_service():获取重排服务的统一接口 注意:本模块只负责调用 rerank server,不包含业务逻辑(文档处理、排序、top_n) @@ -28,7 +32,10 @@ from app.config import ( LLAMACPP_API_KEY, ZHIPUAI_API_KEY, ZHIPU_RERANK_MODEL, - ZHIPU_API_BASE + ZHIPU_API_BASE, + SILICONFLOW_API_KEY, + SILICONFLOW_RERANK_MODEL, + SILICONFLOW_API_BASE ) logger = logging.getLogger(__name__) @@ -136,6 +143,53 @@ class ZhipuRerankService(BaseRerankService): raise +class SiliconFlowRerankService(BaseRerankService): + """ + 硅基流动(SiliconFlow) API 重排服务 - 纯服务层 + """ + + def __init__(self, model: str | None = None, api_key: str | None = None, api_base: str | None = None): + self.model = model or SILICONFLOW_RERANK_MODEL + self.api_key = api_key or SILICONFLOW_API_KEY + self.api_base = api_base or SILICONFLOW_API_BASE + + def compute_scores(self, query: str, documents: List[str]) -> List[float]: + """ + 调用 SiliconFlow rerank API 计算得分 - 纯 API 调用 + """ + if not documents: + return [] + + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json" + } + + base = self.api_base.rstrip("/") + payload = { + "model": self.model, + "query": query, + "documents": documents, + "return_documents": False + } + + with httpx.Client(timeout=120) as client: + response = client.post( + f"{base}/rerank", + headers=headers, + json=payload, + ) + response.raise_for_status() + data = response.json() + + if isinstance(data, dict) and "results" in data: + results = data["results"] + results_sorted = sorted(results, key=lambda x: x["index"]) + return [item["relevance_score"] for item in results_sorted] + else: + raise ValueError(f"未知的 SiliconFlow rerank API 响应格式: {data}") + + class LLMFallbackRerankService(BaseRerankService): """ 使用 LLM 作为最后的降级方案进行重排 @@ -291,18 +345,53 @@ class ZhipuRerankProvider(BaseServiceProvider[BaseRerankService]): return self._service_instance +class SiliconFlowRerankProvider(BaseServiceProvider[BaseRerankService]): + """ + 硅基流动(SiliconFlow) API 重排服务提供者 + """ + + def __init__(self, model: str | None = None): + super().__init__("siliconflow_rerank") + self._model = model or SILICONFLOW_RERANK_MODEL + + def is_available(self) -> bool: + """ + 检查 SiliconFlow API 重排服务是否可用 + """ + if not SILICONFLOW_API_KEY: + logger.warning("SILICONFLOW_API_KEY 未配置") + return False + + try: + service = SiliconFlowRerankService(model=self._model) + test_scores = service.compute_scores("test query", ["test document"]) + logger.info("SiliconFlow 重排服务可用") + return True + except Exception as e: + logger.warning(f"SiliconFlow 重排服务不可用: {e}") + return False + + def get_service(self) -> BaseRerankService: + """ + 获取 SiliconFlow API 重排服务 + """ + if self._service_instance is None: + self._service_instance = SiliconFlowRerankService(model=self._model) + return self._service_instance + + def get_rerank_service() -> BaseRerankService: """ 获取重排服务(带自动降级)- 纯服务层 - 降级链: Local llama.cpp -> Zhipu Rerank -> LLM Fallback + 降级链: Local llama.cpp -> SiliconFlow Rerank -> Zhipu Rerank -> LLM Fallback Returns: BaseRerankService: 重排服务实例 """ def _create_chain(): primary = LocalLlamaCppRerankProvider() - fallbacks = [ZhipuRerankProvider(), LLMFallbackRerankProvider()] + fallbacks = [SiliconFlowRerankProvider(), ZhipuRerankProvider(), LLMFallbackRerankProvider()] return FallbackServiceChain(primary, fallbacks) chain = SingletonServiceManager.get_or_create("rerank_service_chain", _create_chain) diff --git a/backend/app/rag/pipeline.py b/backend/app/rag/pipeline.py index e714fb0..12c994c 100644 --- a/backend/app/rag/pipeline.py +++ b/backend/app/rag/pipeline.py @@ -1,4 +1,11 @@ -# rag/pipeline.py +""" +RAG 检索流水线模块 + +提供固定流程的 RAG 检索: +多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档 + +默认使用混合检索(稠密+稀疏)+ 父子文档模式。 +""" import asyncio import os @@ -6,61 +13,86 @@ from typing import List from langchain_core.documents import Document from langchain_core.language_models import BaseLanguageModel -from ..model_services import get_rerank_service -from .rerank import create_document_reranker -from .query_transform import MultiQueryGenerator -from .fusion import reciprocal_rank_fusion +from app.model_services import get_rerank_service +from app.rag.rerank import create_document_reranker +from app.rag.query_transform import MultiQueryGenerator +from app.rag.fusion import reciprocal_rank_fusion +from app.rag.retriever import create_parent_hybrid_retriever + class RAGPipeline: """ 固定流程的 RAG 检索流水线: - 多路改写 → 并行检索 → RRF融合 → 重排序 → 返回父文档 + 多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档 + + 默认使用混合检索(稠密+BM25稀疏)+ 父子文档模式。 """ def __init__( self, - retriever, # 基础检索器(应返回父文档,例如 ParentDocumentRetriever 实例) - llm: BaseLanguageModel, + retriever=None, + llm: Optional[BaseLanguageModel] = None, num_queries: int = 3, rerank_top_n: int = 5, + collection_name: str = "rag_documents", ): """ Args: - retriever: 基础检索器对象,需实现 ainvoke(query) 异步方法 - llm: 用于生成多路查询的语言模型 - num_queries: 生成的查询变体数量 - rerank_top_n: 最终返回的文档数量 - rerank_model: 重排序模型名称 + retriever: 基础检索器对象,需实现 ainvoke(query) 异步方法。 + 如果不提供,会自动创建默认的父子文档混合检索器。 + llm: 用于生成多路查询的语言模型。 + num_queries: 生成的查询变体数量。 + rerank_top_n: 最终返回的文档数量。 + collection_name: Qdrant 集合名称(仅当 retriever 未提供时使用)。 """ - self.retriever = retriever + # 如果没有提供 retriever,自动创建默认的混合检索器 + if retriever is None: + self.retriever = create_parent_hybrid_retriever( + collection_name=collection_name, + search_k=rerank_top_n * 2 # 多取一些给重排序用 + ) + else: + self.retriever = retriever + self.llm = llm self.num_queries = num_queries self.rerank_top_n = rerank_top_n # 初始化组件 - 使用统一的重排服务获取接口 - self.query_generator = MultiQueryGenerator(llm=llm, num_queries=num_queries) + self.query_generator = MultiQueryGenerator(llm=llm, num_queries=num_queries) if llm else None self.reranker = create_document_reranker() async def aretrieve(self, query: str) -> List[Document]: """ 异步执行完整检索流程 + + Args: + query: 用户查询 + + Returns: + 检索到的相关文档列表 """ - # Step 1: 生成多路查询 - queries = await self.query_generator.agenerate(query) - # 包含原始查询,确保至少有一条 - if query not in queries: - queries.insert(0, query) + # 如果有 query_generator,做多路改写 + if self.query_generator and self.llm: + # Step 1: 生成多路查询 + queries = await self.query_generator.agenerate(query) + # 包含原始查询,确保至少有一条 + if query not in queries: + queries.insert(0, query) + else: + # 如果原始查询已在列表中,将其移至首位 + queries.remove(query) + queries.insert(0, query) + + # Step 2: 并行检索(每个查询获取文档列表) + tasks = [self.retriever.ainvoke(q) for q in queries] + doc_lists = await asyncio.gather(*tasks) + + # Step 3: RRF 融合 + fused_docs = reciprocal_rank_fusion(doc_lists) else: - # 如果原始查询已在列表中,将其移至首位 - queries.remove(query) - queries.insert(0, query) - - # Step 2: 并行检索(每个查询获取文档列表) - tasks = [self.retriever.ainvoke(q) for q in queries] - doc_lists = await asyncio.gather(*tasks) - - # Step 3: RRF 融合 - fused_docs = reciprocal_rank_fusion(doc_lists) + # 没有 LLM 做查询改写,直接用原始查询检索 + fused_docs = await self.retriever.ainvoke(query) # Step 4: 重排序 try: @@ -76,7 +108,15 @@ class RAGPipeline: return asyncio.run(self.aretrieve(query)) def format_context(self, documents: List[Document]) -> str: - """将文档列表格式化为上下文字符串""" + """ + 将文档列表格式化为上下文字符串 + + Args: + documents: 文档列表 + + Returns: + 格式化后的上下文字符串 + """ if not documents: return "" @@ -84,4 +124,30 @@ class RAGPipeline: for i, doc in enumerate(documents, 1): source = doc.metadata.get("source", "未知来源") parts.append(f"【资料 {i}】来源:{source}\n{doc.page_content}\n---\n") - return "\n".join(parts) \ No newline at end of file + return "\n".join(parts) + + +def create_rag_pipeline( + collection_name: str = "rag_documents", + llm: Optional[BaseLanguageModel] = None, + num_queries: int = 3, + rerank_top_n: int = 5, +) -> RAGPipeline: + """ + 创建 RAG 检索流水线的便捷函数 + + Args: + collection_name: Qdrant 集合名称 + llm: 用于生成多路查询的语言模型 + num_queries: 生成的查询变体数量 + rerank_top_n: 最终返回的文档数量 + + Returns: + RAGPipeline 实例 + """ + return RAGPipeline( + llm=llm, + num_queries=num_queries, + rerank_top_n=rerank_top_n, + collection_name=collection_name + ) diff --git a/backend/app/rag/retriever.py b/backend/app/rag/retriever.py index 472b09e..1a36237 100644 --- a/backend/app/rag/retriever.py +++ b/backend/app/rag/retriever.py @@ -1,170 +1,379 @@ """ -Qdrant 向量检索器模块 +Qdrant 混合检索器模块 -提供基于 Qdrant 的混合检索(Dense + Sparse)功能。 +提供基于 Qdrant 的混合检索(Dense + Sparse)功能,包括: +- 纯混合检索(无子父文档) +- 父子文档混合检索(先检索子文档,再返回父文档) 核心原理: -- 使用 Qdrant 原生混合检索(langchain-qdrant 的 RetrievalMode.HYBRID) -- 同时存储稠密向量和稀疏向量 -- 语义理解 + 关键词匹配,效果最优 - -使用示例: - >>> from app.rag.retriever import create_hybrid_retriever - >>> retriever = create_hybrid_retriever(collection_name="rag_documents") - >>> docs = retriever.invoke("什么是 RAG?") +- 使用 Qdrant 原生 Fusion API (RRF) 做分数融合 +- 同时使用稠密向量(语义)和稀疏向量(BM25 关键词) """ -from typing import Dict, Any, Optional +from typing import Dict, Any, Optional, List from qdrant_client import QdrantClient from qdrant_client.http.exceptions import UnexpectedResponse -from langchain_qdrant import ( - QdrantVectorStore, - RetrievalMode, - FastEmbedSparse, +from qdrant_client.http.models import ( + SearchRequest, Fusion, FusionProtocol, NamedVector, NamedSparseVector ) +from langchain_core.documents import Document from langchain_core.embeddings import Embeddings -from langchain_core.retrievers import BaseRetriever +from langchain_core.retrievers import BaseRetriever, RetrieverOutput -from rag_core import QDRANT_URL, QDRANT_API_KEY +from rag_core import QdrantVectorStore, get_sparse_embedder, create_docstore from rag_core.client import create_qdrant_client as create_core_qdrant_client from app.model_services import get_embedding_service -from app.config import SPARSE_MODEL_PATH, SPARSE_MODEL_NAME -from app.logger import info, warning +from app.logger import info, warning, debug + # 模块级常量 DEFAULT_SEARCH_K = 20 -DEFAULT_SCORE_THRESHOLD = 0.3 +DEFAULT_PARENT_SEARCH_K = 5 -def create_base_retriever( - collection_name: str, - search_kwargs: Dict[str, Any] | None = None, - client: QdrantClient | None = None, - embeddings: Embeddings | None = None, -) -> BaseRetriever: +class HybridRetriever(BaseRetriever): """ - 创建基础向量检索器(仅稠密向量检索) - - Args: - collection_name: Qdrant 集合名称 - search_kwargs: 搜索参数 - client: 可选的 Qdrant 客户端 - embeddings: 可选的嵌入模型(默认使用 get_embedding_service()) - - Returns: - LangChain 兼容的检索器 + 混合检索器:稠密向量 + BM25 稀疏向量 RRF 分数融合 + + 直接使用 Qdrant 原生 Fusion API,性能最优。 """ - # 默认使用统一嵌入服务(已内置降级机制) - if embeddings is None: - embeddings = get_embedding_service() - info("✅ 使用统一嵌入服务(本地 llama.cpp → 智谱 API 自动降级)") + + def __init__( + self, + collection_name: str, + vector_store: QdrantVectorStore, + search_k: int = DEFAULT_SEARCH_K, + ): + """ + Args: + collection_name: Qdrant 集合名称 + vector_store: QdrantVectorStore 实例 + search_k: 检索返回结果数 + """ + self.collection_name = collection_name + self.vector_store = vector_store + self.search_k = search_k + self.client = vector_store.get_qdrant_client() + self.sparse_embedder = get_sparse_embedder() + + def _get_relevant_documents( + self, query: str, *, run_manager: Optional[Any] = None + ) -> List[Document]: + """ + 同步检索相关文档 + + Args: + query: 查询字符串 + run_manager: LangChain 运行管理器(可选) + + Returns: + 相关文档列表 + """ + # 生成双向量 + dense_query = self.vector_store.embeddings.embed_query(query) + sparse_query = self.sparse_embedder.embed_query(query) + + # 构建双检索请求 + searches = [ + # 稠密检索 + SearchRequest( + vector=NamedVector(name="dense", vector=dense_query), + limit=self.search_k, + with_payload=True + ), + # 稀疏检索 + SearchRequest( + vector=NamedSparseVector(name="sparse", vector=sparse_query), + limit=self.search_k, + with_payload=True + ) + ] + + # RRF 分数融合 + fused_results = self.client.fusion( + collection_name=self.collection_name, + requests=searches, + fusion=Fusion(fusion=FusionProtocol.RRF) + ) + + # 转换为 Document 格式 + results = [] + for point in fused_results.points: + doc = Document( + page_content=point.payload.pop("text", ""), + metadata=point.payload + ) + results.append(doc) + + debug(f"混合检索返回 {len(results)} 个文档") + return results + + async def _aget_relevant_documents( + self, query: str, *, run_manager: Optional[Any] = None + ) -> List[Document]: + """异步检索(当前调用同步版本)""" + # Qdrant 客户端没有原生 async,这里用同步版本 + return self._get_relevant_documents(query, run_manager=run_manager) - # 合并默认搜索参数 - merged_search_kwargs = {"k": DEFAULT_SEARCH_K} - if search_kwargs: - merged_search_kwargs.update(search_kwargs) - # 创建或复用 Qdrant 客户端 - if client is None: - client = create_core_qdrant_client() - - # 验证集合是否存在 - try: - client.get_collection(collection_name) - except UnexpectedResponse as e: - if e.status_code == 404: - warning(f"⚠️ Qdrant 集合 '{collection_name}' 不存在,请先创建并索引文档") - raise ValueError(f"Qdrant 集合 '{collection_name}' 不存在") - raise - - # 构建向量存储 - vector_store = QdrantVectorStore( - client=client, - collection_name=collection_name, - embedding=embeddings, - ) - - return vector_store.as_retriever(search_kwargs=merged_search_kwargs) +class ParentHybridRetriever(BaseRetriever): + """ + 父子文档混合检索器: + + 1. 先用混合检索找到相关子文档 + 2. 根据子文档的 parent_id 找到对应的父文档 + 3. 去重并返回父文档 + """ + + def __init__( + self, + collection_name: str, + vector_store: QdrantVectorStore, + search_k: int = DEFAULT_PARENT_SEARCH_K, + docstore: Optional[Any] = None, + ): + """ + Args: + collection_name: Qdrant 集合名称 + vector_store: QdrantVectorStore 实例 + search_k: 最终返回的父文档数 + docstore: 文档存储(如果父文档在 PostgreSQL),可选 + """ + self.collection_name = collection_name + self.vector_store = vector_store + self.search_k = search_k + self.client = vector_store.get_qdrant_client() + self.sparse_embedder = get_sparse_embedder() + self.docstore = docstore + + def _get_relevant_documents( + self, query: str, *, run_manager: Optional[Any] = None + ) -> List[Document]: + """ + 同步检索相关父文档 + + Args: + query: 查询字符串 + run_manager: LangChain 运行管理器(可选) + + Returns: + 相关父文档列表 + """ + # 1. 生成查询双向量 + dense_query = self.vector_store.embeddings.embed_query(query) + sparse_query = self.sparse_embedder.embed_query(query) + + # 2. 多取一些子文档,避免去重后数量不足 + search_limit = self.search_k * 2 + searches = [ + # 稠密检索 + SearchRequest( + vector=NamedVector(name="dense", vector=dense_query), + limit=search_limit, + with_payload=True + ), + # 稀疏检索 + SearchRequest( + vector=NamedSparseVector(name="sparse", vector=sparse_query), + limit=search_limit, + with_payload=True + ) + ] + + # 3. RRF 分数融合,拿到子文档命中结果 + fused_results = self.client.fusion( + collection_name=self.collection_name, + requests=searches, + fusion=Fusion(fusion=FusionProtocol.RRF) + ) + + if not fused_results.points: + debug("混合检索未找到任何文档") + return [] + + # 4. 收集 parent_id 和对应最高得分 + parent_score_map = {} + parent_ids = set() + child_point_map = {} # 保存子文档点用于降级 + + for point in fused_results.points: + parent_id = point.payload.get("parent_id", point.id) + score = point.score + + # 同一个 parent_id 只保留最高得分 + if parent_id not in parent_score_map or score > parent_score_map[parent_id]: + parent_score_map[parent_id] = score + parent_ids.add(parent_id) + child_point_map[parent_id] = point + + # 5. 批量查询父文档 + # 首先尝试从 Qdrant 直接查询(因为父文档可能也存在 Qdrant 中) + parent_docs = [] + found_parent_ids = set() + + try: + parent_points = self.client.retrieve( + collection_name=self.collection_name, + ids=list(parent_ids), + with_payload=True + ) + + # 处理找到的父文档 + for point in parent_points: + doc = Document( + page_content=point.payload.pop("text", ""), + metadata=point.payload + ) + parent_docs.append(doc) + found_parent_ids.add(point.id) + + except Exception as e: + warning(f"从 Qdrant 查询父文档失败: {e}") + + # 6. 如果有 docstore,尝试从 docstore 查询剩余的父文档 + if self.docstore and len(found_parent_ids) < len(parent_ids): + missing_parent_ids = parent_ids - found_parent_ids + try: + docstore_docs = self.docstore.mget(missing_parent_ids) + for doc_id, doc in zip(missing_parent_ids, docstore_docs): + if doc is not None: + parent_docs.append(doc) + found_parent_ids.add(doc_id) + except Exception as e: + warning(f"从 docstore 查询父文档失败: {e}") + + # 7. 降级:对于仍未找到的父文档,用子文档本身代替 + missing_parent_ids = parent_ids - found_parent_ids + if missing_parent_ids: + warning(f"以下 parent_id 未找到对应的父文档,将返回子文档本身: {missing_parent_ids}") + for parent_id in missing_parent_ids: + child_point = child_point_map.get(parent_id) + if child_point: + doc = Document( + page_content=child_point.payload.pop("text", ""), + metadata=child_point.payload + ) + parent_docs.append(doc) + + # 8. 按照得分降序排序,返回前 k 个 + parent_docs_with_scores = [ + (doc, parent_score_map.get(doc.metadata.get("id", doc.id), 0.0)) + for doc in parent_docs + ] + parent_docs_with_scores.sort(key=lambda x: x[1], reverse=True) + + final_docs = [doc for doc, _ in parent_docs_with_scores[:self.search_k]] + debug(f"父子文档混合检索返回 {len(final_docs)} 个父文档") + + return final_docs + + async def _aget_relevant_documents( + self, query: str, *, run_manager: Optional[Any] = None + ) -> List[Document]: + """异步检索(当前调用同步版本)""" + return self._get_relevant_documents(query, run_manager=run_manager) def create_hybrid_retriever( collection_name: str, - dense_k: int = 10, - sparse_k: int = 10, - score_threshold: float | None = DEFAULT_SCORE_THRESHOLD, - client: QdrantClient | None = None, - embeddings: Embeddings | None = None, + search_k: int = DEFAULT_SEARCH_K, + embeddings: Optional[Embeddings] = None, ) -> BaseRetriever: """ - 创建混合检索器(稠密向量 + BM25 稀疏向量,Qdrant 原生实现)。 - + 创建混合检索器(稠密向量 + BM25 稀疏向量)。 + + 这是默认推荐的检索方式,效果最优。 + Args: - collection_name: Qdrant 集合名称。 - dense_k: 稠密向量检索返回数量,默认 10。 - sparse_k: 稀疏向量检索返回数量,默认 10。 - score_threshold: 相似度阈值,默认 0.3。 - client: 可选的 Qdrant 客户端实例。 + collection_name: Qdrant 集合名称 + search_k: 检索返回结果数 embeddings: 可选的嵌入模型实例。若未提供,将自动获取统一嵌入服务。 - + Returns: - BaseRetriever 实例,配置了混合搜索参数。 + HybridRetriever 实例 """ - total_k = dense_k + sparse_k - - search_kwargs = { - "k": total_k, - "search_type": "similarity_score_threshold", - "score_threshold": score_threshold, - } - - # 默认使用统一嵌入服务(已内置降级机制) + # 默认使用统一嵌入服务 if embeddings is None: embeddings = get_embedding_service() info("✅ 使用统一嵌入服务(本地 llama.cpp → 智谱 API 自动降级)") - - # 创建或复用 Qdrant 客户端 - if client is None: - client = create_core_qdrant_client() - + + # 创建向量存储 + vector_store = QdrantVectorStore(collection_name=collection_name, embeddings=embeddings) + # 验证集合是否存在 try: - client.get_collection(collection_name) + vector_store.get_client().get_collection(collection_name) except UnexpectedResponse as e: if e.status_code == 404: warning(f"⚠️ Qdrant 集合 '{collection_name}' 不存在,请先创建并索引文档") raise ValueError(f"Qdrant 集合 '{collection_name}' 不存在") raise - - # 初始化稀疏嵌入(使用本地缓存目录) - sparse_embeddings = FastEmbedSparse( - model_name=SPARSE_MODEL_NAME, - cache_dir=SPARSE_MODEL_PATH - ) - info(f"✅ FastEmbedSparse 初始化成功 (cache_dir={SPARSE_MODEL_PATH})") - - # 创建混合模式的 QdrantVectorStore - vector_store = QdrantVectorStore( - client=client, + + info(f"✅ Qdrant 混合检索器初始化成功(search_k={search_k})") + return HybridRetriever( collection_name=collection_name, - embedding=embeddings, - sparse_embedding=sparse_embeddings, - retrieval_mode=RetrievalMode.HYBRID, + vector_store=vector_store, + search_k=search_k ) - info(f"✅ Qdrant 原生混合检索器初始化成功 (k={total_k})") - return vector_store.as_retriever(search_kwargs=search_kwargs) - -# 可选:提供异步友好的辅助函数 -async def acreate_base_retriever( +def create_parent_hybrid_retriever( collection_name: str, - search_kwargs: Dict[str, Any] | None = None, - client: QdrantClient | None = None, + search_k: int = DEFAULT_PARENT_SEARCH_K, + embeddings: Optional[Embeddings] = None, + use_docstore: bool = True, ) -> BaseRetriever: """ - 异步创建基础向量检索器(与同步版本功能相同)。 - - 适用于需要异步初始化的场景(例如在 FastAPI 启动事件中)。 + 创建父子文档混合检索器(默认推荐)。 + + 检索流程: + 1. 混合检索找到相关子文档 + 2. 根据 parent_id 找到对应的父文档 + 3. 去重并返回父文档 + + Args: + collection_name: Qdrant 集合名称 + search_k: 最终返回的父文档数 + embeddings: 可选的嵌入模型实例 + use_docstore: 是否使用 PostgreSQL docstore 存储父文档 + + Returns: + ParentHybridRetriever 实例 """ - # 由于 QdrantVectorStore 初始化本身是同步的,这里直接调用同步版本即可 - return create_base_retriever(collection_name, search_kwargs, client) + # 默认使用统一嵌入服务 + if embeddings is None: + embeddings = get_embedding_service() + info("✅ 使用统一嵌入服务(本地 llama.cpp → 智谱 API 自动降级)") + + # 创建向量存储 + vector_store = QdrantVectorStore(collection_name=collection_name, embeddings=embeddings) + + # 验证集合是否存在 + try: + vector_store.get_client().get_collection(collection_name) + except UnexpectedResponse as e: + if e.status_code == 404: + warning(f"⚠️ Qdrant 集合 '{collection_name}' 不存在,请先创建并索引文档") + raise ValueError(f"Qdrant 集合 '{collection_name}' 不存在") + raise + + # 创建 docstore(如果需要) + docstore = None + if use_docstore: + try: + docstore, _ = create_docstore() + info("✅ 文档存储初始化成功(PostgreSQL)") + except Exception as e: + warning(f"⚠️ 文档存储初始化失败,将不使用 docstore: {e}") + + info(f"✅ Qdrant 父子文档混合检索器初始化成功(search_k={search_k})") + return ParentHybridRetriever( + collection_name=collection_name, + vector_store=vector_store, + search_k=search_k, + docstore=docstore + ) + + +# 别名:默认就是父子文档混合检索 +create_retriever = create_parent_hybrid_retriever diff --git a/backend/app/rag/tools.py b/backend/app/rag/tools.py index 33f79c2..1daec9b 100644 --- a/backend/app/rag/tools.py +++ b/backend/app/rag/tools.py @@ -3,52 +3,94 @@ RAG 工具模块 将检索功能封装为 LangChain Tool,供 Agent 调用。 采用固定流水线:多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档。 + +默认使用混合检索(稠密+BM25稀疏)+ 父子文档模式。 """ -from typing import Callable +from typing import Callable, Optional from langchain_core.tools import tool from langchain_core.language_models import BaseLanguageModel from langchain_core.retrievers import BaseRetriever -from .pipeline import RAGPipeline +from app.rag.pipeline import RAGPipeline, create_rag_pipeline + def create_rag_tool_sync( - retriever: BaseRetriever, - llm: BaseLanguageModel, + retriever: Optional[BaseRetriever] = None, + llm: Optional[BaseLanguageModel] = None, num_queries: int = 3, rerank_top_n: int = 5, collection_name: str = "rag_documents", ) -> Callable: """ - 创建一个配置好的 RAG 检索工具(同步版本,用于不支持异步的旧版 Agent)。 - - 参数同 create_rag_tool。 + 创建一个配置好的 RAG 检索工具(同步版本)。 + + 默认使用混合检索(稠密+BM25稀疏)+ 父子文档模式。 + + Args: + retriever: 基础检索器对象(可选,不提供则自动创建) + llm: 用于生成多路查询的语言模型(可选) + num_queries: 生成的查询变体数量 + rerank_top_n: 最终返回的文档数量 + collection_name: Qdrant 集合名称 + + Returns: + LangChain Tool 函数 """ pipeline = RAGPipeline( retriever=retriever, llm=llm, num_queries=num_queries, rerank_top_n=rerank_top_n, + collection_name=collection_name, ) @tool def search_knowledge_base_sync(query: str) -> str: - """在知识库中搜索与查询相关的文档片段(同步版本)。 - - 功能与异步版本相同:多路改写 → RRF融合 → 重排序 → 返回父文档。 - + """ + 在知识库中搜索与查询相关的文档片段。 + + 使用混合检索(稠密向量语义 + BM25 关键词)+ 父子文档模式, + 检索效果最优。 + Args: query: 用户提出的问题或查询字符串 - + Returns: - 格式化后的相关文档内容。 + 格式化后的相关文档内容 """ try: - documents = pipeline.retrieve(query) # 内部调用异步方法并等待 + documents = pipeline.retrieve(query) if not documents: return f"在知识库 '{collection_name}' 中未找到与 '{query}' 相关的信息。" - + context = pipeline.format_context(documents) return context except Exception as e: return f"检索过程中发生错误: {str(e)}" + + return search_knowledge_base_sync - return search_knowledge_base_sync \ No newline at end of file + +def create_rag_tool( + collection_name: str = "rag_documents", + llm: Optional[BaseLanguageModel] = None, + num_queries: int = 3, + rerank_top_n: int = 5, +) -> Callable: + """ + 创建 RAG 检索工具的便捷函数(同步版本)。 + + Args: + collection_name: Qdrant 集合名称 + llm: 用于生成多路查询的语言模型(可选) + num_queries: 生成的查询变体数量 + rerank_top_n: 最终返回的文档数量 + + Returns: + LangChain Tool 函数 + """ + return create_rag_tool_sync( + collection_name=collection_name, + llm=llm, + num_queries=num_queries, + rerank_top_n=rerank_top_n, + ) diff --git a/backend/rag_core/__init__.py b/backend/rag_core/__init__.py index 6eb92ed..7fcdac8 100644 --- a/backend/rag_core/__init__.py +++ b/backend/rag_core/__init__.py @@ -6,6 +6,7 @@ RAG Core - 公共 RAG 组件包 from .embedders import LlamaCppEmbedder from .vector_store import QdrantVectorStore +from .sparse_embedder import BM25SparseEmbedder, get_sparse_embedder from .store import PostgresDocStore, create_docstore from .retriever_factory import create_parent_retriever from .config import ( @@ -21,6 +22,8 @@ from .config import ( __all__ = [ "LlamaCppEmbedder", "QdrantVectorStore", + "BM25SparseEmbedder", + "get_sparse_embedder", "QDRANT_URL", "QDRANT_API_KEY", "LLAMACPP_EMBEDDING_URL", diff --git a/backend/rag_core/retriever_factory.py b/backend/rag_core/retriever_factory.py index 9559797..03482f6 100644 --- a/backend/rag_core/retriever_factory.py +++ b/backend/rag_core/retriever_factory.py @@ -1,5 +1,14 @@ -# rag_core/retriever_factory.py +""" +RAG 检索器工厂模块 + +提供创建各种检索器的工厂函数,包括: +- 基础向量检索器 +- ParentDocumentRetriever(父子文档) +- 混合检索器(稠密+稀疏) +""" +from typing import Optional from langchain_core.embeddings import Embeddings +from langchain_core.retrievers import BaseRetriever from langchain_classic.retrievers import ParentDocumentRetriever from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter from langchain_core.stores import BaseStore @@ -9,18 +18,18 @@ from rag_core import LlamaCppEmbedder, QdrantVectorStore, create_docstore def create_parent_retriever( collection_name: str = "rag_documents", - parent_splitter: TextSplitter | None = None, - child_splitter: TextSplitter | None = None, - docstore: BaseStore | None = None, + parent_splitter: Optional[TextSplitter] = None, + child_splitter: Optional[TextSplitter] = None, + docstore: Optional[BaseStore] = None, search_k: int = 5, parent_chunk_size: int = 1000, parent_chunk_overlap: int = 100, child_chunk_size: int = 200, child_chunk_overlap: int = 20, - embeddings: Embeddings | None = None, + embeddings: Optional[Embeddings] = None, ) -> ParentDocumentRetriever: """ - 创建 ParentDocumentRetriever 实例。 + 创建 ParentDocumentRetriever 实例(基础稠密向量版本)。 Args: collection_name: Qdrant 集合名称,默认 "rag_documents" @@ -44,7 +53,7 @@ def create_parent_retriever( # 向量存储(只读) vector_store = QdrantVectorStore(collection_name=collection_name, embeddings=embeddings) - + # 切分器(若未提供则创建默认) if parent_splitter is None: parent_splitter = RecursiveCharacterTextSplitter( @@ -56,11 +65,11 @@ def create_parent_retriever( chunk_size=child_chunk_size, chunk_overlap=child_chunk_overlap, ) - + # 文档存储 if docstore is None: docstore, _ = create_docstore() - + return ParentDocumentRetriever( vectorstore=vector_store.get_langchain_vectorstore(), docstore=docstore, @@ -68,3 +77,34 @@ def create_parent_retriever( parent_splitter=parent_splitter, search_kwargs={"k": search_k}, ) + + +def create_hybrid_retriever_factory( + collection_name: str = "rag_documents", + search_k: int = 5, + embeddings: Optional[Embeddings] = None, +) -> BaseRetriever: + """ + 【不完整,仅占位】创建混合检索器的工厂函数占位符。 + + 注意:完整的混合检索逻辑在 app/rag/retriever.py 中实现。 + 这里仅返回 QdrantVectorStore 作为基础。 + + Args: + collection_name: Qdrant 集合名称 + search_k: 检索返回结果数 + embeddings: 嵌入模型实例 + + Returns: + 基础的 QdrantVectorStore(仅稠密检索) + """ + # 嵌入模型 + if embeddings is None: + embedder = LlamaCppEmbedder() + embeddings = embedder.as_langchain_embeddings() + + # 创建向量存储 + vector_store = QdrantVectorStore(collection_name=collection_name, embeddings=embeddings) + + # 返回 LangChain 兼容的 retriever + return vector_store.get_langchain_vectorstore().as_retriever(search_kwargs={"k": search_k}) diff --git a/backend/rag_core/sparse_embedder.py b/backend/rag_core/sparse_embedder.py new file mode 100644 index 0000000..54fb4d2 --- /dev/null +++ b/backend/rag_core/sparse_embedder.py @@ -0,0 +1,34 @@ +""" +BM25 稀疏嵌入器 +基于 FastEmbed 的 Qdrant/bm25 模型,完全离线运行 +""" +from typing import List +from fastembed.sparse.sparse_text_embedding import SparseTextEmbedding +from app.config import FASTEMBED_CACHE_PATH + +class BM25SparseEmbedder: + """BM25 稀疏嵌入包装器,与现有嵌入器风格统一""" + + def __init__(self): + self.model = SparseTextEmbedding( + model_name="Qdrant/bm25", + cache_dir=FASTEMBED_CACHE_PATH, + local_files_only=True, # 强制离线,永不联网 + ) + + def embed_documents(self, texts: List[str]) -> List[dict]: + """返回稀疏向量列表,每个为 Qdrant 兼容的 dict(indices+values)""" + return [vec.as_object() for vec in self.model.embed(texts)] + + def embed_query(self, text: str) -> dict: + """返回单个稀疏向量""" + return list(self.model.embed([text]))[0].as_object() + +# 全局单例 +_sparse_embedder_instance = None + +def get_sparse_embedder() -> BM25SparseEmbedder: + global _sparse_embedder_instance + if _sparse_embedder_instance is None: + _sparse_embedder_instance = BM25SparseEmbedder() + return _sparse_embedder_instance \ No newline at end of file diff --git a/backend/rag_core/vector_store.py b/backend/rag_core/vector_store.py index 88cc518..5679035 100644 --- a/backend/rag_core/vector_store.py +++ b/backend/rag_core/vector_store.py @@ -1,41 +1,48 @@ """ Qdrant 向量数据库包装器。 +支持稠密+稀疏双向量存储。 """ import logging import os import time +import uuid from typing import List, Optional, Dict, Any from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_qdrant import QdrantVectorStore as LangchainQdrantVS from qdrant_client import QdrantClient -from qdrant_client.http.models import Distance, VectorParams +from qdrant_client.http.models import ( + Distance, VectorParams, SparseVectorParams, SparseIndexParams, + SparseIndexType, PointStruct, NamedSparseVector, NamedVector +) from httpx import RemoteProtocolError from qdrant_client.http.exceptions import ResponseHandlingException from .client import create_qdrant_client from .embedders import LlamaCppEmbedder +from .sparse_embedder import BM25SparseEmbedder, get_sparse_embedder logger = logging.getLogger(__name__) class QdrantVectorStore: - """Qdrant 向量数据库操作包装器。""" + """Qdrant 向量数据库操作包装器 - 支持稠密+稀疏双向量存储。""" - def __init__(self, collection_name: str, embeddings: Optional[Embeddings] = None): + def __init__(self, collection_name: str, embeddings: Optional[Embeddings] = None, sparse_embedder: Optional[BM25SparseEmbedder] = None): """ Args: collection_name: Qdrant 集合名称。 embeddings: 嵌入模型实例,默认 None(使用内部默认的 LlamaCppEmbedder)。 + sparse_embedder: 稀疏嵌入模型实例,默认 None(自动加载BM25)。 """ self.collection_name = collection_name self._client: Optional[QdrantClient] = None self._connection_attempts = 0 self._last_connection_time: Optional[float] = None - # 嵌入模型 + # 稠密嵌入模型 if embeddings is None: embedder = LlamaCppEmbedder() self.embeddings = embedder.as_langchain_embeddings() @@ -43,9 +50,13 @@ class QdrantVectorStore: else: self.embeddings = embeddings self._embedder = None + + # 稀疏嵌入模型 + self.sparse_embedder = sparse_embedder or get_sparse_embedder() self.create_collection() + # 保留 LangChain 向量存储实例(用于兼容) self.vector_store = LangchainQdrantVS( client=self.get_client(), collection_name=self.collection_name, @@ -97,7 +108,7 @@ class QdrantVectorStore: } def create_collection(self, force_recreate: bool = False): - """创建集合,设置合适的向量维度。""" + """创建集合,支持稠密+稀疏双向量。""" if self._embedder is not None: # 使用内部的 embedder 获取维度 vector_size = self._embedder.get_embedding_dimension() @@ -119,11 +130,31 @@ class QdrantVectorStore: exists = False if not exists: + # 向量配置:稠密向量 + vectors_config = { + "dense": VectorParams( + size=vector_size, + distance=Distance.COSINE, + optional=True + ) + } + + # 稀疏向量配置 + sparse_vectors_config = { + "sparse": SparseVectorParams( + index=SparseIndexParams( + type=SparseIndexType.MUTABLE + ), + optional=True + ) + } + client.create_collection( collection_name=self.collection_name, - vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE), + vectors_config=vectors_config, + sparse_vectors_config=sparse_vectors_config ) - logger.info("集合 '%s' 已创建(维度=%d)", self.collection_name, vector_size) + logger.info("集合 '%s' 已创建(维度=%d,支持稠密+稀疏双向量)", self.collection_name, vector_size) else: logger.info("集合 '%s' 已存在", self.collection_name) return @@ -142,18 +173,54 @@ class QdrantVectorStore: time.sleep(wait_time) def add_documents(self, documents: List[Document], batch_size: int = 100): - """将文档添加到向量数据库。""" + """将文档添加到向量数据库,自动生成稠密+稀疏双向量。""" if not documents: return [] self.create_collection() - ids = self.vector_store.add_documents(documents, batch_size=batch_size) - logger.info("已向 '%s' 添加 %d 个文档", self.collection_name, len(ids)) - return ids + client = self.get_client() + doc_ids = [] + + # 分批处理 + for i in range(0, len(documents), batch_size): + batch_docs = documents[i:i+batch_size] + texts = [doc.page_content for doc in batch_docs] + + # 生成双向量 + dense_vectors = self.embeddings.embed_documents(texts) + sparse_vectors = self.sparse_embedder.embed_documents(texts) + + points = [] + for j, doc in enumerate(batch_docs): + point_id = doc.metadata.get("id", str(uuid.uuid4())) + doc_ids.append(point_id) + + # 构造双向量 + named_vectors = { + "dense": dense_vectors[j], + "sparse": NamedSparseVector( + name="sparse", + vector=sparse_vectors[j] + ) + } + + points.append(PointStruct( + id=point_id, + vector=named_vectors, + payload={"text": doc.page_content, **doc.metadata} + )) + + # 批量插入 + client.upsert(collection_name=self.collection_name, points=points) + logger.info("已向 '%s' 添加 %d 个文档(稠密+稀疏双向量)", self.collection_name, len(points)) + + return doc_ids def similarity_search(self, query: str, k: int = 5) -> List[Document]: + """基础稠密向量检索(兼容原有接口)。""" return self.vector_store.similarity_search(query, k=k) def similarity_search_with_score(self, query: str, k: int = 5) -> List[tuple[Document, float]]: + """基础稠密向量检索带分数(兼容原有接口)。""" return self.vector_store.similarity_search_with_score(query, k=k) def delete_collection(self): @@ -183,5 +250,5 @@ class QdrantVectorStore: return self.vector_store def get_qdrant_client(self): - """返回原生 Qdrant 客户端(如需手动管理 collection)""" + """返回原生 Qdrant 客户端(用于自定义检索逻辑)""" return self.get_client() diff --git a/backend/requirements.txt b/backend/requirements.txt index 2edc6e8..fef6bd2 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -1,53 +1,52 @@ -# Core -pydantic==2.12.5 -python-dotenv==1.2.2 -typing-extensions==4.15.0 +typing-extensions>=4.15.0 +python-dotenv>=1.2.2 +pydantic>=2.12.5 +requests>=2.32.5 # LangChain -langchain==1.2.15 -langchain-community==0.4.1 -langchain-core==1.2.28 -langchain-openai==1.1.12 -langchain-qdrant==1.1.0 -langgraph==1.1.6 -langgraph-checkpoint-postgres==3.0.5 +langchain>=1.2.15 +langchain-community>=0.4.1 +langchain-core>=1.2.28 +langchain-openai>=1.1.12 +langchain-qdrant>=1.1.0 +langgraph>=1.1.6 +langgraph-checkpoint-postgres>=3.0.5 tiktoken>=0.12.0 # Zhipu AI -zhipuai==2.0.1 +zhipuai>=2.0.1 # Vector DB -qdrant-client==1.17.1 +qdrant-client>=1.17.1 fastembed>=0.3.0 # 用于 Qdrant BM25 稀疏向量 # Memory -mem0ai==1.0.11 +mem0ai>=1.0.11 # Backend -fastapi==0.135.3 -uvicorn[standard]==0.44.0 +fastapi>=0.135.3 +uvicorn[standard]>=0.44.0 # Database -asyncpg==0.31.0 -psycopg[binary]==3.3.3 +asyncpg>=0.31.0 +psycopg[binary]>=3.3.3 # HTTP -httpx==0.28.1 -aiohttp==3.13.5 +httpx>=0.28.1 +aiohttp>=3.13.5 # Utilities -tenacity==9.1.4 -rich==15.0.0 -PyYAML==6.0.3 +tenacity>=9.1.4 +rich>=15.0.0 +PyYAML>=6.0.3 numpy>=1.26.2 -pyjwt==2.8.0 +pyjwt>=2.8.0 ddgs>=6.0.0 # 免费联网搜索(原 duckduckgo-search 已重命名) matplotlib>=3.9.0 # 可视化图表 # Document Processing -unstructured==0.22.21 -pypdf==6.10.0 -beautifulsoup4==4.14.3 -lxml==6.1.0 -pandas==3.0.2 # 若需Excel保留,否则移除 -spacy==3.8.14 # unstructured 可能依赖 +unstructured>=0.22.21 +pypdf>=6.10.0 +beautifulsoup4>=4.14.3 +lxml>=6.1.0 +spacy>=3.8.14 # unstructured 可能依赖 \ No newline at end of file diff --git a/docker/backend/Dockerfile b/docker/backend/Dockerfile index ea56819..14c5c9c 100644 --- a/docker/backend/Dockerfile +++ b/docker/backend/Dockerfile @@ -55,6 +55,7 @@ ENV ENABLE_GRAPH_TRACE=false # ============================================================================= ENV SPARSE_MODEL_PATH=/app/models/sparse ENV SPARSE_MODEL_NAME=Qdrant/bm25 +ENV FASTEMBED_CACHE_PATH=/app/fastembed_cache # ============================================================================= # 日志配置(生产环境默认值) @@ -88,6 +89,11 @@ COPY download_sparse_model.py . RUN python download_sparse_model.py --cache-dir /app/models/sparse --model-name Qdrant/bm25 && \ rm -f download_sparse_model.py +# ============================================================================= +# 复制预下载的BM25模型缓存(FastEmbed) +# ============================================================================= +COPY models/fastembed_cache /app/fastembed_cache + # ============================================================================= # 复制项目代码 # ============================================================================= diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index 513c56d..30b8d65 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -11,6 +11,7 @@ services: - ZHIPUAI_API_KEY=${ZHIPUAI_API_KEY:?请配置 ZHIPUAI_API_KEY(本地:.env 文件 | CI/CD:Secrets)} # ⭐ 敏感密钥配置 - DEEPSEEK_API_KEY=${DEEPSEEK_API_KEY:?请配置 DEEPSEEK_API_KEY(本地:.env 文件 | CI/CD:Secrets)} # ⭐ 敏感密钥配置 - LLAMACPP_API_KEY=${LLAMACPP_API_KEY:?请配置 LLAMACPP_API_KEY(本地:.env 文件 | CI/CD:Secrets)} # ⭐ 敏感密钥配置 + - SILICONFLOW_API_KEY=${SILICONFLOW_API_KEY:-} # 硅基流动API密钥(可选,本地服务故障时降级使用) # ========================================================================= # PostgreSQL 数据库配置 @@ -63,6 +64,7 @@ services: # ========================================================================= - BACKEND_PORT=8079 - MEMORY_SUMMARIZE_INTERVAL=${MEMORY_SUMMARIZE_INTERVAL:-10} + - FASTEMBED_CACHE_PATH=${FASTEMBED_CACHE_PATH:-/app/fastembed_cache} # ========================================================================= # 前端通信地址(Docker 内部网络) diff --git a/download_sparse_model.py b/download_sparse_model.py deleted file mode 100644 index 22ff2fe..0000000 --- a/download_sparse_model.py +++ /dev/null @@ -1,73 +0,0 @@ -#!/usr/bin/env python3 -""" -下载稀疏嵌入模型到本地目录。 -仅需在开发机或构建镜像时执行一次。 -""" - -import logging -import sys -from pathlib import Path - -# 配置日志 -logging.basicConfig( - level=logging.INFO, - format="%(asctime)s - %(levelname)s - %(message)s" -) -logger = logging.getLogger(__name__) - -# 添加 backend 目录到路径 -sys.path.insert(0, str(Path(__file__).parent / "backend")) - - -def download_model(cache_dir: str = "./models/sparse", model_name: str = "Qdrant/bm25"): - """ - 下载稀疏嵌入模型到指定目录。 - - Args: - cache_dir: 模型缓存目录 - model_name: 模型名称 - """ - cache_path = Path(cache_dir) - cache_path.mkdir(parents=True, exist_ok=True) - logger.info(f"准备下载模型 {model_name} 到 {cache_path.absolute()}") - - try: - from fastembed import SparseTextEmbedding - - # 下载并缓存模型 - model = SparseTextEmbedding(model_name=model_name, cache_dir=str(cache_path)) - logger.info(f"✅ 模型 {model_name} 下载/加载成功") - - # 测试一下 - test_result = model.embed(["测试文本"]) - logger.info(f"✅ 模型测试成功,稀疏向量维度: {len(list(test_result)[0])}") - - logger.info("✅ 所有步骤完成!") - return True - - except Exception as e: - logger.error(f"❌ 模型下载失败: {e}") - import traceback - logger.error(traceback.format_exc()) - return False - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser(description="下载稀疏嵌入模型") - parser.add_argument( - "--cache-dir", - default="./models/sparse", - help="模型缓存目录 (默认: ./models/sparse)" - ) - parser.add_argument( - "--model-name", - default="Qdrant/bm25", - help="模型名称 (默认: Qdrant/bm25)" - ) - - args = parser.parse_args() - - success = download_model(args.cache_dir, args.model_name) - sys.exit(0 if success else 1) diff --git a/rag_indexer/index_builder.py b/rag_indexer/index_builder.py index e17a56c..07d6e41 100644 --- a/rag_indexer/index_builder.py +++ b/rag_indexer/index_builder.py @@ -41,15 +41,6 @@ try: except ImportError: HAS_MODEL_SERVICES = False -# 尝试导入稀疏模型配置(如果可用) -try: - from app.config import SPARSE_MODEL_PATH, SPARSE_MODEL_NAME - HAS_SPARSE_CONFIG = True -except ImportError: - HAS_SPARSE_CONFIG = False - SPARSE_MODEL_PATH = "./models/sparse" - SPARSE_MODEL_NAME = "Qdrant/bm25" - logger = logging.getLogger(__name__) # ---------- 配置数据类 ---------- @@ -112,37 +103,27 @@ class IndexBuilder: # 设置嵌入模型 - 优先使用外部提供的,然后尝试使用新服务,最后回退到原来的方式 if embeddings is not None: self.embeddings = embeddings - self.embedder = None + self._embedder = None logger.info("使用外部提供的嵌入模型") elif HAS_MODEL_SERVICES: try: self.embeddings = get_embedding_service() - self.embedder = None + self._embedder = None logger.info("使用 model_services 提供的嵌入服务") except Exception as e: logger.warning(f"获取嵌入服务失败,回退到 LlamaCppEmbedder: {e}") - self.embedder = LlamaCppEmbedder() - self.embeddings = self.embedder.as_langchain_embeddings() + self._embedder = LlamaCppEmbedder() + self.embeddings = self._embedder.as_langchain_embeddings() else: - self.embedder = LlamaCppEmbedder() - self.embeddings = self.embedder.as_langchain_embeddings() + self._embedder = LlamaCppEmbedder() + self.embeddings = self._embedder.as_langchain_embeddings() - # 初始化稀疏嵌入(使用本地缓存目录) - from langchain_qdrant import FastEmbedSparse, RetrievalMode - self.sparse_embeddings = FastEmbedSparse( - model_name=SPARSE_MODEL_NAME, - cache_dir=SPARSE_MODEL_PATH - ) - logger.info(f"✅ FastEmbedSparse 初始化成功 (cache_dir={SPARSE_MODEL_PATH})") - - # 初始化向量存储(混合检索模式) + # 初始化向量存储(自动支持稠密+稀疏混合检索) self.vector_store = QdrantVectorStore( collection_name=config.collection_name, - embedding=self.embeddings if self.embedder is None else None, - sparse_embedding=self.sparse_embeddings, - retrieval_mode=RetrievalMode.HYBRID, + embedding=self.embeddings if self._embedder is None else None ) - logger.info("✅ 混合检索向量存储初始化成功") + logger.info("✅ 混合检索向量存储初始化成功(稠密+BM25稀疏)") # 根据切分类型初始化相关组件 self._init_splitters_and_retriever() diff --git a/rag_indexer/requirements.txt b/rag_indexer/requirements.txt deleted file mode 100644 index 1dc4327..0000000 --- a/rag_indexer/requirements.txt +++ /dev/null @@ -1,34 +0,0 @@ -# RAG Indexer - 本地索引工具依赖 -# 依赖 rag_core (从 ../backend/rag_core 导入) - -# Core -pydantic==2.12.5 -python-dotenv==1.2.2 -typing-extensions==4.15.0 - -# LangChain (用于文档处理) -langchain==1.2.15 -langchain-community==0.4.1 -langchain-core==1.2.28 -tiktoken>=0.12.0 - -# Vector DB -qdrant-client==1.17.1 -fastembed>=0.3.0 # 用于 Qdrant BM25 稀疏向量 - -# HTTP -httpx==0.28.1 - -# Utilities -tenacity==9.1.4 -rich==15.0.0 -PyYAML==6.0.3 -numpy>=1.26.2 - -# Document Processing -unstructured==0.22.21 -pypdf==6.10.0 -beautifulsoup4==4.14.3 -lxml==6.1.0 -pandas==3.0.2 -spacy==3.8.14 diff --git a/requirement.txt b/requirement.txt deleted file mode 100644 index a9d522a..0000000 --- a/requirement.txt +++ /dev/null @@ -1,47 +0,0 @@ -# Core -pydantic==2.12.5 -python-dotenv==1.2.2 -typing-extensions==4.15.0 - -# LangChain -langchain==1.2.15 -langchain-community==0.4.1 -langchain-core==1.2.28 -langchain-openai==1.1.12 -langchain-qdrant==1.1.0 -langgraph==1.1.6 -langgraph-checkpoint-postgres==3.0.5 -tiktoken>=0.12.0 - -# Vector DB -qdrant-client==1.17.1 - -# Memory -mem0ai==1.0.11 - -# Backend -fastapi==0.135.3 -uvicorn[standard]==0.44.0 - -# Database -asyncpg==0.31.0 -psycopg[binary]==3.3.3 - -# HTTP -httpx==0.28.1 -aiohttp==3.13.5 - -# Utilities -tenacity==9.1.4 -rich==15.0.0 -PyYAML==6.0.3 -numpy>=1.26.2 - -# Document Processing -unstructured==0.22.21 -pypdf==6.10.0 -beautifulsoup4==4.14.3 -lxml==6.1.0 -pandas==3.0.2 # 若需Excel保留,否则移除 -spacy==3.8.14 # unstructured 可能依赖 -duckduckgo-search>=6.0.0 # 联网搜索 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..ef13505 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,6 @@ +# 根目录requirements - 仅本地运行零散脚本/工具使用 +# 完全不与前后端requirements重叠,前后端独立运行无需安装这里的依赖 +gitpython==3.1.43 # 本地git脚本工具 +tqdm==4.67.0 # 本地脚本进度条 +ipython==8.30.0 # 本地交互式调试 +pytest==8.3.4 # 本地单元测试 diff --git a/tools/download_bm25.py b/tools/download_bm25.py new file mode 100644 index 0000000..fcbdfe7 --- /dev/null +++ b/tools/download_bm25.py @@ -0,0 +1,22 @@ +""" +BM25模型预下载脚本 +执行后将模型缓存到 ./models/fastembed_cache 目录,打包进Docker镜像 +""" +import os +from fastembed.sparse.sparse_text_embedding import SparseTextEmbedding + +if __name__ == "__main__": + # 指定缓存目录 + cache_dir = "./models/fastembed_cache" + os.makedirs(cache_dir, exist_ok=True) + + print("正在下载BM25稀疏向量模型...") + model = SparseTextEmbedding( + model_name="Qdrant/bm25", + cache_dir=cache_dir + ) + + # 触发一次推理,确保模型文件完整下载 + list(model.embed(["init trigger"])) + print(f"✅ BM25模型已成功缓存到: {cache_dir}") + print("请将该目录提交到项目仓库,打包进Docker镜像") diff --git a/test/test_backend.py b/tools/test/test_backend.py similarity index 100% rename from test/test_backend.py rename to tools/test/test_backend.py diff --git a/test/test_dqrant.py b/tools/test/test_dqrant.py similarity index 100% rename from test/test_dqrant.py rename to tools/test/test_dqrant.py diff --git a/test/test_frontend.py b/tools/test/test_frontend.py similarity index 100% rename from test/test_frontend.py rename to tools/test/test_frontend.py diff --git a/test/test_rag.py b/tools/test/test_rag.py similarity index 100% rename from test/test_rag.py rename to tools/test/test_rag.py diff --git a/test/test_rag_indexer_result.py b/tools/test/test_rag_indexer_result.py similarity index 100% rename from test/test_rag_indexer_result.py rename to tools/test/test_rag_indexer_result.py