Files
ailine/backend/app/rag/retriever.py
root 8af82f8f7f
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 5m4s
feat: RAG混合检索系统完整实现 + 启动脚本修复
- 实现了稠密+稀疏混合检索,使用 Qdrant 原生 RRF 融合
- 修复了 retriever.py 的 BaseRetriever 继承和稀疏向量包装问题
- 修复了 pipeline.py 的 Optional 导入问题
- 添加了稀疏 embedder 的缓存配置
- 简化了 vector_store.py,移除不必要的逻辑
- 修复了 start.sh 的 PROJECT_DIR 硬编码和端口配置问题
- 完善了 RAG 检索的测试文件
2026-05-04 02:54:37 +08:00

429 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Qdrant 混合检索器模块
提供基于 Qdrant 的混合检索Dense + Sparse功能包括
- 纯混合检索(无子父文档)
- 父子文档混合检索(先检索子文档,再返回父文档)
核心原理:
- 使用 Qdrant Universal Query API (query_points)
- 使用 Prefetch 并行检索多个源
- 使用 RRF 分数融合
"""
from typing import Dict, Any, Optional, List
from qdrant_client import QdrantClient, models
from qdrant_client.http.exceptions import UnexpectedResponse
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.retrievers import BaseRetriever, RetrieverOutput
from pydantic import Field, PrivateAttr
from rag_core import QdrantVectorStore, get_sparse_embedder, create_docstore
from rag_core.client import create_qdrant_client as create_core_qdrant_client
from app.model_services import get_embedding_service
from app.logger import info, warning, debug
# 模块级常量
DEFAULT_SEARCH_K = 20
DEFAULT_PARENT_SEARCH_K = 5
class HybridRetriever(BaseRetriever):
"""
混合检索器:稠密向量 + BM25 稀疏向量 RRF 分数融合
使用 Qdrant Universal Query API (query_points)
"""
collection_name: str = Field(description="Qdrant 集合名称")
search_k: int = Field(default=DEFAULT_SEARCH_K, description="检索返回结果数")
_vector_store: Any = PrivateAttr()
_client: Any = PrivateAttr()
_sparse_embedder: Any = PrivateAttr()
def __init__(
self,
collection_name: str,
vector_store: QdrantVectorStore,
search_k: int = DEFAULT_SEARCH_K,
):
"""
Args:
collection_name: Qdrant 集合名称
vector_store: QdrantVectorStore 实例
search_k: 检索返回结果数
"""
super().__init__(
collection_name=collection_name,
search_k=search_k
)
self._vector_store = vector_store
self._client = vector_store.get_qdrant_client()
self._sparse_embedder = get_sparse_embedder()
def _get_relevant_documents(
self, query: str, **kwargs
) -> List[Document]:
"""
同步检索相关文档
Args:
query: 查询字符串
Returns:
相关文档列表
"""
# 1. 生成双向量
dense_query = self._vector_store.embeddings.embed_query(query)
sparse_query = self._sparse_embedder.embed_query(query)
sparse_vec = models.SparseVector(
indices=sparse_query["indices"],
values=sparse_query["values"]
)
# 2. 使用官方的 query_points API推荐方式
response = self._client.query_points(
collection_name=self.collection_name,
prefetch=[ # 并行预取多个检索源
models.Prefetch(
query=dense_query,
using="dense", # 使用稠密向量进行语义搜索
limit=self.search_k
),
models.Prefetch(
query=sparse_vec,
using="sparse", # 使用稀疏向量进行关键词搜索
limit=self.search_k
)
],
query=models.FusionQuery(fusion=models.Fusion.RRF), # 指定融合算法为 RRF
limit=self.search_k, # 最终返回的结果数量
with_payload=True
)
# 3. 转换结果
results = []
for point in response.points:
doc = Document(
page_content=point.payload.pop("page_content", point.payload.pop("text", "")),
metadata=point.payload
)
results.append(doc)
debug(f"混合检索返回 {len(results)} 个文档")
return results
async def _aget_relevant_documents(
self, query: str, **kwargs
) -> List[Document]:
"""异步检索(当前调用同步版本)"""
# Qdrant 客户端没有原生 async这里用同步版本
return self._get_relevant_documents(query, **kwargs)
class ParentHybridRetriever(BaseRetriever):
"""
父子文档混合检索器:
1. 先用混合检索找到相关子文档
2. 根据子文档的 parent_id 找到对应的父文档
3. 去重并返回父文档
"""
collection_name: str = Field(description="Qdrant 集合名称")
search_k: int = Field(default=DEFAULT_PARENT_SEARCH_K, description="检索返回结果数")
_vector_store: Any = PrivateAttr()
_client: Any = PrivateAttr()
_sparse_embedder: Any = PrivateAttr()
_docstore: Any = PrivateAttr()
def __init__(
self,
collection_name: str,
vector_store: QdrantVectorStore,
search_k: int = DEFAULT_PARENT_SEARCH_K,
docstore: Optional[Any] = None,
):
"""
Args:
collection_name: Qdrant 集合名称
vector_store: QdrantVectorStore 实例
search_k: 最终返回的父文档数量
docstore: 文档存储(如果父文档在 PostgreSQL可选
"""
super().__init__(
collection_name=collection_name,
search_k=search_k
)
self._vector_store = vector_store
self._client = vector_store.get_qdrant_client()
self._sparse_embedder = get_sparse_embedder()
self._docstore = docstore
def _get_relevant_documents(
self, query: str, **kwargs
) -> List[Document]:
"""
同步检索相关父文档
Args:
query: 查询字符串
Returns:
相关父文档列表
"""
# 1. 生成查询双向量
dense_query = self._vector_store.embeddings.embed_query(query)
sparse_query = self._sparse_embedder.embed_query(query)
sparse_vec = models.SparseVector(
indices=sparse_query["indices"],
values=sparse_query["values"]
)
# 2. 多取一些子文档,避免去重后数量不足
search_limit = self.search_k * 2
# 3. 使用 query_points API 进行混合检索
response = self._client.query_points(
collection_name=self.collection_name,
prefetch=[
models.Prefetch(
query=dense_query,
using="dense",
limit=search_limit
),
models.Prefetch(
query=sparse_vec,
using="sparse",
limit=search_limit
)
],
query=models.FusionQuery(fusion=models.Fusion.RRF),
limit=search_limit,
with_payload=True
)
if not response.points:
debug("混合检索未找到任何文档")
return []
# 4. 收集 parent_id 和对应最高得分
parent_score_map = {}
parent_ids = set()
child_point_map = {} # 保存子文档点用于降级
for point in response.points:
# 先复制 payload避免修改原始对象
payload_copy = point.payload.copy()
parent_id = payload_copy.get("parent_id", point.id)
score = point.score
# 同一个 parent_id 只保留最高得分
if parent_id not in parent_score_map or score > parent_score_map[parent_id]:
parent_score_map[parent_id] = score
parent_ids.add(parent_id)
child_point_map[parent_id] = point
# 5. 批量查询父文档
# 首先尝试从 Qdrant 直接查询(因为父文档可能也在 Qdrant 中)
parent_docs = []
found_parent_ids = set()
try:
parent_points = self._client.retrieve(
collection_name=self.collection_name,
ids=list(parent_ids),
with_payload=True
)
# 处理找到的父文档
for point in parent_points:
payload_copy = point.payload.copy()
doc = Document(
page_content=payload_copy.pop("page_content", payload_copy.pop("text", "")),
metadata=payload_copy
)
parent_docs.append(doc)
found_parent_ids.add(point.id)
except Exception as e:
warning(f"从 Qdrant 查询父文档失败: {e}")
# 6. 如果有 docstore尝试从 docstore 查询剩余的父文档
if self._docstore and len(found_parent_ids) < len(parent_ids):
missing_parent_ids = parent_ids - found_parent_ids
try:
docstore_docs = self._docstore.mget(missing_parent_ids)
for doc_id, doc in zip(missing_parent_ids, docstore_docs):
if doc is not None:
parent_docs.append(doc)
found_parent_ids.add(doc_id)
except Exception as e:
warning(f"从 docstore 查询父文档失败: {e}")
# 7. 降级:对于仍未找到的父文档,用子文档本身代替
missing_parent_ids = parent_ids - found_parent_ids
if missing_parent_ids:
warning(f"以下 parent_id 未找到对应的父文档,将返回子文档本身: {missing_parent_ids}")
for parent_id in missing_parent_ids:
child_point = child_point_map.get(parent_id)
if child_point:
payload_copy = child_point.payload.copy()
doc = Document(
page_content=payload_copy.pop("page_content", payload_copy.pop("text", "")),
metadata=payload_copy
)
parent_docs.append(doc)
# 8. 按照得分降序排序,返回前 k 个
parent_docs_with_scores = [
(doc, parent_score_map.get(doc.metadata.get("id", doc.id), 0.0))
for doc in parent_docs
]
parent_docs_with_scores.sort(key=lambda x: x[1], reverse=True)
final_docs = [doc for doc, _ in parent_docs_with_scores[:self.search_k]]
debug(f"父子文档混合检索返回 {len(final_docs)} 个父文档")
return final_docs
async def _aget_relevant_documents(
self, query: str, **kwargs
) -> List[Document]:
"""异步检索(当前调用同步版本)"""
return self._get_relevant_documents(query, **kwargs)
def create_hybrid_retriever(
collection_name: str,
search_k: int = DEFAULT_SEARCH_K,
embeddings: Optional[Embeddings] = None,
) -> BaseRetriever:
"""
创建混合检索器(稠密向量 + BM25 稀疏向量)。
这是默认推荐的检索方式,效果最优。
Args:
collection_name: Qdrant 集合名称
search_k: 检索返回结果数
embeddings: 可选的嵌入模型实例。若未提供,将自动获取统一嵌入服务。
Returns:
HybridRetriever 实例
"""
# 默认使用统一嵌入服务
if embeddings is None:
embeddings = get_embedding_service()
info("✅ 使用统一嵌入服务(本地 llama.cpp → 智谱 API 自动降级)")
# 创建向量存储
vector_store = QdrantVectorStore(collection_name=collection_name, embeddings=embeddings)
# 验证集合是否存在
try:
vector_store.get_client().get_collection(collection_name)
except UnexpectedResponse as e:
if e.status_code == 404:
warning(f"⚠️ Qdrant 集合 '{collection_name}' 不存在,请先创建并索引文档")
raise ValueError(f"Qdrant 集合 '{collection_name}' 不存在")
raise
info(f"✅ Qdrant 混合检索器初始化成功search_k={search_k}")
return HybridRetriever(
collection_name=collection_name,
vector_store=vector_store,
search_k=search_k
)
def create_parent_hybrid_retriever(
collection_name: str,
search_k: int = DEFAULT_PARENT_SEARCH_K,
embeddings: Optional[Embeddings] = None,
use_docstore: bool = True,
) -> BaseRetriever:
"""
创建父子文档混合检索器(默认推荐)。
检索流程:
1. 混合检索找到相关子文档
2. 根据 parent_id 找到对应的父文档
3. 去重并返回父文档
Args:
collection_name: Qdrant 集合名称
search_k: 最终返回的父文档数量
embeddings: 可选的嵌入模型实例
use_docstore: 是否使用 PostgreSQL docstore 存储父文档
Returns:
ParentHybridRetriever 实例
"""
# 默认使用统一嵌入服务
if embeddings is None:
embeddings = get_embedding_service()
info("✅ 使用统一嵌入服务(本地 llama.cpp → 智谱 API 自动降级)")
# 创建向量存储
vector_store = QdrantVectorStore(collection_name=collection_name, embeddings=embeddings)
# 验证集合是否存在
try:
vector_store.get_client().get_collection(collection_name)
except UnexpectedResponse as e:
if e.status_code == 404:
warning(f"⚠️ Qdrant 集合 '{collection_name}' 不存在,请先创建并索引文档")
raise ValueError(f"Qdrant 集合 '{collection_name}' 不存在")
raise
# 创建 docstore如果需要
docstore = None
if use_docstore:
try:
docstore, _ = create_docstore()
info("✅ 文档存储初始化成功PostgreSQL")
except Exception as e:
warning(f"⚠️ 文档存储初始化失败,将不使用 docstore: {e}")
info(f"✅ Qdrant 父子文档混合检索器初始化成功search_k={search_k}")
return ParentHybridRetriever(
collection_name=collection_name,
vector_store=vector_store,
search_k=search_k,
docstore=docstore
)
def create_base_retriever(
collection_name: str,
search_k: int = DEFAULT_SEARCH_K,
embeddings: Optional[Embeddings] = None,
) -> BaseRetriever:
"""
创建基础稠密检索器(向后兼容)。
Args:
collection_name: Qdrant 集合名称
search_k: 检索返回结果数
embeddings: 可选的嵌入模型实例
Returns:
LangChain 的 BaseRetriever 实例
"""
# 默认使用统一嵌入服务
if embeddings is None:
embeddings = get_embedding_service()
vector_store = QdrantVectorStore(collection_name=collection_name, embeddings=embeddings)
info(f"✅ Qdrant 基础稠密检索器初始化成功search_k={search_k}")
return vector_store.as_langchain_vectorstore().as_retriever(k=search_k)
# 别名:默认就是父子文档混合检索
create_retriever = create_parent_hybrid_retriever