Files
ailine/backend/app/rag/retriever.py
root 2183c901b4
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 13m54s
添加稀疏模型本地缓存功能
- 创建 download_sparse_model.py 脚本用于下载稀疏模型到本地
- 添加 SPARSE_MODEL_PATH 和 SPARSE_MODEL_NAME 配置
- 修改 retriever.py 和 index_builder.py 使用 cache_dir
- 更新 .gitignore 排除 models/ 目录
- 更新 Dockerfile 在构建时下载稀疏模型
2026-05-03 18:55:39 +08:00

171 lines
5.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Qdrant 向量检索器模块
提供基于 Qdrant 的混合检索Dense + Sparse功能。
核心原理:
- 使用 Qdrant 原生混合检索langchain-qdrant 的 RetrievalMode.HYBRID
- 同时存储稠密向量和稀疏向量
- 语义理解 + 关键词匹配,效果最优
使用示例:
>>> from app.rag.retriever import create_hybrid_retriever
>>> retriever = create_hybrid_retriever(collection_name="rag_documents")
>>> docs = retriever.invoke("什么是 RAG")
"""
from typing import Dict, Any, Optional
from qdrant_client import QdrantClient
from qdrant_client.http.exceptions import UnexpectedResponse
from langchain_qdrant import (
QdrantVectorStore,
RetrievalMode,
FastEmbedSparse,
)
from langchain_core.embeddings import Embeddings
from langchain_core.retrievers import BaseRetriever
from rag_core import QDRANT_URL, QDRANT_API_KEY
from rag_core.client import create_qdrant_client as create_core_qdrant_client
from app.model_services import get_embedding_service
from app.config import SPARSE_MODEL_PATH, SPARSE_MODEL_NAME
from app.logger import info, warning
# 模块级常量
DEFAULT_SEARCH_K = 20
DEFAULT_SCORE_THRESHOLD = 0.3
def create_base_retriever(
collection_name: str,
search_kwargs: Dict[str, Any] | None = None,
client: QdrantClient | None = None,
embeddings: Embeddings | None = None,
) -> BaseRetriever:
"""
创建基础向量检索器(仅稠密向量检索)
Args:
collection_name: Qdrant 集合名称
search_kwargs: 搜索参数
client: 可选的 Qdrant 客户端
embeddings: 可选的嵌入模型(默认使用 get_embedding_service()
Returns:
LangChain 兼容的检索器
"""
# 默认使用统一嵌入服务(已内置降级机制)
if embeddings is None:
embeddings = get_embedding_service()
info("✅ 使用统一嵌入服务(本地 llama.cpp → 智谱 API 自动降级)")
# 合并默认搜索参数
merged_search_kwargs = {"k": DEFAULT_SEARCH_K}
if search_kwargs:
merged_search_kwargs.update(search_kwargs)
# 创建或复用 Qdrant 客户端
if client is None:
client = create_core_qdrant_client()
# 验证集合是否存在
try:
client.get_collection(collection_name)
except UnexpectedResponse as e:
if e.status_code == 404:
warning(f"⚠️ Qdrant 集合 '{collection_name}' 不存在,请先创建并索引文档")
raise ValueError(f"Qdrant 集合 '{collection_name}' 不存在")
raise
# 构建向量存储
vector_store = QdrantVectorStore(
client=client,
collection_name=collection_name,
embedding=embeddings,
)
return vector_store.as_retriever(search_kwargs=merged_search_kwargs)
def create_hybrid_retriever(
collection_name: str,
dense_k: int = 10,
sparse_k: int = 10,
score_threshold: float | None = DEFAULT_SCORE_THRESHOLD,
client: QdrantClient | None = None,
embeddings: Embeddings | None = None,
) -> BaseRetriever:
"""
创建混合检索器(稠密向量 + BM25 稀疏向量Qdrant 原生实现)。
Args:
collection_name: Qdrant 集合名称。
dense_k: 稠密向量检索返回数量,默认 10。
sparse_k: 稀疏向量检索返回数量,默认 10。
score_threshold: 相似度阈值,默认 0.3。
client: 可选的 Qdrant 客户端实例。
embeddings: 可选的嵌入模型实例。若未提供,将自动获取统一嵌入服务。
Returns:
BaseRetriever 实例,配置了混合搜索参数。
"""
total_k = dense_k + sparse_k
search_kwargs = {
"k": total_k,
"search_type": "similarity_score_threshold",
"score_threshold": score_threshold,
}
# 默认使用统一嵌入服务(已内置降级机制)
if embeddings is None:
embeddings = get_embedding_service()
info("✅ 使用统一嵌入服务(本地 llama.cpp → 智谱 API 自动降级)")
# 创建或复用 Qdrant 客户端
if client is None:
client = create_core_qdrant_client()
# 验证集合是否存在
try:
client.get_collection(collection_name)
except UnexpectedResponse as e:
if e.status_code == 404:
warning(f"⚠️ Qdrant 集合 '{collection_name}' 不存在,请先创建并索引文档")
raise ValueError(f"Qdrant 集合 '{collection_name}' 不存在")
raise
# 初始化稀疏嵌入(使用本地缓存目录)
sparse_embeddings = FastEmbedSparse(
model_name=SPARSE_MODEL_NAME,
cache_dir=SPARSE_MODEL_PATH
)
info(f"✅ FastEmbedSparse 初始化成功 (cache_dir={SPARSE_MODEL_PATH})")
# 创建混合模式的 QdrantVectorStore
vector_store = QdrantVectorStore(
client=client,
collection_name=collection_name,
embedding=embeddings,
sparse_embedding=sparse_embeddings,
retrieval_mode=RetrievalMode.HYBRID,
)
info(f"✅ Qdrant 原生混合检索器初始化成功 (k={total_k})")
return vector_store.as_retriever(search_kwargs=search_kwargs)
# 可选:提供异步友好的辅助函数
async def acreate_base_retriever(
collection_name: str,
search_kwargs: Dict[str, Any] | None = None,
client: QdrantClient | None = None,
) -> BaseRetriever:
"""
异步创建基础向量检索器(与同步版本功能相同)。
适用于需要异步初始化的场景(例如在 FastAPI 启动事件中)。
"""
# 由于 QdrantVectorStore 初始化本身是同步的,这里直接调用同步版本即可
return create_base_retriever(collection_name, search_kwargs, client)