Files
ailine/app/rag/retriever.py
2026-04-18 16:31:48 +08:00

144 lines
3.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Qdrant 向量检索器
提供基础向量检索、混合检索Dense + BM25功能。
"""
import os
from typing import List, Dict, Any, Optional
from langchain_qdrant import Qdrant
from langchain.embeddings.base import Embeddings
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import DocumentCompressorPipeline
from langchain.retrievers import EnsembleRetriever
from qdrant_client import QdrantClient
from qdrant_client.http import models
def create_qdrant_client(
url: Optional[str] = None,
api_key: Optional[str] = None,
) -> QdrantClient:
"""
创建 Qdrant 客户端
Args:
url: Qdrant 服务地址,默认从环境变量 QDRANT_URL 读取
api_key: API 密钥,默认从环境变量 QDRANT_API_KEY 读取
Returns:
QdrantClient 实例
"""
url = url or os.getenv("QDRANT_URL", "http://localhost:6333")
api_key = api_key or os.getenv("QDRANT_API_KEY")
client_args = {"url": url}
if api_key:
client_args["api_key"] = api_key
return QdrantClient(**client_args)
def create_base_retriever(
collection_name: str,
embeddings: Embeddings,
search_kwargs: Optional[Dict[str, Any]] = None,
client: Optional[QdrantClient] = None,
) -> Qdrant:
"""
创建基础向量检索器
Args:
collection_name: Qdrant 集合名称
embeddings: 嵌入模型
search_kwargs: 搜索参数,默认 {"k": 20}
client: Qdrant 客户端,如果为 None 则自动创建
Returns:
Qdrant 检索器实例
"""
if client is None:
client = create_qdrant_client()
search_kwargs = search_kwargs or {"k": 20}
# 创建 Qdrant 检索器
retriever = Qdrant.from_existing_collection(
embedding=embeddings,
collection_name=collection_name,
client=client,
content_payload_key="content", # 假设存储的文本字段名为 "content"
metadata_payload_key="metadata", # 元数据字段名
)
return retriever.as_retriever(search_kwargs=search_kwargs)
def create_hybrid_retriever(
collection_name: str,
embeddings: Embeddings,
dense_k: int = 10,
sparse_k: int = 10,
client: Optional[QdrantClient] = None,
) -> ContextualCompressionRetriever:
"""
创建混合检索器Dense Vector + BM25
Args:
collection_name: Qdrant 集合名称
embeddings: 嵌入模型
dense_k: 向量检索返回数量
sparse_k: BM25 检索返回数量
client: Qdrant 客户端
Returns:
混合检索器
"""
if client is None:
client = create_qdrant_client()
# 基础检索器Qdrant 支持混合检索)
base_retriever = Qdrant.from_existing_collection(
embedding=embeddings,
collection_name=collection_name,
client=client,
content_payload_key="content",
metadata_payload_key="metadata",
)
# 配置混合检索参数
search_kwargs = {
"k": dense_k + sparse_k, # 总返回数量
"score_threshold": 0.3, # 相似度阈值
}
return base_retriever.as_retriever(search_kwargs=search_kwargs)
def create_ensemble_retriever(
retrievers: List[Any],
weights: Optional[List[float]] = None,
c: int = 60,
) -> EnsembleRetriever:
"""
创建集成检索器,支持倒数排名融合 (RRF)
Args:
retrievers: 检索器列表
weights: 检索器权重
c: RRF 常数通常为60
Returns:
集成检索器
"""
if weights is None:
weights = [1.0 / len(retrievers)] * len(retrievers)
ensemble = EnsembleRetriever(
retrievers=retrievers,
weights=weights,
c=c,
search_type="rrf", # 使用倒数排名融合
)
return ensemble