feat: 实现 BM25 稀疏 + 稠密向量混合检索功能
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Has been cancelled
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Has been cancelled
This commit is contained in:
134
.env.docker
134
.env.docker
@@ -1,86 +1,100 @@
|
|||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Docker Compose 服务器部署配置模板
|
# Docker 部署环境配置文件
|
||||||
# 用法: cp .env.docker .env 然后填入敏感密钥
|
# 用法: cp .env.docker .env 然后修改配置值用于Docker部署
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# AI 模型 API 密钥(⭐ 敏感配置 - 必须配置)
|
# AI 模型 API 密钥(必需 - 请填入真实值)
|
||||||
# 本地部署:在此文件中填入
|
|
||||||
# CI/CD 部署:在仓库 Settings → Secrets 中配置
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
ZHIPUAI_API_KEY=your_zhipuai_api_key_here # ⭐ 敏感密钥配置
|
ZHIPUAI_API_KEY=你的智谱API密钥
|
||||||
DEEPSEEK_API_KEY=your_deepseek_api_key_here # ⭐ 敏感密钥配置
|
DEEPSEEK_API_KEY=你的深度求索API密钥
|
||||||
LLAMACPP_API_KEY=your_llamacpp_api_key_here # ⭐ 敏感密钥配置
|
LLAMACPP_API_KEY=huang1998
|
||||||
|
SILICONFLOW_API_KEY=你的硅基流动API密钥(可选,本地服务故障时降级使用)
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# PostgreSQL 数据库配置(分离配置,易于管理)
|
# llama.cpp 服务配置(Docker环境下使用host.docker.internal访问宿主服务)
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# 主 LLM 服务 (Gemma-4-E2B GGUF) - 宿主端口 18000
|
||||||
|
VLLM_BASE_URL=http://host.docker.internal:18000/v1
|
||||||
|
|
||||||
|
# Embedding 服务 (Qwen3-Embedding-0.6B GGUF) - 宿主端口 18001
|
||||||
|
LLAMACPP_EMBEDDING_URL=http://host.docker.internal:18001/v1
|
||||||
|
|
||||||
|
# Reranker 服务 (bge-reranker-v2-m3) - 宿主端口 18002
|
||||||
|
LLAMACPP_RERANKER_URL=http://host.docker.internal:18002/v1
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Qdrant 向量数据库配置(使用远程服务)
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
QDRANT_URL=http://115.190.121.151:6333
|
||||||
|
QDRANT_API_KEY=你的QdrantAPI密钥
|
||||||
|
QDRANT_COLLECTION_NAME=mem0_user_memories
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# PostgreSQL 数据库配置(使用远程服务)
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
DB_HOST=115.190.121.151
|
DB_HOST=115.190.121.151
|
||||||
DB_PORT=5432
|
DB_PORT=5432
|
||||||
DB_USER=postgres
|
DB_USER=postgres
|
||||||
DB_PASSWORD=your_db_password_here # ⭐ 敏感密钥配置
|
DB_PASSWORD=你的PostgreSQL密码
|
||||||
DB_NAME=langgraph_db
|
DB_NAME=langgraph_db
|
||||||
# 完整连接字符串(也支持直接配置,优先使用分离配置)
|
# 完整连接字符串(可选,优先使用分离配置)
|
||||||
DB_URI=postgresql://postgres:${DB_PASSWORD}@115.190.121.151:5432/langgraph_db?sslmode=disable
|
DB_URI=postgresql://postgres:你的PostgreSQL密码@115.190.121.151:5432/langgraph_db?sslmode=disable
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Qdrant 向量数据库配置(URL + API密钥 配对)
|
# 后端服务配置
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
QDRANT_URL=http://115.190.121.151:6333
|
BACKEND_PORT=8079
|
||||||
QDRANT_API_KEY=your_qdrant_api_key_here # ⭐ 敏感密钥配置
|
|
||||||
QDRANT_COLLECTION_NAME=mem0_user_memories
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# llama.cpp 服务配置(URL + API密钥 配对)
|
# 前端配置(Docker内部通信)
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# 主 LLM 服务 (Gemma-4-E2B GGUF) - 端口 18000 (Docker host 映射)
|
API_URL=http://backend:8079/chat
|
||||||
VLLM_BASE_URL=http://host.docker.internal:18000/v1
|
|
||||||
|
|
||||||
# Embedding 服务 (Qwen3-Embedding-0.6B GGUF) - 端口 18001
|
|
||||||
LLAMACPP_EMBEDDING_URL=http://host.docker.internal:18001/v1
|
|
||||||
# LLAMACPP_API_KEY=your_llamacpp_api_key_here (已在上面配置)
|
|
||||||
|
|
||||||
# Reranker 服务 (bge-reranker-v2-m3) - 端口 18002
|
|
||||||
LLAMACPP_RERANKER_URL=http://host.docker.internal:18002/v1
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
|
||||||
# RAG 索引构建配置(非敏感,可直接使用)
|
|
||||||
# -----------------------------------------------------------------------------
|
|
||||||
RAG_COLLECTION_NAME=rag_documents
|
|
||||||
RAG_CHUNK_SIZE=500
|
|
||||||
RAG_CHUNK_OVERLAP=50
|
|
||||||
RAG_PARENT_CHUNK_SIZE=1000
|
|
||||||
RAG_CHILD_CHUNK_SIZE=200
|
|
||||||
RAG_PARENT_CHUNK_OVERLAP=100
|
|
||||||
RAG_CHILD_CHUNK_OVERLAP=20
|
|
||||||
RAG_STRATEGY=parent-child
|
|
||||||
RAG_STORAGE_TYPE=postgres
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
|
||||||
# 日志调试配置(部署时可灵活调整)
|
|
||||||
# -----------------------------------------------------------------------------
|
|
||||||
# 日志级别:DEBUG, INFO, WARNING, ERROR, CRITICAL
|
|
||||||
# 生产环境推荐 WARNING,排查问题时改为 DEBUG
|
|
||||||
LOG_LEVEL=WARNING
|
|
||||||
|
|
||||||
# 是否启用 DEBUG 模式
|
|
||||||
# true: 输出详细调试信息,包含完整的工具调用、数据库查询等
|
|
||||||
# false: 仅输出关键信息,适合生产环境
|
|
||||||
DEBUG=false
|
|
||||||
|
|
||||||
# 是否启用 Graph 流转追踪
|
|
||||||
# true: 输出每个节点的输入输出状态,便于调试工作流
|
|
||||||
# false: 关闭追踪,减少日志量
|
|
||||||
ENABLE_GRAPH_TRACE=false
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# 应用行为配置
|
# 应用行为配置
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
BACKEND_PORT=8079
|
# 记忆提取间隔:每 N 轮对话执行一次记忆提取
|
||||||
MEMORY_SUMMARIZE_INTERVAL=10
|
MEMORY_SUMMARIZE_INTERVAL=10
|
||||||
|
|
||||||
|
# 是否启用 Graph 执行追踪(调试用)
|
||||||
|
ENABLE_GRAPH_TRACE=true
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# 前端配置
|
# 稀疏模型配置
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Docker Compose 内部网络,使用服务名 'backend'
|
FASTEMBED_CACHE_PATH=/app/fastembed_cache
|
||||||
API_URL=http://backend:8079/chat
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# RAG 索引构建配置
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Qdrant 集合名称
|
||||||
|
RAG_COLLECTION_NAME=rag_documents
|
||||||
|
|
||||||
|
# 基础切分参数
|
||||||
|
RAG_CHUNK_SIZE=500
|
||||||
|
RAG_CHUNK_OVERLAP=50
|
||||||
|
|
||||||
|
# 父子块切分参数
|
||||||
|
RAG_PARENT_CHUNK_SIZE=1000
|
||||||
|
RAG_CHILD_CHUNK_SIZE=200
|
||||||
|
RAG_PARENT_CHUNK_OVERLAP=100
|
||||||
|
RAG_CHILD_CHUNK_OVERLAP=20
|
||||||
|
|
||||||
|
# 切分策略:basic(基础)、semantic(语义)、parent-child(父子块)
|
||||||
|
RAG_STRATEGY=parent-child
|
||||||
|
|
||||||
|
# 存储类型:postgres(PostgreSQL)、local(本地文件)
|
||||||
|
RAG_STORAGE_TYPE=postgres
|
||||||
|
|
||||||
|
# 文档加载器配置(可选)
|
||||||
|
# OCR 语言列表(逗号分隔)
|
||||||
|
RAG_OCR_LANGUAGES=chi_sim,eng
|
||||||
|
# 文档主语言列表(逗号分隔)
|
||||||
|
RAG_DOC_LANGUAGES=zh
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# 日志配置
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
LOG_LEVEL=DEBUG
|
||||||
|
DEBUG=true
|
||||||
11
.gitignore
vendored
11
.gitignore
vendored
@@ -21,7 +21,8 @@
|
|||||||
!test/**
|
!test/**
|
||||||
!.gitea/
|
!.gitea/
|
||||||
!.gitea/**
|
!.gitea/**
|
||||||
!download_sparse_model.py
|
!tools/
|
||||||
|
!tools/**
|
||||||
|
|
||||||
# 3. 放行必要的根目录文件
|
# 3. 放行必要的根目录文件
|
||||||
!.gitignore
|
!.gitignore
|
||||||
@@ -29,7 +30,7 @@
|
|||||||
!QUICKSTART.md
|
!QUICKSTART.md
|
||||||
!REACT_MODE_SUMMARY.md
|
!REACT_MODE_SUMMARY.md
|
||||||
!LICENSE
|
!LICENSE
|
||||||
!requirement.txt
|
!requirements.txt
|
||||||
!.env.docker
|
!.env.docker
|
||||||
|
|
||||||
# ==========================================
|
# ==========================================
|
||||||
@@ -41,12 +42,8 @@ __pycache__/
|
|||||||
*.so
|
*.so
|
||||||
.DS_Store
|
.DS_Store
|
||||||
|
|
||||||
# 模型目录(不提交到 Git,在 Docker 构建时下载)
|
|
||||||
models/
|
|
||||||
|
|
||||||
# 包含敏感信息的环境变量配置(绝对不能传)
|
# 包含敏感信息的环境变量配置(绝对不能传)
|
||||||
.env
|
.env
|
||||||
.env.local
|
|
||||||
|
|
||||||
# 日志
|
# 日志
|
||||||
*.log
|
*.log
|
||||||
@@ -54,4 +51,4 @@ app/*.log
|
|||||||
frontend/*.log
|
frontend/*.log
|
||||||
|
|
||||||
# 测试和用户数据
|
# 测试和用户数据
|
||||||
data/
|
data/
|
||||||
@@ -1,182 +0,0 @@
|
|||||||
# React 模式架构总结
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## ✅ 当前架构:混合路由 + React 循环
|
|
||||||
|
|
||||||
本项目采用 **两层混合架构**:
|
|
||||||
|
|
||||||
```
|
|
||||||
┌─────────────────────────────────────────────────────────────┐
|
|
||||||
│ 第一层:前置混合路由(低延迟) │
|
|
||||||
│ ├─ 规则快速分流(无 LLM) │
|
|
||||||
│ ├─ 轻量级意图分类(smallLLM) │
|
|
||||||
│ └─ 快速路径(fast_chitchat, fast_rag, fast_tool) │
|
|
||||||
└───────────────────────┬─────────────────────────────────────┘
|
|
||||||
↓(自动升级:失败时)
|
|
||||||
┌─────────────────────────────────────────────────────────────┐
|
|
||||||
│ 第二层:完整 React 循环(兜底,复杂任务处理) │
|
|
||||||
│ └─ 推理 → 行动 → 观察(最多 40 步) │
|
|
||||||
└─────────────────────────────────────────────────────────────┘
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 🎯 第一层:前置混合路由(新)
|
|
||||||
|
|
||||||
### 核心功能
|
|
||||||
|
|
||||||
| 功能 | 说明 |
|
|
||||||
|------|------|
|
|
||||||
| 规则快速分流 | 无 LLM,毫秒级响应,用于问候、感谢、子图关键词等 |
|
|
||||||
| 轻量级意图分类 | 使用 smallLLM,压缩到 4 类:chitchat, knowledge, tool, complex |
|
|
||||||
| 快速路径 | 三个快速处理节点:fast_chitchat, fast_rag, fast_tool |
|
|
||||||
| 自动升级 | 快速路径失败时,自动回到完整 React 循环 |
|
|
||||||
| SSE 事件增强 | intent_classified, path_decision, fast_path_*, escalation |
|
|
||||||
|
|
||||||
### 快速流程图
|
|
||||||
|
|
||||||
```
|
|
||||||
START
|
|
||||||
↓
|
|
||||||
init_state
|
|
||||||
↓
|
|
||||||
hybrid_router (前置路由) ←────────────┐
|
|
||||||
↓ │
|
|
||||||
├─ 规则分流 → fast_chitchat →────────┤
|
|
||||||
│ ↓ │
|
|
||||||
├─ 模型分类 → fast_rag →────────────┤
|
|
||||||
│ ↓ │
|
|
||||||
├─ fast_tool →────────┤
|
|
||||||
│ ↓ │
|
|
||||||
└─ react_loop →────────┤
|
|
||||||
↓ │
|
|
||||||
检查成功/升级? ──────────┘
|
|
||||||
↓ ↓
|
|
||||||
finalize react_reason
|
|
||||||
```
|
|
||||||
|
|
||||||
### 关键文件
|
|
||||||
|
|
||||||
| 文件 | 说明 |
|
|
||||||
|------|------|
|
|
||||||
| `backend/app/main_graph/nodes/hybrid_router.py` | 混合路由完整实现 |
|
|
||||||
| `backend/app/model_services/chat_services.py` | get_chat_service() + get_small_llm_service() |
|
|
||||||
| `backend/app/main_graph/utils/main_graph_builder.py` | 集成混合路由到主图 |
|
|
||||||
|
|
||||||
### 配置项
|
|
||||||
|
|
||||||
```python
|
|
||||||
# 构建图时可选择
|
|
||||||
graph = build_react_main_graph(use_hybrid_router=True) # 启用混合路由(默认)
|
|
||||||
graph = build_react_main_graph(use_hybrid_router=False) # 禁用,纯 React 循环
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 🎯 第二层:完整 React 循环(保留)
|
|
||||||
|
|
||||||
### 核心特性
|
|
||||||
|
|
||||||
| 特性 | 说明 |
|
|
||||||
|------|------|
|
|
||||||
| 循环推理 | 每轮推理判断下一步,最多 40 步 |
|
|
||||||
| 结构化错误 | ErrorRecord + ErrorSeverity |
|
|
||||||
| 超时重试 | RAG 最多 2 次,子图最多 1 次 |
|
|
||||||
| 子图集成 | contact, dictionary, news_analysis |
|
|
||||||
| RAG 检索 | 支持重检索(re_retrieve) |
|
|
||||||
|
|
||||||
### 流程图
|
|
||||||
|
|
||||||
```
|
|
||||||
react_reason (推理) ←──────────────────┐
|
|
||||||
↓ │
|
|
||||||
条件路由 │
|
|
||||||
├─→ rag_retrieve (带重试) →──────────┤
|
|
||||||
├─→ contact_subgraph →───────────────┤
|
|
||||||
├─→ dictionary_subgraph →────────────┤
|
|
||||||
├─→ news_analysis_subgraph →─────────┤
|
|
||||||
├─→ handle_error → (重试或降级) →────┤
|
|
||||||
└─→ finalize
|
|
||||||
↓
|
|
||||||
END
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 📦 关键文件清单
|
|
||||||
|
|
||||||
| 文件 | 说明 |
|
|
||||||
|------|------|
|
|
||||||
| `backend/app/main_graph/utils/main_graph_builder.py` | 主图构建(支持混合路由开关) |
|
|
||||||
| `backend/app/main_graph/nodes/react_nodes.py` | React 循环节点 |
|
|
||||||
| `backend/app/main_graph/nodes/hybrid_router.py` | 混合路由节点(新) |
|
|
||||||
| `backend/app/main_graph/nodes/rag_nodes.py` | RAG 检索节点 |
|
|
||||||
| `backend/app/main_graph/utils/retry_utils.py` | 超时重试工具 |
|
|
||||||
| `backend/app/main_graph/state.py` | 主状态 |
|
|
||||||
| `backend/app/core/intent.py` | React 模式意图推理器 |
|
|
||||||
| `backend/app/model_services/chat_services.py` | 双模型服务(llm + smallLLM) |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 🛠️ 模型服务层
|
|
||||||
|
|
||||||
### 生成式大模型服务(Chat)
|
|
||||||
|
|
||||||
| 函数 | 说明 |
|
|
||||||
|------|------|
|
|
||||||
| `get_chat_service()` | 获取大模型服务(用于复杂推理、生成) |
|
|
||||||
| `get_small_llm_service()` | 获取轻量级模型服务(用于简单意图分类、快速问答) |
|
|
||||||
| `get_all_chat_services()` | 获取所有可用的生成式大模型服务(用于多模型切换) |
|
|
||||||
|
|
||||||
### 使用方法
|
|
||||||
|
|
||||||
```python
|
|
||||||
from app.model_services import get_chat_service, get_small_llm_service
|
|
||||||
|
|
||||||
# 获取大模型服务(复杂任务)
|
|
||||||
llm = get_chat_service()
|
|
||||||
response = llm.invoke("什么是 LangGraph?")
|
|
||||||
|
|
||||||
# 获取轻量级模型服务(简单任务)
|
|
||||||
small_llm = get_small_llm_service()
|
|
||||||
response = small_llm.invoke("分类用户意图:'你好'")
|
|
||||||
```
|
|
||||||
|
|
||||||
### 嵌入与重排模型服务
|
|
||||||
|
|
||||||
| 函数 | 说明 |
|
|
||||||
|------|------|
|
|
||||||
| `get_embedding_service()` | 获取嵌入模型服务(自动降级) |
|
|
||||||
| `get_rerank_service()` | 获取重排模型服务(自动降级) |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 🚀 快速使用
|
|
||||||
|
|
||||||
```python
|
|
||||||
from backend.app.main_graph.utils.main_graph_builder import build_react_main_graph
|
|
||||||
|
|
||||||
# 构建图(默认启用混合路由)
|
|
||||||
graph = build_react_main_graph(use_hybrid_router=True)
|
|
||||||
compiled_graph = graph.compile()
|
|
||||||
|
|
||||||
# 调用
|
|
||||||
result = compiled_graph.invoke({"user_query": "你好", "user_id": "test"})
|
|
||||||
print(result.final_result)
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 🎉 完整特性总结
|
|
||||||
|
|
||||||
✅ 双模型服务 (llm + smallLLM)
|
|
||||||
✅ 前置混合路由(规则快速分流 + 轻量级意图分类)
|
|
||||||
✅ 三个快速路径(fast_chitchat, fast_rag, fast_tool)
|
|
||||||
✅ 自动升级机制(快速路径失败 → 完整 React 循环)
|
|
||||||
✅ SSE 事件增强(intent_classified, path_decision, fast_path_*, escalation)
|
|
||||||
✅ 完整 React 循环(最多 40 步)
|
|
||||||
✅ 结构化错误处理
|
|
||||||
✅ 超时和重试策略
|
|
||||||
✅ 子图集成(contact, dictionary, news_analysis)
|
|
||||||
✅ 向后兼容(use_hybrid_router=True/False)
|
|
||||||
@@ -37,8 +37,9 @@ def _get_bool(key: str) -> bool | None:
|
|||||||
|
|
||||||
|
|
||||||
# ========== 第三方 API 密钥 ==========
|
# ========== 第三方 API 密钥 ==========
|
||||||
ZHIPUAI_API_KEY = _get_str("ZHIPUAI_API_KEY")
|
ZHIPUAI_API_KEY=_get_str("ZHIPUAI_API_KEY")
|
||||||
DEEPSEEK_API_KEY = _get_str("DEEPSEEK_API_KEY")
|
DEEPSEEK_API_KEY=_get_str("DEEPSEEK_API_KEY")
|
||||||
|
SILICONFLOW_API_KEY=_get_str("SILICONFLOW_API_KEY")
|
||||||
|
|
||||||
|
|
||||||
# ========== 智谱 API 配置 ==========
|
# ========== 智谱 API 配置 ==========
|
||||||
@@ -51,9 +52,16 @@ ZHIPU_RERANK_MODEL = _get_str("ZHIPU_RERANK_MODEL") or "rerank-2"
|
|||||||
ZHIPU_API_BASE = _get_str("ZHIPU_API_BASE") or "https://open.bigmodel.cn/api/paas/v4"
|
ZHIPU_API_BASE = _get_str("ZHIPU_API_BASE") or "https://open.bigmodel.cn/api/paas/v4"
|
||||||
|
|
||||||
|
|
||||||
|
# ========== 硅基流动(SiliconFlow) API 配置 ==========
|
||||||
|
# 重排模型:BAAI/bge-reranker-v2-m3
|
||||||
|
SILICONFLOW_RERANK_MODEL = _get_str("SILICONFLOW_RERANK_MODEL") or "BAAI/bge-reranker-v2-m3"
|
||||||
|
SILICONFLOW_API_BASE = _get_str("SILICONFLOW_API_BASE") or "https://api.siliconflow.cn/v1"
|
||||||
|
|
||||||
|
|
||||||
# ========== 稀疏模型配置 ==========
|
# ========== 稀疏模型配置 ==========
|
||||||
SPARSE_MODEL_PATH = _get_str("SPARSE_MODEL_PATH") or "./models/sparse"
|
SPARSE_MODEL_PATH = _get_str("SPARSE_MODEL_PATH") or "./models/sparse"
|
||||||
SPARSE_MODEL_NAME = _get_str("SPARSE_MODEL_NAME") or "Qdrant/bm25"
|
SPARSE_MODEL_NAME = _get_str("SPARSE_MODEL_NAME") or "Qdrant/bm25"
|
||||||
|
FASTEMBED_CACHE_PATH = _get_str("FASTEMBED_CACHE_PATH") or "./models/fastembed_cache"
|
||||||
|
|
||||||
# ========== llama.cpp 服务配置(URL + API密钥 配对) ==========
|
# ========== llama.cpp 服务配置(URL + API密钥 配对) ==========
|
||||||
# 主 LLM 服务
|
# 主 LLM 服务
|
||||||
|
|||||||
@@ -3,11 +3,15 @@
|
|||||||
|
|
||||||
本模块提供统一的重排模型服务获取接口,支持自动降级:
|
本模块提供统一的重排模型服务获取接口,支持自动降级:
|
||||||
1. 优先使用本地 llama.cpp 重排服务
|
1. 优先使用本地 llama.cpp 重排服务
|
||||||
2. 本地服务不可用时,自动降级到智谱 API 重排服务
|
2. 本地服务不可用时,自动降级到硅基流动(SiliconFlow) API 重排服务
|
||||||
|
3. 硅基流动服务不可用时,自动降级到智谱 API 重排服务
|
||||||
|
4. 所有API服务不可用时,自动降级到 LLM 评分重排服务
|
||||||
|
|
||||||
主要功能:
|
主要功能:
|
||||||
- LocalLlamaCppRerankProvider:本地 llama.cpp 重排服务提供者
|
- LocalLlamaCppRerankProvider:本地 llama.cpp 重排服务提供者
|
||||||
|
- SiliconFlowRerankProvider:硅基流动 API 重排服务提供者
|
||||||
- ZhipuRerankProvider:智谱 API 重排服务提供者
|
- ZhipuRerankProvider:智谱 API 重排服务提供者
|
||||||
|
- LLMFallbackRerankProvider:LLM 评分降级重排服务提供者
|
||||||
- get_rerank_service():获取重排服务的统一接口
|
- get_rerank_service():获取重排服务的统一接口
|
||||||
|
|
||||||
注意:本模块只负责调用 rerank server,不包含业务逻辑(文档处理、排序、top_n)
|
注意:本模块只负责调用 rerank server,不包含业务逻辑(文档处理、排序、top_n)
|
||||||
@@ -28,7 +32,10 @@ from app.config import (
|
|||||||
LLAMACPP_API_KEY,
|
LLAMACPP_API_KEY,
|
||||||
ZHIPUAI_API_KEY,
|
ZHIPUAI_API_KEY,
|
||||||
ZHIPU_RERANK_MODEL,
|
ZHIPU_RERANK_MODEL,
|
||||||
ZHIPU_API_BASE
|
ZHIPU_API_BASE,
|
||||||
|
SILICONFLOW_API_KEY,
|
||||||
|
SILICONFLOW_RERANK_MODEL,
|
||||||
|
SILICONFLOW_API_BASE
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -136,6 +143,53 @@ class ZhipuRerankService(BaseRerankService):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
class SiliconFlowRerankService(BaseRerankService):
|
||||||
|
"""
|
||||||
|
硅基流动(SiliconFlow) API 重排服务 - 纯服务层
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model: str | None = None, api_key: str | None = None, api_base: str | None = None):
|
||||||
|
self.model = model or SILICONFLOW_RERANK_MODEL
|
||||||
|
self.api_key = api_key or SILICONFLOW_API_KEY
|
||||||
|
self.api_base = api_base or SILICONFLOW_API_BASE
|
||||||
|
|
||||||
|
def compute_scores(self, query: str, documents: List[str]) -> List[float]:
|
||||||
|
"""
|
||||||
|
调用 SiliconFlow rerank API 计算得分 - 纯 API 调用
|
||||||
|
"""
|
||||||
|
if not documents:
|
||||||
|
return []
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
base = self.api_base.rstrip("/")
|
||||||
|
payload = {
|
||||||
|
"model": self.model,
|
||||||
|
"query": query,
|
||||||
|
"documents": documents,
|
||||||
|
"return_documents": False
|
||||||
|
}
|
||||||
|
|
||||||
|
with httpx.Client(timeout=120) as client:
|
||||||
|
response = client.post(
|
||||||
|
f"{base}/rerank",
|
||||||
|
headers=headers,
|
||||||
|
json=payload,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
if isinstance(data, dict) and "results" in data:
|
||||||
|
results = data["results"]
|
||||||
|
results_sorted = sorted(results, key=lambda x: x["index"])
|
||||||
|
return [item["relevance_score"] for item in results_sorted]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"未知的 SiliconFlow rerank API 响应格式: {data}")
|
||||||
|
|
||||||
|
|
||||||
class LLMFallbackRerankService(BaseRerankService):
|
class LLMFallbackRerankService(BaseRerankService):
|
||||||
"""
|
"""
|
||||||
使用 LLM 作为最后的降级方案进行重排
|
使用 LLM 作为最后的降级方案进行重排
|
||||||
@@ -291,18 +345,53 @@ class ZhipuRerankProvider(BaseServiceProvider[BaseRerankService]):
|
|||||||
return self._service_instance
|
return self._service_instance
|
||||||
|
|
||||||
|
|
||||||
|
class SiliconFlowRerankProvider(BaseServiceProvider[BaseRerankService]):
|
||||||
|
"""
|
||||||
|
硅基流动(SiliconFlow) API 重排服务提供者
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model: str | None = None):
|
||||||
|
super().__init__("siliconflow_rerank")
|
||||||
|
self._model = model or SILICONFLOW_RERANK_MODEL
|
||||||
|
|
||||||
|
def is_available(self) -> bool:
|
||||||
|
"""
|
||||||
|
检查 SiliconFlow API 重排服务是否可用
|
||||||
|
"""
|
||||||
|
if not SILICONFLOW_API_KEY:
|
||||||
|
logger.warning("SILICONFLOW_API_KEY 未配置")
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
service = SiliconFlowRerankService(model=self._model)
|
||||||
|
test_scores = service.compute_scores("test query", ["test document"])
|
||||||
|
logger.info("SiliconFlow 重排服务可用")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"SiliconFlow 重排服务不可用: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_service(self) -> BaseRerankService:
|
||||||
|
"""
|
||||||
|
获取 SiliconFlow API 重排服务
|
||||||
|
"""
|
||||||
|
if self._service_instance is None:
|
||||||
|
self._service_instance = SiliconFlowRerankService(model=self._model)
|
||||||
|
return self._service_instance
|
||||||
|
|
||||||
|
|
||||||
def get_rerank_service() -> BaseRerankService:
|
def get_rerank_service() -> BaseRerankService:
|
||||||
"""
|
"""
|
||||||
获取重排服务(带自动降级)- 纯服务层
|
获取重排服务(带自动降级)- 纯服务层
|
||||||
|
|
||||||
降级链: Local llama.cpp -> Zhipu Rerank -> LLM Fallback
|
降级链: Local llama.cpp -> SiliconFlow Rerank -> Zhipu Rerank -> LLM Fallback
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
BaseRerankService: 重排服务实例
|
BaseRerankService: 重排服务实例
|
||||||
"""
|
"""
|
||||||
def _create_chain():
|
def _create_chain():
|
||||||
primary = LocalLlamaCppRerankProvider()
|
primary = LocalLlamaCppRerankProvider()
|
||||||
fallbacks = [ZhipuRerankProvider(), LLMFallbackRerankProvider()]
|
fallbacks = [SiliconFlowRerankProvider(), ZhipuRerankProvider(), LLMFallbackRerankProvider()]
|
||||||
return FallbackServiceChain(primary, fallbacks)
|
return FallbackServiceChain(primary, fallbacks)
|
||||||
|
|
||||||
chain = SingletonServiceManager.get_or_create("rerank_service_chain", _create_chain)
|
chain = SingletonServiceManager.get_or_create("rerank_service_chain", _create_chain)
|
||||||
|
|||||||
@@ -1,4 +1,11 @@
|
|||||||
# rag/pipeline.py
|
"""
|
||||||
|
RAG 检索流水线模块
|
||||||
|
|
||||||
|
提供固定流程的 RAG 检索:
|
||||||
|
多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档
|
||||||
|
|
||||||
|
默认使用混合检索(稠密+稀疏)+ 父子文档模式。
|
||||||
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
@@ -6,61 +13,86 @@ from typing import List
|
|||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
|
|
||||||
from ..model_services import get_rerank_service
|
from app.model_services import get_rerank_service
|
||||||
from .rerank import create_document_reranker
|
from app.rag.rerank import create_document_reranker
|
||||||
from .query_transform import MultiQueryGenerator
|
from app.rag.query_transform import MultiQueryGenerator
|
||||||
from .fusion import reciprocal_rank_fusion
|
from app.rag.fusion import reciprocal_rank_fusion
|
||||||
|
from app.rag.retriever import create_parent_hybrid_retriever
|
||||||
|
|
||||||
|
|
||||||
class RAGPipeline:
|
class RAGPipeline:
|
||||||
"""
|
"""
|
||||||
固定流程的 RAG 检索流水线:
|
固定流程的 RAG 检索流水线:
|
||||||
多路改写 → 并行检索 → RRF融合 → 重排序 → 返回父文档
|
多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档
|
||||||
|
|
||||||
|
默认使用混合检索(稠密+BM25稀疏)+ 父子文档模式。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
retriever, # 基础检索器(应返回父文档,例如 ParentDocumentRetriever 实例)
|
retriever=None,
|
||||||
llm: BaseLanguageModel,
|
llm: Optional[BaseLanguageModel] = None,
|
||||||
num_queries: int = 3,
|
num_queries: int = 3,
|
||||||
rerank_top_n: int = 5,
|
rerank_top_n: int = 5,
|
||||||
|
collection_name: str = "rag_documents",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
retriever: 基础检索器对象,需实现 ainvoke(query) 异步方法
|
retriever: 基础检索器对象,需实现 ainvoke(query) 异步方法。
|
||||||
llm: 用于生成多路查询的语言模型
|
如果不提供,会自动创建默认的父子文档混合检索器。
|
||||||
num_queries: 生成的查询变体数量
|
llm: 用于生成多路查询的语言模型。
|
||||||
rerank_top_n: 最终返回的文档数量
|
num_queries: 生成的查询变体数量。
|
||||||
rerank_model: 重排序模型名称
|
rerank_top_n: 最终返回的文档数量。
|
||||||
|
collection_name: Qdrant 集合名称(仅当 retriever 未提供时使用)。
|
||||||
"""
|
"""
|
||||||
self.retriever = retriever
|
# 如果没有提供 retriever,自动创建默认的混合检索器
|
||||||
|
if retriever is None:
|
||||||
|
self.retriever = create_parent_hybrid_retriever(
|
||||||
|
collection_name=collection_name,
|
||||||
|
search_k=rerank_top_n * 2 # 多取一些给重排序用
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.retriever = retriever
|
||||||
|
|
||||||
self.llm = llm
|
self.llm = llm
|
||||||
self.num_queries = num_queries
|
self.num_queries = num_queries
|
||||||
self.rerank_top_n = rerank_top_n
|
self.rerank_top_n = rerank_top_n
|
||||||
|
|
||||||
# 初始化组件 - 使用统一的重排服务获取接口
|
# 初始化组件 - 使用统一的重排服务获取接口
|
||||||
self.query_generator = MultiQueryGenerator(llm=llm, num_queries=num_queries)
|
self.query_generator = MultiQueryGenerator(llm=llm, num_queries=num_queries) if llm else None
|
||||||
self.reranker = create_document_reranker()
|
self.reranker = create_document_reranker()
|
||||||
|
|
||||||
async def aretrieve(self, query: str) -> List[Document]:
|
async def aretrieve(self, query: str) -> List[Document]:
|
||||||
"""
|
"""
|
||||||
异步执行完整检索流程
|
异步执行完整检索流程
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 用户查询
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
检索到的相关文档列表
|
||||||
"""
|
"""
|
||||||
# Step 1: 生成多路查询
|
# 如果有 query_generator,做多路改写
|
||||||
queries = await self.query_generator.agenerate(query)
|
if self.query_generator and self.llm:
|
||||||
# 包含原始查询,确保至少有一条
|
# Step 1: 生成多路查询
|
||||||
if query not in queries:
|
queries = await self.query_generator.agenerate(query)
|
||||||
queries.insert(0, query)
|
# 包含原始查询,确保至少有一条
|
||||||
|
if query not in queries:
|
||||||
|
queries.insert(0, query)
|
||||||
|
else:
|
||||||
|
# 如果原始查询已在列表中,将其移至首位
|
||||||
|
queries.remove(query)
|
||||||
|
queries.insert(0, query)
|
||||||
|
|
||||||
|
# Step 2: 并行检索(每个查询获取文档列表)
|
||||||
|
tasks = [self.retriever.ainvoke(q) for q in queries]
|
||||||
|
doc_lists = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
# Step 3: RRF 融合
|
||||||
|
fused_docs = reciprocal_rank_fusion(doc_lists)
|
||||||
else:
|
else:
|
||||||
# 如果原始查询已在列表中,将其移至首位
|
# 没有 LLM 做查询改写,直接用原始查询检索
|
||||||
queries.remove(query)
|
fused_docs = await self.retriever.ainvoke(query)
|
||||||
queries.insert(0, query)
|
|
||||||
|
|
||||||
# Step 2: 并行检索(每个查询获取文档列表)
|
|
||||||
tasks = [self.retriever.ainvoke(q) for q in queries]
|
|
||||||
doc_lists = await asyncio.gather(*tasks)
|
|
||||||
|
|
||||||
# Step 3: RRF 融合
|
|
||||||
fused_docs = reciprocal_rank_fusion(doc_lists)
|
|
||||||
|
|
||||||
# Step 4: 重排序
|
# Step 4: 重排序
|
||||||
try:
|
try:
|
||||||
@@ -76,7 +108,15 @@ class RAGPipeline:
|
|||||||
return asyncio.run(self.aretrieve(query))
|
return asyncio.run(self.aretrieve(query))
|
||||||
|
|
||||||
def format_context(self, documents: List[Document]) -> str:
|
def format_context(self, documents: List[Document]) -> str:
|
||||||
"""将文档列表格式化为上下文字符串"""
|
"""
|
||||||
|
将文档列表格式化为上下文字符串
|
||||||
|
|
||||||
|
Args:
|
||||||
|
documents: 文档列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
格式化后的上下文字符串
|
||||||
|
"""
|
||||||
if not documents:
|
if not documents:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
@@ -84,4 +124,30 @@ class RAGPipeline:
|
|||||||
for i, doc in enumerate(documents, 1):
|
for i, doc in enumerate(documents, 1):
|
||||||
source = doc.metadata.get("source", "未知来源")
|
source = doc.metadata.get("source", "未知来源")
|
||||||
parts.append(f"【资料 {i}】来源:{source}\n{doc.page_content}\n---\n")
|
parts.append(f"【资料 {i}】来源:{source}\n{doc.page_content}\n---\n")
|
||||||
return "\n".join(parts)
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
def create_rag_pipeline(
|
||||||
|
collection_name: str = "rag_documents",
|
||||||
|
llm: Optional[BaseLanguageModel] = None,
|
||||||
|
num_queries: int = 3,
|
||||||
|
rerank_top_n: int = 5,
|
||||||
|
) -> RAGPipeline:
|
||||||
|
"""
|
||||||
|
创建 RAG 检索流水线的便捷函数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
collection_name: Qdrant 集合名称
|
||||||
|
llm: 用于生成多路查询的语言模型
|
||||||
|
num_queries: 生成的查询变体数量
|
||||||
|
rerank_top_n: 最终返回的文档数量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RAGPipeline 实例
|
||||||
|
"""
|
||||||
|
return RAGPipeline(
|
||||||
|
llm=llm,
|
||||||
|
num_queries=num_queries,
|
||||||
|
rerank_top_n=rerank_top_n,
|
||||||
|
collection_name=collection_name
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,170 +1,379 @@
|
|||||||
"""
|
"""
|
||||||
Qdrant 向量检索器模块
|
Qdrant 混合检索器模块
|
||||||
|
|
||||||
提供基于 Qdrant 的混合检索(Dense + Sparse)功能。
|
提供基于 Qdrant 的混合检索(Dense + Sparse)功能,包括:
|
||||||
|
- 纯混合检索(无子父文档)
|
||||||
|
- 父子文档混合检索(先检索子文档,再返回父文档)
|
||||||
|
|
||||||
核心原理:
|
核心原理:
|
||||||
- 使用 Qdrant 原生混合检索(langchain-qdrant 的 RetrievalMode.HYBRID)
|
- 使用 Qdrant 原生 Fusion API (RRF) 做分数融合
|
||||||
- 同时存储稠密向量和稀疏向量
|
- 同时使用稠密向量(语义)和稀疏向量(BM25 关键词)
|
||||||
- 语义理解 + 关键词匹配,效果最优
|
|
||||||
|
|
||||||
使用示例:
|
|
||||||
>>> from app.rag.retriever import create_hybrid_retriever
|
|
||||||
>>> retriever = create_hybrid_retriever(collection_name="rag_documents")
|
|
||||||
>>> docs = retriever.invoke("什么是 RAG?")
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Dict, Any, Optional
|
from typing import Dict, Any, Optional, List
|
||||||
from qdrant_client import QdrantClient
|
from qdrant_client import QdrantClient
|
||||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||||
from langchain_qdrant import (
|
from qdrant_client.http.models import (
|
||||||
QdrantVectorStore,
|
SearchRequest, Fusion, FusionProtocol, NamedVector, NamedSparseVector
|
||||||
RetrievalMode,
|
|
||||||
FastEmbedSparse,
|
|
||||||
)
|
)
|
||||||
|
from langchain_core.documents import Document
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.retrievers import BaseRetriever
|
from langchain_core.retrievers import BaseRetriever, RetrieverOutput
|
||||||
|
|
||||||
from rag_core import QDRANT_URL, QDRANT_API_KEY
|
from rag_core import QdrantVectorStore, get_sparse_embedder, create_docstore
|
||||||
from rag_core.client import create_qdrant_client as create_core_qdrant_client
|
from rag_core.client import create_qdrant_client as create_core_qdrant_client
|
||||||
from app.model_services import get_embedding_service
|
from app.model_services import get_embedding_service
|
||||||
from app.config import SPARSE_MODEL_PATH, SPARSE_MODEL_NAME
|
from app.logger import info, warning, debug
|
||||||
from app.logger import info, warning
|
|
||||||
|
|
||||||
# 模块级常量
|
# 模块级常量
|
||||||
DEFAULT_SEARCH_K = 20
|
DEFAULT_SEARCH_K = 20
|
||||||
DEFAULT_SCORE_THRESHOLD = 0.3
|
DEFAULT_PARENT_SEARCH_K = 5
|
||||||
|
|
||||||
|
|
||||||
def create_base_retriever(
|
class HybridRetriever(BaseRetriever):
|
||||||
collection_name: str,
|
|
||||||
search_kwargs: Dict[str, Any] | None = None,
|
|
||||||
client: QdrantClient | None = None,
|
|
||||||
embeddings: Embeddings | None = None,
|
|
||||||
) -> BaseRetriever:
|
|
||||||
"""
|
"""
|
||||||
创建基础向量检索器(仅稠密向量检索)
|
混合检索器:稠密向量 + BM25 稀疏向量 RRF 分数融合
|
||||||
|
|
||||||
Args:
|
直接使用 Qdrant 原生 Fusion API,性能最优。
|
||||||
collection_name: Qdrant 集合名称
|
|
||||||
search_kwargs: 搜索参数
|
|
||||||
client: 可选的 Qdrant 客户端
|
|
||||||
embeddings: 可选的嵌入模型(默认使用 get_embedding_service())
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
LangChain 兼容的检索器
|
|
||||||
"""
|
"""
|
||||||
# 默认使用统一嵌入服务(已内置降级机制)
|
|
||||||
if embeddings is None:
|
def __init__(
|
||||||
embeddings = get_embedding_service()
|
self,
|
||||||
info("✅ 使用统一嵌入服务(本地 llama.cpp → 智谱 API 自动降级)")
|
collection_name: str,
|
||||||
|
vector_store: QdrantVectorStore,
|
||||||
|
search_k: int = DEFAULT_SEARCH_K,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
collection_name: Qdrant 集合名称
|
||||||
|
vector_store: QdrantVectorStore 实例
|
||||||
|
search_k: 检索返回结果数
|
||||||
|
"""
|
||||||
|
self.collection_name = collection_name
|
||||||
|
self.vector_store = vector_store
|
||||||
|
self.search_k = search_k
|
||||||
|
self.client = vector_store.get_qdrant_client()
|
||||||
|
self.sparse_embedder = get_sparse_embedder()
|
||||||
|
|
||||||
|
def _get_relevant_documents(
|
||||||
|
self, query: str, *, run_manager: Optional[Any] = None
|
||||||
|
) -> List[Document]:
|
||||||
|
"""
|
||||||
|
同步检索相关文档
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 查询字符串
|
||||||
|
run_manager: LangChain 运行管理器(可选)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
相关文档列表
|
||||||
|
"""
|
||||||
|
# 生成双向量
|
||||||
|
dense_query = self.vector_store.embeddings.embed_query(query)
|
||||||
|
sparse_query = self.sparse_embedder.embed_query(query)
|
||||||
|
|
||||||
|
# 构建双检索请求
|
||||||
|
searches = [
|
||||||
|
# 稠密检索
|
||||||
|
SearchRequest(
|
||||||
|
vector=NamedVector(name="dense", vector=dense_query),
|
||||||
|
limit=self.search_k,
|
||||||
|
with_payload=True
|
||||||
|
),
|
||||||
|
# 稀疏检索
|
||||||
|
SearchRequest(
|
||||||
|
vector=NamedSparseVector(name="sparse", vector=sparse_query),
|
||||||
|
limit=self.search_k,
|
||||||
|
with_payload=True
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
# RRF 分数融合
|
||||||
|
fused_results = self.client.fusion(
|
||||||
|
collection_name=self.collection_name,
|
||||||
|
requests=searches,
|
||||||
|
fusion=Fusion(fusion=FusionProtocol.RRF)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 转换为 Document 格式
|
||||||
|
results = []
|
||||||
|
for point in fused_results.points:
|
||||||
|
doc = Document(
|
||||||
|
page_content=point.payload.pop("text", ""),
|
||||||
|
metadata=point.payload
|
||||||
|
)
|
||||||
|
results.append(doc)
|
||||||
|
|
||||||
|
debug(f"混合检索返回 {len(results)} 个文档")
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def _aget_relevant_documents(
|
||||||
|
self, query: str, *, run_manager: Optional[Any] = None
|
||||||
|
) -> List[Document]:
|
||||||
|
"""异步检索(当前调用同步版本)"""
|
||||||
|
# Qdrant 客户端没有原生 async,这里用同步版本
|
||||||
|
return self._get_relevant_documents(query, run_manager=run_manager)
|
||||||
|
|
||||||
# 合并默认搜索参数
|
|
||||||
merged_search_kwargs = {"k": DEFAULT_SEARCH_K}
|
|
||||||
if search_kwargs:
|
|
||||||
merged_search_kwargs.update(search_kwargs)
|
|
||||||
|
|
||||||
# 创建或复用 Qdrant 客户端
|
class ParentHybridRetriever(BaseRetriever):
|
||||||
if client is None:
|
"""
|
||||||
client = create_core_qdrant_client()
|
父子文档混合检索器:
|
||||||
|
|
||||||
# 验证集合是否存在
|
1. 先用混合检索找到相关子文档
|
||||||
try:
|
2. 根据子文档的 parent_id 找到对应的父文档
|
||||||
client.get_collection(collection_name)
|
3. 去重并返回父文档
|
||||||
except UnexpectedResponse as e:
|
"""
|
||||||
if e.status_code == 404:
|
|
||||||
warning(f"⚠️ Qdrant 集合 '{collection_name}' 不存在,请先创建并索引文档")
|
def __init__(
|
||||||
raise ValueError(f"Qdrant 集合 '{collection_name}' 不存在")
|
self,
|
||||||
raise
|
collection_name: str,
|
||||||
|
vector_store: QdrantVectorStore,
|
||||||
# 构建向量存储
|
search_k: int = DEFAULT_PARENT_SEARCH_K,
|
||||||
vector_store = QdrantVectorStore(
|
docstore: Optional[Any] = None,
|
||||||
client=client,
|
):
|
||||||
collection_name=collection_name,
|
"""
|
||||||
embedding=embeddings,
|
Args:
|
||||||
)
|
collection_name: Qdrant 集合名称
|
||||||
|
vector_store: QdrantVectorStore 实例
|
||||||
return vector_store.as_retriever(search_kwargs=merged_search_kwargs)
|
search_k: 最终返回的父文档数
|
||||||
|
docstore: 文档存储(如果父文档在 PostgreSQL),可选
|
||||||
|
"""
|
||||||
|
self.collection_name = collection_name
|
||||||
|
self.vector_store = vector_store
|
||||||
|
self.search_k = search_k
|
||||||
|
self.client = vector_store.get_qdrant_client()
|
||||||
|
self.sparse_embedder = get_sparse_embedder()
|
||||||
|
self.docstore = docstore
|
||||||
|
|
||||||
|
def _get_relevant_documents(
|
||||||
|
self, query: str, *, run_manager: Optional[Any] = None
|
||||||
|
) -> List[Document]:
|
||||||
|
"""
|
||||||
|
同步检索相关父文档
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 查询字符串
|
||||||
|
run_manager: LangChain 运行管理器(可选)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
相关父文档列表
|
||||||
|
"""
|
||||||
|
# 1. 生成查询双向量
|
||||||
|
dense_query = self.vector_store.embeddings.embed_query(query)
|
||||||
|
sparse_query = self.sparse_embedder.embed_query(query)
|
||||||
|
|
||||||
|
# 2. 多取一些子文档,避免去重后数量不足
|
||||||
|
search_limit = self.search_k * 2
|
||||||
|
searches = [
|
||||||
|
# 稠密检索
|
||||||
|
SearchRequest(
|
||||||
|
vector=NamedVector(name="dense", vector=dense_query),
|
||||||
|
limit=search_limit,
|
||||||
|
with_payload=True
|
||||||
|
),
|
||||||
|
# 稀疏检索
|
||||||
|
SearchRequest(
|
||||||
|
vector=NamedSparseVector(name="sparse", vector=sparse_query),
|
||||||
|
limit=search_limit,
|
||||||
|
with_payload=True
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
# 3. RRF 分数融合,拿到子文档命中结果
|
||||||
|
fused_results = self.client.fusion(
|
||||||
|
collection_name=self.collection_name,
|
||||||
|
requests=searches,
|
||||||
|
fusion=Fusion(fusion=FusionProtocol.RRF)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not fused_results.points:
|
||||||
|
debug("混合检索未找到任何文档")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 4. 收集 parent_id 和对应最高得分
|
||||||
|
parent_score_map = {}
|
||||||
|
parent_ids = set()
|
||||||
|
child_point_map = {} # 保存子文档点用于降级
|
||||||
|
|
||||||
|
for point in fused_results.points:
|
||||||
|
parent_id = point.payload.get("parent_id", point.id)
|
||||||
|
score = point.score
|
||||||
|
|
||||||
|
# 同一个 parent_id 只保留最高得分
|
||||||
|
if parent_id not in parent_score_map or score > parent_score_map[parent_id]:
|
||||||
|
parent_score_map[parent_id] = score
|
||||||
|
parent_ids.add(parent_id)
|
||||||
|
child_point_map[parent_id] = point
|
||||||
|
|
||||||
|
# 5. 批量查询父文档
|
||||||
|
# 首先尝试从 Qdrant 直接查询(因为父文档可能也存在 Qdrant 中)
|
||||||
|
parent_docs = []
|
||||||
|
found_parent_ids = set()
|
||||||
|
|
||||||
|
try:
|
||||||
|
parent_points = self.client.retrieve(
|
||||||
|
collection_name=self.collection_name,
|
||||||
|
ids=list(parent_ids),
|
||||||
|
with_payload=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# 处理找到的父文档
|
||||||
|
for point in parent_points:
|
||||||
|
doc = Document(
|
||||||
|
page_content=point.payload.pop("text", ""),
|
||||||
|
metadata=point.payload
|
||||||
|
)
|
||||||
|
parent_docs.append(doc)
|
||||||
|
found_parent_ids.add(point.id)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
warning(f"从 Qdrant 查询父文档失败: {e}")
|
||||||
|
|
||||||
|
# 6. 如果有 docstore,尝试从 docstore 查询剩余的父文档
|
||||||
|
if self.docstore and len(found_parent_ids) < len(parent_ids):
|
||||||
|
missing_parent_ids = parent_ids - found_parent_ids
|
||||||
|
try:
|
||||||
|
docstore_docs = self.docstore.mget(missing_parent_ids)
|
||||||
|
for doc_id, doc in zip(missing_parent_ids, docstore_docs):
|
||||||
|
if doc is not None:
|
||||||
|
parent_docs.append(doc)
|
||||||
|
found_parent_ids.add(doc_id)
|
||||||
|
except Exception as e:
|
||||||
|
warning(f"从 docstore 查询父文档失败: {e}")
|
||||||
|
|
||||||
|
# 7. 降级:对于仍未找到的父文档,用子文档本身代替
|
||||||
|
missing_parent_ids = parent_ids - found_parent_ids
|
||||||
|
if missing_parent_ids:
|
||||||
|
warning(f"以下 parent_id 未找到对应的父文档,将返回子文档本身: {missing_parent_ids}")
|
||||||
|
for parent_id in missing_parent_ids:
|
||||||
|
child_point = child_point_map.get(parent_id)
|
||||||
|
if child_point:
|
||||||
|
doc = Document(
|
||||||
|
page_content=child_point.payload.pop("text", ""),
|
||||||
|
metadata=child_point.payload
|
||||||
|
)
|
||||||
|
parent_docs.append(doc)
|
||||||
|
|
||||||
|
# 8. 按照得分降序排序,返回前 k 个
|
||||||
|
parent_docs_with_scores = [
|
||||||
|
(doc, parent_score_map.get(doc.metadata.get("id", doc.id), 0.0))
|
||||||
|
for doc in parent_docs
|
||||||
|
]
|
||||||
|
parent_docs_with_scores.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
|
final_docs = [doc for doc, _ in parent_docs_with_scores[:self.search_k]]
|
||||||
|
debug(f"父子文档混合检索返回 {len(final_docs)} 个父文档")
|
||||||
|
|
||||||
|
return final_docs
|
||||||
|
|
||||||
|
async def _aget_relevant_documents(
|
||||||
|
self, query: str, *, run_manager: Optional[Any] = None
|
||||||
|
) -> List[Document]:
|
||||||
|
"""异步检索(当前调用同步版本)"""
|
||||||
|
return self._get_relevant_documents(query, run_manager=run_manager)
|
||||||
|
|
||||||
|
|
||||||
def create_hybrid_retriever(
|
def create_hybrid_retriever(
|
||||||
collection_name: str,
|
collection_name: str,
|
||||||
dense_k: int = 10,
|
search_k: int = DEFAULT_SEARCH_K,
|
||||||
sparse_k: int = 10,
|
embeddings: Optional[Embeddings] = None,
|
||||||
score_threshold: float | None = DEFAULT_SCORE_THRESHOLD,
|
|
||||||
client: QdrantClient | None = None,
|
|
||||||
embeddings: Embeddings | None = None,
|
|
||||||
) -> BaseRetriever:
|
) -> BaseRetriever:
|
||||||
"""
|
"""
|
||||||
创建混合检索器(稠密向量 + BM25 稀疏向量,Qdrant 原生实现)。
|
创建混合检索器(稠密向量 + BM25 稀疏向量)。
|
||||||
|
|
||||||
|
这是默认推荐的检索方式,效果最优。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
collection_name: Qdrant 集合名称。
|
collection_name: Qdrant 集合名称
|
||||||
dense_k: 稠密向量检索返回数量,默认 10。
|
search_k: 检索返回结果数
|
||||||
sparse_k: 稀疏向量检索返回数量,默认 10。
|
|
||||||
score_threshold: 相似度阈值,默认 0.3。
|
|
||||||
client: 可选的 Qdrant 客户端实例。
|
|
||||||
embeddings: 可选的嵌入模型实例。若未提供,将自动获取统一嵌入服务。
|
embeddings: 可选的嵌入模型实例。若未提供,将自动获取统一嵌入服务。
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
BaseRetriever 实例,配置了混合搜索参数。
|
HybridRetriever 实例
|
||||||
"""
|
"""
|
||||||
total_k = dense_k + sparse_k
|
# 默认使用统一嵌入服务
|
||||||
|
|
||||||
search_kwargs = {
|
|
||||||
"k": total_k,
|
|
||||||
"search_type": "similarity_score_threshold",
|
|
||||||
"score_threshold": score_threshold,
|
|
||||||
}
|
|
||||||
|
|
||||||
# 默认使用统一嵌入服务(已内置降级机制)
|
|
||||||
if embeddings is None:
|
if embeddings is None:
|
||||||
embeddings = get_embedding_service()
|
embeddings = get_embedding_service()
|
||||||
info("✅ 使用统一嵌入服务(本地 llama.cpp → 智谱 API 自动降级)")
|
info("✅ 使用统一嵌入服务(本地 llama.cpp → 智谱 API 自动降级)")
|
||||||
|
|
||||||
# 创建或复用 Qdrant 客户端
|
# 创建向量存储
|
||||||
if client is None:
|
vector_store = QdrantVectorStore(collection_name=collection_name, embeddings=embeddings)
|
||||||
client = create_core_qdrant_client()
|
|
||||||
|
|
||||||
# 验证集合是否存在
|
# 验证集合是否存在
|
||||||
try:
|
try:
|
||||||
client.get_collection(collection_name)
|
vector_store.get_client().get_collection(collection_name)
|
||||||
except UnexpectedResponse as e:
|
except UnexpectedResponse as e:
|
||||||
if e.status_code == 404:
|
if e.status_code == 404:
|
||||||
warning(f"⚠️ Qdrant 集合 '{collection_name}' 不存在,请先创建并索引文档")
|
warning(f"⚠️ Qdrant 集合 '{collection_name}' 不存在,请先创建并索引文档")
|
||||||
raise ValueError(f"Qdrant 集合 '{collection_name}' 不存在")
|
raise ValueError(f"Qdrant 集合 '{collection_name}' 不存在")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# 初始化稀疏嵌入(使用本地缓存目录)
|
info(f"✅ Qdrant 混合检索器初始化成功(search_k={search_k})")
|
||||||
sparse_embeddings = FastEmbedSparse(
|
return HybridRetriever(
|
||||||
model_name=SPARSE_MODEL_NAME,
|
|
||||||
cache_dir=SPARSE_MODEL_PATH
|
|
||||||
)
|
|
||||||
info(f"✅ FastEmbedSparse 初始化成功 (cache_dir={SPARSE_MODEL_PATH})")
|
|
||||||
|
|
||||||
# 创建混合模式的 QdrantVectorStore
|
|
||||||
vector_store = QdrantVectorStore(
|
|
||||||
client=client,
|
|
||||||
collection_name=collection_name,
|
collection_name=collection_name,
|
||||||
embedding=embeddings,
|
vector_store=vector_store,
|
||||||
sparse_embedding=sparse_embeddings,
|
search_k=search_k
|
||||||
retrieval_mode=RetrievalMode.HYBRID,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
info(f"✅ Qdrant 原生混合检索器初始化成功 (k={total_k})")
|
|
||||||
return vector_store.as_retriever(search_kwargs=search_kwargs)
|
|
||||||
|
|
||||||
|
def create_parent_hybrid_retriever(
|
||||||
# 可选:提供异步友好的辅助函数
|
|
||||||
async def acreate_base_retriever(
|
|
||||||
collection_name: str,
|
collection_name: str,
|
||||||
search_kwargs: Dict[str, Any] | None = None,
|
search_k: int = DEFAULT_PARENT_SEARCH_K,
|
||||||
client: QdrantClient | None = None,
|
embeddings: Optional[Embeddings] = None,
|
||||||
|
use_docstore: bool = True,
|
||||||
) -> BaseRetriever:
|
) -> BaseRetriever:
|
||||||
"""
|
"""
|
||||||
异步创建基础向量检索器(与同步版本功能相同)。
|
创建父子文档混合检索器(默认推荐)。
|
||||||
|
|
||||||
适用于需要异步初始化的场景(例如在 FastAPI 启动事件中)。
|
检索流程:
|
||||||
|
1. 混合检索找到相关子文档
|
||||||
|
2. 根据 parent_id 找到对应的父文档
|
||||||
|
3. 去重并返回父文档
|
||||||
|
|
||||||
|
Args:
|
||||||
|
collection_name: Qdrant 集合名称
|
||||||
|
search_k: 最终返回的父文档数
|
||||||
|
embeddings: 可选的嵌入模型实例
|
||||||
|
use_docstore: 是否使用 PostgreSQL docstore 存储父文档
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ParentHybridRetriever 实例
|
||||||
"""
|
"""
|
||||||
# 由于 QdrantVectorStore 初始化本身是同步的,这里直接调用同步版本即可
|
# 默认使用统一嵌入服务
|
||||||
return create_base_retriever(collection_name, search_kwargs, client)
|
if embeddings is None:
|
||||||
|
embeddings = get_embedding_service()
|
||||||
|
info("✅ 使用统一嵌入服务(本地 llama.cpp → 智谱 API 自动降级)")
|
||||||
|
|
||||||
|
# 创建向量存储
|
||||||
|
vector_store = QdrantVectorStore(collection_name=collection_name, embeddings=embeddings)
|
||||||
|
|
||||||
|
# 验证集合是否存在
|
||||||
|
try:
|
||||||
|
vector_store.get_client().get_collection(collection_name)
|
||||||
|
except UnexpectedResponse as e:
|
||||||
|
if e.status_code == 404:
|
||||||
|
warning(f"⚠️ Qdrant 集合 '{collection_name}' 不存在,请先创建并索引文档")
|
||||||
|
raise ValueError(f"Qdrant 集合 '{collection_name}' 不存在")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# 创建 docstore(如果需要)
|
||||||
|
docstore = None
|
||||||
|
if use_docstore:
|
||||||
|
try:
|
||||||
|
docstore, _ = create_docstore()
|
||||||
|
info("✅ 文档存储初始化成功(PostgreSQL)")
|
||||||
|
except Exception as e:
|
||||||
|
warning(f"⚠️ 文档存储初始化失败,将不使用 docstore: {e}")
|
||||||
|
|
||||||
|
info(f"✅ Qdrant 父子文档混合检索器初始化成功(search_k={search_k})")
|
||||||
|
return ParentHybridRetriever(
|
||||||
|
collection_name=collection_name,
|
||||||
|
vector_store=vector_store,
|
||||||
|
search_k=search_k,
|
||||||
|
docstore=docstore
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# 别名:默认就是父子文档混合检索
|
||||||
|
create_retriever = create_parent_hybrid_retriever
|
||||||
|
|||||||
@@ -3,52 +3,94 @@ RAG 工具模块
|
|||||||
|
|
||||||
将检索功能封装为 LangChain Tool,供 Agent 调用。
|
将检索功能封装为 LangChain Tool,供 Agent 调用。
|
||||||
采用固定流水线:多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档。
|
采用固定流水线:多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档。
|
||||||
|
|
||||||
|
默认使用混合检索(稠密+BM25稀疏)+ 父子文档模式。
|
||||||
"""
|
"""
|
||||||
from typing import Callable
|
from typing import Callable, Optional
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
from langchain_core.retrievers import BaseRetriever
|
from langchain_core.retrievers import BaseRetriever
|
||||||
from .pipeline import RAGPipeline
|
from app.rag.pipeline import RAGPipeline, create_rag_pipeline
|
||||||
|
|
||||||
|
|
||||||
def create_rag_tool_sync(
|
def create_rag_tool_sync(
|
||||||
retriever: BaseRetriever,
|
retriever: Optional[BaseRetriever] = None,
|
||||||
llm: BaseLanguageModel,
|
llm: Optional[BaseLanguageModel] = None,
|
||||||
num_queries: int = 3,
|
num_queries: int = 3,
|
||||||
rerank_top_n: int = 5,
|
rerank_top_n: int = 5,
|
||||||
collection_name: str = "rag_documents",
|
collection_name: str = "rag_documents",
|
||||||
) -> Callable:
|
) -> Callable:
|
||||||
"""
|
"""
|
||||||
创建一个配置好的 RAG 检索工具(同步版本,用于不支持异步的旧版 Agent)。
|
创建一个配置好的 RAG 检索工具(同步版本)。
|
||||||
|
|
||||||
参数同 create_rag_tool。
|
默认使用混合检索(稠密+BM25稀疏)+ 父子文档模式。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
retriever: 基础检索器对象(可选,不提供则自动创建)
|
||||||
|
llm: 用于生成多路查询的语言模型(可选)
|
||||||
|
num_queries: 生成的查询变体数量
|
||||||
|
rerank_top_n: 最终返回的文档数量
|
||||||
|
collection_name: Qdrant 集合名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LangChain Tool 函数
|
||||||
"""
|
"""
|
||||||
pipeline = RAGPipeline(
|
pipeline = RAGPipeline(
|
||||||
retriever=retriever,
|
retriever=retriever,
|
||||||
llm=llm,
|
llm=llm,
|
||||||
num_queries=num_queries,
|
num_queries=num_queries,
|
||||||
rerank_top_n=rerank_top_n,
|
rerank_top_n=rerank_top_n,
|
||||||
|
collection_name=collection_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def search_knowledge_base_sync(query: str) -> str:
|
def search_knowledge_base_sync(query: str) -> str:
|
||||||
"""在知识库中搜索与查询相关的文档片段(同步版本)。
|
"""
|
||||||
|
在知识库中搜索与查询相关的文档片段。
|
||||||
功能与异步版本相同:多路改写 → RRF融合 → 重排序 → 返回父文档。
|
|
||||||
|
使用混合检索(稠密向量语义 + BM25 关键词)+ 父子文档模式,
|
||||||
|
检索效果最优。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: 用户提出的问题或查询字符串
|
query: 用户提出的问题或查询字符串
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
格式化后的相关文档内容。
|
格式化后的相关文档内容
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
documents = pipeline.retrieve(query) # 内部调用异步方法并等待
|
documents = pipeline.retrieve(query)
|
||||||
if not documents:
|
if not documents:
|
||||||
return f"在知识库 '{collection_name}' 中未找到与 '{query}' 相关的信息。"
|
return f"在知识库 '{collection_name}' 中未找到与 '{query}' 相关的信息。"
|
||||||
|
|
||||||
context = pipeline.format_context(documents)
|
context = pipeline.format_context(documents)
|
||||||
return context
|
return context
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return f"检索过程中发生错误: {str(e)}"
|
return f"检索过程中发生错误: {str(e)}"
|
||||||
|
|
||||||
|
return search_knowledge_base_sync
|
||||||
|
|
||||||
return search_knowledge_base_sync
|
|
||||||
|
def create_rag_tool(
|
||||||
|
collection_name: str = "rag_documents",
|
||||||
|
llm: Optional[BaseLanguageModel] = None,
|
||||||
|
num_queries: int = 3,
|
||||||
|
rerank_top_n: int = 5,
|
||||||
|
) -> Callable:
|
||||||
|
"""
|
||||||
|
创建 RAG 检索工具的便捷函数(同步版本)。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
collection_name: Qdrant 集合名称
|
||||||
|
llm: 用于生成多路查询的语言模型(可选)
|
||||||
|
num_queries: 生成的查询变体数量
|
||||||
|
rerank_top_n: 最终返回的文档数量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LangChain Tool 函数
|
||||||
|
"""
|
||||||
|
return create_rag_tool_sync(
|
||||||
|
collection_name=collection_name,
|
||||||
|
llm=llm,
|
||||||
|
num_queries=num_queries,
|
||||||
|
rerank_top_n=rerank_top_n,
|
||||||
|
)
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ RAG Core - 公共 RAG 组件包
|
|||||||
|
|
||||||
from .embedders import LlamaCppEmbedder
|
from .embedders import LlamaCppEmbedder
|
||||||
from .vector_store import QdrantVectorStore
|
from .vector_store import QdrantVectorStore
|
||||||
|
from .sparse_embedder import BM25SparseEmbedder, get_sparse_embedder
|
||||||
from .store import PostgresDocStore, create_docstore
|
from .store import PostgresDocStore, create_docstore
|
||||||
from .retriever_factory import create_parent_retriever
|
from .retriever_factory import create_parent_retriever
|
||||||
from .config import (
|
from .config import (
|
||||||
@@ -21,6 +22,8 @@ from .config import (
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"LlamaCppEmbedder",
|
"LlamaCppEmbedder",
|
||||||
"QdrantVectorStore",
|
"QdrantVectorStore",
|
||||||
|
"BM25SparseEmbedder",
|
||||||
|
"get_sparse_embedder",
|
||||||
"QDRANT_URL",
|
"QDRANT_URL",
|
||||||
"QDRANT_API_KEY",
|
"QDRANT_API_KEY",
|
||||||
"LLAMACPP_EMBEDDING_URL",
|
"LLAMACPP_EMBEDDING_URL",
|
||||||
|
|||||||
@@ -1,5 +1,14 @@
|
|||||||
# rag_core/retriever_factory.py
|
"""
|
||||||
|
RAG 检索器工厂模块
|
||||||
|
|
||||||
|
提供创建各种检索器的工厂函数,包括:
|
||||||
|
- 基础向量检索器
|
||||||
|
- ParentDocumentRetriever(父子文档)
|
||||||
|
- 混合检索器(稠密+稀疏)
|
||||||
|
"""
|
||||||
|
from typing import Optional
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
|
from langchain_core.retrievers import BaseRetriever
|
||||||
from langchain_classic.retrievers import ParentDocumentRetriever
|
from langchain_classic.retrievers import ParentDocumentRetriever
|
||||||
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
|
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
|
||||||
from langchain_core.stores import BaseStore
|
from langchain_core.stores import BaseStore
|
||||||
@@ -9,18 +18,18 @@ from rag_core import LlamaCppEmbedder, QdrantVectorStore, create_docstore
|
|||||||
|
|
||||||
def create_parent_retriever(
|
def create_parent_retriever(
|
||||||
collection_name: str = "rag_documents",
|
collection_name: str = "rag_documents",
|
||||||
parent_splitter: TextSplitter | None = None,
|
parent_splitter: Optional[TextSplitter] = None,
|
||||||
child_splitter: TextSplitter | None = None,
|
child_splitter: Optional[TextSplitter] = None,
|
||||||
docstore: BaseStore | None = None,
|
docstore: Optional[BaseStore] = None,
|
||||||
search_k: int = 5,
|
search_k: int = 5,
|
||||||
parent_chunk_size: int = 1000,
|
parent_chunk_size: int = 1000,
|
||||||
parent_chunk_overlap: int = 100,
|
parent_chunk_overlap: int = 100,
|
||||||
child_chunk_size: int = 200,
|
child_chunk_size: int = 200,
|
||||||
child_chunk_overlap: int = 20,
|
child_chunk_overlap: int = 20,
|
||||||
embeddings: Embeddings | None = None,
|
embeddings: Optional[Embeddings] = None,
|
||||||
) -> ParentDocumentRetriever:
|
) -> ParentDocumentRetriever:
|
||||||
"""
|
"""
|
||||||
创建 ParentDocumentRetriever 实例。
|
创建 ParentDocumentRetriever 实例(基础稠密向量版本)。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
collection_name: Qdrant 集合名称,默认 "rag_documents"
|
collection_name: Qdrant 集合名称,默认 "rag_documents"
|
||||||
@@ -44,7 +53,7 @@ def create_parent_retriever(
|
|||||||
|
|
||||||
# 向量存储(只读)
|
# 向量存储(只读)
|
||||||
vector_store = QdrantVectorStore(collection_name=collection_name, embeddings=embeddings)
|
vector_store = QdrantVectorStore(collection_name=collection_name, embeddings=embeddings)
|
||||||
|
|
||||||
# 切分器(若未提供则创建默认)
|
# 切分器(若未提供则创建默认)
|
||||||
if parent_splitter is None:
|
if parent_splitter is None:
|
||||||
parent_splitter = RecursiveCharacterTextSplitter(
|
parent_splitter = RecursiveCharacterTextSplitter(
|
||||||
@@ -56,11 +65,11 @@ def create_parent_retriever(
|
|||||||
chunk_size=child_chunk_size,
|
chunk_size=child_chunk_size,
|
||||||
chunk_overlap=child_chunk_overlap,
|
chunk_overlap=child_chunk_overlap,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 文档存储
|
# 文档存储
|
||||||
if docstore is None:
|
if docstore is None:
|
||||||
docstore, _ = create_docstore()
|
docstore, _ = create_docstore()
|
||||||
|
|
||||||
return ParentDocumentRetriever(
|
return ParentDocumentRetriever(
|
||||||
vectorstore=vector_store.get_langchain_vectorstore(),
|
vectorstore=vector_store.get_langchain_vectorstore(),
|
||||||
docstore=docstore,
|
docstore=docstore,
|
||||||
@@ -68,3 +77,34 @@ def create_parent_retriever(
|
|||||||
parent_splitter=parent_splitter,
|
parent_splitter=parent_splitter,
|
||||||
search_kwargs={"k": search_k},
|
search_kwargs={"k": search_k},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_hybrid_retriever_factory(
|
||||||
|
collection_name: str = "rag_documents",
|
||||||
|
search_k: int = 5,
|
||||||
|
embeddings: Optional[Embeddings] = None,
|
||||||
|
) -> BaseRetriever:
|
||||||
|
"""
|
||||||
|
【不完整,仅占位】创建混合检索器的工厂函数占位符。
|
||||||
|
|
||||||
|
注意:完整的混合检索逻辑在 app/rag/retriever.py 中实现。
|
||||||
|
这里仅返回 QdrantVectorStore 作为基础。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
collection_name: Qdrant 集合名称
|
||||||
|
search_k: 检索返回结果数
|
||||||
|
embeddings: 嵌入模型实例
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
基础的 QdrantVectorStore(仅稠密检索)
|
||||||
|
"""
|
||||||
|
# 嵌入模型
|
||||||
|
if embeddings is None:
|
||||||
|
embedder = LlamaCppEmbedder()
|
||||||
|
embeddings = embedder.as_langchain_embeddings()
|
||||||
|
|
||||||
|
# 创建向量存储
|
||||||
|
vector_store = QdrantVectorStore(collection_name=collection_name, embeddings=embeddings)
|
||||||
|
|
||||||
|
# 返回 LangChain 兼容的 retriever
|
||||||
|
return vector_store.get_langchain_vectorstore().as_retriever(search_kwargs={"k": search_k})
|
||||||
|
|||||||
34
backend/rag_core/sparse_embedder.py
Normal file
34
backend/rag_core/sparse_embedder.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
"""
|
||||||
|
BM25 稀疏嵌入器
|
||||||
|
基于 FastEmbed 的 Qdrant/bm25 模型,完全离线运行
|
||||||
|
"""
|
||||||
|
from typing import List
|
||||||
|
from fastembed.sparse.sparse_text_embedding import SparseTextEmbedding
|
||||||
|
from app.config import FASTEMBED_CACHE_PATH
|
||||||
|
|
||||||
|
class BM25SparseEmbedder:
|
||||||
|
"""BM25 稀疏嵌入包装器,与现有嵌入器风格统一"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.model = SparseTextEmbedding(
|
||||||
|
model_name="Qdrant/bm25",
|
||||||
|
cache_dir=FASTEMBED_CACHE_PATH,
|
||||||
|
local_files_only=True, # 强制离线,永不联网
|
||||||
|
)
|
||||||
|
|
||||||
|
def embed_documents(self, texts: List[str]) -> List[dict]:
|
||||||
|
"""返回稀疏向量列表,每个为 Qdrant 兼容的 dict(indices+values)"""
|
||||||
|
return [vec.as_object() for vec in self.model.embed(texts)]
|
||||||
|
|
||||||
|
def embed_query(self, text: str) -> dict:
|
||||||
|
"""返回单个稀疏向量"""
|
||||||
|
return list(self.model.embed([text]))[0].as_object()
|
||||||
|
|
||||||
|
# 全局单例
|
||||||
|
_sparse_embedder_instance = None
|
||||||
|
|
||||||
|
def get_sparse_embedder() -> BM25SparseEmbedder:
|
||||||
|
global _sparse_embedder_instance
|
||||||
|
if _sparse_embedder_instance is None:
|
||||||
|
_sparse_embedder_instance = BM25SparseEmbedder()
|
||||||
|
return _sparse_embedder_instance
|
||||||
@@ -1,41 +1,48 @@
|
|||||||
"""
|
"""
|
||||||
Qdrant 向量数据库包装器。
|
Qdrant 向量数据库包装器。
|
||||||
|
支持稠密+稀疏双向量存储。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
import uuid
|
||||||
from typing import List, Optional, Dict, Any
|
from typing import List, Optional, Dict, Any
|
||||||
|
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_qdrant import QdrantVectorStore as LangchainQdrantVS
|
from langchain_qdrant import QdrantVectorStore as LangchainQdrantVS
|
||||||
from qdrant_client import QdrantClient
|
from qdrant_client import QdrantClient
|
||||||
from qdrant_client.http.models import Distance, VectorParams
|
from qdrant_client.http.models import (
|
||||||
|
Distance, VectorParams, SparseVectorParams, SparseIndexParams,
|
||||||
|
SparseIndexType, PointStruct, NamedSparseVector, NamedVector
|
||||||
|
)
|
||||||
from httpx import RemoteProtocolError
|
from httpx import RemoteProtocolError
|
||||||
from qdrant_client.http.exceptions import ResponseHandlingException
|
from qdrant_client.http.exceptions import ResponseHandlingException
|
||||||
|
|
||||||
from .client import create_qdrant_client
|
from .client import create_qdrant_client
|
||||||
from .embedders import LlamaCppEmbedder
|
from .embedders import LlamaCppEmbedder
|
||||||
|
from .sparse_embedder import BM25SparseEmbedder, get_sparse_embedder
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class QdrantVectorStore:
|
class QdrantVectorStore:
|
||||||
"""Qdrant 向量数据库操作包装器。"""
|
"""Qdrant 向量数据库操作包装器 - 支持稠密+稀疏双向量存储。"""
|
||||||
|
|
||||||
def __init__(self, collection_name: str, embeddings: Optional[Embeddings] = None):
|
def __init__(self, collection_name: str, embeddings: Optional[Embeddings] = None, sparse_embedder: Optional[BM25SparseEmbedder] = None):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
collection_name: Qdrant 集合名称。
|
collection_name: Qdrant 集合名称。
|
||||||
embeddings: 嵌入模型实例,默认 None(使用内部默认的 LlamaCppEmbedder)。
|
embeddings: 嵌入模型实例,默认 None(使用内部默认的 LlamaCppEmbedder)。
|
||||||
|
sparse_embedder: 稀疏嵌入模型实例,默认 None(自动加载BM25)。
|
||||||
"""
|
"""
|
||||||
self.collection_name = collection_name
|
self.collection_name = collection_name
|
||||||
self._client: Optional[QdrantClient] = None
|
self._client: Optional[QdrantClient] = None
|
||||||
self._connection_attempts = 0
|
self._connection_attempts = 0
|
||||||
self._last_connection_time: Optional[float] = None
|
self._last_connection_time: Optional[float] = None
|
||||||
|
|
||||||
# 嵌入模型
|
# 稠密嵌入模型
|
||||||
if embeddings is None:
|
if embeddings is None:
|
||||||
embedder = LlamaCppEmbedder()
|
embedder = LlamaCppEmbedder()
|
||||||
self.embeddings = embedder.as_langchain_embeddings()
|
self.embeddings = embedder.as_langchain_embeddings()
|
||||||
@@ -43,9 +50,13 @@ class QdrantVectorStore:
|
|||||||
else:
|
else:
|
||||||
self.embeddings = embeddings
|
self.embeddings = embeddings
|
||||||
self._embedder = None
|
self._embedder = None
|
||||||
|
|
||||||
|
# 稀疏嵌入模型
|
||||||
|
self.sparse_embedder = sparse_embedder or get_sparse_embedder()
|
||||||
|
|
||||||
self.create_collection()
|
self.create_collection()
|
||||||
|
|
||||||
|
# 保留 LangChain 向量存储实例(用于兼容)
|
||||||
self.vector_store = LangchainQdrantVS(
|
self.vector_store = LangchainQdrantVS(
|
||||||
client=self.get_client(),
|
client=self.get_client(),
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
@@ -97,7 +108,7 @@ class QdrantVectorStore:
|
|||||||
}
|
}
|
||||||
|
|
||||||
def create_collection(self, force_recreate: bool = False):
|
def create_collection(self, force_recreate: bool = False):
|
||||||
"""创建集合,设置合适的向量维度。"""
|
"""创建集合,支持稠密+稀疏双向量。"""
|
||||||
if self._embedder is not None:
|
if self._embedder is not None:
|
||||||
# 使用内部的 embedder 获取维度
|
# 使用内部的 embedder 获取维度
|
||||||
vector_size = self._embedder.get_embedding_dimension()
|
vector_size = self._embedder.get_embedding_dimension()
|
||||||
@@ -119,11 +130,31 @@ class QdrantVectorStore:
|
|||||||
exists = False
|
exists = False
|
||||||
|
|
||||||
if not exists:
|
if not exists:
|
||||||
|
# 向量配置:稠密向量
|
||||||
|
vectors_config = {
|
||||||
|
"dense": VectorParams(
|
||||||
|
size=vector_size,
|
||||||
|
distance=Distance.COSINE,
|
||||||
|
optional=True
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
# 稀疏向量配置
|
||||||
|
sparse_vectors_config = {
|
||||||
|
"sparse": SparseVectorParams(
|
||||||
|
index=SparseIndexParams(
|
||||||
|
type=SparseIndexType.MUTABLE
|
||||||
|
),
|
||||||
|
optional=True
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
client.create_collection(
|
client.create_collection(
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE),
|
vectors_config=vectors_config,
|
||||||
|
sparse_vectors_config=sparse_vectors_config
|
||||||
)
|
)
|
||||||
logger.info("集合 '%s' 已创建(维度=%d)", self.collection_name, vector_size)
|
logger.info("集合 '%s' 已创建(维度=%d,支持稠密+稀疏双向量)", self.collection_name, vector_size)
|
||||||
else:
|
else:
|
||||||
logger.info("集合 '%s' 已存在", self.collection_name)
|
logger.info("集合 '%s' 已存在", self.collection_name)
|
||||||
return
|
return
|
||||||
@@ -142,18 +173,54 @@ class QdrantVectorStore:
|
|||||||
time.sleep(wait_time)
|
time.sleep(wait_time)
|
||||||
|
|
||||||
def add_documents(self, documents: List[Document], batch_size: int = 100):
|
def add_documents(self, documents: List[Document], batch_size: int = 100):
|
||||||
"""将文档添加到向量数据库。"""
|
"""将文档添加到向量数据库,自动生成稠密+稀疏双向量。"""
|
||||||
if not documents:
|
if not documents:
|
||||||
return []
|
return []
|
||||||
self.create_collection()
|
self.create_collection()
|
||||||
ids = self.vector_store.add_documents(documents, batch_size=batch_size)
|
client = self.get_client()
|
||||||
logger.info("已向 '%s' 添加 %d 个文档", self.collection_name, len(ids))
|
doc_ids = []
|
||||||
return ids
|
|
||||||
|
# 分批处理
|
||||||
|
for i in range(0, len(documents), batch_size):
|
||||||
|
batch_docs = documents[i:i+batch_size]
|
||||||
|
texts = [doc.page_content for doc in batch_docs]
|
||||||
|
|
||||||
|
# 生成双向量
|
||||||
|
dense_vectors = self.embeddings.embed_documents(texts)
|
||||||
|
sparse_vectors = self.sparse_embedder.embed_documents(texts)
|
||||||
|
|
||||||
|
points = []
|
||||||
|
for j, doc in enumerate(batch_docs):
|
||||||
|
point_id = doc.metadata.get("id", str(uuid.uuid4()))
|
||||||
|
doc_ids.append(point_id)
|
||||||
|
|
||||||
|
# 构造双向量
|
||||||
|
named_vectors = {
|
||||||
|
"dense": dense_vectors[j],
|
||||||
|
"sparse": NamedSparseVector(
|
||||||
|
name="sparse",
|
||||||
|
vector=sparse_vectors[j]
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
points.append(PointStruct(
|
||||||
|
id=point_id,
|
||||||
|
vector=named_vectors,
|
||||||
|
payload={"text": doc.page_content, **doc.metadata}
|
||||||
|
))
|
||||||
|
|
||||||
|
# 批量插入
|
||||||
|
client.upsert(collection_name=self.collection_name, points=points)
|
||||||
|
logger.info("已向 '%s' 添加 %d 个文档(稠密+稀疏双向量)", self.collection_name, len(points))
|
||||||
|
|
||||||
|
return doc_ids
|
||||||
|
|
||||||
def similarity_search(self, query: str, k: int = 5) -> List[Document]:
|
def similarity_search(self, query: str, k: int = 5) -> List[Document]:
|
||||||
|
"""基础稠密向量检索(兼容原有接口)。"""
|
||||||
return self.vector_store.similarity_search(query, k=k)
|
return self.vector_store.similarity_search(query, k=k)
|
||||||
|
|
||||||
def similarity_search_with_score(self, query: str, k: int = 5) -> List[tuple[Document, float]]:
|
def similarity_search_with_score(self, query: str, k: int = 5) -> List[tuple[Document, float]]:
|
||||||
|
"""基础稠密向量检索带分数(兼容原有接口)。"""
|
||||||
return self.vector_store.similarity_search_with_score(query, k=k)
|
return self.vector_store.similarity_search_with_score(query, k=k)
|
||||||
|
|
||||||
def delete_collection(self):
|
def delete_collection(self):
|
||||||
@@ -183,5 +250,5 @@ class QdrantVectorStore:
|
|||||||
return self.vector_store
|
return self.vector_store
|
||||||
|
|
||||||
def get_qdrant_client(self):
|
def get_qdrant_client(self):
|
||||||
"""返回原生 Qdrant 客户端(如需手动管理 collection)"""
|
"""返回原生 Qdrant 客户端(用于自定义检索逻辑)"""
|
||||||
return self.get_client()
|
return self.get_client()
|
||||||
|
|||||||
@@ -1,53 +1,52 @@
|
|||||||
# Core
|
typing-extensions>=4.15.0
|
||||||
pydantic==2.12.5
|
python-dotenv>=1.2.2
|
||||||
python-dotenv==1.2.2
|
pydantic>=2.12.5
|
||||||
typing-extensions==4.15.0
|
requests>=2.32.5
|
||||||
|
|
||||||
# LangChain
|
# LangChain
|
||||||
langchain==1.2.15
|
langchain>=1.2.15
|
||||||
langchain-community==0.4.1
|
langchain-community>=0.4.1
|
||||||
langchain-core==1.2.28
|
langchain-core>=1.2.28
|
||||||
langchain-openai==1.1.12
|
langchain-openai>=1.1.12
|
||||||
langchain-qdrant==1.1.0
|
langchain-qdrant>=1.1.0
|
||||||
langgraph==1.1.6
|
langgraph>=1.1.6
|
||||||
langgraph-checkpoint-postgres==3.0.5
|
langgraph-checkpoint-postgres>=3.0.5
|
||||||
tiktoken>=0.12.0
|
tiktoken>=0.12.0
|
||||||
|
|
||||||
# Zhipu AI
|
# Zhipu AI
|
||||||
zhipuai==2.0.1
|
zhipuai>=2.0.1
|
||||||
|
|
||||||
# Vector DB
|
# Vector DB
|
||||||
qdrant-client==1.17.1
|
qdrant-client>=1.17.1
|
||||||
fastembed>=0.3.0 # 用于 Qdrant BM25 稀疏向量
|
fastembed>=0.3.0 # 用于 Qdrant BM25 稀疏向量
|
||||||
|
|
||||||
# Memory
|
# Memory
|
||||||
mem0ai==1.0.11
|
mem0ai>=1.0.11
|
||||||
|
|
||||||
# Backend
|
# Backend
|
||||||
fastapi==0.135.3
|
fastapi>=0.135.3
|
||||||
uvicorn[standard]==0.44.0
|
uvicorn[standard]>=0.44.0
|
||||||
|
|
||||||
# Database
|
# Database
|
||||||
asyncpg==0.31.0
|
asyncpg>=0.31.0
|
||||||
psycopg[binary]==3.3.3
|
psycopg[binary]>=3.3.3
|
||||||
|
|
||||||
# HTTP
|
# HTTP
|
||||||
httpx==0.28.1
|
httpx>=0.28.1
|
||||||
aiohttp==3.13.5
|
aiohttp>=3.13.5
|
||||||
|
|
||||||
# Utilities
|
# Utilities
|
||||||
tenacity==9.1.4
|
tenacity>=9.1.4
|
||||||
rich==15.0.0
|
rich>=15.0.0
|
||||||
PyYAML==6.0.3
|
PyYAML>=6.0.3
|
||||||
numpy>=1.26.2
|
numpy>=1.26.2
|
||||||
pyjwt==2.8.0
|
pyjwt>=2.8.0
|
||||||
ddgs>=6.0.0 # 免费联网搜索(原 duckduckgo-search 已重命名)
|
ddgs>=6.0.0 # 免费联网搜索(原 duckduckgo-search 已重命名)
|
||||||
matplotlib>=3.9.0 # 可视化图表
|
matplotlib>=3.9.0 # 可视化图表
|
||||||
|
|
||||||
# Document Processing
|
# Document Processing
|
||||||
unstructured==0.22.21
|
unstructured>=0.22.21
|
||||||
pypdf==6.10.0
|
pypdf>=6.10.0
|
||||||
beautifulsoup4==4.14.3
|
beautifulsoup4>=4.14.3
|
||||||
lxml==6.1.0
|
lxml>=6.1.0
|
||||||
pandas==3.0.2 # 若需Excel保留,否则移除
|
spacy>=3.8.14 # unstructured 可能依赖
|
||||||
spacy==3.8.14 # unstructured 可能依赖
|
|
||||||
@@ -55,6 +55,7 @@ ENV ENABLE_GRAPH_TRACE=false
|
|||||||
# =============================================================================
|
# =============================================================================
|
||||||
ENV SPARSE_MODEL_PATH=/app/models/sparse
|
ENV SPARSE_MODEL_PATH=/app/models/sparse
|
||||||
ENV SPARSE_MODEL_NAME=Qdrant/bm25
|
ENV SPARSE_MODEL_NAME=Qdrant/bm25
|
||||||
|
ENV FASTEMBED_CACHE_PATH=/app/fastembed_cache
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# 日志配置(生产环境默认值)
|
# 日志配置(生产环境默认值)
|
||||||
@@ -88,6 +89,11 @@ COPY download_sparse_model.py .
|
|||||||
RUN python download_sparse_model.py --cache-dir /app/models/sparse --model-name Qdrant/bm25 && \
|
RUN python download_sparse_model.py --cache-dir /app/models/sparse --model-name Qdrant/bm25 && \
|
||||||
rm -f download_sparse_model.py
|
rm -f download_sparse_model.py
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# 复制预下载的BM25模型缓存(FastEmbed)
|
||||||
|
# =============================================================================
|
||||||
|
COPY models/fastembed_cache /app/fastembed_cache
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# 复制项目代码
|
# 复制项目代码
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ services:
|
|||||||
- ZHIPUAI_API_KEY=${ZHIPUAI_API_KEY:?请配置 ZHIPUAI_API_KEY(本地:.env 文件 | CI/CD:Secrets)} # ⭐ 敏感密钥配置
|
- ZHIPUAI_API_KEY=${ZHIPUAI_API_KEY:?请配置 ZHIPUAI_API_KEY(本地:.env 文件 | CI/CD:Secrets)} # ⭐ 敏感密钥配置
|
||||||
- DEEPSEEK_API_KEY=${DEEPSEEK_API_KEY:?请配置 DEEPSEEK_API_KEY(本地:.env 文件 | CI/CD:Secrets)} # ⭐ 敏感密钥配置
|
- DEEPSEEK_API_KEY=${DEEPSEEK_API_KEY:?请配置 DEEPSEEK_API_KEY(本地:.env 文件 | CI/CD:Secrets)} # ⭐ 敏感密钥配置
|
||||||
- LLAMACPP_API_KEY=${LLAMACPP_API_KEY:?请配置 LLAMACPP_API_KEY(本地:.env 文件 | CI/CD:Secrets)} # ⭐ 敏感密钥配置
|
- LLAMACPP_API_KEY=${LLAMACPP_API_KEY:?请配置 LLAMACPP_API_KEY(本地:.env 文件 | CI/CD:Secrets)} # ⭐ 敏感密钥配置
|
||||||
|
- SILICONFLOW_API_KEY=${SILICONFLOW_API_KEY:-} # 硅基流动API密钥(可选,本地服务故障时降级使用)
|
||||||
|
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
# PostgreSQL 数据库配置
|
# PostgreSQL 数据库配置
|
||||||
@@ -63,6 +64,7 @@ services:
|
|||||||
# =========================================================================
|
# =========================================================================
|
||||||
- BACKEND_PORT=8079
|
- BACKEND_PORT=8079
|
||||||
- MEMORY_SUMMARIZE_INTERVAL=${MEMORY_SUMMARIZE_INTERVAL:-10}
|
- MEMORY_SUMMARIZE_INTERVAL=${MEMORY_SUMMARIZE_INTERVAL:-10}
|
||||||
|
- FASTEMBED_CACHE_PATH=${FASTEMBED_CACHE_PATH:-/app/fastembed_cache}
|
||||||
|
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
# 前端通信地址(Docker 内部网络)
|
# 前端通信地址(Docker 内部网络)
|
||||||
|
|||||||
@@ -1,73 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
下载稀疏嵌入模型到本地目录。
|
|
||||||
仅需在开发机或构建镜像时执行一次。
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# 配置日志
|
|
||||||
logging.basicConfig(
|
|
||||||
level=logging.INFO,
|
|
||||||
format="%(asctime)s - %(levelname)s - %(message)s"
|
|
||||||
)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# 添加 backend 目录到路径
|
|
||||||
sys.path.insert(0, str(Path(__file__).parent / "backend"))
|
|
||||||
|
|
||||||
|
|
||||||
def download_model(cache_dir: str = "./models/sparse", model_name: str = "Qdrant/bm25"):
|
|
||||||
"""
|
|
||||||
下载稀疏嵌入模型到指定目录。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cache_dir: 模型缓存目录
|
|
||||||
model_name: 模型名称
|
|
||||||
"""
|
|
||||||
cache_path = Path(cache_dir)
|
|
||||||
cache_path.mkdir(parents=True, exist_ok=True)
|
|
||||||
logger.info(f"准备下载模型 {model_name} 到 {cache_path.absolute()}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
from fastembed import SparseTextEmbedding
|
|
||||||
|
|
||||||
# 下载并缓存模型
|
|
||||||
model = SparseTextEmbedding(model_name=model_name, cache_dir=str(cache_path))
|
|
||||||
logger.info(f"✅ 模型 {model_name} 下载/加载成功")
|
|
||||||
|
|
||||||
# 测试一下
|
|
||||||
test_result = model.embed(["测试文本"])
|
|
||||||
logger.info(f"✅ 模型测试成功,稀疏向量维度: {len(list(test_result)[0])}")
|
|
||||||
|
|
||||||
logger.info("✅ 所有步骤完成!")
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 模型下载失败: {e}")
|
|
||||||
import traceback
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="下载稀疏嵌入模型")
|
|
||||||
parser.add_argument(
|
|
||||||
"--cache-dir",
|
|
||||||
default="./models/sparse",
|
|
||||||
help="模型缓存目录 (默认: ./models/sparse)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--model-name",
|
|
||||||
default="Qdrant/bm25",
|
|
||||||
help="模型名称 (默认: Qdrant/bm25)"
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
success = download_model(args.cache_dir, args.model_name)
|
|
||||||
sys.exit(0 if success else 1)
|
|
||||||
@@ -41,15 +41,6 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
HAS_MODEL_SERVICES = False
|
HAS_MODEL_SERVICES = False
|
||||||
|
|
||||||
# 尝试导入稀疏模型配置(如果可用)
|
|
||||||
try:
|
|
||||||
from app.config import SPARSE_MODEL_PATH, SPARSE_MODEL_NAME
|
|
||||||
HAS_SPARSE_CONFIG = True
|
|
||||||
except ImportError:
|
|
||||||
HAS_SPARSE_CONFIG = False
|
|
||||||
SPARSE_MODEL_PATH = "./models/sparse"
|
|
||||||
SPARSE_MODEL_NAME = "Qdrant/bm25"
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# ---------- 配置数据类 ----------
|
# ---------- 配置数据类 ----------
|
||||||
@@ -112,37 +103,27 @@ class IndexBuilder:
|
|||||||
# 设置嵌入模型 - 优先使用外部提供的,然后尝试使用新服务,最后回退到原来的方式
|
# 设置嵌入模型 - 优先使用外部提供的,然后尝试使用新服务,最后回退到原来的方式
|
||||||
if embeddings is not None:
|
if embeddings is not None:
|
||||||
self.embeddings = embeddings
|
self.embeddings = embeddings
|
||||||
self.embedder = None
|
self._embedder = None
|
||||||
logger.info("使用外部提供的嵌入模型")
|
logger.info("使用外部提供的嵌入模型")
|
||||||
elif HAS_MODEL_SERVICES:
|
elif HAS_MODEL_SERVICES:
|
||||||
try:
|
try:
|
||||||
self.embeddings = get_embedding_service()
|
self.embeddings = get_embedding_service()
|
||||||
self.embedder = None
|
self._embedder = None
|
||||||
logger.info("使用 model_services 提供的嵌入服务")
|
logger.info("使用 model_services 提供的嵌入服务")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"获取嵌入服务失败,回退到 LlamaCppEmbedder: {e}")
|
logger.warning(f"获取嵌入服务失败,回退到 LlamaCppEmbedder: {e}")
|
||||||
self.embedder = LlamaCppEmbedder()
|
self._embedder = LlamaCppEmbedder()
|
||||||
self.embeddings = self.embedder.as_langchain_embeddings()
|
self.embeddings = self._embedder.as_langchain_embeddings()
|
||||||
else:
|
else:
|
||||||
self.embedder = LlamaCppEmbedder()
|
self._embedder = LlamaCppEmbedder()
|
||||||
self.embeddings = self.embedder.as_langchain_embeddings()
|
self.embeddings = self._embedder.as_langchain_embeddings()
|
||||||
|
|
||||||
# 初始化稀疏嵌入(使用本地缓存目录)
|
# 初始化向量存储(自动支持稠密+稀疏混合检索)
|
||||||
from langchain_qdrant import FastEmbedSparse, RetrievalMode
|
|
||||||
self.sparse_embeddings = FastEmbedSparse(
|
|
||||||
model_name=SPARSE_MODEL_NAME,
|
|
||||||
cache_dir=SPARSE_MODEL_PATH
|
|
||||||
)
|
|
||||||
logger.info(f"✅ FastEmbedSparse 初始化成功 (cache_dir={SPARSE_MODEL_PATH})")
|
|
||||||
|
|
||||||
# 初始化向量存储(混合检索模式)
|
|
||||||
self.vector_store = QdrantVectorStore(
|
self.vector_store = QdrantVectorStore(
|
||||||
collection_name=config.collection_name,
|
collection_name=config.collection_name,
|
||||||
embedding=self.embeddings if self.embedder is None else None,
|
embedding=self.embeddings if self._embedder is None else None
|
||||||
sparse_embedding=self.sparse_embeddings,
|
|
||||||
retrieval_mode=RetrievalMode.HYBRID,
|
|
||||||
)
|
)
|
||||||
logger.info("✅ 混合检索向量存储初始化成功")
|
logger.info("✅ 混合检索向量存储初始化成功(稠密+BM25稀疏)")
|
||||||
|
|
||||||
# 根据切分类型初始化相关组件
|
# 根据切分类型初始化相关组件
|
||||||
self._init_splitters_and_retriever()
|
self._init_splitters_and_retriever()
|
||||||
|
|||||||
@@ -1,34 +0,0 @@
|
|||||||
# RAG Indexer - 本地索引工具依赖
|
|
||||||
# 依赖 rag_core (从 ../backend/rag_core 导入)
|
|
||||||
|
|
||||||
# Core
|
|
||||||
pydantic==2.12.5
|
|
||||||
python-dotenv==1.2.2
|
|
||||||
typing-extensions==4.15.0
|
|
||||||
|
|
||||||
# LangChain (用于文档处理)
|
|
||||||
langchain==1.2.15
|
|
||||||
langchain-community==0.4.1
|
|
||||||
langchain-core==1.2.28
|
|
||||||
tiktoken>=0.12.0
|
|
||||||
|
|
||||||
# Vector DB
|
|
||||||
qdrant-client==1.17.1
|
|
||||||
fastembed>=0.3.0 # 用于 Qdrant BM25 稀疏向量
|
|
||||||
|
|
||||||
# HTTP
|
|
||||||
httpx==0.28.1
|
|
||||||
|
|
||||||
# Utilities
|
|
||||||
tenacity==9.1.4
|
|
||||||
rich==15.0.0
|
|
||||||
PyYAML==6.0.3
|
|
||||||
numpy>=1.26.2
|
|
||||||
|
|
||||||
# Document Processing
|
|
||||||
unstructured==0.22.21
|
|
||||||
pypdf==6.10.0
|
|
||||||
beautifulsoup4==4.14.3
|
|
||||||
lxml==6.1.0
|
|
||||||
pandas==3.0.2
|
|
||||||
spacy==3.8.14
|
|
||||||
@@ -1,47 +0,0 @@
|
|||||||
# Core
|
|
||||||
pydantic==2.12.5
|
|
||||||
python-dotenv==1.2.2
|
|
||||||
typing-extensions==4.15.0
|
|
||||||
|
|
||||||
# LangChain
|
|
||||||
langchain==1.2.15
|
|
||||||
langchain-community==0.4.1
|
|
||||||
langchain-core==1.2.28
|
|
||||||
langchain-openai==1.1.12
|
|
||||||
langchain-qdrant==1.1.0
|
|
||||||
langgraph==1.1.6
|
|
||||||
langgraph-checkpoint-postgres==3.0.5
|
|
||||||
tiktoken>=0.12.0
|
|
||||||
|
|
||||||
# Vector DB
|
|
||||||
qdrant-client==1.17.1
|
|
||||||
|
|
||||||
# Memory
|
|
||||||
mem0ai==1.0.11
|
|
||||||
|
|
||||||
# Backend
|
|
||||||
fastapi==0.135.3
|
|
||||||
uvicorn[standard]==0.44.0
|
|
||||||
|
|
||||||
# Database
|
|
||||||
asyncpg==0.31.0
|
|
||||||
psycopg[binary]==3.3.3
|
|
||||||
|
|
||||||
# HTTP
|
|
||||||
httpx==0.28.1
|
|
||||||
aiohttp==3.13.5
|
|
||||||
|
|
||||||
# Utilities
|
|
||||||
tenacity==9.1.4
|
|
||||||
rich==15.0.0
|
|
||||||
PyYAML==6.0.3
|
|
||||||
numpy>=1.26.2
|
|
||||||
|
|
||||||
# Document Processing
|
|
||||||
unstructured==0.22.21
|
|
||||||
pypdf==6.10.0
|
|
||||||
beautifulsoup4==4.14.3
|
|
||||||
lxml==6.1.0
|
|
||||||
pandas==3.0.2 # 若需Excel保留,否则移除
|
|
||||||
spacy==3.8.14 # unstructured 可能依赖
|
|
||||||
duckduckgo-search>=6.0.0 # 联网搜索
|
|
||||||
6
requirements.txt
Normal file
6
requirements.txt
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
# 根目录requirements - 仅本地运行零散脚本/工具使用
|
||||||
|
# 完全不与前后端requirements重叠,前后端独立运行无需安装这里的依赖
|
||||||
|
gitpython==3.1.43 # 本地git脚本工具
|
||||||
|
tqdm==4.67.0 # 本地脚本进度条
|
||||||
|
ipython==8.30.0 # 本地交互式调试
|
||||||
|
pytest==8.3.4 # 本地单元测试
|
||||||
22
tools/download_bm25.py
Normal file
22
tools/download_bm25.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
"""
|
||||||
|
BM25模型预下载脚本
|
||||||
|
执行后将模型缓存到 ./models/fastembed_cache 目录,打包进Docker镜像
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
from fastembed.sparse.sparse_text_embedding import SparseTextEmbedding
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 指定缓存目录
|
||||||
|
cache_dir = "./models/fastembed_cache"
|
||||||
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
|
|
||||||
|
print("正在下载BM25稀疏向量模型...")
|
||||||
|
model = SparseTextEmbedding(
|
||||||
|
model_name="Qdrant/bm25",
|
||||||
|
cache_dir=cache_dir
|
||||||
|
)
|
||||||
|
|
||||||
|
# 触发一次推理,确保模型文件完整下载
|
||||||
|
list(model.embed(["init trigger"]))
|
||||||
|
print(f"✅ BM25模型已成功缓存到: {cache_dir}")
|
||||||
|
print("请将该目录提交到项目仓库,打包进Docker镜像")
|
||||||
Reference in New Issue
Block a user