feat: 实现 BM25 稀疏 + 稠密向量混合检索功能
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Has been cancelled

This commit is contained in:
2026-05-04 02:01:22 +08:00
parent 2183c901b4
commit 60afa86ded
26 changed files with 905 additions and 656 deletions

View File

@@ -37,8 +37,9 @@ def _get_bool(key: str) -> bool | None:
# ========== 第三方 API 密钥 ==========
ZHIPUAI_API_KEY = _get_str("ZHIPUAI_API_KEY")
DEEPSEEK_API_KEY = _get_str("DEEPSEEK_API_KEY")
ZHIPUAI_API_KEY=_get_str("ZHIPUAI_API_KEY")
DEEPSEEK_API_KEY=_get_str("DEEPSEEK_API_KEY")
SILICONFLOW_API_KEY=_get_str("SILICONFLOW_API_KEY")
# ========== 智谱 API 配置 ==========
@@ -51,9 +52,16 @@ ZHIPU_RERANK_MODEL = _get_str("ZHIPU_RERANK_MODEL") or "rerank-2"
ZHIPU_API_BASE = _get_str("ZHIPU_API_BASE") or "https://open.bigmodel.cn/api/paas/v4"
# ========== 硅基流动(SiliconFlow) API 配置 ==========
# 重排模型BAAI/bge-reranker-v2-m3
SILICONFLOW_RERANK_MODEL = _get_str("SILICONFLOW_RERANK_MODEL") or "BAAI/bge-reranker-v2-m3"
SILICONFLOW_API_BASE = _get_str("SILICONFLOW_API_BASE") or "https://api.siliconflow.cn/v1"
# ========== 稀疏模型配置 ==========
SPARSE_MODEL_PATH = _get_str("SPARSE_MODEL_PATH") or "./models/sparse"
SPARSE_MODEL_NAME = _get_str("SPARSE_MODEL_NAME") or "Qdrant/bm25"
FASTEMBED_CACHE_PATH = _get_str("FASTEMBED_CACHE_PATH") or "./models/fastembed_cache"
# ========== llama.cpp 服务配置URL + API密钥 配对) ==========
# 主 LLM 服务

View File

@@ -3,11 +3,15 @@
本模块提供统一的重排模型服务获取接口,支持自动降级:
1. 优先使用本地 llama.cpp 重排服务
2. 本地服务不可用时,自动降级到智谱 API 重排服务
2. 本地服务不可用时,自动降级到硅基流动(SiliconFlow) API 重排服务
3. 硅基流动服务不可用时,自动降级到智谱 API 重排服务
4. 所有API服务不可用时自动降级到 LLM 评分重排服务
主要功能:
- LocalLlamaCppRerankProvider本地 llama.cpp 重排服务提供者
- SiliconFlowRerankProvider硅基流动 API 重排服务提供者
- ZhipuRerankProvider智谱 API 重排服务提供者
- LLMFallbackRerankProviderLLM 评分降级重排服务提供者
- get_rerank_service():获取重排服务的统一接口
注意:本模块只负责调用 rerank server不包含业务逻辑文档处理、排序、top_n
@@ -28,7 +32,10 @@ from app.config import (
LLAMACPP_API_KEY,
ZHIPUAI_API_KEY,
ZHIPU_RERANK_MODEL,
ZHIPU_API_BASE
ZHIPU_API_BASE,
SILICONFLOW_API_KEY,
SILICONFLOW_RERANK_MODEL,
SILICONFLOW_API_BASE
)
logger = logging.getLogger(__name__)
@@ -136,6 +143,53 @@ class ZhipuRerankService(BaseRerankService):
raise
class SiliconFlowRerankService(BaseRerankService):
"""
硅基流动(SiliconFlow) API 重排服务 - 纯服务层
"""
def __init__(self, model: str | None = None, api_key: str | None = None, api_base: str | None = None):
self.model = model or SILICONFLOW_RERANK_MODEL
self.api_key = api_key or SILICONFLOW_API_KEY
self.api_base = api_base or SILICONFLOW_API_BASE
def compute_scores(self, query: str, documents: List[str]) -> List[float]:
"""
调用 SiliconFlow rerank API 计算得分 - 纯 API 调用
"""
if not documents:
return []
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
base = self.api_base.rstrip("/")
payload = {
"model": self.model,
"query": query,
"documents": documents,
"return_documents": False
}
with httpx.Client(timeout=120) as client:
response = client.post(
f"{base}/rerank",
headers=headers,
json=payload,
)
response.raise_for_status()
data = response.json()
if isinstance(data, dict) and "results" in data:
results = data["results"]
results_sorted = sorted(results, key=lambda x: x["index"])
return [item["relevance_score"] for item in results_sorted]
else:
raise ValueError(f"未知的 SiliconFlow rerank API 响应格式: {data}")
class LLMFallbackRerankService(BaseRerankService):
"""
使用 LLM 作为最后的降级方案进行重排
@@ -291,18 +345,53 @@ class ZhipuRerankProvider(BaseServiceProvider[BaseRerankService]):
return self._service_instance
class SiliconFlowRerankProvider(BaseServiceProvider[BaseRerankService]):
"""
硅基流动(SiliconFlow) API 重排服务提供者
"""
def __init__(self, model: str | None = None):
super().__init__("siliconflow_rerank")
self._model = model or SILICONFLOW_RERANK_MODEL
def is_available(self) -> bool:
"""
检查 SiliconFlow API 重排服务是否可用
"""
if not SILICONFLOW_API_KEY:
logger.warning("SILICONFLOW_API_KEY 未配置")
return False
try:
service = SiliconFlowRerankService(model=self._model)
test_scores = service.compute_scores("test query", ["test document"])
logger.info("SiliconFlow 重排服务可用")
return True
except Exception as e:
logger.warning(f"SiliconFlow 重排服务不可用: {e}")
return False
def get_service(self) -> BaseRerankService:
"""
获取 SiliconFlow API 重排服务
"""
if self._service_instance is None:
self._service_instance = SiliconFlowRerankService(model=self._model)
return self._service_instance
def get_rerank_service() -> BaseRerankService:
"""
获取重排服务(带自动降级)- 纯服务层
降级链: Local llama.cpp -> Zhipu Rerank -> LLM Fallback
降级链: Local llama.cpp -> SiliconFlow Rerank -> Zhipu Rerank -> LLM Fallback
Returns:
BaseRerankService: 重排服务实例
"""
def _create_chain():
primary = LocalLlamaCppRerankProvider()
fallbacks = [ZhipuRerankProvider(), LLMFallbackRerankProvider()]
fallbacks = [SiliconFlowRerankProvider(), ZhipuRerankProvider(), LLMFallbackRerankProvider()]
return FallbackServiceChain(primary, fallbacks)
chain = SingletonServiceManager.get_or_create("rerank_service_chain", _create_chain)

View File

@@ -1,4 +1,11 @@
# rag/pipeline.py
"""
RAG 检索流水线模块
提供固定流程的 RAG 检索:
多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档
默认使用混合检索(稠密+稀疏)+ 父子文档模式。
"""
import asyncio
import os
@@ -6,61 +13,86 @@ from typing import List
from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel
from ..model_services import get_rerank_service
from .rerank import create_document_reranker
from .query_transform import MultiQueryGenerator
from .fusion import reciprocal_rank_fusion
from app.model_services import get_rerank_service
from app.rag.rerank import create_document_reranker
from app.rag.query_transform import MultiQueryGenerator
from app.rag.fusion import reciprocal_rank_fusion
from app.rag.retriever import create_parent_hybrid_retriever
class RAGPipeline:
"""
固定流程的 RAG 检索流水线:
多路改写 → 并行检索 → RRF融合 → 重排序 → 返回父文档
多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档
默认使用混合检索(稠密+BM25稀疏+ 父子文档模式。
"""
def __init__(
self,
retriever, # 基础检索器(应返回父文档,例如 ParentDocumentRetriever 实例)
llm: BaseLanguageModel,
retriever=None,
llm: Optional[BaseLanguageModel] = None,
num_queries: int = 3,
rerank_top_n: int = 5,
collection_name: str = "rag_documents",
):
"""
Args:
retriever: 基础检索器对象,需实现 ainvoke(query) 异步方法
llm: 用于生成多路查询的语言模型
num_queries: 生成的查询变体数量
rerank_top_n: 最终返回的文档数量
rerank_model: 重排序模型名称
retriever: 基础检索器对象,需实现 ainvoke(query) 异步方法
如果不提供,会自动创建默认的父子文档混合检索器。
llm: 用于生成多路查询的语言模型。
num_queries: 生成的查询变体数量
rerank_top_n: 最终返回的文档数量。
collection_name: Qdrant 集合名称(仅当 retriever 未提供时使用)。
"""
self.retriever = retriever
# 如果没有提供 retriever自动创建默认的混合检索器
if retriever is None:
self.retriever = create_parent_hybrid_retriever(
collection_name=collection_name,
search_k=rerank_top_n * 2 # 多取一些给重排序用
)
else:
self.retriever = retriever
self.llm = llm
self.num_queries = num_queries
self.rerank_top_n = rerank_top_n
# 初始化组件 - 使用统一的重排服务获取接口
self.query_generator = MultiQueryGenerator(llm=llm, num_queries=num_queries)
self.query_generator = MultiQueryGenerator(llm=llm, num_queries=num_queries) if llm else None
self.reranker = create_document_reranker()
async def aretrieve(self, query: str) -> List[Document]:
"""
异步执行完整检索流程
Args:
query: 用户查询
Returns:
检索到的相关文档列表
"""
# Step 1: 生成多路查询
queries = await self.query_generator.agenerate(query)
# 包含原始查询,确保至少有一条
if query not in queries:
queries.insert(0, query)
# 如果有 query_generator做多路改写
if self.query_generator and self.llm:
# Step 1: 生成多路查询
queries = await self.query_generator.agenerate(query)
# 包含原始查询,确保至少有一条
if query not in queries:
queries.insert(0, query)
else:
# 如果原始查询已在列表中,将其移至首位
queries.remove(query)
queries.insert(0, query)
# Step 2: 并行检索(每个查询获取文档列表)
tasks = [self.retriever.ainvoke(q) for q in queries]
doc_lists = await asyncio.gather(*tasks)
# Step 3: RRF 融合
fused_docs = reciprocal_rank_fusion(doc_lists)
else:
# 如果原始查询已在列表中,将其移至首位
queries.remove(query)
queries.insert(0, query)
# Step 2: 并行检索(每个查询获取文档列表)
tasks = [self.retriever.ainvoke(q) for q in queries]
doc_lists = await asyncio.gather(*tasks)
# Step 3: RRF 融合
fused_docs = reciprocal_rank_fusion(doc_lists)
# 没有 LLM 做查询改写,直接用原始查询检索
fused_docs = await self.retriever.ainvoke(query)
# Step 4: 重排序
try:
@@ -76,7 +108,15 @@ class RAGPipeline:
return asyncio.run(self.aretrieve(query))
def format_context(self, documents: List[Document]) -> str:
"""将文档列表格式化为上下文字符串"""
"""
将文档列表格式化为上下文字符串
Args:
documents: 文档列表
Returns:
格式化后的上下文字符串
"""
if not documents:
return ""
@@ -84,4 +124,30 @@ class RAGPipeline:
for i, doc in enumerate(documents, 1):
source = doc.metadata.get("source", "未知来源")
parts.append(f"【资料 {i}】来源:{source}\n{doc.page_content}\n---\n")
return "\n".join(parts)
return "\n".join(parts)
def create_rag_pipeline(
collection_name: str = "rag_documents",
llm: Optional[BaseLanguageModel] = None,
num_queries: int = 3,
rerank_top_n: int = 5,
) -> RAGPipeline:
"""
创建 RAG 检索流水线的便捷函数
Args:
collection_name: Qdrant 集合名称
llm: 用于生成多路查询的语言模型
num_queries: 生成的查询变体数量
rerank_top_n: 最终返回的文档数量
Returns:
RAGPipeline 实例
"""
return RAGPipeline(
llm=llm,
num_queries=num_queries,
rerank_top_n=rerank_top_n,
collection_name=collection_name
)

View File

@@ -1,170 +1,379 @@
"""
Qdrant 向量检索器模块
Qdrant 混合检索器模块
提供基于 Qdrant 的混合检索Dense + Sparse功能
提供基于 Qdrant 的混合检索Dense + Sparse功能,包括:
- 纯混合检索(无子父文档)
- 父子文档混合检索(先检索子文档,再返回父文档)
核心原理:
- 使用 Qdrant 原生混合检索langchain-qdrant 的 RetrievalMode.HYBRID
- 同时存储稠密向量和稀疏向量
- 语义理解 + 关键词匹配,效果最优
使用示例:
>>> from app.rag.retriever import create_hybrid_retriever
>>> retriever = create_hybrid_retriever(collection_name="rag_documents")
>>> docs = retriever.invoke("什么是 RAG")
- 使用 Qdrant 原生 Fusion API (RRF) 做分数融合
- 同时使用稠密向量(语义)和稀疏向量BM25 关键词)
"""
from typing import Dict, Any, Optional
from typing import Dict, Any, Optional, List
from qdrant_client import QdrantClient
from qdrant_client.http.exceptions import UnexpectedResponse
from langchain_qdrant import (
QdrantVectorStore,
RetrievalMode,
FastEmbedSparse,
from qdrant_client.http.models import (
SearchRequest, Fusion, FusionProtocol, NamedVector, NamedSparseVector
)
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.retrievers import BaseRetriever
from langchain_core.retrievers import BaseRetriever, RetrieverOutput
from rag_core import QDRANT_URL, QDRANT_API_KEY
from rag_core import QdrantVectorStore, get_sparse_embedder, create_docstore
from rag_core.client import create_qdrant_client as create_core_qdrant_client
from app.model_services import get_embedding_service
from app.config import SPARSE_MODEL_PATH, SPARSE_MODEL_NAME
from app.logger import info, warning
from app.logger import info, warning, debug
# 模块级常量
DEFAULT_SEARCH_K = 20
DEFAULT_SCORE_THRESHOLD = 0.3
DEFAULT_PARENT_SEARCH_K = 5
def create_base_retriever(
collection_name: str,
search_kwargs: Dict[str, Any] | None = None,
client: QdrantClient | None = None,
embeddings: Embeddings | None = None,
) -> BaseRetriever:
class HybridRetriever(BaseRetriever):
"""
创建基础向量检索器(仅稠密向量检索)
Args:
collection_name: Qdrant 集合名称
search_kwargs: 搜索参数
client: 可选的 Qdrant 客户端
embeddings: 可选的嵌入模型(默认使用 get_embedding_service()
Returns:
LangChain 兼容的检索器
混合检索器稠密向量 + BM25 稀疏向量 RRF 分数融合
直接使用 Qdrant 原生 Fusion API性能最优。
"""
# 默认使用统一嵌入服务(已内置降级机制)
if embeddings is None:
embeddings = get_embedding_service()
info("✅ 使用统一嵌入服务(本地 llama.cpp → 智谱 API 自动降级)")
def __init__(
self,
collection_name: str,
vector_store: QdrantVectorStore,
search_k: int = DEFAULT_SEARCH_K,
):
"""
Args:
collection_name: Qdrant 集合名称
vector_store: QdrantVectorStore 实例
search_k: 检索返回结果数
"""
self.collection_name = collection_name
self.vector_store = vector_store
self.search_k = search_k
self.client = vector_store.get_qdrant_client()
self.sparse_embedder = get_sparse_embedder()
def _get_relevant_documents(
self, query: str, *, run_manager: Optional[Any] = None
) -> List[Document]:
"""
同步检索相关文档
Args:
query: 查询字符串
run_manager: LangChain 运行管理器(可选)
Returns:
相关文档列表
"""
# 生成双向量
dense_query = self.vector_store.embeddings.embed_query(query)
sparse_query = self.sparse_embedder.embed_query(query)
# 构建双检索请求
searches = [
# 稠密检索
SearchRequest(
vector=NamedVector(name="dense", vector=dense_query),
limit=self.search_k,
with_payload=True
),
# 稀疏检索
SearchRequest(
vector=NamedSparseVector(name="sparse", vector=sparse_query),
limit=self.search_k,
with_payload=True
)
]
# RRF 分数融合
fused_results = self.client.fusion(
collection_name=self.collection_name,
requests=searches,
fusion=Fusion(fusion=FusionProtocol.RRF)
)
# 转换为 Document 格式
results = []
for point in fused_results.points:
doc = Document(
page_content=point.payload.pop("text", ""),
metadata=point.payload
)
results.append(doc)
debug(f"混合检索返回 {len(results)} 个文档")
return results
async def _aget_relevant_documents(
self, query: str, *, run_manager: Optional[Any] = None
) -> List[Document]:
"""异步检索(当前调用同步版本)"""
# Qdrant 客户端没有原生 async这里用同步版本
return self._get_relevant_documents(query, run_manager=run_manager)
# 合并默认搜索参数
merged_search_kwargs = {"k": DEFAULT_SEARCH_K}
if search_kwargs:
merged_search_kwargs.update(search_kwargs)
# 创建或复用 Qdrant 客户端
if client is None:
client = create_core_qdrant_client()
# 验证集合是否存在
try:
client.get_collection(collection_name)
except UnexpectedResponse as e:
if e.status_code == 404:
warning(f"⚠️ Qdrant 集合 '{collection_name}' 不存在,请先创建并索引文档")
raise ValueError(f"Qdrant 集合 '{collection_name}' 不存在")
raise
# 构建向量存储
vector_store = QdrantVectorStore(
client=client,
collection_name=collection_name,
embedding=embeddings,
)
return vector_store.as_retriever(search_kwargs=merged_search_kwargs)
class ParentHybridRetriever(BaseRetriever):
"""
父子文档混合检索器:
1. 先用混合检索找到相关子文档
2. 根据子文档的 parent_id 找到对应的父文档
3. 去重并返回父文档
"""
def __init__(
self,
collection_name: str,
vector_store: QdrantVectorStore,
search_k: int = DEFAULT_PARENT_SEARCH_K,
docstore: Optional[Any] = None,
):
"""
Args:
collection_name: Qdrant 集合名称
vector_store: QdrantVectorStore 实例
search_k: 最终返回的父文档数
docstore: 文档存储(如果父文档在 PostgreSQL可选
"""
self.collection_name = collection_name
self.vector_store = vector_store
self.search_k = search_k
self.client = vector_store.get_qdrant_client()
self.sparse_embedder = get_sparse_embedder()
self.docstore = docstore
def _get_relevant_documents(
self, query: str, *, run_manager: Optional[Any] = None
) -> List[Document]:
"""
同步检索相关父文档
Args:
query: 查询字符串
run_manager: LangChain 运行管理器(可选)
Returns:
相关父文档列表
"""
# 1. 生成查询双向量
dense_query = self.vector_store.embeddings.embed_query(query)
sparse_query = self.sparse_embedder.embed_query(query)
# 2. 多取一些子文档,避免去重后数量不足
search_limit = self.search_k * 2
searches = [
# 稠密检索
SearchRequest(
vector=NamedVector(name="dense", vector=dense_query),
limit=search_limit,
with_payload=True
),
# 稀疏检索
SearchRequest(
vector=NamedSparseVector(name="sparse", vector=sparse_query),
limit=search_limit,
with_payload=True
)
]
# 3. RRF 分数融合,拿到子文档命中结果
fused_results = self.client.fusion(
collection_name=self.collection_name,
requests=searches,
fusion=Fusion(fusion=FusionProtocol.RRF)
)
if not fused_results.points:
debug("混合检索未找到任何文档")
return []
# 4. 收集 parent_id 和对应最高得分
parent_score_map = {}
parent_ids = set()
child_point_map = {} # 保存子文档点用于降级
for point in fused_results.points:
parent_id = point.payload.get("parent_id", point.id)
score = point.score
# 同一个 parent_id 只保留最高得分
if parent_id not in parent_score_map or score > parent_score_map[parent_id]:
parent_score_map[parent_id] = score
parent_ids.add(parent_id)
child_point_map[parent_id] = point
# 5. 批量查询父文档
# 首先尝试从 Qdrant 直接查询(因为父文档可能也存在 Qdrant 中)
parent_docs = []
found_parent_ids = set()
try:
parent_points = self.client.retrieve(
collection_name=self.collection_name,
ids=list(parent_ids),
with_payload=True
)
# 处理找到的父文档
for point in parent_points:
doc = Document(
page_content=point.payload.pop("text", ""),
metadata=point.payload
)
parent_docs.append(doc)
found_parent_ids.add(point.id)
except Exception as e:
warning(f"从 Qdrant 查询父文档失败: {e}")
# 6. 如果有 docstore尝试从 docstore 查询剩余的父文档
if self.docstore and len(found_parent_ids) < len(parent_ids):
missing_parent_ids = parent_ids - found_parent_ids
try:
docstore_docs = self.docstore.mget(missing_parent_ids)
for doc_id, doc in zip(missing_parent_ids, docstore_docs):
if doc is not None:
parent_docs.append(doc)
found_parent_ids.add(doc_id)
except Exception as e:
warning(f"从 docstore 查询父文档失败: {e}")
# 7. 降级:对于仍未找到的父文档,用子文档本身代替
missing_parent_ids = parent_ids - found_parent_ids
if missing_parent_ids:
warning(f"以下 parent_id 未找到对应的父文档,将返回子文档本身: {missing_parent_ids}")
for parent_id in missing_parent_ids:
child_point = child_point_map.get(parent_id)
if child_point:
doc = Document(
page_content=child_point.payload.pop("text", ""),
metadata=child_point.payload
)
parent_docs.append(doc)
# 8. 按照得分降序排序,返回前 k 个
parent_docs_with_scores = [
(doc, parent_score_map.get(doc.metadata.get("id", doc.id), 0.0))
for doc in parent_docs
]
parent_docs_with_scores.sort(key=lambda x: x[1], reverse=True)
final_docs = [doc for doc, _ in parent_docs_with_scores[:self.search_k]]
debug(f"父子文档混合检索返回 {len(final_docs)} 个父文档")
return final_docs
async def _aget_relevant_documents(
self, query: str, *, run_manager: Optional[Any] = None
) -> List[Document]:
"""异步检索(当前调用同步版本)"""
return self._get_relevant_documents(query, run_manager=run_manager)
def create_hybrid_retriever(
collection_name: str,
dense_k: int = 10,
sparse_k: int = 10,
score_threshold: float | None = DEFAULT_SCORE_THRESHOLD,
client: QdrantClient | None = None,
embeddings: Embeddings | None = None,
search_k: int = DEFAULT_SEARCH_K,
embeddings: Optional[Embeddings] = None,
) -> BaseRetriever:
"""
创建混合检索器(稠密向量 + BM25 稀疏向量Qdrant 原生实现)。
创建混合检索器(稠密向量 + BM25 稀疏向量)。
这是默认推荐的检索方式,效果最优。
Args:
collection_name: Qdrant 集合名称
dense_k: 稠密向量检索返回数量,默认 10。
sparse_k: 稀疏向量检索返回数量,默认 10。
score_threshold: 相似度阈值,默认 0.3。
client: 可选的 Qdrant 客户端实例。
collection_name: Qdrant 集合名称
search_k: 检索返回结果数
embeddings: 可选的嵌入模型实例。若未提供,将自动获取统一嵌入服务。
Returns:
BaseRetriever 实例,配置了混合搜索参数。
HybridRetriever 实例
"""
total_k = dense_k + sparse_k
search_kwargs = {
"k": total_k,
"search_type": "similarity_score_threshold",
"score_threshold": score_threshold,
}
# 默认使用统一嵌入服务(已内置降级机制)
# 默认使用统一嵌入服务
if embeddings is None:
embeddings = get_embedding_service()
info("✅ 使用统一嵌入服务(本地 llama.cpp → 智谱 API 自动降级)")
# 创建或复用 Qdrant 客户端
if client is None:
client = create_core_qdrant_client()
# 创建向量存储
vector_store = QdrantVectorStore(collection_name=collection_name, embeddings=embeddings)
# 验证集合是否存在
try:
client.get_collection(collection_name)
vector_store.get_client().get_collection(collection_name)
except UnexpectedResponse as e:
if e.status_code == 404:
warning(f"⚠️ Qdrant 集合 '{collection_name}' 不存在,请先创建并索引文档")
raise ValueError(f"Qdrant 集合 '{collection_name}' 不存在")
raise
# 初始化稀疏嵌入(使用本地缓存目录)
sparse_embeddings = FastEmbedSparse(
model_name=SPARSE_MODEL_NAME,
cache_dir=SPARSE_MODEL_PATH
)
info(f"✅ FastEmbedSparse 初始化成功 (cache_dir={SPARSE_MODEL_PATH})")
# 创建混合模式的 QdrantVectorStore
vector_store = QdrantVectorStore(
client=client,
info(f"✅ Qdrant 混合检索器初始化成功search_k={search_k}")
return HybridRetriever(
collection_name=collection_name,
embedding=embeddings,
sparse_embedding=sparse_embeddings,
retrieval_mode=RetrievalMode.HYBRID,
vector_store=vector_store,
search_k=search_k
)
info(f"✅ Qdrant 原生混合检索器初始化成功 (k={total_k})")
return vector_store.as_retriever(search_kwargs=search_kwargs)
# 可选:提供异步友好的辅助函数
async def acreate_base_retriever(
def create_parent_hybrid_retriever(
collection_name: str,
search_kwargs: Dict[str, Any] | None = None,
client: QdrantClient | None = None,
search_k: int = DEFAULT_PARENT_SEARCH_K,
embeddings: Optional[Embeddings] = None,
use_docstore: bool = True,
) -> BaseRetriever:
"""
异步创建基础向量检索器(与同步版本功能相同)。
适用于需要异步初始化的场景(例如在 FastAPI 启动事件中)。
创建父子文档混合检索器(默认推荐)。
检索流程:
1. 混合检索找到相关子文档
2. 根据 parent_id 找到对应的父文档
3. 去重并返回父文档
Args:
collection_name: Qdrant 集合名称
search_k: 最终返回的父文档数
embeddings: 可选的嵌入模型实例
use_docstore: 是否使用 PostgreSQL docstore 存储父文档
Returns:
ParentHybridRetriever 实例
"""
# 由于 QdrantVectorStore 初始化本身是同步的,这里直接调用同步版本即可
return create_base_retriever(collection_name, search_kwargs, client)
# 默认使用统一嵌入服务
if embeddings is None:
embeddings = get_embedding_service()
info("✅ 使用统一嵌入服务(本地 llama.cpp → 智谱 API 自动降级)")
# 创建向量存储
vector_store = QdrantVectorStore(collection_name=collection_name, embeddings=embeddings)
# 验证集合是否存在
try:
vector_store.get_client().get_collection(collection_name)
except UnexpectedResponse as e:
if e.status_code == 404:
warning(f"⚠️ Qdrant 集合 '{collection_name}' 不存在,请先创建并索引文档")
raise ValueError(f"Qdrant 集合 '{collection_name}' 不存在")
raise
# 创建 docstore如果需要
docstore = None
if use_docstore:
try:
docstore, _ = create_docstore()
info("✅ 文档存储初始化成功PostgreSQL")
except Exception as e:
warning(f"⚠️ 文档存储初始化失败,将不使用 docstore: {e}")
info(f"✅ Qdrant 父子文档混合检索器初始化成功search_k={search_k}")
return ParentHybridRetriever(
collection_name=collection_name,
vector_store=vector_store,
search_k=search_k,
docstore=docstore
)
# 别名:默认就是父子文档混合检索
create_retriever = create_parent_hybrid_retriever

View File

@@ -3,52 +3,94 @@ RAG 工具模块
将检索功能封装为 LangChain Tool供 Agent 调用。
采用固定流水线:多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档。
默认使用混合检索(稠密+BM25稀疏+ 父子文档模式。
"""
from typing import Callable
from typing import Callable, Optional
from langchain_core.tools import tool
from langchain_core.language_models import BaseLanguageModel
from langchain_core.retrievers import BaseRetriever
from .pipeline import RAGPipeline
from app.rag.pipeline import RAGPipeline, create_rag_pipeline
def create_rag_tool_sync(
retriever: BaseRetriever,
llm: BaseLanguageModel,
retriever: Optional[BaseRetriever] = None,
llm: Optional[BaseLanguageModel] = None,
num_queries: int = 3,
rerank_top_n: int = 5,
collection_name: str = "rag_documents",
) -> Callable:
"""
创建一个配置好的 RAG 检索工具(同步版本,用于不支持异步的旧版 Agent)。
参数同 create_rag_tool
创建一个配置好的 RAG 检索工具(同步版本)。
默认使用混合检索(稠密+BM25稀疏+ 父子文档模式
Args:
retriever: 基础检索器对象(可选,不提供则自动创建)
llm: 用于生成多路查询的语言模型(可选)
num_queries: 生成的查询变体数量
rerank_top_n: 最终返回的文档数量
collection_name: Qdrant 集合名称
Returns:
LangChain Tool 函数
"""
pipeline = RAGPipeline(
retriever=retriever,
llm=llm,
num_queries=num_queries,
rerank_top_n=rerank_top_n,
collection_name=collection_name,
)
@tool
def search_knowledge_base_sync(query: str) -> str:
"""在知识库中搜索与查询相关的文档片段(同步版本)。
功能与异步版本相同:多路改写 → RRF融合 → 重排序 → 返回父文档。
"""
在知识库中搜索与查询相关的文档片段。
使用混合检索(稠密向量语义 + BM25 关键词)+ 父子文档模式,
检索效果最优。
Args:
query: 用户提出的问题或查询字符串
Returns:
格式化后的相关文档内容
格式化后的相关文档内容
"""
try:
documents = pipeline.retrieve(query) # 内部调用异步方法并等待
documents = pipeline.retrieve(query)
if not documents:
return f"在知识库 '{collection_name}' 中未找到与 '{query}' 相关的信息。"
context = pipeline.format_context(documents)
return context
except Exception as e:
return f"检索过程中发生错误: {str(e)}"
return search_knowledge_base_sync
return search_knowledge_base_sync
def create_rag_tool(
collection_name: str = "rag_documents",
llm: Optional[BaseLanguageModel] = None,
num_queries: int = 3,
rerank_top_n: int = 5,
) -> Callable:
"""
创建 RAG 检索工具的便捷函数(同步版本)。
Args:
collection_name: Qdrant 集合名称
llm: 用于生成多路查询的语言模型(可选)
num_queries: 生成的查询变体数量
rerank_top_n: 最终返回的文档数量
Returns:
LangChain Tool 函数
"""
return create_rag_tool_sync(
collection_name=collection_name,
llm=llm,
num_queries=num_queries,
rerank_top_n=rerank_top_n,
)

View File

@@ -6,6 +6,7 @@ RAG Core - 公共 RAG 组件包
from .embedders import LlamaCppEmbedder
from .vector_store import QdrantVectorStore
from .sparse_embedder import BM25SparseEmbedder, get_sparse_embedder
from .store import PostgresDocStore, create_docstore
from .retriever_factory import create_parent_retriever
from .config import (
@@ -21,6 +22,8 @@ from .config import (
__all__ = [
"LlamaCppEmbedder",
"QdrantVectorStore",
"BM25SparseEmbedder",
"get_sparse_embedder",
"QDRANT_URL",
"QDRANT_API_KEY",
"LLAMACPP_EMBEDDING_URL",

View File

@@ -1,5 +1,14 @@
# rag_core/retriever_factory.py
"""
RAG 检索器工厂模块
提供创建各种检索器的工厂函数,包括:
- 基础向量检索器
- ParentDocumentRetriever父子文档
- 混合检索器(稠密+稀疏)
"""
from typing import Optional
from langchain_core.embeddings import Embeddings
from langchain_core.retrievers import BaseRetriever
from langchain_classic.retrievers import ParentDocumentRetriever
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
from langchain_core.stores import BaseStore
@@ -9,18 +18,18 @@ from rag_core import LlamaCppEmbedder, QdrantVectorStore, create_docstore
def create_parent_retriever(
collection_name: str = "rag_documents",
parent_splitter: TextSplitter | None = None,
child_splitter: TextSplitter | None = None,
docstore: BaseStore | None = None,
parent_splitter: Optional[TextSplitter] = None,
child_splitter: Optional[TextSplitter] = None,
docstore: Optional[BaseStore] = None,
search_k: int = 5,
parent_chunk_size: int = 1000,
parent_chunk_overlap: int = 100,
child_chunk_size: int = 200,
child_chunk_overlap: int = 20,
embeddings: Embeddings | None = None,
embeddings: Optional[Embeddings] = None,
) -> ParentDocumentRetriever:
"""
创建 ParentDocumentRetriever 实例。
创建 ParentDocumentRetriever 实例(基础稠密向量版本)
Args:
collection_name: Qdrant 集合名称,默认 "rag_documents"
@@ -44,7 +53,7 @@ def create_parent_retriever(
# 向量存储(只读)
vector_store = QdrantVectorStore(collection_name=collection_name, embeddings=embeddings)
# 切分器(若未提供则创建默认)
if parent_splitter is None:
parent_splitter = RecursiveCharacterTextSplitter(
@@ -56,11 +65,11 @@ def create_parent_retriever(
chunk_size=child_chunk_size,
chunk_overlap=child_chunk_overlap,
)
# 文档存储
if docstore is None:
docstore, _ = create_docstore()
return ParentDocumentRetriever(
vectorstore=vector_store.get_langchain_vectorstore(),
docstore=docstore,
@@ -68,3 +77,34 @@ def create_parent_retriever(
parent_splitter=parent_splitter,
search_kwargs={"k": search_k},
)
def create_hybrid_retriever_factory(
collection_name: str = "rag_documents",
search_k: int = 5,
embeddings: Optional[Embeddings] = None,
) -> BaseRetriever:
"""
【不完整,仅占位】创建混合检索器的工厂函数占位符。
注意:完整的混合检索逻辑在 app/rag/retriever.py 中实现。
这里仅返回 QdrantVectorStore 作为基础。
Args:
collection_name: Qdrant 集合名称
search_k: 检索返回结果数
embeddings: 嵌入模型实例
Returns:
基础的 QdrantVectorStore仅稠密检索
"""
# 嵌入模型
if embeddings is None:
embedder = LlamaCppEmbedder()
embeddings = embedder.as_langchain_embeddings()
# 创建向量存储
vector_store = QdrantVectorStore(collection_name=collection_name, embeddings=embeddings)
# 返回 LangChain 兼容的 retriever
return vector_store.get_langchain_vectorstore().as_retriever(search_kwargs={"k": search_k})

View File

@@ -0,0 +1,34 @@
"""
BM25 稀疏嵌入器
基于 FastEmbed 的 Qdrant/bm25 模型,完全离线运行
"""
from typing import List
from fastembed.sparse.sparse_text_embedding import SparseTextEmbedding
from app.config import FASTEMBED_CACHE_PATH
class BM25SparseEmbedder:
"""BM25 稀疏嵌入包装器,与现有嵌入器风格统一"""
def __init__(self):
self.model = SparseTextEmbedding(
model_name="Qdrant/bm25",
cache_dir=FASTEMBED_CACHE_PATH,
local_files_only=True, # 强制离线,永不联网
)
def embed_documents(self, texts: List[str]) -> List[dict]:
"""返回稀疏向量列表,每个为 Qdrant 兼容的 dictindices+values"""
return [vec.as_object() for vec in self.model.embed(texts)]
def embed_query(self, text: str) -> dict:
"""返回单个稀疏向量"""
return list(self.model.embed([text]))[0].as_object()
# 全局单例
_sparse_embedder_instance = None
def get_sparse_embedder() -> BM25SparseEmbedder:
global _sparse_embedder_instance
if _sparse_embedder_instance is None:
_sparse_embedder_instance = BM25SparseEmbedder()
return _sparse_embedder_instance

View File

@@ -1,41 +1,48 @@
"""
Qdrant 向量数据库包装器。
支持稠密+稀疏双向量存储。
"""
import logging
import os
import time
import uuid
from typing import List, Optional, Dict, Any
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_qdrant import QdrantVectorStore as LangchainQdrantVS
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams
from qdrant_client.http.models import (
Distance, VectorParams, SparseVectorParams, SparseIndexParams,
SparseIndexType, PointStruct, NamedSparseVector, NamedVector
)
from httpx import RemoteProtocolError
from qdrant_client.http.exceptions import ResponseHandlingException
from .client import create_qdrant_client
from .embedders import LlamaCppEmbedder
from .sparse_embedder import BM25SparseEmbedder, get_sparse_embedder
logger = logging.getLogger(__name__)
class QdrantVectorStore:
"""Qdrant 向量数据库操作包装器。"""
"""Qdrant 向量数据库操作包装器 - 支持稠密+稀疏双向量存储"""
def __init__(self, collection_name: str, embeddings: Optional[Embeddings] = None):
def __init__(self, collection_name: str, embeddings: Optional[Embeddings] = None, sparse_embedder: Optional[BM25SparseEmbedder] = None):
"""
Args:
collection_name: Qdrant 集合名称。
embeddings: 嵌入模型实例,默认 None使用内部默认的 LlamaCppEmbedder
sparse_embedder: 稀疏嵌入模型实例,默认 None自动加载BM25
"""
self.collection_name = collection_name
self._client: Optional[QdrantClient] = None
self._connection_attempts = 0
self._last_connection_time: Optional[float] = None
# 嵌入模型
# 稠密嵌入模型
if embeddings is None:
embedder = LlamaCppEmbedder()
self.embeddings = embedder.as_langchain_embeddings()
@@ -43,9 +50,13 @@ class QdrantVectorStore:
else:
self.embeddings = embeddings
self._embedder = None
# 稀疏嵌入模型
self.sparse_embedder = sparse_embedder or get_sparse_embedder()
self.create_collection()
# 保留 LangChain 向量存储实例(用于兼容)
self.vector_store = LangchainQdrantVS(
client=self.get_client(),
collection_name=self.collection_name,
@@ -97,7 +108,7 @@ class QdrantVectorStore:
}
def create_collection(self, force_recreate: bool = False):
"""创建集合,设置合适的向量维度"""
"""创建集合,支持稠密+稀疏双向量"""
if self._embedder is not None:
# 使用内部的 embedder 获取维度
vector_size = self._embedder.get_embedding_dimension()
@@ -119,11 +130,31 @@ class QdrantVectorStore:
exists = False
if not exists:
# 向量配置:稠密向量
vectors_config = {
"dense": VectorParams(
size=vector_size,
distance=Distance.COSINE,
optional=True
)
}
# 稀疏向量配置
sparse_vectors_config = {
"sparse": SparseVectorParams(
index=SparseIndexParams(
type=SparseIndexType.MUTABLE
),
optional=True
)
}
client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE),
vectors_config=vectors_config,
sparse_vectors_config=sparse_vectors_config
)
logger.info("集合 '%s' 已创建(维度=%d", self.collection_name, vector_size)
logger.info("集合 '%s' 已创建(维度=%d,支持稠密+稀疏双向量", self.collection_name, vector_size)
else:
logger.info("集合 '%s' 已存在", self.collection_name)
return
@@ -142,18 +173,54 @@ class QdrantVectorStore:
time.sleep(wait_time)
def add_documents(self, documents: List[Document], batch_size: int = 100):
"""将文档添加到向量数据库。"""
"""将文档添加到向量数据库,自动生成稠密+稀疏双向量"""
if not documents:
return []
self.create_collection()
ids = self.vector_store.add_documents(documents, batch_size=batch_size)
logger.info("已向 '%s' 添加 %d 个文档", self.collection_name, len(ids))
return ids
client = self.get_client()
doc_ids = []
# 分批处理
for i in range(0, len(documents), batch_size):
batch_docs = documents[i:i+batch_size]
texts = [doc.page_content for doc in batch_docs]
# 生成双向量
dense_vectors = self.embeddings.embed_documents(texts)
sparse_vectors = self.sparse_embedder.embed_documents(texts)
points = []
for j, doc in enumerate(batch_docs):
point_id = doc.metadata.get("id", str(uuid.uuid4()))
doc_ids.append(point_id)
# 构造双向量
named_vectors = {
"dense": dense_vectors[j],
"sparse": NamedSparseVector(
name="sparse",
vector=sparse_vectors[j]
)
}
points.append(PointStruct(
id=point_id,
vector=named_vectors,
payload={"text": doc.page_content, **doc.metadata}
))
# 批量插入
client.upsert(collection_name=self.collection_name, points=points)
logger.info("已向 '%s' 添加 %d 个文档(稠密+稀疏双向量)", self.collection_name, len(points))
return doc_ids
def similarity_search(self, query: str, k: int = 5) -> List[Document]:
"""基础稠密向量检索(兼容原有接口)。"""
return self.vector_store.similarity_search(query, k=k)
def similarity_search_with_score(self, query: str, k: int = 5) -> List[tuple[Document, float]]:
"""基础稠密向量检索带分数(兼容原有接口)。"""
return self.vector_store.similarity_search_with_score(query, k=k)
def delete_collection(self):
@@ -183,5 +250,5 @@ class QdrantVectorStore:
return self.vector_store
def get_qdrant_client(self):
"""返回原生 Qdrant 客户端(如需手动管理 collection"""
"""返回原生 Qdrant 客户端(用于自定义检索逻辑"""
return self.get_client()

View File

@@ -1,53 +1,52 @@
# Core
pydantic==2.12.5
python-dotenv==1.2.2
typing-extensions==4.15.0
typing-extensions>=4.15.0
python-dotenv>=1.2.2
pydantic>=2.12.5
requests>=2.32.5
# LangChain
langchain==1.2.15
langchain-community==0.4.1
langchain-core==1.2.28
langchain-openai==1.1.12
langchain-qdrant==1.1.0
langgraph==1.1.6
langgraph-checkpoint-postgres==3.0.5
langchain>=1.2.15
langchain-community>=0.4.1
langchain-core>=1.2.28
langchain-openai>=1.1.12
langchain-qdrant>=1.1.0
langgraph>=1.1.6
langgraph-checkpoint-postgres>=3.0.5
tiktoken>=0.12.0
# Zhipu AI
zhipuai==2.0.1
zhipuai>=2.0.1
# Vector DB
qdrant-client==1.17.1
qdrant-client>=1.17.1
fastembed>=0.3.0 # 用于 Qdrant BM25 稀疏向量
# Memory
mem0ai==1.0.11
mem0ai>=1.0.11
# Backend
fastapi==0.135.3
uvicorn[standard]==0.44.0
fastapi>=0.135.3
uvicorn[standard]>=0.44.0
# Database
asyncpg==0.31.0
psycopg[binary]==3.3.3
asyncpg>=0.31.0
psycopg[binary]>=3.3.3
# HTTP
httpx==0.28.1
aiohttp==3.13.5
httpx>=0.28.1
aiohttp>=3.13.5
# Utilities
tenacity==9.1.4
rich==15.0.0
PyYAML==6.0.3
tenacity>=9.1.4
rich>=15.0.0
PyYAML>=6.0.3
numpy>=1.26.2
pyjwt==2.8.0
pyjwt>=2.8.0
ddgs>=6.0.0 # 免费联网搜索(原 duckduckgo-search 已重命名)
matplotlib>=3.9.0 # 可视化图表
# Document Processing
unstructured==0.22.21
pypdf==6.10.0
beautifulsoup4==4.14.3
lxml==6.1.0
pandas==3.0.2 # 若需Excel保留否则移除
spacy==3.8.14 # unstructured 可能依赖
unstructured>=0.22.21
pypdf>=6.10.0
beautifulsoup4>=4.14.3
lxml>=6.1.0
spacy>=3.8.14 # unstructured 可能依赖