feat: 实现 BM25 稀疏 + 稠密向量混合检索功能
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Has been cancelled

This commit is contained in:
2026-05-04 02:01:22 +08:00
parent 2183c901b4
commit 60afa86ded
26 changed files with 905 additions and 656 deletions

View File

@@ -6,6 +6,7 @@ RAG Core - 公共 RAG 组件包
from .embedders import LlamaCppEmbedder
from .vector_store import QdrantVectorStore
from .sparse_embedder import BM25SparseEmbedder, get_sparse_embedder
from .store import PostgresDocStore, create_docstore
from .retriever_factory import create_parent_retriever
from .config import (
@@ -21,6 +22,8 @@ from .config import (
__all__ = [
"LlamaCppEmbedder",
"QdrantVectorStore",
"BM25SparseEmbedder",
"get_sparse_embedder",
"QDRANT_URL",
"QDRANT_API_KEY",
"LLAMACPP_EMBEDDING_URL",

View File

@@ -1,5 +1,14 @@
# rag_core/retriever_factory.py
"""
RAG 检索器工厂模块
提供创建各种检索器的工厂函数,包括:
- 基础向量检索器
- ParentDocumentRetriever父子文档
- 混合检索器(稠密+稀疏)
"""
from typing import Optional
from langchain_core.embeddings import Embeddings
from langchain_core.retrievers import BaseRetriever
from langchain_classic.retrievers import ParentDocumentRetriever
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
from langchain_core.stores import BaseStore
@@ -9,18 +18,18 @@ from rag_core import LlamaCppEmbedder, QdrantVectorStore, create_docstore
def create_parent_retriever(
collection_name: str = "rag_documents",
parent_splitter: TextSplitter | None = None,
child_splitter: TextSplitter | None = None,
docstore: BaseStore | None = None,
parent_splitter: Optional[TextSplitter] = None,
child_splitter: Optional[TextSplitter] = None,
docstore: Optional[BaseStore] = None,
search_k: int = 5,
parent_chunk_size: int = 1000,
parent_chunk_overlap: int = 100,
child_chunk_size: int = 200,
child_chunk_overlap: int = 20,
embeddings: Embeddings | None = None,
embeddings: Optional[Embeddings] = None,
) -> ParentDocumentRetriever:
"""
创建 ParentDocumentRetriever 实例。
创建 ParentDocumentRetriever 实例(基础稠密向量版本)
Args:
collection_name: Qdrant 集合名称,默认 "rag_documents"
@@ -44,7 +53,7 @@ def create_parent_retriever(
# 向量存储(只读)
vector_store = QdrantVectorStore(collection_name=collection_name, embeddings=embeddings)
# 切分器(若未提供则创建默认)
if parent_splitter is None:
parent_splitter = RecursiveCharacterTextSplitter(
@@ -56,11 +65,11 @@ def create_parent_retriever(
chunk_size=child_chunk_size,
chunk_overlap=child_chunk_overlap,
)
# 文档存储
if docstore is None:
docstore, _ = create_docstore()
return ParentDocumentRetriever(
vectorstore=vector_store.get_langchain_vectorstore(),
docstore=docstore,
@@ -68,3 +77,34 @@ def create_parent_retriever(
parent_splitter=parent_splitter,
search_kwargs={"k": search_k},
)
def create_hybrid_retriever_factory(
collection_name: str = "rag_documents",
search_k: int = 5,
embeddings: Optional[Embeddings] = None,
) -> BaseRetriever:
"""
【不完整,仅占位】创建混合检索器的工厂函数占位符。
注意:完整的混合检索逻辑在 app/rag/retriever.py 中实现。
这里仅返回 QdrantVectorStore 作为基础。
Args:
collection_name: Qdrant 集合名称
search_k: 检索返回结果数
embeddings: 嵌入模型实例
Returns:
基础的 QdrantVectorStore仅稠密检索
"""
# 嵌入模型
if embeddings is None:
embedder = LlamaCppEmbedder()
embeddings = embedder.as_langchain_embeddings()
# 创建向量存储
vector_store = QdrantVectorStore(collection_name=collection_name, embeddings=embeddings)
# 返回 LangChain 兼容的 retriever
return vector_store.get_langchain_vectorstore().as_retriever(search_kwargs={"k": search_k})

View File

@@ -0,0 +1,34 @@
"""
BM25 稀疏嵌入器
基于 FastEmbed 的 Qdrant/bm25 模型,完全离线运行
"""
from typing import List
from fastembed.sparse.sparse_text_embedding import SparseTextEmbedding
from app.config import FASTEMBED_CACHE_PATH
class BM25SparseEmbedder:
"""BM25 稀疏嵌入包装器,与现有嵌入器风格统一"""
def __init__(self):
self.model = SparseTextEmbedding(
model_name="Qdrant/bm25",
cache_dir=FASTEMBED_CACHE_PATH,
local_files_only=True, # 强制离线,永不联网
)
def embed_documents(self, texts: List[str]) -> List[dict]:
"""返回稀疏向量列表,每个为 Qdrant 兼容的 dictindices+values"""
return [vec.as_object() for vec in self.model.embed(texts)]
def embed_query(self, text: str) -> dict:
"""返回单个稀疏向量"""
return list(self.model.embed([text]))[0].as_object()
# 全局单例
_sparse_embedder_instance = None
def get_sparse_embedder() -> BM25SparseEmbedder:
global _sparse_embedder_instance
if _sparse_embedder_instance is None:
_sparse_embedder_instance = BM25SparseEmbedder()
return _sparse_embedder_instance

View File

@@ -1,41 +1,48 @@
"""
Qdrant 向量数据库包装器。
支持稠密+稀疏双向量存储。
"""
import logging
import os
import time
import uuid
from typing import List, Optional, Dict, Any
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_qdrant import QdrantVectorStore as LangchainQdrantVS
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams
from qdrant_client.http.models import (
Distance, VectorParams, SparseVectorParams, SparseIndexParams,
SparseIndexType, PointStruct, NamedSparseVector, NamedVector
)
from httpx import RemoteProtocolError
from qdrant_client.http.exceptions import ResponseHandlingException
from .client import create_qdrant_client
from .embedders import LlamaCppEmbedder
from .sparse_embedder import BM25SparseEmbedder, get_sparse_embedder
logger = logging.getLogger(__name__)
class QdrantVectorStore:
"""Qdrant 向量数据库操作包装器。"""
"""Qdrant 向量数据库操作包装器 - 支持稠密+稀疏双向量存储"""
def __init__(self, collection_name: str, embeddings: Optional[Embeddings] = None):
def __init__(self, collection_name: str, embeddings: Optional[Embeddings] = None, sparse_embedder: Optional[BM25SparseEmbedder] = None):
"""
Args:
collection_name: Qdrant 集合名称。
embeddings: 嵌入模型实例,默认 None使用内部默认的 LlamaCppEmbedder
sparse_embedder: 稀疏嵌入模型实例,默认 None自动加载BM25
"""
self.collection_name = collection_name
self._client: Optional[QdrantClient] = None
self._connection_attempts = 0
self._last_connection_time: Optional[float] = None
# 嵌入模型
# 稠密嵌入模型
if embeddings is None:
embedder = LlamaCppEmbedder()
self.embeddings = embedder.as_langchain_embeddings()
@@ -43,9 +50,13 @@ class QdrantVectorStore:
else:
self.embeddings = embeddings
self._embedder = None
# 稀疏嵌入模型
self.sparse_embedder = sparse_embedder or get_sparse_embedder()
self.create_collection()
# 保留 LangChain 向量存储实例(用于兼容)
self.vector_store = LangchainQdrantVS(
client=self.get_client(),
collection_name=self.collection_name,
@@ -97,7 +108,7 @@ class QdrantVectorStore:
}
def create_collection(self, force_recreate: bool = False):
"""创建集合,设置合适的向量维度"""
"""创建集合,支持稠密+稀疏双向量"""
if self._embedder is not None:
# 使用内部的 embedder 获取维度
vector_size = self._embedder.get_embedding_dimension()
@@ -119,11 +130,31 @@ class QdrantVectorStore:
exists = False
if not exists:
# 向量配置:稠密向量
vectors_config = {
"dense": VectorParams(
size=vector_size,
distance=Distance.COSINE,
optional=True
)
}
# 稀疏向量配置
sparse_vectors_config = {
"sparse": SparseVectorParams(
index=SparseIndexParams(
type=SparseIndexType.MUTABLE
),
optional=True
)
}
client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE),
vectors_config=vectors_config,
sparse_vectors_config=sparse_vectors_config
)
logger.info("集合 '%s' 已创建(维度=%d", self.collection_name, vector_size)
logger.info("集合 '%s' 已创建(维度=%d,支持稠密+稀疏双向量", self.collection_name, vector_size)
else:
logger.info("集合 '%s' 已存在", self.collection_name)
return
@@ -142,18 +173,54 @@ class QdrantVectorStore:
time.sleep(wait_time)
def add_documents(self, documents: List[Document], batch_size: int = 100):
"""将文档添加到向量数据库。"""
"""将文档添加到向量数据库,自动生成稠密+稀疏双向量"""
if not documents:
return []
self.create_collection()
ids = self.vector_store.add_documents(documents, batch_size=batch_size)
logger.info("已向 '%s' 添加 %d 个文档", self.collection_name, len(ids))
return ids
client = self.get_client()
doc_ids = []
# 分批处理
for i in range(0, len(documents), batch_size):
batch_docs = documents[i:i+batch_size]
texts = [doc.page_content for doc in batch_docs]
# 生成双向量
dense_vectors = self.embeddings.embed_documents(texts)
sparse_vectors = self.sparse_embedder.embed_documents(texts)
points = []
for j, doc in enumerate(batch_docs):
point_id = doc.metadata.get("id", str(uuid.uuid4()))
doc_ids.append(point_id)
# 构造双向量
named_vectors = {
"dense": dense_vectors[j],
"sparse": NamedSparseVector(
name="sparse",
vector=sparse_vectors[j]
)
}
points.append(PointStruct(
id=point_id,
vector=named_vectors,
payload={"text": doc.page_content, **doc.metadata}
))
# 批量插入
client.upsert(collection_name=self.collection_name, points=points)
logger.info("已向 '%s' 添加 %d 个文档(稠密+稀疏双向量)", self.collection_name, len(points))
return doc_ids
def similarity_search(self, query: str, k: int = 5) -> List[Document]:
"""基础稠密向量检索(兼容原有接口)。"""
return self.vector_store.similarity_search(query, k=k)
def similarity_search_with_score(self, query: str, k: int = 5) -> List[tuple[Document, float]]:
"""基础稠密向量检索带分数(兼容原有接口)。"""
return self.vector_store.similarity_search_with_score(query, k=k)
def delete_collection(self):
@@ -183,5 +250,5 @@ class QdrantVectorStore:
return self.vector_store
def get_qdrant_client(self):
"""返回原生 Qdrant 客户端(如需手动管理 collection"""
"""返回原生 Qdrant 客户端(用于自定义检索逻辑"""
return self.get_client()