Files
ailine/backend/app/rag/retriever.py
root 60afa86ded
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Has been cancelled
feat: 实现 BM25 稀疏 + 稠密向量混合检索功能
2026-05-04 02:01:22 +08:00

380 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Qdrant 混合检索器模块
提供基于 Qdrant 的混合检索Dense + Sparse功能包括
- 纯混合检索(无子父文档)
- 父子文档混合检索(先检索子文档,再返回父文档)
核心原理:
- 使用 Qdrant 原生 Fusion API (RRF) 做分数融合
- 同时使用稠密向量语义和稀疏向量BM25 关键词)
"""
from typing import Dict, Any, Optional, List
from qdrant_client import QdrantClient
from qdrant_client.http.exceptions import UnexpectedResponse
from qdrant_client.http.models import (
SearchRequest, Fusion, FusionProtocol, NamedVector, NamedSparseVector
)
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.retrievers import BaseRetriever, RetrieverOutput
from rag_core import QdrantVectorStore, get_sparse_embedder, create_docstore
from rag_core.client import create_qdrant_client as create_core_qdrant_client
from app.model_services import get_embedding_service
from app.logger import info, warning, debug
# 模块级常量
DEFAULT_SEARCH_K = 20
DEFAULT_PARENT_SEARCH_K = 5
class HybridRetriever(BaseRetriever):
"""
混合检索器:稠密向量 + BM25 稀疏向量 RRF 分数融合
直接使用 Qdrant 原生 Fusion API性能最优。
"""
def __init__(
self,
collection_name: str,
vector_store: QdrantVectorStore,
search_k: int = DEFAULT_SEARCH_K,
):
"""
Args:
collection_name: Qdrant 集合名称
vector_store: QdrantVectorStore 实例
search_k: 检索返回结果数
"""
self.collection_name = collection_name
self.vector_store = vector_store
self.search_k = search_k
self.client = vector_store.get_qdrant_client()
self.sparse_embedder = get_sparse_embedder()
def _get_relevant_documents(
self, query: str, *, run_manager: Optional[Any] = None
) -> List[Document]:
"""
同步检索相关文档
Args:
query: 查询字符串
run_manager: LangChain 运行管理器(可选)
Returns:
相关文档列表
"""
# 生成双向量
dense_query = self.vector_store.embeddings.embed_query(query)
sparse_query = self.sparse_embedder.embed_query(query)
# 构建双检索请求
searches = [
# 稠密检索
SearchRequest(
vector=NamedVector(name="dense", vector=dense_query),
limit=self.search_k,
with_payload=True
),
# 稀疏检索
SearchRequest(
vector=NamedSparseVector(name="sparse", vector=sparse_query),
limit=self.search_k,
with_payload=True
)
]
# RRF 分数融合
fused_results = self.client.fusion(
collection_name=self.collection_name,
requests=searches,
fusion=Fusion(fusion=FusionProtocol.RRF)
)
# 转换为 Document 格式
results = []
for point in fused_results.points:
doc = Document(
page_content=point.payload.pop("text", ""),
metadata=point.payload
)
results.append(doc)
debug(f"混合检索返回 {len(results)} 个文档")
return results
async def _aget_relevant_documents(
self, query: str, *, run_manager: Optional[Any] = None
) -> List[Document]:
"""异步检索(当前调用同步版本)"""
# Qdrant 客户端没有原生 async这里用同步版本
return self._get_relevant_documents(query, run_manager=run_manager)
class ParentHybridRetriever(BaseRetriever):
"""
父子文档混合检索器:
1. 先用混合检索找到相关子文档
2. 根据子文档的 parent_id 找到对应的父文档
3. 去重并返回父文档
"""
def __init__(
self,
collection_name: str,
vector_store: QdrantVectorStore,
search_k: int = DEFAULT_PARENT_SEARCH_K,
docstore: Optional[Any] = None,
):
"""
Args:
collection_name: Qdrant 集合名称
vector_store: QdrantVectorStore 实例
search_k: 最终返回的父文档数
docstore: 文档存储(如果父文档在 PostgreSQL可选
"""
self.collection_name = collection_name
self.vector_store = vector_store
self.search_k = search_k
self.client = vector_store.get_qdrant_client()
self.sparse_embedder = get_sparse_embedder()
self.docstore = docstore
def _get_relevant_documents(
self, query: str, *, run_manager: Optional[Any] = None
) -> List[Document]:
"""
同步检索相关父文档
Args:
query: 查询字符串
run_manager: LangChain 运行管理器(可选)
Returns:
相关父文档列表
"""
# 1. 生成查询双向量
dense_query = self.vector_store.embeddings.embed_query(query)
sparse_query = self.sparse_embedder.embed_query(query)
# 2. 多取一些子文档,避免去重后数量不足
search_limit = self.search_k * 2
searches = [
# 稠密检索
SearchRequest(
vector=NamedVector(name="dense", vector=dense_query),
limit=search_limit,
with_payload=True
),
# 稀疏检索
SearchRequest(
vector=NamedSparseVector(name="sparse", vector=sparse_query),
limit=search_limit,
with_payload=True
)
]
# 3. RRF 分数融合,拿到子文档命中结果
fused_results = self.client.fusion(
collection_name=self.collection_name,
requests=searches,
fusion=Fusion(fusion=FusionProtocol.RRF)
)
if not fused_results.points:
debug("混合检索未找到任何文档")
return []
# 4. 收集 parent_id 和对应最高得分
parent_score_map = {}
parent_ids = set()
child_point_map = {} # 保存子文档点用于降级
for point in fused_results.points:
parent_id = point.payload.get("parent_id", point.id)
score = point.score
# 同一个 parent_id 只保留最高得分
if parent_id not in parent_score_map or score > parent_score_map[parent_id]:
parent_score_map[parent_id] = score
parent_ids.add(parent_id)
child_point_map[parent_id] = point
# 5. 批量查询父文档
# 首先尝试从 Qdrant 直接查询(因为父文档可能也存在 Qdrant 中)
parent_docs = []
found_parent_ids = set()
try:
parent_points = self.client.retrieve(
collection_name=self.collection_name,
ids=list(parent_ids),
with_payload=True
)
# 处理找到的父文档
for point in parent_points:
doc = Document(
page_content=point.payload.pop("text", ""),
metadata=point.payload
)
parent_docs.append(doc)
found_parent_ids.add(point.id)
except Exception as e:
warning(f"从 Qdrant 查询父文档失败: {e}")
# 6. 如果有 docstore尝试从 docstore 查询剩余的父文档
if self.docstore and len(found_parent_ids) < len(parent_ids):
missing_parent_ids = parent_ids - found_parent_ids
try:
docstore_docs = self.docstore.mget(missing_parent_ids)
for doc_id, doc in zip(missing_parent_ids, docstore_docs):
if doc is not None:
parent_docs.append(doc)
found_parent_ids.add(doc_id)
except Exception as e:
warning(f"从 docstore 查询父文档失败: {e}")
# 7. 降级:对于仍未找到的父文档,用子文档本身代替
missing_parent_ids = parent_ids - found_parent_ids
if missing_parent_ids:
warning(f"以下 parent_id 未找到对应的父文档,将返回子文档本身: {missing_parent_ids}")
for parent_id in missing_parent_ids:
child_point = child_point_map.get(parent_id)
if child_point:
doc = Document(
page_content=child_point.payload.pop("text", ""),
metadata=child_point.payload
)
parent_docs.append(doc)
# 8. 按照得分降序排序,返回前 k 个
parent_docs_with_scores = [
(doc, parent_score_map.get(doc.metadata.get("id", doc.id), 0.0))
for doc in parent_docs
]
parent_docs_with_scores.sort(key=lambda x: x[1], reverse=True)
final_docs = [doc for doc, _ in parent_docs_with_scores[:self.search_k]]
debug(f"父子文档混合检索返回 {len(final_docs)} 个父文档")
return final_docs
async def _aget_relevant_documents(
self, query: str, *, run_manager: Optional[Any] = None
) -> List[Document]:
"""异步检索(当前调用同步版本)"""
return self._get_relevant_documents(query, run_manager=run_manager)
def create_hybrid_retriever(
collection_name: str,
search_k: int = DEFAULT_SEARCH_K,
embeddings: Optional[Embeddings] = None,
) -> BaseRetriever:
"""
创建混合检索器(稠密向量 + BM25 稀疏向量)。
这是默认推荐的检索方式,效果最优。
Args:
collection_name: Qdrant 集合名称
search_k: 检索返回结果数
embeddings: 可选的嵌入模型实例。若未提供,将自动获取统一嵌入服务。
Returns:
HybridRetriever 实例
"""
# 默认使用统一嵌入服务
if embeddings is None:
embeddings = get_embedding_service()
info("✅ 使用统一嵌入服务(本地 llama.cpp → 智谱 API 自动降级)")
# 创建向量存储
vector_store = QdrantVectorStore(collection_name=collection_name, embeddings=embeddings)
# 验证集合是否存在
try:
vector_store.get_client().get_collection(collection_name)
except UnexpectedResponse as e:
if e.status_code == 404:
warning(f"⚠️ Qdrant 集合 '{collection_name}' 不存在,请先创建并索引文档")
raise ValueError(f"Qdrant 集合 '{collection_name}' 不存在")
raise
info(f"✅ Qdrant 混合检索器初始化成功search_k={search_k}")
return HybridRetriever(
collection_name=collection_name,
vector_store=vector_store,
search_k=search_k
)
def create_parent_hybrid_retriever(
collection_name: str,
search_k: int = DEFAULT_PARENT_SEARCH_K,
embeddings: Optional[Embeddings] = None,
use_docstore: bool = True,
) -> BaseRetriever:
"""
创建父子文档混合检索器(默认推荐)。
检索流程:
1. 混合检索找到相关子文档
2. 根据 parent_id 找到对应的父文档
3. 去重并返回父文档
Args:
collection_name: Qdrant 集合名称
search_k: 最终返回的父文档数
embeddings: 可选的嵌入模型实例
use_docstore: 是否使用 PostgreSQL docstore 存储父文档
Returns:
ParentHybridRetriever 实例
"""
# 默认使用统一嵌入服务
if embeddings is None:
embeddings = get_embedding_service()
info("✅ 使用统一嵌入服务(本地 llama.cpp → 智谱 API 自动降级)")
# 创建向量存储
vector_store = QdrantVectorStore(collection_name=collection_name, embeddings=embeddings)
# 验证集合是否存在
try:
vector_store.get_client().get_collection(collection_name)
except UnexpectedResponse as e:
if e.status_code == 404:
warning(f"⚠️ Qdrant 集合 '{collection_name}' 不存在,请先创建并索引文档")
raise ValueError(f"Qdrant 集合 '{collection_name}' 不存在")
raise
# 创建 docstore如果需要
docstore = None
if use_docstore:
try:
docstore, _ = create_docstore()
info("✅ 文档存储初始化成功PostgreSQL")
except Exception as e:
warning(f"⚠️ 文档存储初始化失败,将不使用 docstore: {e}")
info(f"✅ Qdrant 父子文档混合检索器初始化成功search_k={search_k}")
return ParentHybridRetriever(
collection_name=collection_name,
vector_store=vector_store,
search_k=search_k,
docstore=docstore
)
# 别名:默认就是父子文档混合检索
create_retriever = create_parent_hybrid_retriever