Files
ailine/app/rag/retriever.py
root 933d418d77
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 17m12s
检索器重构
2026-04-19 22:01:55 +08:00

139 lines
3.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Qdrant 向量检索器
提供基础向量检索、混合检索Dense + BM25功能。
"""
from typing import List, Dict, Any, Optional
from langchain_qdrant import QdrantVectorStore
from langchain.embeddings.base import Embeddings
# from langchain.retrievers import EnsembleRetriever
from qdrant_client import QdrantClient
from rag_core import QDRANT_URL, QDRANT_API_KEY
def create_qdrant_client(
url: Optional[str] = None,
api_key: Optional[str] = None,
) -> QdrantClient:
"""
创建 Qdrant 客户端
Args:
url: Qdrant 服务地址,默认从环境变量 QDRANT_URL 读取
api_key: API 密钥,默认从环境变量 QDRANT_API_KEY 读取
Returns:
QdrantClient 实例
"""
url = url or QDRANT_URL
api_key = api_key or QDRANT_API_KEY
client_args = {"url": url}
if api_key:
client_args["api_key"] = api_key
return QdrantClient(**client_args)
def create_base_retriever(
collection_name: str,
embeddings: Embeddings,
search_kwargs: Optional[Dict[str, Any]] = None,
client: Optional[QdrantClient] = None,
) -> QdrantVectorStore:
"""
创建基础向量检索器
Args:
collection_name: Qdrant 集合名称
embeddings: 嵌入模型
search_kwargs: 搜索参数,默认 {"k": 20}
client: Qdrant 客户端,如果为 None 则自动创建
Returns:
QdrantVectorStore 检索器实例
"""
search_kwargs = search_kwargs or {"k": 20}
# 创建 Qdrant 客户端
if client is None:
client = create_qdrant_client()
# 使用 QdrantVectorStore 创建向量存储
vector_store = QdrantVectorStore(
client=client,
collection_name=collection_name,
embedding=embeddings,
)
return vector_store.as_retriever(search_kwargs=search_kwargs)
def create_hybrid_retriever(
collection_name: str,
embeddings: Embeddings,
dense_k: int = 10,
sparse_k: int = 10,
client: Optional[QdrantClient] = None,
) -> QdrantVectorStore:
"""
创建混合检索器Dense Vector + BM25
Args:
collection_name: Qdrant 集合名称
embeddings: 嵌入模型
dense_k: 向量检索返回数量
sparse_k: BM25 检索返回数量
client: Qdrant 客户端
Returns:
混合检索器
"""
# 创建 Qdrant 客户端
if client is None:
client = create_qdrant_client()
# 使用 QdrantVectorStore 创建向量存储
vector_store = QdrantVectorStore(
client=client,
collection_name=collection_name,
embedding=embeddings,
)
search_kwargs = {
"k": dense_k + sparse_k,
"score_threshold": 0.3,
}
return vector_store.as_retriever(search_kwargs=search_kwargs)
# def create_ensemble_retriever(
# retrievers: List[Any],
# weights: Optional[List[float]] = None,
# c: int = 60,
# ) -> EnsembleRetriever:
# """
# 创建集成检索器,支持倒数排名融合 (RRF)
#
# Args:
# retrievers: 检索器列表
# weights: 检索器权重
# c: RRF 常数通常为60
#
# Returns:
# 集成检索器
# """
# if weights is None:
# weights = [1.0 / len(retrievers)] * len(retrievers)
#
# ensemble = EnsembleRetriever(
# retrievers=retrievers,
# weights=weights,
# c=c,
# search_type="rrf",
# )
#
# return ensemble