feat: 实现 BM25 稀疏 + 稠密向量混合检索功能
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Has been cancelled
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Has been cancelled
This commit is contained in:
@@ -6,6 +6,7 @@ RAG Core - 公共 RAG 组件包
|
||||
|
||||
from .embedders import LlamaCppEmbedder
|
||||
from .vector_store import QdrantVectorStore
|
||||
from .sparse_embedder import BM25SparseEmbedder, get_sparse_embedder
|
||||
from .store import PostgresDocStore, create_docstore
|
||||
from .retriever_factory import create_parent_retriever
|
||||
from .config import (
|
||||
@@ -21,6 +22,8 @@ from .config import (
|
||||
__all__ = [
|
||||
"LlamaCppEmbedder",
|
||||
"QdrantVectorStore",
|
||||
"BM25SparseEmbedder",
|
||||
"get_sparse_embedder",
|
||||
"QDRANT_URL",
|
||||
"QDRANT_API_KEY",
|
||||
"LLAMACPP_EMBEDDING_URL",
|
||||
|
||||
@@ -1,5 +1,14 @@
|
||||
# rag_core/retriever_factory.py
|
||||
"""
|
||||
RAG 检索器工厂模块
|
||||
|
||||
提供创建各种检索器的工厂函数,包括:
|
||||
- 基础向量检索器
|
||||
- ParentDocumentRetriever(父子文档)
|
||||
- 混合检索器(稠密+稀疏)
|
||||
"""
|
||||
from typing import Optional
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_classic.retrievers import ParentDocumentRetriever
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
|
||||
from langchain_core.stores import BaseStore
|
||||
@@ -9,18 +18,18 @@ from rag_core import LlamaCppEmbedder, QdrantVectorStore, create_docstore
|
||||
|
||||
def create_parent_retriever(
|
||||
collection_name: str = "rag_documents",
|
||||
parent_splitter: TextSplitter | None = None,
|
||||
child_splitter: TextSplitter | None = None,
|
||||
docstore: BaseStore | None = None,
|
||||
parent_splitter: Optional[TextSplitter] = None,
|
||||
child_splitter: Optional[TextSplitter] = None,
|
||||
docstore: Optional[BaseStore] = None,
|
||||
search_k: int = 5,
|
||||
parent_chunk_size: int = 1000,
|
||||
parent_chunk_overlap: int = 100,
|
||||
child_chunk_size: int = 200,
|
||||
child_chunk_overlap: int = 20,
|
||||
embeddings: Embeddings | None = None,
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
) -> ParentDocumentRetriever:
|
||||
"""
|
||||
创建 ParentDocumentRetriever 实例。
|
||||
创建 ParentDocumentRetriever 实例(基础稠密向量版本)。
|
||||
|
||||
Args:
|
||||
collection_name: Qdrant 集合名称,默认 "rag_documents"
|
||||
@@ -44,7 +53,7 @@ def create_parent_retriever(
|
||||
|
||||
# 向量存储(只读)
|
||||
vector_store = QdrantVectorStore(collection_name=collection_name, embeddings=embeddings)
|
||||
|
||||
|
||||
# 切分器(若未提供则创建默认)
|
||||
if parent_splitter is None:
|
||||
parent_splitter = RecursiveCharacterTextSplitter(
|
||||
@@ -56,11 +65,11 @@ def create_parent_retriever(
|
||||
chunk_size=child_chunk_size,
|
||||
chunk_overlap=child_chunk_overlap,
|
||||
)
|
||||
|
||||
|
||||
# 文档存储
|
||||
if docstore is None:
|
||||
docstore, _ = create_docstore()
|
||||
|
||||
|
||||
return ParentDocumentRetriever(
|
||||
vectorstore=vector_store.get_langchain_vectorstore(),
|
||||
docstore=docstore,
|
||||
@@ -68,3 +77,34 @@ def create_parent_retriever(
|
||||
parent_splitter=parent_splitter,
|
||||
search_kwargs={"k": search_k},
|
||||
)
|
||||
|
||||
|
||||
def create_hybrid_retriever_factory(
|
||||
collection_name: str = "rag_documents",
|
||||
search_k: int = 5,
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
) -> BaseRetriever:
|
||||
"""
|
||||
【不完整,仅占位】创建混合检索器的工厂函数占位符。
|
||||
|
||||
注意:完整的混合检索逻辑在 app/rag/retriever.py 中实现。
|
||||
这里仅返回 QdrantVectorStore 作为基础。
|
||||
|
||||
Args:
|
||||
collection_name: Qdrant 集合名称
|
||||
search_k: 检索返回结果数
|
||||
embeddings: 嵌入模型实例
|
||||
|
||||
Returns:
|
||||
基础的 QdrantVectorStore(仅稠密检索)
|
||||
"""
|
||||
# 嵌入模型
|
||||
if embeddings is None:
|
||||
embedder = LlamaCppEmbedder()
|
||||
embeddings = embedder.as_langchain_embeddings()
|
||||
|
||||
# 创建向量存储
|
||||
vector_store = QdrantVectorStore(collection_name=collection_name, embeddings=embeddings)
|
||||
|
||||
# 返回 LangChain 兼容的 retriever
|
||||
return vector_store.get_langchain_vectorstore().as_retriever(search_kwargs={"k": search_k})
|
||||
|
||||
34
backend/rag_core/sparse_embedder.py
Normal file
34
backend/rag_core/sparse_embedder.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""
|
||||
BM25 稀疏嵌入器
|
||||
基于 FastEmbed 的 Qdrant/bm25 模型,完全离线运行
|
||||
"""
|
||||
from typing import List
|
||||
from fastembed.sparse.sparse_text_embedding import SparseTextEmbedding
|
||||
from app.config import FASTEMBED_CACHE_PATH
|
||||
|
||||
class BM25SparseEmbedder:
|
||||
"""BM25 稀疏嵌入包装器,与现有嵌入器风格统一"""
|
||||
|
||||
def __init__(self):
|
||||
self.model = SparseTextEmbedding(
|
||||
model_name="Qdrant/bm25",
|
||||
cache_dir=FASTEMBED_CACHE_PATH,
|
||||
local_files_only=True, # 强制离线,永不联网
|
||||
)
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[dict]:
|
||||
"""返回稀疏向量列表,每个为 Qdrant 兼容的 dict(indices+values)"""
|
||||
return [vec.as_object() for vec in self.model.embed(texts)]
|
||||
|
||||
def embed_query(self, text: str) -> dict:
|
||||
"""返回单个稀疏向量"""
|
||||
return list(self.model.embed([text]))[0].as_object()
|
||||
|
||||
# 全局单例
|
||||
_sparse_embedder_instance = None
|
||||
|
||||
def get_sparse_embedder() -> BM25SparseEmbedder:
|
||||
global _sparse_embedder_instance
|
||||
if _sparse_embedder_instance is None:
|
||||
_sparse_embedder_instance = BM25SparseEmbedder()
|
||||
return _sparse_embedder_instance
|
||||
@@ -1,41 +1,48 @@
|
||||
"""
|
||||
Qdrant 向量数据库包装器。
|
||||
支持稠密+稀疏双向量存储。
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from typing import List, Optional, Dict, Any
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_qdrant import QdrantVectorStore as LangchainQdrantVS
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.http.models import Distance, VectorParams
|
||||
from qdrant_client.http.models import (
|
||||
Distance, VectorParams, SparseVectorParams, SparseIndexParams,
|
||||
SparseIndexType, PointStruct, NamedSparseVector, NamedVector
|
||||
)
|
||||
from httpx import RemoteProtocolError
|
||||
from qdrant_client.http.exceptions import ResponseHandlingException
|
||||
|
||||
from .client import create_qdrant_client
|
||||
from .embedders import LlamaCppEmbedder
|
||||
from .sparse_embedder import BM25SparseEmbedder, get_sparse_embedder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class QdrantVectorStore:
|
||||
"""Qdrant 向量数据库操作包装器。"""
|
||||
"""Qdrant 向量数据库操作包装器 - 支持稠密+稀疏双向量存储。"""
|
||||
|
||||
def __init__(self, collection_name: str, embeddings: Optional[Embeddings] = None):
|
||||
def __init__(self, collection_name: str, embeddings: Optional[Embeddings] = None, sparse_embedder: Optional[BM25SparseEmbedder] = None):
|
||||
"""
|
||||
Args:
|
||||
collection_name: Qdrant 集合名称。
|
||||
embeddings: 嵌入模型实例,默认 None(使用内部默认的 LlamaCppEmbedder)。
|
||||
sparse_embedder: 稀疏嵌入模型实例,默认 None(自动加载BM25)。
|
||||
"""
|
||||
self.collection_name = collection_name
|
||||
self._client: Optional[QdrantClient] = None
|
||||
self._connection_attempts = 0
|
||||
self._last_connection_time: Optional[float] = None
|
||||
|
||||
# 嵌入模型
|
||||
# 稠密嵌入模型
|
||||
if embeddings is None:
|
||||
embedder = LlamaCppEmbedder()
|
||||
self.embeddings = embedder.as_langchain_embeddings()
|
||||
@@ -43,9 +50,13 @@ class QdrantVectorStore:
|
||||
else:
|
||||
self.embeddings = embeddings
|
||||
self._embedder = None
|
||||
|
||||
# 稀疏嵌入模型
|
||||
self.sparse_embedder = sparse_embedder or get_sparse_embedder()
|
||||
|
||||
self.create_collection()
|
||||
|
||||
# 保留 LangChain 向量存储实例(用于兼容)
|
||||
self.vector_store = LangchainQdrantVS(
|
||||
client=self.get_client(),
|
||||
collection_name=self.collection_name,
|
||||
@@ -97,7 +108,7 @@ class QdrantVectorStore:
|
||||
}
|
||||
|
||||
def create_collection(self, force_recreate: bool = False):
|
||||
"""创建集合,设置合适的向量维度。"""
|
||||
"""创建集合,支持稠密+稀疏双向量。"""
|
||||
if self._embedder is not None:
|
||||
# 使用内部的 embedder 获取维度
|
||||
vector_size = self._embedder.get_embedding_dimension()
|
||||
@@ -119,11 +130,31 @@ class QdrantVectorStore:
|
||||
exists = False
|
||||
|
||||
if not exists:
|
||||
# 向量配置:稠密向量
|
||||
vectors_config = {
|
||||
"dense": VectorParams(
|
||||
size=vector_size,
|
||||
distance=Distance.COSINE,
|
||||
optional=True
|
||||
)
|
||||
}
|
||||
|
||||
# 稀疏向量配置
|
||||
sparse_vectors_config = {
|
||||
"sparse": SparseVectorParams(
|
||||
index=SparseIndexParams(
|
||||
type=SparseIndexType.MUTABLE
|
||||
),
|
||||
optional=True
|
||||
)
|
||||
}
|
||||
|
||||
client.create_collection(
|
||||
collection_name=self.collection_name,
|
||||
vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE),
|
||||
vectors_config=vectors_config,
|
||||
sparse_vectors_config=sparse_vectors_config
|
||||
)
|
||||
logger.info("集合 '%s' 已创建(维度=%d)", self.collection_name, vector_size)
|
||||
logger.info("集合 '%s' 已创建(维度=%d,支持稠密+稀疏双向量)", self.collection_name, vector_size)
|
||||
else:
|
||||
logger.info("集合 '%s' 已存在", self.collection_name)
|
||||
return
|
||||
@@ -142,18 +173,54 @@ class QdrantVectorStore:
|
||||
time.sleep(wait_time)
|
||||
|
||||
def add_documents(self, documents: List[Document], batch_size: int = 100):
|
||||
"""将文档添加到向量数据库。"""
|
||||
"""将文档添加到向量数据库,自动生成稠密+稀疏双向量。"""
|
||||
if not documents:
|
||||
return []
|
||||
self.create_collection()
|
||||
ids = self.vector_store.add_documents(documents, batch_size=batch_size)
|
||||
logger.info("已向 '%s' 添加 %d 个文档", self.collection_name, len(ids))
|
||||
return ids
|
||||
client = self.get_client()
|
||||
doc_ids = []
|
||||
|
||||
# 分批处理
|
||||
for i in range(0, len(documents), batch_size):
|
||||
batch_docs = documents[i:i+batch_size]
|
||||
texts = [doc.page_content for doc in batch_docs]
|
||||
|
||||
# 生成双向量
|
||||
dense_vectors = self.embeddings.embed_documents(texts)
|
||||
sparse_vectors = self.sparse_embedder.embed_documents(texts)
|
||||
|
||||
points = []
|
||||
for j, doc in enumerate(batch_docs):
|
||||
point_id = doc.metadata.get("id", str(uuid.uuid4()))
|
||||
doc_ids.append(point_id)
|
||||
|
||||
# 构造双向量
|
||||
named_vectors = {
|
||||
"dense": dense_vectors[j],
|
||||
"sparse": NamedSparseVector(
|
||||
name="sparse",
|
||||
vector=sparse_vectors[j]
|
||||
)
|
||||
}
|
||||
|
||||
points.append(PointStruct(
|
||||
id=point_id,
|
||||
vector=named_vectors,
|
||||
payload={"text": doc.page_content, **doc.metadata}
|
||||
))
|
||||
|
||||
# 批量插入
|
||||
client.upsert(collection_name=self.collection_name, points=points)
|
||||
logger.info("已向 '%s' 添加 %d 个文档(稠密+稀疏双向量)", self.collection_name, len(points))
|
||||
|
||||
return doc_ids
|
||||
|
||||
def similarity_search(self, query: str, k: int = 5) -> List[Document]:
|
||||
"""基础稠密向量检索(兼容原有接口)。"""
|
||||
return self.vector_store.similarity_search(query, k=k)
|
||||
|
||||
def similarity_search_with_score(self, query: str, k: int = 5) -> List[tuple[Document, float]]:
|
||||
"""基础稠密向量检索带分数(兼容原有接口)。"""
|
||||
return self.vector_store.similarity_search_with_score(query, k=k)
|
||||
|
||||
def delete_collection(self):
|
||||
@@ -183,5 +250,5 @@ class QdrantVectorStore:
|
||||
return self.vector_store
|
||||
|
||||
def get_qdrant_client(self):
|
||||
"""返回原生 Qdrant 客户端(如需手动管理 collection)"""
|
||||
"""返回原生 Qdrant 客户端(用于自定义检索逻辑)"""
|
||||
return self.get_client()
|
||||
|
||||
Reference in New Issue
Block a user