Files
ailine/backend/app/rag/retriever.py
root 3ae9daa01a
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m44s
导入方式修改
2026-05-05 23:17:00 +08:00

353 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Qdrant 混合检索器模块(完全异步)
提供基于 Qdrant 的混合检索Dense + Sparse功能包括
- 纯混合检索(无子父文档)
- 父子文档混合检索(先检索子文档,再返回父文档)
核心原理:
- 使用 Qdrant Universal Query API (query_points)
- 使用 Prefetch 并行检索多个源
- 使用 RRF 分数融合
"""
from typing import Dict, Any, Optional, List
from qdrant_client import AsyncQdrantClient, models
from qdrant_client.http.exceptions import UnexpectedResponse
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.retrievers import BaseRetriever
from pydantic import Field, PrivateAttr
from backend.rag_core import QdrantHybridStore, get_sparse_embedder, create_docstore
from backend.rag_core.client import create_async_qdrant_client
from ..model_services import get_embedding_service
from ..logger import info, warning, debug
# 模块级常量
DEFAULT_SEARCH_K = 20
DEFAULT_PARENT_SEARCH_K = 5
class HybridRetriever(BaseRetriever):
"""
混合检索器:稠密向量 + BM25 稀疏向量 RRF 分数融合(异步)
使用 Qdrant Universal Query API (query_points)
"""
collection_name: str = Field(description="Qdrant 集合名称")
search_k: int = Field(default=DEFAULT_SEARCH_K, description="检索返回结果数")
_vector_store: Any = PrivateAttr()
_client: Any = PrivateAttr()
_sparse_embedder: Any = PrivateAttr()
def __init__(
self,
collection_name: str,
vector_store: QdrantHybridStore,
search_k: int = DEFAULT_SEARCH_K,
):
"""
Args:
collection_name: Qdrant 集合名称
vector_store: QdrantHybridStore 实例
search_k: 检索返回结果数
"""
super().__init__(
collection_name=collection_name,
search_k=search_k
)
self._vector_store = vector_store
self._client = vector_store.get_async_qdrant_client()
self._sparse_embedder = get_sparse_embedder()
def _get_relevant_documents(
self, query: str, *, run_manager: Any = None
) -> List[Document]:
"""
同步检索(不推荐使用,仅供兼容性)
注意:在异步环境中请使用 _aget_relevant_documents 或 ainvoke
"""
import asyncio
try:
loop = asyncio.get_running_loop()
# 已有事件循环,使用 create_task
task = loop.create_task(self._aget_relevant_documents(query))
return loop.run_until_complete(task)
except RuntimeError:
# 没有事件循环,创建新的
return asyncio.run(self._aget_relevant_documents(query))
async def _aget_relevant_documents(
self, query: str, *, run_manager: Any = None
) -> List[Document]:
"""
异步混合检索相关文档
"""
# 1. 生成查询向量
dense_query = await self._vector_store.aembed_query(query)
sparse_query = self._sparse_embedder.embed_query(query)
sparse_vec = models.SparseVector(
indices=sparse_query["indices"],
values=sparse_query["values"]
)
# 2. 使用 Qdrant 的 query_points API
response = await self._client.query_points(
collection_name=self.collection_name,
prefetch=[
models.Prefetch(
query=dense_query,
using="dense",
limit=self.search_k
),
models.Prefetch(
query=sparse_vec,
using="sparse",
limit=self.search_k
)
],
query=models.FusionQuery(fusion=models.Fusion.RRF),
limit=self.search_k,
with_payload=True
)
# 3. 转换结果
results = []
for point in response.points:
doc = Document(
page_content=point.payload.pop("page_content", point.payload.pop("text", "")),
metadata=point.payload
)
results.append(doc)
debug(f"混合检索返回 {len(results)} 个文档")
return results
class ParentHybridRetriever(BaseRetriever):
"""
父子文档混合检索器(异步):
1. 先用混合检索找到相关子文档
2. 根据子文档的 parent_id 找到对应的父文档
3. 去重并返回父文档
"""
collection_name: str = Field(description="Qdrant 集合名称")
search_k: int = Field(default=DEFAULT_PARENT_SEARCH_K, description="检索返回结果数")
_vector_store: Any = PrivateAttr()
_client: Any = PrivateAttr()
_sparse_embedder: Any = PrivateAttr()
_docstore: Any = PrivateAttr()
def __init__(
self,
collection_name: str,
vector_store: QdrantHybridStore,
search_k: int = DEFAULT_PARENT_SEARCH_K,
docstore: Optional[Any] = None,
):
"""
Args:
collection_name: Qdrant 集合名称
vector_store: QdrantHybridStore 实例
search_k: 最终返回的父文档数量
docstore: 文档存储(如果父文档在 PostgreSQL可选
"""
super().__init__(
collection_name=collection_name,
search_k=search_k
)
self._vector_store = vector_store
self._client = vector_store.get_async_qdrant_client()
self._sparse_embedder = get_sparse_embedder()
self._docstore = docstore
def _get_relevant_documents(
self, query: str, *, run_manager: Any = None
) -> List[Document]:
"""
同步检索(不推荐使用,仅供兼容性)
注意:在异步环境中请使用 _aget_relevant_documents 或 ainvoke
"""
import asyncio
try:
loop = asyncio.get_running_loop()
task = loop.create_task(self._aget_relevant_documents(query))
return loop.run_until_complete(task)
except RuntimeError:
return asyncio.run(self._aget_relevant_documents(query))
async def _aget_relevant_documents(
self, query: str, *, run_manager: Any = None
) -> List[Document]:
"""
异步检索相关子文档
"""
# 1. 生成查询向量
dense_query = await self._vector_store.aembed_query(query)
sparse_query = self._sparse_embedder.embed_query(query)
sparse_vec = models.SparseVector(
indices=sparse_query["indices"],
values=sparse_query["values"]
)
# 2. 多取一些子文档,避免去重后数量不足
search_limit = self.search_k * 2
# 3. 使用 query_points API 进行混合检索
response = await self._client.query_points(
collection_name=self.collection_name,
prefetch=[
models.Prefetch(
query=dense_query,
using="dense",
limit=search_limit
),
models.Prefetch(
query=sparse_vec,
using="sparse",
limit=search_limit
)
],
query=models.FusionQuery(fusion=models.Fusion.RRF),
limit=search_limit,
with_payload=True
)
if not response.points:
debug("混合检索未找到任何文档")
return []
# 4. 构建子文档列表
child_docs = []
for point in response.points:
payload_copy = point.payload.copy()
doc = Document(
page_content=payload_copy.pop("page_content", payload_copy.pop("text", "")),
metadata={
**payload_copy,
"child_id": point.id,
"score": point.score
}
)
child_docs.append(doc)
debug(f"父子文档混合检索返回 {len(child_docs)} 个子文档")
return child_docs
def create_hybrid_retriever(
collection_name: str,
search_k: int = DEFAULT_SEARCH_K,
embeddings: Optional[Embeddings] = None,
) -> BaseRetriever:
"""
创建混合检索器(稠密向量 + BM25 稀疏向量)- 异步版本。
这是默认推荐的检索方式,效果最优。
Args:
collection_name: Qdrant 集合名称
search_k: 检索返回结果数
embeddings: 可选的嵌入模型实例。若未提供,将自动获取统一嵌入服务。
Returns:
HybridRetriever 实例
"""
if embeddings is None:
embeddings = get_embedding_service()
info("使用统一嵌入服务(本地 llama.cpp → 智谱 API 自动降级)")
vector_store = QdrantHybridStore(collection_name=collection_name)
try:
vector_store.get_client().get_collection(collection_name)
except UnexpectedResponse as e:
if e.status_code == 404:
warning(f"⚠️ Qdrant 集合 '{collection_name}' 不存在,请先创建并索引文档")
raise ValueError(f"Qdrant 集合 '{collection_name}' 不存在")
raise
info(f"✅ Qdrant 混合检索器初始化成功search_k={search_k}")
return HybridRetriever(
collection_name=collection_name,
vector_store=vector_store,
search_k=search_k
)
def create_parent_hybrid_retriever(
collection_name: str,
search_k: int = DEFAULT_PARENT_SEARCH_K,
embeddings: Optional[Embeddings] = None,
use_docstore: bool = True,
) -> BaseRetriever:
"""
创建父子文档混合检索器(默认推荐)- 异步版本。
检索流程:
1. 混合检索找到相关子文档
2. 根据 parent_id 找到对应的父文档
3. 去重并返回父文档
Args:
collection_name: Qdrant 集合名称
search_k: 最终返回的父文档数量
embeddings: 可选的嵌入模型实例
use_docstore: 是否使用 PostgreSQL docstore 存储父文档
Returns:
ParentHybridRetriever 实例
"""
if embeddings is None:
embeddings = get_embedding_service()
info("使用统一嵌入服务(本地 llama.cpp → 智谱 API 自动降级)")
vector_store = QdrantHybridStore(collection_name=collection_name)
try:
vector_store.get_client().get_collection(collection_name)
except UnexpectedResponse as e:
if e.status_code == 404:
warning(f"⚠️ Qdrant 集合 '{collection_name}' 不存在,请先创建并索引文档")
raise ValueError(f"Qdrant 集合 '{collection_name}' 不存在")
raise
docstore = None
if use_docstore:
try:
docstore, _ = create_docstore()
info("✅ 文档存储初始化成功PostgreSQL")
except Exception as e:
warning(f"⚠️ 文档存储初始化失败,将不使用 docstore: %s", e)
info(f"✅ Qdrant 父子文档混合检索器初始化成功search_k={search_k}")
return ParentHybridRetriever(
collection_name=collection_name,
vector_store=vector_store,
search_k=search_k,
docstore=docstore
)
def create_base_retriever(
collection_name: str,
search_k: int = DEFAULT_SEARCH_K,
embeddings: Optional[Embeddings] = None,
) -> BaseRetriever:
"""
创建基础检索器(向后兼容)- 实际上返回混合检索器。
"""
return create_hybrid_retriever(collection_name, search_k, embeddings)
# 别名:默认就是父子文档混合检索
create_retriever = create_parent_hybrid_retriever