144 lines
3.8 KiB
Python
144 lines
3.8 KiB
Python
"""
|
||
Qdrant 向量检索器
|
||
|
||
提供基础向量检索、混合检索(Dense + BM25)功能。
|
||
"""
|
||
|
||
import os
|
||
from typing import List, Dict, Any, Optional
|
||
from langchain_qdrant import Qdrant
|
||
from langchain.embeddings.base import Embeddings
|
||
from langchain.retrievers import ContextualCompressionRetriever
|
||
from langchain.retrievers.document_compressors import DocumentCompressorPipeline
|
||
from langchain.retrievers import EnsembleRetriever
|
||
from qdrant_client import QdrantClient
|
||
from qdrant_client.http import models
|
||
|
||
|
||
def create_qdrant_client(
|
||
url: Optional[str] = None,
|
||
api_key: Optional[str] = None,
|
||
) -> QdrantClient:
|
||
"""
|
||
创建 Qdrant 客户端
|
||
|
||
Args:
|
||
url: Qdrant 服务地址,默认从环境变量 QDRANT_URL 读取
|
||
api_key: API 密钥,默认从环境变量 QDRANT_API_KEY 读取
|
||
|
||
Returns:
|
||
QdrantClient 实例
|
||
"""
|
||
url = url or os.getenv("QDRANT_URL", "http://localhost:6333")
|
||
api_key = api_key or os.getenv("QDRANT_API_KEY")
|
||
|
||
client_args = {"url": url}
|
||
if api_key:
|
||
client_args["api_key"] = api_key
|
||
|
||
return QdrantClient(**client_args)
|
||
|
||
|
||
def create_base_retriever(
|
||
collection_name: str,
|
||
embeddings: Embeddings,
|
||
search_kwargs: Optional[Dict[str, Any]] = None,
|
||
client: Optional[QdrantClient] = None,
|
||
) -> Qdrant:
|
||
"""
|
||
创建基础向量检索器
|
||
|
||
Args:
|
||
collection_name: Qdrant 集合名称
|
||
embeddings: 嵌入模型
|
||
search_kwargs: 搜索参数,默认 {"k": 20}
|
||
client: Qdrant 客户端,如果为 None 则自动创建
|
||
|
||
Returns:
|
||
Qdrant 检索器实例
|
||
"""
|
||
if client is None:
|
||
client = create_qdrant_client()
|
||
|
||
search_kwargs = search_kwargs or {"k": 20}
|
||
|
||
# 创建 Qdrant 检索器
|
||
retriever = Qdrant.from_existing_collection(
|
||
embedding=embeddings,
|
||
collection_name=collection_name,
|
||
client=client,
|
||
content_payload_key="content", # 假设存储的文本字段名为 "content"
|
||
metadata_payload_key="metadata", # 元数据字段名
|
||
)
|
||
|
||
return retriever.as_retriever(search_kwargs=search_kwargs)
|
||
|
||
|
||
def create_hybrid_retriever(
|
||
collection_name: str,
|
||
embeddings: Embeddings,
|
||
dense_k: int = 10,
|
||
sparse_k: int = 10,
|
||
client: Optional[QdrantClient] = None,
|
||
) -> ContextualCompressionRetriever:
|
||
"""
|
||
创建混合检索器(Dense Vector + BM25)
|
||
|
||
Args:
|
||
collection_name: Qdrant 集合名称
|
||
embeddings: 嵌入模型
|
||
dense_k: 向量检索返回数量
|
||
sparse_k: BM25 检索返回数量
|
||
client: Qdrant 客户端
|
||
|
||
Returns:
|
||
混合检索器
|
||
"""
|
||
if client is None:
|
||
client = create_qdrant_client()
|
||
|
||
# 基础检索器(Qdrant 支持混合检索)
|
||
base_retriever = Qdrant.from_existing_collection(
|
||
embedding=embeddings,
|
||
collection_name=collection_name,
|
||
client=client,
|
||
content_payload_key="content",
|
||
metadata_payload_key="metadata",
|
||
)
|
||
|
||
# 配置混合检索参数
|
||
search_kwargs = {
|
||
"k": dense_k + sparse_k, # 总返回数量
|
||
"score_threshold": 0.3, # 相似度阈值
|
||
}
|
||
|
||
return base_retriever.as_retriever(search_kwargs=search_kwargs)
|
||
|
||
|
||
def create_ensemble_retriever(
|
||
retrievers: List[Any],
|
||
weights: Optional[List[float]] = None,
|
||
c: int = 60,
|
||
) -> EnsembleRetriever:
|
||
"""
|
||
创建集成检索器,支持倒数排名融合 (RRF)
|
||
|
||
Args:
|
||
retrievers: 检索器列表
|
||
weights: 检索器权重
|
||
c: RRF 常数(通常为60)
|
||
|
||
Returns:
|
||
集成检索器
|
||
"""
|
||
if weights is None:
|
||
weights = [1.0 / len(retrievers)] * len(retrievers)
|
||
|
||
ensemble = EnsembleRetriever(
|
||
retrievers=retrievers,
|
||
weights=weights,
|
||
c=c,
|
||
search_type="rrf", # 使用倒数排名融合
|
||
)
|
||
|
||
return ensemble |