Files
ailine/app/rag/retriever.py

139 lines
3.4 KiB
Python
Raw Normal View History

2026-04-18 16:31:48 +08:00
"""
Qdrant 向量检索器
提供基础向量检索混合检索Dense + BM25功能
"""
from typing import List, Dict, Any, Optional
2026-04-19 22:01:55 +08:00
from langchain_qdrant import QdrantVectorStore
2026-04-18 16:31:48 +08:00
from langchain.embeddings.base import Embeddings
2026-04-19 22:01:55 +08:00
# from langchain.retrievers import EnsembleRetriever
2026-04-18 16:31:48 +08:00
from qdrant_client import QdrantClient
2026-04-19 22:01:55 +08:00
from rag_core import QDRANT_URL, QDRANT_API_KEY
2026-04-18 16:31:48 +08:00
def create_qdrant_client(
url: Optional[str] = None,
api_key: Optional[str] = None,
) -> QdrantClient:
"""
创建 Qdrant 客户端
2026-04-19 22:01:55 +08:00
2026-04-18 16:31:48 +08:00
Args:
url: Qdrant 服务地址默认从环境变量 QDRANT_URL 读取
api_key: API 密钥默认从环境变量 QDRANT_API_KEY 读取
2026-04-19 22:01:55 +08:00
2026-04-18 16:31:48 +08:00
Returns:
QdrantClient 实例
"""
2026-04-19 22:01:55 +08:00
url = url or QDRANT_URL
api_key = api_key or QDRANT_API_KEY
2026-04-18 16:31:48 +08:00
client_args = {"url": url}
if api_key:
client_args["api_key"] = api_key
2026-04-19 22:01:55 +08:00
2026-04-18 16:31:48 +08:00
return QdrantClient(**client_args)
def create_base_retriever(
collection_name: str,
embeddings: Embeddings,
search_kwargs: Optional[Dict[str, Any]] = None,
client: Optional[QdrantClient] = None,
2026-04-19 22:01:55 +08:00
) -> QdrantVectorStore:
2026-04-18 16:31:48 +08:00
"""
创建基础向量检索器
2026-04-19 22:01:55 +08:00
2026-04-18 16:31:48 +08:00
Args:
collection_name: Qdrant 集合名称
embeddings: 嵌入模型
search_kwargs: 搜索参数默认 {"k": 20}
client: Qdrant 客户端如果为 None 则自动创建
2026-04-19 22:01:55 +08:00
2026-04-18 16:31:48 +08:00
Returns:
2026-04-19 22:01:55 +08:00
QdrantVectorStore 检索器实例
2026-04-18 16:31:48 +08:00
"""
2026-04-19 22:01:55 +08:00
search_kwargs = search_kwargs or {"k": 20}
# 创建 Qdrant 客户端
2026-04-18 16:31:48 +08:00
if client is None:
client = create_qdrant_client()
2026-04-19 22:01:55 +08:00
# 使用 QdrantVectorStore 创建向量存储
vector_store = QdrantVectorStore(
2026-04-18 16:31:48 +08:00
client=client,
2026-04-19 22:01:55 +08:00
collection_name=collection_name,
embedding=embeddings,
2026-04-18 16:31:48 +08:00
)
2026-04-19 22:01:55 +08:00
return vector_store.as_retriever(search_kwargs=search_kwargs)
2026-04-18 16:31:48 +08:00
def create_hybrid_retriever(
collection_name: str,
embeddings: Embeddings,
dense_k: int = 10,
sparse_k: int = 10,
client: Optional[QdrantClient] = None,
2026-04-19 22:01:55 +08:00
) -> QdrantVectorStore:
2026-04-18 16:31:48 +08:00
"""
创建混合检索器Dense Vector + BM25
2026-04-19 22:01:55 +08:00
2026-04-18 16:31:48 +08:00
Args:
collection_name: Qdrant 集合名称
embeddings: 嵌入模型
dense_k: 向量检索返回数量
sparse_k: BM25 检索返回数量
client: Qdrant 客户端
2026-04-19 22:01:55 +08:00
2026-04-18 16:31:48 +08:00
Returns:
混合检索器
"""
2026-04-19 22:01:55 +08:00
# 创建 Qdrant 客户端
2026-04-18 16:31:48 +08:00
if client is None:
client = create_qdrant_client()
2026-04-19 22:01:55 +08:00
# 使用 QdrantVectorStore 创建向量存储
vector_store = QdrantVectorStore(
2026-04-18 16:31:48 +08:00
client=client,
2026-04-19 22:01:55 +08:00
collection_name=collection_name,
embedding=embeddings,
2026-04-18 16:31:48 +08:00
)
2026-04-19 22:01:55 +08:00
2026-04-18 16:31:48 +08:00
search_kwargs = {
2026-04-19 22:01:55 +08:00
"k": dense_k + sparse_k,
"score_threshold": 0.3,
2026-04-18 16:31:48 +08:00
}
2026-04-19 22:01:55 +08:00
return vector_store.as_retriever(search_kwargs=search_kwargs)
2026-04-18 16:31:48 +08:00
2026-04-19 22:01:55 +08:00
# def create_ensemble_retriever(
# retrievers: List[Any],
# weights: Optional[List[float]] = None,
# c: int = 60,
# ) -> EnsembleRetriever:
# """
# 创建集成检索器,支持倒数排名融合 (RRF)
#
# Args:
# retrievers: 检索器列表
# weights: 检索器权重
# c: RRF 常数通常为60
#
# Returns:
# 集成检索器
# """
# if weights is None:
# weights = [1.0 / len(retrievers)] * len(retrievers)
#
# ensemble = EnsembleRetriever(
# retrievers=retrievers,
# weights=weights,
# c=c,
# search_type="rrf",
# )
#
# return ensemble