refactor!: 完全异步化 RAG 系统,移除 LangChain ParentDocumentRetriever 依赖
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m34s
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m34s
- 重写 rag_core/vector_store.py:完全异步实现 aadd_documents、asimilarity_search - 重写 app/rag/retriever.py:异步混合检索,移除同步兼容代码 - 修改 rag_indexer/index_builder.py:全链路异步调用 - 删除 rag_core/retriever_factory.py:不再使用 LangChain ParentDocumentRetriever - 清理冗余导入和代码:移除 model_services 兼容、不需要的异常导入 - 更新 rag_indexer/README.md:反映新架构 核心改进: - 完全异步化:索引构建和检索全链路 async/await - 自定义实现:不再依赖 LangChain 的 ParentDocumentRetriever - 双向量支持:子文档同时存储 dense + sparse 向量到 Qdrant - 架构清晰:rag_core 公共组件、rag_indexer 索引、app/rag 检索
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
Qdrant 混合检索器模块
|
||||
Qdrant 混合检索器模块(完全异步)
|
||||
|
||||
提供基于 Qdrant 的混合检索(Dense + Sparse)功能,包括:
|
||||
- 纯混合检索(无子父文档)
|
||||
@@ -12,15 +12,15 @@ Qdrant 混合检索器模块
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional, List
|
||||
from qdrant_client import QdrantClient, models
|
||||
from qdrant_client import AsyncQdrantClient, models
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.retrievers import BaseRetriever, RetrieverOutput
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from pydantic import Field, PrivateAttr
|
||||
|
||||
from rag_core import QdrantVectorStore, get_sparse_embedder, create_docstore
|
||||
from rag_core.client import create_qdrant_client as create_core_qdrant_client
|
||||
from rag_core import QdrantHybridStore, get_sparse_embedder, create_docstore
|
||||
from rag_core.client import create_async_qdrant_client
|
||||
from app.model_services import get_embedding_service
|
||||
from app.logger import info, warning, debug
|
||||
|
||||
@@ -32,13 +32,13 @@ DEFAULT_PARENT_SEARCH_K = 5
|
||||
|
||||
class HybridRetriever(BaseRetriever):
|
||||
"""
|
||||
混合检索器:稠密向量 + BM25 稀疏向量 RRF 分数融合
|
||||
混合检索器:稠密向量 + BM25 稀疏向量 RRF 分数融合(异步)
|
||||
|
||||
使用 Qdrant Universal Query API (query_points)
|
||||
"""
|
||||
collection_name: str = Field(description="Qdrant 集合名称")
|
||||
search_k: int = Field(default=DEFAULT_SEARCH_K, description="检索返回结果数")
|
||||
|
||||
|
||||
_vector_store: Any = PrivateAttr()
|
||||
_client: Any = PrivateAttr()
|
||||
_sparse_embedder: Any = PrivateAttr()
|
||||
@@ -46,13 +46,13 @@ class HybridRetriever(BaseRetriever):
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str,
|
||||
vector_store: QdrantVectorStore,
|
||||
vector_store: QdrantHybridStore,
|
||||
search_k: int = DEFAULT_SEARCH_K,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
collection_name: Qdrant 集合名称
|
||||
vector_store: QdrantVectorStore 实例
|
||||
vector_store: QdrantHybridStore 实例
|
||||
search_k: 检索返回结果数
|
||||
"""
|
||||
super().__init__(
|
||||
@@ -60,46 +60,40 @@ class HybridRetriever(BaseRetriever):
|
||||
search_k=search_k
|
||||
)
|
||||
self._vector_store = vector_store
|
||||
self._client = vector_store.get_qdrant_client()
|
||||
self._client = vector_store.get_async_qdrant_client()
|
||||
self._sparse_embedder = get_sparse_embedder()
|
||||
|
||||
def _get_relevant_documents(
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, **kwargs
|
||||
) -> List[Document]:
|
||||
"""
|
||||
同步检索相关文档
|
||||
|
||||
Args:
|
||||
query: 查询字符串
|
||||
|
||||
Returns:
|
||||
相关文档列表
|
||||
异步混合检索相关文档
|
||||
"""
|
||||
# 1. 生成双向量
|
||||
dense_query = self._vector_store.embeddings.embed_query(query)
|
||||
# 1. 生成查询向量
|
||||
dense_query = await self._vector_store._aembed_query(query)
|
||||
sparse_query = self._sparse_embedder.embed_query(query)
|
||||
sparse_vec = models.SparseVector(
|
||||
indices=sparse_query["indices"],
|
||||
values=sparse_query["values"]
|
||||
)
|
||||
|
||||
# 2. 使用官方的 query_points API(推荐方式)
|
||||
response = self._client.query_points(
|
||||
# 2. 使用 Qdrant 的 query_points API
|
||||
response = await self._client.query_points(
|
||||
collection_name=self.collection_name,
|
||||
prefetch=[ # 并行预取多个检索源
|
||||
prefetch=[
|
||||
models.Prefetch(
|
||||
query=dense_query,
|
||||
using="dense", # 使用稠密向量进行语义搜索
|
||||
using="dense",
|
||||
limit=self.search_k
|
||||
),
|
||||
models.Prefetch(
|
||||
query=sparse_vec,
|
||||
using="sparse", # 使用稀疏向量进行关键词搜索
|
||||
using="sparse",
|
||||
limit=self.search_k
|
||||
)
|
||||
],
|
||||
query=models.FusionQuery(fusion=models.Fusion.RRF), # 指定融合算法为 RRF
|
||||
limit=self.search_k, # 最终返回的结果数量
|
||||
query=models.FusionQuery(fusion=models.Fusion.RRF),
|
||||
limit=self.search_k,
|
||||
with_payload=True
|
||||
)
|
||||
|
||||
@@ -112,20 +106,13 @@ class HybridRetriever(BaseRetriever):
|
||||
)
|
||||
results.append(doc)
|
||||
|
||||
debug(f"混合检索返回 {len(results)} 个文档")
|
||||
debug(f"混合检索返回 %d 个文档", len(results))
|
||||
return results
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, **kwargs
|
||||
) -> List[Document]:
|
||||
"""异步检索(当前调用同步版本)"""
|
||||
# Qdrant 客户端没有原生 async,这里用同步版本
|
||||
return self._get_relevant_documents(query, **kwargs)
|
||||
|
||||
|
||||
class ParentHybridRetriever(BaseRetriever):
|
||||
"""
|
||||
父子文档混合检索器:
|
||||
父子文档混合检索器(异步):
|
||||
|
||||
1. 先用混合检索找到相关子文档
|
||||
2. 根据子文档的 parent_id 找到对应的父文档
|
||||
@@ -134,7 +121,7 @@ class ParentHybridRetriever(BaseRetriever):
|
||||
|
||||
collection_name: str = Field(description="Qdrant 集合名称")
|
||||
search_k: int = Field(default=DEFAULT_PARENT_SEARCH_K, description="检索返回结果数")
|
||||
|
||||
|
||||
_vector_store: Any = PrivateAttr()
|
||||
_client: Any = PrivateAttr()
|
||||
_sparse_embedder: Any = PrivateAttr()
|
||||
@@ -143,14 +130,14 @@ class ParentHybridRetriever(BaseRetriever):
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str,
|
||||
vector_store: QdrantVectorStore,
|
||||
vector_store: QdrantHybridStore,
|
||||
search_k: int = DEFAULT_PARENT_SEARCH_K,
|
||||
docstore: Optional[Any] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
collection_name: Qdrant 集合名称
|
||||
vector_store: QdrantVectorStore 实例
|
||||
vector_store: QdrantHybridStore 实例
|
||||
search_k: 最终返回的父文档数量
|
||||
docstore: 文档存储(如果父文档在 PostgreSQL),可选
|
||||
"""
|
||||
@@ -159,24 +146,18 @@ class ParentHybridRetriever(BaseRetriever):
|
||||
search_k=search_k
|
||||
)
|
||||
self._vector_store = vector_store
|
||||
self._client = vector_store.get_qdrant_client()
|
||||
self._client = vector_store.get_async_qdrant_client()
|
||||
self._sparse_embedder = get_sparse_embedder()
|
||||
self._docstore = docstore
|
||||
|
||||
def _get_relevant_documents(
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, **kwargs
|
||||
) -> List[Document]:
|
||||
"""
|
||||
同步检索相关父文档
|
||||
|
||||
Args:
|
||||
query: 查询字符串
|
||||
|
||||
Returns:
|
||||
相关父文档列表
|
||||
异步检索相关父文档
|
||||
"""
|
||||
# 1. 生成查询双向量
|
||||
dense_query = self._vector_store.embeddings.embed_query(query)
|
||||
# 1. 生成查询向量
|
||||
dense_query = await self._vector_store._aembed_query(query)
|
||||
sparse_query = self._sparse_embedder.embed_query(query)
|
||||
sparse_vec = models.SparseVector(
|
||||
indices=sparse_query["indices"],
|
||||
@@ -187,7 +168,7 @@ class ParentHybridRetriever(BaseRetriever):
|
||||
search_limit = self.search_k * 2
|
||||
|
||||
# 3. 使用 query_points API 进行混合检索
|
||||
response = self._client.query_points(
|
||||
response = await self._client.query_points(
|
||||
collection_name=self.collection_name,
|
||||
prefetch=[
|
||||
models.Prefetch(
|
||||
@@ -216,30 +197,27 @@ class ParentHybridRetriever(BaseRetriever):
|
||||
child_point_map = {} # 保存子文档点用于降级
|
||||
|
||||
for point in response.points:
|
||||
# 先复制 payload,避免修改原始对象
|
||||
payload_copy = point.payload.copy()
|
||||
parent_id = payload_copy.get("parent_id", point.id)
|
||||
score = point.score
|
||||
|
||||
# 同一个 parent_id 只保留最高得分
|
||||
if parent_id not in parent_score_map or score > parent_score_map[parent_id]:
|
||||
parent_score_map[parent_id] = score
|
||||
parent_ids.add(parent_id)
|
||||
child_point_map[parent_id] = point
|
||||
|
||||
# 5. 批量查询父文档
|
||||
# 首先尝试从 Qdrant 直接查询(因为父文档可能也在 Qdrant 中)
|
||||
parent_docs = []
|
||||
found_parent_ids = set()
|
||||
|
||||
# 先尝试从 Qdrant 直接查询(如果父文档也在 Qdrant 中)
|
||||
try:
|
||||
parent_points = self._client.retrieve(
|
||||
parent_points = await self._client.retrieve(
|
||||
collection_name=self.collection_name,
|
||||
ids=list(parent_ids),
|
||||
with_payload=True
|
||||
)
|
||||
|
||||
# 处理找到的父文档
|
||||
for point in parent_points:
|
||||
payload_copy = point.payload.copy()
|
||||
doc = Document(
|
||||
@@ -250,24 +228,24 @@ class ParentHybridRetriever(BaseRetriever):
|
||||
found_parent_ids.add(point.id)
|
||||
|
||||
except Exception as e:
|
||||
warning(f"从 Qdrant 查询父文档失败: {e}")
|
||||
warning(f"从 Qdrant 查询父文档失败: %s", e)
|
||||
|
||||
# 6. 如果有 docstore,尝试从 docstore 查询剩余的父文档
|
||||
if self._docstore and len(found_parent_ids) < len(parent_ids):
|
||||
missing_parent_ids = parent_ids - found_parent_ids
|
||||
try:
|
||||
docstore_docs = self._docstore.mget(missing_parent_ids)
|
||||
docstore_docs = await self._docstore.amget(missing_parent_ids)
|
||||
for doc_id, doc in zip(missing_parent_ids, docstore_docs):
|
||||
if doc is not None:
|
||||
parent_docs.append(doc)
|
||||
found_parent_ids.add(doc_id)
|
||||
except Exception as e:
|
||||
warning(f"从 docstore 查询父文档失败: {e}")
|
||||
warning(f"从 docstore 查询父文档失败: %s", e)
|
||||
|
||||
# 7. 降级:对于仍未找到的父文档,用子文档本身代替
|
||||
missing_parent_ids = parent_ids - found_parent_ids
|
||||
if missing_parent_ids:
|
||||
warning(f"以下 parent_id 未找到对应的父文档,将返回子文档本身: {missing_parent_ids}")
|
||||
warning(f"以下 parent_id 未找到对应的父文档,将返回子文档本身: %s", missing_parent_ids)
|
||||
for parent_id in missing_parent_ids:
|
||||
child_point = child_point_map.get(parent_id)
|
||||
if child_point:
|
||||
@@ -280,22 +258,16 @@ class ParentHybridRetriever(BaseRetriever):
|
||||
|
||||
# 8. 按照得分降序排序,返回前 k 个
|
||||
parent_docs_with_scores = [
|
||||
(doc, parent_score_map.get(doc.metadata.get("id", doc.id), 0.0))
|
||||
(doc, parent_score_map.get(doc.metadata.get("id", doc.id if hasattr(doc, "id") else ""), 0.0))
|
||||
for doc in parent_docs
|
||||
]
|
||||
parent_docs_with_scores.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
final_docs = [doc for doc, _ in parent_docs_with_scores[:self.search_k]]
|
||||
debug(f"父子文档混合检索返回 {len(final_docs)} 个父文档")
|
||||
debug(f"父子文档混合检索返回 %d 个父文档", len(final_docs))
|
||||
|
||||
return final_docs
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, **kwargs
|
||||
) -> List[Document]:
|
||||
"""异步检索(当前调用同步版本)"""
|
||||
return self._get_relevant_documents(query, **kwargs)
|
||||
|
||||
|
||||
def create_hybrid_retriever(
|
||||
collection_name: str,
|
||||
@@ -303,7 +275,7 @@ def create_hybrid_retriever(
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
) -> BaseRetriever:
|
||||
"""
|
||||
创建混合检索器(稠密向量 + BM25 稀疏向量)。
|
||||
创建混合检索器(稠密向量 + BM25 稀疏向量)- 异步版本。
|
||||
|
||||
这是默认推荐的检索方式,效果最优。
|
||||
|
||||
@@ -315,15 +287,12 @@ def create_hybrid_retriever(
|
||||
Returns:
|
||||
HybridRetriever 实例
|
||||
"""
|
||||
# 默认使用统一嵌入服务
|
||||
if embeddings is None:
|
||||
embeddings = get_embedding_service()
|
||||
info("✅ 使用统一嵌入服务(本地 llama.cpp → 智谱 API 自动降级)")
|
||||
info("使用统一嵌入服务(本地 llama.cpp → 智谱 API 自动降级)")
|
||||
|
||||
# 创建向量存储
|
||||
vector_store = QdrantVectorStore(collection_name=collection_name, embeddings=embeddings)
|
||||
vector_store = QdrantHybridStore(collection_name=collection_name, embeddings=embeddings)
|
||||
|
||||
# 验证集合是否存在
|
||||
try:
|
||||
vector_store.get_client().get_collection(collection_name)
|
||||
except UnexpectedResponse as e:
|
||||
@@ -347,7 +316,7 @@ def create_parent_hybrid_retriever(
|
||||
use_docstore: bool = True,
|
||||
) -> BaseRetriever:
|
||||
"""
|
||||
创建父子文档混合检索器(默认推荐)。
|
||||
创建父子文档混合检索器(默认推荐)- 异步版本。
|
||||
|
||||
检索流程:
|
||||
1. 混合检索找到相关子文档
|
||||
@@ -363,15 +332,12 @@ def create_parent_hybrid_retriever(
|
||||
Returns:
|
||||
ParentHybridRetriever 实例
|
||||
"""
|
||||
# 默认使用统一嵌入服务
|
||||
if embeddings is None:
|
||||
embeddings = get_embedding_service()
|
||||
info("✅ 使用统一嵌入服务(本地 llama.cpp → 智谱 API 自动降级)")
|
||||
info("使用统一嵌入服务(本地 llama.cpp → 智谱 API 自动降级)")
|
||||
|
||||
# 创建向量存储
|
||||
vector_store = QdrantVectorStore(collection_name=collection_name, embeddings=embeddings)
|
||||
vector_store = QdrantHybridStore(collection_name=collection_name, embeddings=embeddings)
|
||||
|
||||
# 验证集合是否存在
|
||||
try:
|
||||
vector_store.get_client().get_collection(collection_name)
|
||||
except UnexpectedResponse as e:
|
||||
@@ -380,14 +346,13 @@ def create_parent_hybrid_retriever(
|
||||
raise ValueError(f"Qdrant 集合 '{collection_name}' 不存在")
|
||||
raise
|
||||
|
||||
# 创建 docstore(如果需要)
|
||||
docstore = None
|
||||
if use_docstore:
|
||||
try:
|
||||
docstore, _ = create_docstore()
|
||||
info("✅ 文档存储初始化成功(PostgreSQL)")
|
||||
except Exception as e:
|
||||
warning(f"⚠️ 文档存储初始化失败,将不使用 docstore: {e}")
|
||||
warning(f"⚠️ 文档存储初始化失败,将不使用 docstore: %s", e)
|
||||
|
||||
info(f"✅ Qdrant 父子文档混合检索器初始化成功(search_k={search_k})")
|
||||
return ParentHybridRetriever(
|
||||
@@ -404,24 +369,9 @@ def create_base_retriever(
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
) -> BaseRetriever:
|
||||
"""
|
||||
创建基础稠密检索器(向后兼容)。
|
||||
|
||||
Args:
|
||||
collection_name: Qdrant 集合名称
|
||||
search_k: 检索返回结果数
|
||||
embeddings: 可选的嵌入模型实例
|
||||
|
||||
Returns:
|
||||
LangChain 的 BaseRetriever 实例
|
||||
创建基础检索器(向后兼容)- 实际上返回混合检索器。
|
||||
"""
|
||||
# 默认使用统一嵌入服务
|
||||
if embeddings is None:
|
||||
embeddings = get_embedding_service()
|
||||
|
||||
vector_store = QdrantVectorStore(collection_name=collection_name, embeddings=embeddings)
|
||||
|
||||
info(f"✅ Qdrant 基础稠密检索器初始化成功(search_k={search_k})")
|
||||
return vector_store.as_langchain_vectorstore().as_retriever(k=search_k)
|
||||
return create_hybrid_retriever(collection_name, search_k, embeddings)
|
||||
|
||||
|
||||
# 别名:默认就是父子文档混合检索
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
"""
|
||||
RAG Core - 公共 RAG 组件包
|
||||
|
||||
提供嵌入模型、向量存储和文档存储的公共功能,被 rag_indexer 和 app/rag 共用。
|
||||
"""
|
||||
|
||||
from .embedders import LlamaCppEmbedder
|
||||
from .vector_store import QdrantVectorStore
|
||||
from .embedders import get_embeddings, get_embedding_dimension
|
||||
from .vector_store import QdrantHybridStore
|
||||
from .sparse_embedder import BM25SparseEmbedder, get_sparse_embedder
|
||||
from .store import PostgresDocStore, create_docstore
|
||||
from .retriever_factory import create_parent_retriever
|
||||
from .client import create_qdrant_client, create_async_qdrant_client
|
||||
from .config import (
|
||||
QDRANT_URL,
|
||||
QDRANT_API_KEY,
|
||||
@@ -20,8 +19,9 @@ from .config import (
|
||||
|
||||
|
||||
__all__ = [
|
||||
"LlamaCppEmbedder",
|
||||
"QdrantVectorStore",
|
||||
"get_embeddings",
|
||||
"get_embedding_dimension",
|
||||
"QdrantHybridStore",
|
||||
"BM25SparseEmbedder",
|
||||
"get_sparse_embedder",
|
||||
"QDRANT_URL",
|
||||
@@ -32,5 +32,6 @@ __all__ = [
|
||||
"DOCSTORE_URI",
|
||||
"PostgresDocStore",
|
||||
"create_docstore",
|
||||
"create_parent_retriever",
|
||||
"create_qdrant_client",
|
||||
"create_async_qdrant_client",
|
||||
]
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# rag_core/client.py
|
||||
import os
|
||||
from .config import QDRANT_URL, QDRANT_API_KEY
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client import QdrantClient, AsyncQdrantClient
|
||||
|
||||
|
||||
def create_qdrant_client(timeout: int = 300) -> QdrantClient:
|
||||
@@ -28,3 +28,29 @@ def create_qdrant_client(timeout: int = 300) -> QdrantClient:
|
||||
client_kwargs["api_key"] = QDRANT_API_KEY
|
||||
|
||||
return QdrantClient(**client_kwargs)
|
||||
|
||||
|
||||
def create_async_qdrant_client(timeout: int = 300) -> AsyncQdrantClient:
|
||||
"""
|
||||
创建并返回一个配置好的 Qdrant 异步客户端。
|
||||
|
||||
Args:
|
||||
timeout: 请求超时时间(秒),默认 300 秒。
|
||||
|
||||
Returns:
|
||||
配置好的 AsyncQdrantClient 实例。
|
||||
|
||||
Raises:
|
||||
ValueError: 如果 QDRANT_URL 未配置。
|
||||
"""
|
||||
if not QDRANT_URL:
|
||||
raise ValueError("Qdrant URL 未配置")
|
||||
|
||||
client_kwargs = {
|
||||
"url": QDRANT_URL,
|
||||
"timeout": timeout,
|
||||
}
|
||||
if QDRANT_API_KEY:
|
||||
client_kwargs["api_key"] = QDRANT_API_KEY
|
||||
|
||||
return AsyncQdrantClient(**client_kwargs)
|
||||
|
||||
@@ -1,121 +1,37 @@
|
||||
"""
|
||||
嵌入模型包装器 - 直接使用统一嵌入服务
|
||||
支持自动降级(本地 llama.cpp → 智谱),由 get_embedding_service() 内部处理
|
||||
"""
|
||||
|
||||
import sys
|
||||
import logging
|
||||
from typing import List
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from .config import LLAMACPP_EMBEDDING_URL, LLAMACPP_API_KEY
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LlamaCppEmbedder:
|
||||
def get_embeddings() -> Embeddings:
|
||||
"""
|
||||
嵌入器包装类 - 直接使用统一的 get_embedding_service()
|
||||
降级逻辑完全由 app.model_services 处理
|
||||
获取统一的嵌入服务实例。
|
||||
|
||||
Returns:
|
||||
LangChain 兼容的 Embeddings 实例
|
||||
"""
|
||||
|
||||
def __init__(self, model: str = "Qwen3-Embedding-0.6B-Q8_0", use_fallback: bool = True):
|
||||
"""
|
||||
Args:
|
||||
model: 嵌入模型名称(向后兼容,现在实际使用统一服务)
|
||||
use_fallback: 是否使用降级机制(保留参数,现在始终为 True)
|
||||
"""
|
||||
self.model = model
|
||||
self._fallback_embeddings = None
|
||||
|
||||
# 直接获取统一嵌入服务
|
||||
try:
|
||||
from backend.app.model_services import get_embedding_service
|
||||
self._fallback_embeddings = get_embedding_service()
|
||||
logger.info("✅ 统一嵌入服务加载成功")
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ 无法加载统一嵌入服务: {e}")
|
||||
# 保留向后兼容的初始化
|
||||
self.base_url = LLAMACPP_EMBEDDING_URL
|
||||
self.api_key = LLAMACPP_API_KEY
|
||||
|
||||
def as_langchain_embeddings(self) -> Embeddings:
|
||||
"""创建 LangChain 兼容的嵌入实例"""
|
||||
if self._fallback_embeddings:
|
||||
logger.info("✅ 使用统一嵌入服务(已内置降级机制)")
|
||||
return self._fallback_embeddings
|
||||
|
||||
# 向后兼容,仅在统一服务不可用时使用传统方式
|
||||
logger.warning("⚠️ 统一服务不可用,使用传统模式(不推荐)")
|
||||
return _LlamaCppLangchainAdapter(self)
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""嵌入一批文档"""
|
||||
if self._fallback_embeddings:
|
||||
return self._fallback_embeddings.embed_documents(texts)
|
||||
|
||||
# 向后兼容
|
||||
return self._call_embedding_api(texts)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""嵌入单个查询"""
|
||||
if self._fallback_embeddings:
|
||||
return self._fallback_embeddings.embed_query(text)
|
||||
|
||||
# 向后兼容
|
||||
return self._call_embedding_api([text])[0]
|
||||
|
||||
def get_embedding_dimension(self) -> int:
|
||||
"""通过嵌入测试字符串获取嵌入维度"""
|
||||
test_embedding = self.embed_query("test")
|
||||
return len(test_embedding)
|
||||
|
||||
def _call_embedding_api(self, texts: List[str]) -> List[List[float]]:
|
||||
"""仅作为向后兼容的备用方法"""
|
||||
import httpx
|
||||
|
||||
if not hasattr(self, 'base_url') or not self.base_url:
|
||||
raise ValueError("LLAMACPP_EMBEDDING_URL 未配置且统一服务不可用")
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
base = self.base_url.rstrip("/")
|
||||
if not base.endswith("/v1"):
|
||||
base = base + "/v1"
|
||||
|
||||
payload = {
|
||||
"input": texts,
|
||||
"model": self.model,
|
||||
}
|
||||
|
||||
with httpx.Client(timeout=120) as client:
|
||||
response = client.post(
|
||||
f"{base}/embeddings",
|
||||
headers=headers,
|
||||
json=payload,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if isinstance(data, list):
|
||||
return [item["embedding"] for item in data]
|
||||
elif isinstance(data, dict) and "data" in data:
|
||||
return [item["embedding"] for item in sorted(data["data"], key=lambda x: x["index"])]
|
||||
else:
|
||||
raise ValueError(f"未知的嵌入 API 响应格式: {data}")
|
||||
from backend.app.model_services import get_embedding_service
|
||||
return get_embedding_service()
|
||||
|
||||
|
||||
class _LlamaCppLangchainAdapter(Embeddings):
|
||||
"""仅作为向后兼容的适配器"""
|
||||
|
||||
def __init__(self, embedder: LlamaCppEmbedder):
|
||||
self._embedder = embedder
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
return self._embedder.embed_documents(texts)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
return self._embedder.embed_query(text)
|
||||
def get_embedding_dimension(embeddings: Optional[Embeddings] = None) -> int:
|
||||
"""
|
||||
获取嵌入维度。
|
||||
|
||||
Args:
|
||||
embeddings: 可选的嵌入实例,如果不提供则自动获取
|
||||
|
||||
Returns:
|
||||
嵌入维度大小
|
||||
"""
|
||||
if embeddings is None:
|
||||
embeddings = get_embeddings()
|
||||
test_embedding = embeddings.embed_query("test")
|
||||
return len(test_embedding)
|
||||
|
||||
@@ -1,112 +0,0 @@
|
||||
"""
|
||||
RAG 检索器工厂模块
|
||||
|
||||
提供创建各种检索器的工厂函数,包括:
|
||||
- 基础向量检索器
|
||||
- ParentDocumentRetriever(父子文档)
|
||||
- 混合检索器(稠密+稀疏)
|
||||
"""
|
||||
from typing import Optional
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_classic.retrievers import ParentDocumentRetriever
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
|
||||
from langchain_core.stores import BaseStore
|
||||
|
||||
from .embedders import LlamaCppEmbedder
|
||||
from .vector_store import QdrantVectorStore
|
||||
from .store import create_docstore
|
||||
|
||||
|
||||
def create_parent_retriever(
|
||||
collection_name: str = "rag_documents",
|
||||
parent_splitter: Optional[TextSplitter] = None,
|
||||
child_splitter: Optional[TextSplitter] = None,
|
||||
docstore: Optional[BaseStore] = None,
|
||||
search_k: int = 5,
|
||||
parent_chunk_size: int = 1000,
|
||||
parent_chunk_overlap: int = 100,
|
||||
child_chunk_size: int = 200,
|
||||
child_chunk_overlap: int = 20,
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
) -> ParentDocumentRetriever:
|
||||
"""
|
||||
创建 ParentDocumentRetriever 实例(基础稠密向量版本)。
|
||||
|
||||
Args:
|
||||
collection_name: Qdrant 集合名称,默认 "rag_documents"
|
||||
parent_splitter: 父文档切分器,默认 None(使用默认参数创建)
|
||||
child_splitter: 子文档切分器,默认 None(使用默认参数创建)
|
||||
docstore: 文档存储实例,默认 None(使用默认参数创建)
|
||||
search_k: 检索时返回的结果数,默认 5
|
||||
parent_chunk_size: 父文档块大小,默认 1000
|
||||
parent_chunk_overlap: 父文档块重叠大小,默认 100
|
||||
child_chunk_size: 子文档块大小,默认 200
|
||||
child_chunk_overlap: 子文档块重叠大小,默认 20
|
||||
embeddings: 嵌入模型实例,默认 None(使用内部默认的 LocalLlamaCppEmbedder)
|
||||
|
||||
Returns:
|
||||
ParentDocumentRetriever 实例
|
||||
"""
|
||||
# 嵌入模型
|
||||
if embeddings is None:
|
||||
embedder = LlamaCppEmbedder()
|
||||
embeddings = embedder.as_langchain_embeddings()
|
||||
|
||||
# 向量存储(只读)
|
||||
vector_store = QdrantVectorStore(collection_name=collection_name, embeddings=embeddings)
|
||||
|
||||
# 切分器(若未提供则创建默认)
|
||||
if parent_splitter is None:
|
||||
parent_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=parent_chunk_size,
|
||||
chunk_overlap=parent_chunk_overlap,
|
||||
)
|
||||
if child_splitter is None:
|
||||
child_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=child_chunk_size,
|
||||
chunk_overlap=child_chunk_overlap,
|
||||
)
|
||||
|
||||
# 文档存储
|
||||
if docstore is None:
|
||||
docstore, _ = create_docstore()
|
||||
|
||||
return ParentDocumentRetriever(
|
||||
vectorstore=vector_store.get_langchain_vectorstore(),
|
||||
docstore=docstore,
|
||||
child_splitter=child_splitter,
|
||||
parent_splitter=parent_splitter,
|
||||
search_kwargs={"k": search_k},
|
||||
)
|
||||
|
||||
|
||||
def create_hybrid_retriever_factory(
|
||||
collection_name: str = "rag_documents",
|
||||
search_k: int = 5,
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
) -> BaseRetriever:
|
||||
"""
|
||||
【不完整,仅占位】创建混合检索器的工厂函数占位符。
|
||||
|
||||
注意:完整的混合检索逻辑在 app/rag/retriever.py 中实现。
|
||||
这里仅返回 QdrantVectorStore 作为基础。
|
||||
|
||||
Args:
|
||||
collection_name: Qdrant 集合名称
|
||||
search_k: 检索返回结果数
|
||||
embeddings: 嵌入模型实例
|
||||
|
||||
Returns:
|
||||
基础的 QdrantVectorStore(仅稠密检索)
|
||||
"""
|
||||
# 嵌入模型
|
||||
if embeddings is None:
|
||||
embedder = LlamaCppEmbedder()
|
||||
embeddings = embedder.as_langchain_embeddings()
|
||||
|
||||
# 创建向量存储
|
||||
vector_store = QdrantVectorStore(collection_name=collection_name, embeddings=embeddings)
|
||||
|
||||
# 返回 LangChain 兼容的 retriever
|
||||
return vector_store.get_langchain_vectorstore().as_retriever(search_kwargs={"k": search_k})
|
||||
@@ -1,6 +1,5 @@
|
||||
"""
|
||||
Qdrant 向量数据库包装器。
|
||||
支持稠密+稀疏双向量存储。
|
||||
Qdrant 向量数据库包装器(完全异步实现)。
|
||||
"""
|
||||
|
||||
import logging
|
||||
@@ -11,111 +10,91 @@ from typing import List, Optional, Dict, Any
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_qdrant import QdrantVectorStore as LangchainQdrantVS
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client import AsyncQdrantClient, QdrantClient
|
||||
from qdrant_client.http.models import (
|
||||
Distance, VectorParams, SparseVectorParams, PointStruct
|
||||
Distance, VectorParams, SparseVectorParams, PointStruct, models
|
||||
)
|
||||
from httpx import RemoteProtocolError
|
||||
from qdrant_client.http.exceptions import ResponseHandlingException
|
||||
|
||||
from .client import create_qdrant_client
|
||||
from .embedders import LlamaCppEmbedder
|
||||
from .client import create_qdrant_client, create_async_qdrant_client
|
||||
from .embedders import get_embeddings, get_embedding_dimension
|
||||
from .sparse_embedder import BM25SparseEmbedder, get_sparse_embedder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class QdrantVectorStore:
|
||||
"""Qdrant 向量数据库操作包装器 - 支持稠密+稀疏双向量存储。"""
|
||||
class QdrantHybridStore:
|
||||
"""
|
||||
Qdrant 向量数据库操作包装器 - 稠密+稀疏混合检索(完全异步)。
|
||||
直接使用 Qdrant 异步客户端实现,不依赖 LangChain。
|
||||
"""
|
||||
|
||||
def __init__(self, collection_name: str, embeddings: Optional[Embeddings] = None, sparse_embedder: Optional[BM25SparseEmbedder] = None):
|
||||
"""
|
||||
Args:
|
||||
collection_name: Qdrant 集合名称。
|
||||
embeddings: 嵌入模型实例,默认 None(使用内部默认的 LlamaCppEmbedder)。
|
||||
sparse_embedder: 稀疏嵌入模型实例,默认 None(自动加载BM25)。
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str,
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
sparse_embedder: Optional[BM25SparseEmbedder] = None,
|
||||
):
|
||||
self.collection_name = collection_name
|
||||
self._client: Optional[QdrantClient] = None
|
||||
self._async_client: Optional[AsyncQdrantClient] = None
|
||||
self._connection_attempts = 0
|
||||
self._last_connection_time: Optional[float] = None
|
||||
|
||||
|
||||
# 稠密嵌入模型
|
||||
if embeddings is None:
|
||||
embedder = LlamaCppEmbedder()
|
||||
self.embeddings = embedder.as_langchain_embeddings()
|
||||
self._embedder = embedder
|
||||
self.embeddings = get_embeddings()
|
||||
else:
|
||||
self.embeddings = embeddings
|
||||
self._embedder = None
|
||||
|
||||
|
||||
# 稀疏嵌入模型
|
||||
self.sparse_embedder = sparse_embedder or get_sparse_embedder()
|
||||
|
||||
# 集合初始化
|
||||
self.create_collection()
|
||||
|
||||
# 保留 LangChain 向量存储实例(用于兼容)
|
||||
self.vector_store = LangchainQdrantVS(
|
||||
client=self.get_client(),
|
||||
collection_name=self.collection_name,
|
||||
embedding=self.embeddings,
|
||||
vector_name="dense",
|
||||
)
|
||||
|
||||
# ---------- 同步连接管理 ----------
|
||||
def get_client(self) -> QdrantClient:
|
||||
if self._client is None:
|
||||
self._client = create_qdrant_client(timeout=300)
|
||||
self._connection_attempts += 1
|
||||
self._last_connection_time = time.time()
|
||||
logger.debug("Qdrant 客户端已创建 (第 %d 次连接)", self._connection_attempts)
|
||||
logger.debug("Qdrant 同步客户端已创建 (第 %d 次连接)", self._connection_attempts)
|
||||
return self._client
|
||||
|
||||
def refresh_client(self):
|
||||
"""关闭旧连接,创建新连接。"""
|
||||
if self._client is not None:
|
||||
try:
|
||||
self._client.close()
|
||||
logger.debug("Qdrant 旧连接已关闭")
|
||||
logger.debug("Qdrant 旧同步连接已关闭")
|
||||
except Exception as e:
|
||||
logger.warning("关闭 Qdrant 连接时出现异常: %s", e)
|
||||
logger.warning("关闭 Qdrant 同步连接时出现异常: %s", e)
|
||||
finally:
|
||||
self._client = None
|
||||
self._last_connection_time = None
|
||||
|
||||
def check_connection_health(self) -> bool:
|
||||
"""检查连接健康状态,如果连接已失效则自动重建。"""
|
||||
if self._client is None:
|
||||
logger.info("Qdrant 客户端未初始化,将创建新连接")
|
||||
return False
|
||||
|
||||
try:
|
||||
client = self.get_client()
|
||||
client.get_collections()
|
||||
logger.debug("Qdrant 连接健康检查通过")
|
||||
return True
|
||||
except (RemoteProtocolError, ConnectionError, OSError, ResponseHandlingException) as e:
|
||||
logger.warning("Qdrant 连接健康检查失败: %s", e)
|
||||
self.refresh_client()
|
||||
return False
|
||||
# ---------- 异步连接管理 ----------
|
||||
def get_async_client(self) -> AsyncQdrantClient:
|
||||
if self._async_client is None:
|
||||
self._async_client = create_async_qdrant_client(timeout=300)
|
||||
logger.debug("Qdrant 异步客户端已创建")
|
||||
return self._async_client
|
||||
|
||||
def get_connection_stats(self) -> Dict[str, Any]:
|
||||
"""获取连接统计信息。"""
|
||||
return {
|
||||
"connection_attempts": self._connection_attempts,
|
||||
"last_connection_time": self._last_connection_time,
|
||||
"client_initialized": self._client is not None,
|
||||
}
|
||||
async def close_async_client(self):
|
||||
if self._async_client is not None:
|
||||
try:
|
||||
await self._async_client.close()
|
||||
logger.debug("Qdrant 异步连接已关闭")
|
||||
except Exception as e:
|
||||
logger.warning("关闭 Qdrant 异步连接时出现异常: %s", e)
|
||||
finally:
|
||||
self._async_client = None
|
||||
|
||||
# ---------- 集合创建(同步,用于初始化) ----------
|
||||
def create_collection(self, force_recreate: bool = False):
|
||||
"""创建集合,支持稠密+稀疏双向量。"""
|
||||
if self._embedder is not None:
|
||||
# 使用内部的 embedder 获取维度
|
||||
vector_size = self._embedder.get_embedding_dimension()
|
||||
else:
|
||||
# 使用外部传入的 embeddings,通过测试获取维度
|
||||
test_embedding = self.embeddings.embed_query("test")
|
||||
vector_size = len(test_embedding)
|
||||
"""创建集合,确保有 'dense' 和 'sparse' 两个命名向量字段。"""
|
||||
vector_size = get_embedding_dimension(self.embeddings)
|
||||
|
||||
max_retries = 3
|
||||
base_delay = 2
|
||||
@@ -130,90 +109,168 @@ class QdrantVectorStore:
|
||||
exists = False
|
||||
|
||||
if not exists:
|
||||
# 向量配置:稠密向量
|
||||
vectors_config = {
|
||||
"dense": VectorParams(
|
||||
size=vector_size,
|
||||
distance=Distance.COSINE
|
||||
)
|
||||
}
|
||||
|
||||
# 稀疏向量配置(简化版,不使用特殊索引类型)
|
||||
sparse_vectors_config = {
|
||||
"sparse": SparseVectorParams()
|
||||
}
|
||||
|
||||
client.create_collection(
|
||||
collection_name=self.collection_name,
|
||||
vectors_config=vectors_config,
|
||||
sparse_vectors_config=sparse_vectors_config
|
||||
sparse_vectors_config=sparse_vectors_config,
|
||||
)
|
||||
logger.info(
|
||||
"集合 '%s' 已创建(维度=%d,稠密+稀疏双向量)",
|
||||
self.collection_name, vector_size,
|
||||
)
|
||||
logger.info("集合 '%s' 已创建(维度=%d,支持稠密+稀疏双向量)", self.collection_name, vector_size)
|
||||
else:
|
||||
logger.info("集合 '%s' 已存在", self.collection_name)
|
||||
return
|
||||
except (RemoteProtocolError, ConnectionError, OSError, ResponseHandlingException) as e:
|
||||
if attempt == max_retries - 1:
|
||||
logger.error("创建集合 '%s' 重试 %d 次后仍然失败: %s", self.collection_name, max_retries, e)
|
||||
raise
|
||||
wait_time = base_delay * (2 ** attempt)
|
||||
error_type = type(e).__name__
|
||||
logger.warning(
|
||||
"创建集合 '%s' 遇到网络异常 [%s],%d秒后重试 (%d/%d): %s",
|
||||
self.collection_name, error_type, wait_time, attempt + 1, max_retries, e
|
||||
self.collection_name, type(e).__name__, wait_time, attempt + 1, max_retries, e,
|
||||
)
|
||||
self.refresh_client()
|
||||
logger.debug("已刷新 Qdrant 客户端连接")
|
||||
time.sleep(wait_time)
|
||||
|
||||
def add_documents(self, documents: List[Document], batch_size: int = 100):
|
||||
"""将文档添加到向量数据库,自动生成稠密+稀疏双向量。"""
|
||||
# ---------- 异步索引方法 ----------
|
||||
async def aadd_documents(self, documents: List[Document], batch_size: int = 100) -> List[str]:
|
||||
"""
|
||||
异步添加文档(自动生成稠密+稀疏向量并批量写入)。
|
||||
"""
|
||||
if not documents:
|
||||
return []
|
||||
|
||||
# 确保集合存在
|
||||
self.create_collection()
|
||||
client = self.get_client()
|
||||
doc_ids = []
|
||||
|
||||
# 分批处理
|
||||
for i in range(0, len(documents), batch_size):
|
||||
batch_docs = documents[i:i+batch_size]
|
||||
texts = [doc.page_content for doc in batch_docs]
|
||||
|
||||
# 生成双向量
|
||||
dense_vectors = self.embeddings.embed_documents(texts)
|
||||
sparse_vectors = self.sparse_embedder.embed_documents(texts)
|
||||
all_ids = []
|
||||
total_docs = len(documents)
|
||||
|
||||
points = []
|
||||
for j, doc in enumerate(batch_docs):
|
||||
point_id = doc.metadata.get("id", str(uuid.uuid4()))
|
||||
doc_ids.append(point_id)
|
||||
for i in range(0, total_docs, batch_size):
|
||||
batch = documents[i:i+batch_size]
|
||||
batch_ids = await self._aadd_batch(batch)
|
||||
all_ids.extend(batch_ids)
|
||||
logger.info("已向 '%s' 添加批次 %d/%d,共 %d 个文档",
|
||||
self.collection_name,
|
||||
i//batch_size + 1,
|
||||
(total_docs + batch_size - 1)//batch_size,
|
||||
len(batch))
|
||||
|
||||
# 构造双向量
|
||||
named_vectors = {
|
||||
"dense": dense_vectors[j],
|
||||
"sparse": sparse_vectors[j]
|
||||
}
|
||||
logger.info("已向 '%s' 总共添加 %d 个文档(混合模式)", self.collection_name, len(all_ids))
|
||||
return all_ids
|
||||
|
||||
points.append(PointStruct(
|
||||
id=point_id,
|
||||
vector=named_vectors,
|
||||
payload={"text": doc.page_content, **doc.metadata}
|
||||
))
|
||||
async def _aadd_batch(self, documents: List[Document]) -> List[str]:
|
||||
"""异步添加单个批次的文档"""
|
||||
client = self.get_async_client()
|
||||
|
||||
# 批量插入
|
||||
client.upsert(collection_name=self.collection_name, points=points)
|
||||
logger.info("已向 '%s' 添加 %d 个文档(稠密+稀疏双向量)", self.collection_name, len(points))
|
||||
# 提取文本
|
||||
texts = [doc.page_content for doc in documents]
|
||||
|
||||
return doc_ids
|
||||
# 生成稠密向量
|
||||
dense_vectors = await self._aembed_texts(texts)
|
||||
|
||||
def similarity_search(self, query: str, k: int = 5) -> List[Document]:
|
||||
"""基础稠密向量检索(兼容原有接口)。"""
|
||||
return self.vector_store.similarity_search(query, k=k)
|
||||
# 生成稀疏向量
|
||||
sparse_vectors = self.sparse_embedder.embed_documents(texts)
|
||||
|
||||
def similarity_search_with_score(self, query: str, k: int = 5) -> List[tuple[Document, float]]:
|
||||
"""基础稠密向量检索带分数(兼容原有接口)。"""
|
||||
return self.vector_store.similarity_search_with_score(query, k=k)
|
||||
# 构建点结构
|
||||
points = []
|
||||
for doc, dense_vec, sparse_vec in zip(documents, dense_vectors, sparse_vectors):
|
||||
point_id = str(uuid.uuid4())
|
||||
payload = {
|
||||
"page_content": doc.page_content,
|
||||
**doc.metadata
|
||||
}
|
||||
point = PointStruct(
|
||||
id=point_id,
|
||||
vector={
|
||||
"dense": dense_vec,
|
||||
"sparse": models.SparseVector(
|
||||
indices=sparse_vec["indices"],
|
||||
values=sparse_vec["values"]
|
||||
)
|
||||
},
|
||||
payload=payload
|
||||
)
|
||||
points.append(point)
|
||||
|
||||
# 写入 Qdrant
|
||||
await client.upsert(
|
||||
collection_name=self.collection_name,
|
||||
points=points
|
||||
)
|
||||
|
||||
return [p.id for p in points]
|
||||
|
||||
async def _aembed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
"""异步生成稠密向量(适配同步 Embeddings 接口)"""
|
||||
# 注意:LangChain 的 Embeddings 接口目前主要是同步的
|
||||
# 使用线程池或直接调用(如果 embedding 内部有异步支持)
|
||||
import asyncio
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, self.embeddings.embed_documents, texts)
|
||||
|
||||
# ---------- 异步检索方法 ----------
|
||||
async def asimilarity_search(self, query: str, k: int = 5) -> List[Document]:
|
||||
"""
|
||||
异步混合检索(稠密 + 稀疏),返回文档列表。
|
||||
使用 Qdrant 的 Universal Query API + RRF 融合。
|
||||
"""
|
||||
client = self.get_async_client()
|
||||
|
||||
# 生成查询向量
|
||||
dense_query = await self._aembed_query(query)
|
||||
sparse_query = self.sparse_embedder.embed_query(query)
|
||||
sparse_vec = models.SparseVector(
|
||||
indices=sparse_query["indices"],
|
||||
values=sparse_query["values"]
|
||||
)
|
||||
|
||||
# 使用 Qdrant 的 query_points API
|
||||
response = await client.query_points(
|
||||
collection_name=self.collection_name,
|
||||
prefetch=[
|
||||
models.Prefetch(
|
||||
query=dense_query,
|
||||
using="dense",
|
||||
limit=k
|
||||
),
|
||||
models.Prefetch(
|
||||
query=sparse_vec,
|
||||
using="sparse",
|
||||
limit=k
|
||||
)
|
||||
],
|
||||
query=models.FusionQuery(fusion=models.Fusion.RRF),
|
||||
limit=k,
|
||||
with_payload=True
|
||||
)
|
||||
|
||||
# 转换结果
|
||||
results = []
|
||||
for point in response.points:
|
||||
page_content = point.payload.pop("page_content", "")
|
||||
doc = Document(page_content=page_content, metadata=point.payload)
|
||||
results.append(doc)
|
||||
|
||||
logger.debug("混合检索返回 %d 个文档", len(results))
|
||||
return results
|
||||
|
||||
async def _aembed_query(self, text: str) -> List[float]:
|
||||
"""异步生成查询稠密向量"""
|
||||
import asyncio
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, self.embeddings.embed_query, text)
|
||||
|
||||
# ---------- 同步管理方法(保留,用于初始化和管理) ----------
|
||||
def delete_collection(self):
|
||||
self.get_client().delete_collection(self.collection_name)
|
||||
logger.info("集合 '%s' 已删除", self.collection_name)
|
||||
@@ -233,13 +290,10 @@ class QdrantVectorStore:
|
||||
"vector_size": vector_size,
|
||||
}
|
||||
|
||||
def as_langchain_vectorstore(self):
|
||||
return self.vector_store
|
||||
|
||||
def get_langchain_vectorstore(self):
|
||||
"""返回 LangChain Qdrant 向量存储对象(别名)"""
|
||||
return self.vector_store
|
||||
|
||||
def get_qdrant_client(self):
|
||||
"""返回原生 Qdrant 客户端(用于自定义检索逻辑)"""
|
||||
"""返回原生 Qdrant 同步客户端(用于管理操作)。"""
|
||||
return self.get_client()
|
||||
|
||||
def get_async_qdrant_client(self):
|
||||
"""返回原生 Qdrant 异步客户端(用于索引和检索)。"""
|
||||
return self.get_async_client()
|
||||
|
||||
@@ -11,10 +11,11 @@
|
||||
| **文档解析** | `unstructured` | 0.22+ | 多格式文档解析(PDF/DOCX/TXT等) |
|
||||
| **文本切分** | `langchain-text-splitters` | 内置 | 递归字符切分 + 语义切分 |
|
||||
| **语义切分** | `langchain-experimental` | 内置 | `SemanticChunker` 基于句子相似度 |
|
||||
| **嵌入模型** | `llama.cpp` | 本地服务 | `embeddinggemma-300M` GGUF 模型 |
|
||||
| **向量数据库** | `Qdrant` | 1.17+ | HNSW 索引,支持稠密/稀疏向量 |
|
||||
| **嵌入模型** | `llama.cpp` | 本地服务 | `Qwen3-Embedding-0.6B` GGUF 模型 |
|
||||
| **稀疏嵌入** | `fastembed` | 内置 | BM25 关键词检索 |
|
||||
| **向量数据库** | `Qdrant` | 1.17+ | HNSW 索引,支持稠密/稀疏向量 + RRF 融合 |
|
||||
| **文档存储** | `PostgreSQL` | 16+ | 异步连接池,持久化父块 |
|
||||
| **编排框架** | `asyncio` | Python 3.10+ | 异步批量处理与重试 |
|
||||
| **编排框架** | `asyncio` | Python 3.10+ | 全异步批量处理 |
|
||||
|
||||
### 数据流向总览
|
||||
|
||||
@@ -33,27 +34,28 @@
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ ParentDocumentRetriever │
|
||||
│ 自定义父子块索引实现 │
|
||||
│ ┌─────────────────────────────────────────────────────┐ │
|
||||
│ │ parent_splitter (粗切) │ │
|
||||
│ │ 父块 ~1000 字符 │ │
|
||||
│ └──────────────────────┬──────────────────────────────┘ │
|
||||
│ │ │
|
||||
│ ┌──────────────────────▼──────────────────────────────┐ │
|
||||
│ │ 父文档存入 PostgreSQL (UUID 映射) │ │
|
||||
│ └──────────────────────┬──────────────────────────────┘ │
|
||||
│ │ │
|
||||
│ ┌──────────────────────▼──────────────────────────────┐ │
|
||||
│ │ child_splitter (细切) │ │
|
||||
│ │ 子块 ~200 字符 │ │
|
||||
│ └──────────────────────┬──────────────────────────────┘ │
|
||||
│ │ │
|
||||
│ ┌─────────────┴─────────────┐ │
|
||||
│ ▼ ▼ │
|
||||
│ ┌─────────────┐ ┌─────────────────┐ │
|
||||
│ │ 子块向量 │ │ 父块原始内容 │ │
|
||||
│ │ │ │ │ │
|
||||
│ ▼ │ ▼ │ │
|
||||
│ ┌────────────┐ │ ┌─────────────────┐ │ │
|
||||
│ │vector_store│ │ │ store/ │ │ │
|
||||
│ │ (Qdrant) │ │ │ (PostgreSQL) │ │ │
|
||||
│ └──────────── │ └─────────────────┘ │ │
|
||||
│ ┌──────────────────────▼──────────────────────────────┐ │
|
||||
│ │ 子文档生成 dense + sparse 双向量 │ │
|
||||
│ └──────────────────────┬──────────────────────────────┘ │
|
||||
│ │ │
|
||||
│ ┌──────────────────────▼──────────────────────────────┐ │
|
||||
│ │ 子文档存入 Qdrant (payload 含 parent_id) │ │
|
||||
│ └─────────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
@@ -63,7 +65,9 @@
|
||||
- ✅ **三种切分策略**:递归字符切分、语义切分、父子块策略
|
||||
- ✅ **Parent-Child 架构**:子块精准检索,父块完整上下文
|
||||
- ✅ **PostgreSQL DocStore**:持久化存储父块,支持异步连接池
|
||||
- ✅ **批量写入与重试**:自动处理网络波动,确保索引完整性
|
||||
- ✅ **混合检索**:稠密向量(语义)+ 稀疏向量(关键词),Qdrant 原生 RRF 融合
|
||||
- ✅ **完全异步化**:索引构建、检索全链路 async / await
|
||||
- ✅ **批量写入**:高效批量处理,自动分批
|
||||
- ✅ **上下文管理器**:支持同步/异步资源管理
|
||||
|
||||
## 📂 架构与文件结构
|
||||
@@ -71,17 +75,26 @@
|
||||
```
|
||||
rag_indexer/
|
||||
├── __init__.py
|
||||
├── cli.py # 命令行入口
|
||||
├── index_builder.py # 索引构建主流水线
|
||||
├── index_builder.py # 索引构建主流水线(自定义父子块实现)
|
||||
├── loaders.py # 文档加载器(多格式支持)
|
||||
├── splitters.py # 文本切分器(递归/语义/父子块)
|
||||
└── README.md # 本文档
|
||||
```
|
||||
|
||||
```
|
||||
backend/rag_core/
|
||||
├── __init__.py
|
||||
├── vector_store.py # Qdrant 混合存储(异步)
|
||||
├── sparse_embedder.py # BM25 稀疏嵌入
|
||||
├── embedders.py # 嵌入模型封装
|
||||
├── vector_store.py # Qdrant 向量存储
|
||||
├── store/
|
||||
│ ├── __init__.py
|
||||
│ ├── factory.py # DocStore 工厂函数
|
||||
│ └── postgres.py # PostgreSQL DocStore 实现
|
||||
└── test/ # 测试脚本
|
||||
├── store.py # PostgreSQL 文档存储
|
||||
├── client.py # Qdrant 同步/异步客户端工厂
|
||||
└── config.py # 配置管理
|
||||
```
|
||||
|
||||
```
|
||||
backend/app/rag/
|
||||
└── retriever.py # 混合检索器(异步)
|
||||
```
|
||||
|
||||
## 🎯 演进路线与核心算法 (Roadmap)
|
||||
@@ -133,26 +146,30 @@ chunks = chunker.split_documents(documents)
|
||||
|
||||
### Level 3: 高级父子块策略 (Parent-Child / Auto-merging)
|
||||
|
||||
- **核心算法**: 层次化双重存储与映射。
|
||||
- **核心算法**: 层次化双重存储与映射(自定义实现)。
|
||||
- **切分机制**: 首先将文档粗切为较大的"父块 (Parent Chunk, 约 1000 字符)",随后将父块细切为较小的"子块 (Child Chunk, 约 200 字符)"
|
||||
- **存储机制**: 仅仅将**子块**的向量存入 Qdrant 用于精准计算距离;将**父块**的原始内容存在 PostgreSQL DocStore 中,通过 UUID 相互映射
|
||||
- **存储机制**:
|
||||
- **子块**: 存入 Qdrant,同时生成 dense 向量(语义)和 sparse 向量(关键词),payload 中包含 `parent_id`
|
||||
- **父块**: 存入 PostgreSQL,通过 UUID 与子块映射
|
||||
- **核心思路**: 解决 RAG 领域经典的矛盾——检索时块越小越容易精确命中(去除噪声);但生成回答时,块越大越能给大模型提供充足的上下文背景。
|
||||
- **实现指南**:
|
||||
- 使用 `langchain_classic.retrievers` 中的 `ParentDocumentRetriever` 模块
|
||||
- 在写入时,需要同时准备一个底层的 `VectorStore` (即 Qdrant) 和一个 `BaseStore`
|
||||
- **推荐方案**: 使用 `PostgresDocStore` 作为 docstore,支持持久化存储
|
||||
- 将两种不同的 `TextSplitter` 分别赋值给检索器的 `child_splitter` 和 `parent_splitter`,然后调用 `.add_documents()` 即可让系统自动完成映射
|
||||
- **实现**:
|
||||
- 完全自定义实现,不依赖 LangChain 的 `ParentDocumentRetriever`
|
||||
- 支持异步批量写入
|
||||
- 支持双向量混合检索
|
||||
|
||||
```python
|
||||
from langchain.retrievers import ParentDocumentRetriever
|
||||
from rag_indexer.index_builder import IndexBuilder, IndexBuilderConfig
|
||||
from rag_indexer.splitters import SplitterType
|
||||
|
||||
retriever = ParentDocumentRetriever(
|
||||
vectorstore=qdrant_store,
|
||||
docstore=postgres_docstore,
|
||||
parent_splitter=parent_splitter,
|
||||
child_splitter=child_splitter,
|
||||
config = IndexBuilderConfig(
|
||||
collection_name="rag_documents",
|
||||
splitter_type=SplitterType.PARENT_CHILD,
|
||||
parent_chunk_size=1000,
|
||||
child_chunk_size=200,
|
||||
)
|
||||
await retriever.aadd_documents(documents)
|
||||
|
||||
builder = IndexBuilder(config)
|
||||
await builder.build_from_file("document.pdf")
|
||||
```
|
||||
|
||||
### Level 3.1: PostgreSQL DocStore 集成
|
||||
@@ -191,11 +208,232 @@ config = IndexBuilderConfig(
|
||||
child_chunk_size=200,
|
||||
child_splitter_type=SplitterType.SEMANTIC, # 子块使用语义切分
|
||||
docstore=DocstoreConfig(
|
||||
connection_string="postgresql://user:pass@host:5432/db",
|
||||
connection_string="postgresql://user:***@host:5432/db",
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
### Level 3.3: 混合检索架构(稠密 + 稀疏)
|
||||
|
||||
- **核心算法**: Qdrant 原生双向量存储 + RRF 分数融合
|
||||
- **稠密向量 (Dense)**: 语义相似度检索,捕捉深层含义
|
||||
- **稀疏向量 (Sparse)**: BM25 关键词检索,精确匹配术语
|
||||
- **RRF 融合 (Reciprocal Rank Fusion)**: 服务端分数融合,无需客户端后处理
|
||||
- **核心思路**: 结合语义理解和关键词匹配的双重优势,大幅提升召回率
|
||||
- **实现原理**:
|
||||
- 每个子文档同时生成 dense 向量和 sparse 向量
|
||||
- 使用 Qdrant 的 `query_points` API + `Prefetch` 并行检索
|
||||
- 通过 `FusionQuery` 自动进行 RRF 分数融合
|
||||
|
||||
```python
|
||||
from app.rag.retriever import create_parent_hybrid_retriever
|
||||
|
||||
# 创建父子文档混合检索器
|
||||
retriever = create_parent_hybrid_retriever(
|
||||
collection_name="rag_documents",
|
||||
search_k=5
|
||||
)
|
||||
|
||||
# 异步检索相关文档
|
||||
docs = await retriever.ainvoke("用户查询")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 📦 存储结构详解
|
||||
|
||||
### 整体数据流向
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────┐
|
||||
│ 原始文档 │
|
||||
│ (Document + Metadata) │
|
||||
└───────────────┬─────────────────────────┘
|
||||
│ 切分
|
||||
┌───────────────▼─────────────────────────┐
|
||||
│ 父文档块 (Parent Chunks) │
|
||||
│ 大粒度:1000-2000字符/块 │
|
||||
│ 存:PostgreSQL JSONB │
|
||||
└───────────────┬─────────────────────────┘
|
||||
│ 再切分
|
||||
┌───────────────▼─────────────────────────┐
|
||||
│ 子文档块 (Child Chunks) │
|
||||
│ 小粒度:200-400字符/块 │
|
||||
│ 存:Qdrant (稠密+稀疏双向量) │
|
||||
└─────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### PostgreSQL 存储结构(父文档)
|
||||
|
||||
#### 表结构
|
||||
|
||||
```sql
|
||||
CREATE TABLE parent_documents (
|
||||
key TEXT PRIMARY KEY,
|
||||
value JSONB NOT NULL,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW()
|
||||
);
|
||||
```
|
||||
|
||||
#### 数据格式(JSONB)
|
||||
|
||||
```json
|
||||
{
|
||||
"page_content": "这是一个父文档块,包含完整的上下文信息,用于最终给 LLM 生成回答...",
|
||||
"metadata": {
|
||||
"source": "file_name.pdf",
|
||||
"page": 10,
|
||||
"chunk_id": "parent-12345",
|
||||
"timestamp": "2024-05-04T12:34:56Z"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Qdrant 存储结构(子文档)
|
||||
|
||||
#### 集合配置
|
||||
|
||||
```python
|
||||
vectors_config = {
|
||||
"dense": VectorParams(
|
||||
size=2048, # 或 1024、4096,取决于嵌入模型
|
||||
distance=Distance.COSINE
|
||||
)
|
||||
}
|
||||
|
||||
sparse_vectors_config = {
|
||||
"sparse": SparseVectorParams()
|
||||
}
|
||||
```
|
||||
|
||||
#### 点结构(Point)
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "child-12345",
|
||||
"vector": {
|
||||
"dense": [0.123, 0.456, ...],
|
||||
"sparse": {
|
||||
"indices": [10, 50, 234, ...],
|
||||
"values": [0.8, 0.5, 0.3, ...]
|
||||
}
|
||||
},
|
||||
"payload": {
|
||||
"text": "这是一个子文档块,用于检索...",
|
||||
"parent_id": "parent-12345",
|
||||
"source": "file_name.pdf",
|
||||
"page": 10,
|
||||
"chunk_index": 0
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🔄 完整数据流
|
||||
|
||||
### 索引构建阶段
|
||||
|
||||
```
|
||||
原始文档
|
||||
↓
|
||||
切分为父块(1000字符/块)
|
||||
↓
|
||||
为每个父块分配唯一 ID (parent_id)
|
||||
↓
|
||||
存父块到 PostgreSQL (key=parent_id, value=Document)
|
||||
↓
|
||||
每个父块再切分为子块(200字符/块)
|
||||
↓
|
||||
为每个子块生成:
|
||||
- dense 向量
|
||||
- sparse 向量
|
||||
- payload 中加入 parent_id
|
||||
↓
|
||||
存子块到 Qdrant
|
||||
```
|
||||
|
||||
### 检索阶段
|
||||
|
||||
```
|
||||
用户查询
|
||||
↓
|
||||
生成查询的 dense + sparse 向量
|
||||
↓
|
||||
Qdrant 混合检索(RRF 分数融合)
|
||||
↓
|
||||
得到相关子文档列表
|
||||
↓
|
||||
收集子文档的 parent_id(去重)
|
||||
↓
|
||||
用 parent_id 批量查询 PostgreSQL
|
||||
↓
|
||||
得到完整的父文档
|
||||
↓
|
||||
返回给 LLM
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 📊 存储消耗分析(估算)
|
||||
|
||||
假设我们有 **100 个 PDF 文档,平均每个文档 100,000 字符**,总字符数 10,000,000。
|
||||
|
||||
| 存储类型 | 数量 | 单条大小 | 总大小 |
|
||||
|---------|------|---------|--------|
|
||||
| **PostgreSQL 父文档** | ~10,000 块 | 1KB (text) + 0.5KB (metadata) | **15MB** |
|
||||
| **Qdrant 子文档** | ~50,000 块 | 见下文 | **~450-500MB** |
|
||||
|
||||
### Qdrant 单条子文档详细分解
|
||||
|
||||
| 项 | 说明 | 大小 |
|
||||
|---|-------|------|
|
||||
| dense 向量 | float32[2048] | 8,192 bytes (~8KB) |
|
||||
| sparse 向量 | 平均 50-100 非零维 | 400-800 bytes |
|
||||
| payload | 子文本 + metadata | 200-500 bytes |
|
||||
| **合计** | | **~9-10KB / 条** |
|
||||
|
||||
对于 50,000 条子文档:**~450-500MB**
|
||||
|
||||
---
|
||||
|
||||
## ⚡ 优化策略
|
||||
|
||||
### 1. 分层存储
|
||||
|
||||
- **热数据(频繁访问)**:父文档 + 子文档都在 Qdrant(更快)
|
||||
- **冷数据(不常访问)**:父文档在 PostgreSQL,子文档在 Qdrant(更省)
|
||||
|
||||
### 2. 向量压缩
|
||||
|
||||
- Qdrant 支持 Scalar Quantization (SQ) 或 Product Quantization (PQ)
|
||||
- 可将 dense 向量从 8KB 压缩到 2-4KB,节省 50-75%
|
||||
|
||||
### 3. 稀疏向量优化
|
||||
|
||||
- BM25 可以剪枝(prune)低权重的词
|
||||
- 保留 top 50 关键词即可,不用全量
|
||||
|
||||
### 4. 父子块大小调整
|
||||
|
||||
- 父块:1000-2000(平衡上下文完整性)
|
||||
- 子块:100-300(平衡检索精度)
|
||||
|
||||
---
|
||||
|
||||
## ✨ 核心优势总结
|
||||
|
||||
| 特性 | 说明 |
|
||||
|------|------|
|
||||
| **检索精度** | 子块小 → 语义更精准 |
|
||||
| **回答质量** | 父块大 → 上下文完整 |
|
||||
| **混合检索** | dense(语义)+ sparse(关键词)= 召回率高 |
|
||||
| **存储效率** | 父子分离 → 不用重复存储大段文本 |
|
||||
|
||||
### Level 4: GraphRAG(基于图和关系的 RAG)
|
||||
|
||||
- **核心算法**: LLM 实体关系抽取 (NER & Relation Extraction)。
|
||||
@@ -329,9 +567,9 @@ async with IndexBuilder(config) as builder:
|
||||
封装 Qdrant 向量数据库操作。
|
||||
|
||||
```python
|
||||
from rag_core import QdrantVectorStore
|
||||
from rag_core import QdrantHybridStore
|
||||
|
||||
vector_store = QdrantVectorStore(
|
||||
vector_store = QdrantHybridStore(
|
||||
collection_name="rag_documents",
|
||||
embeddings=embeddings,
|
||||
)
|
||||
|
||||
@@ -39,8 +39,9 @@ from .config import (
|
||||
|
||||
# 从 rag_core 重新导出常用组件
|
||||
from backend.rag_core import (
|
||||
LlamaCppEmbedder,
|
||||
QdrantVectorStore,
|
||||
get_embeddings,
|
||||
get_embedding_dimension,
|
||||
QdrantHybridStore,
|
||||
PostgresDocStore,
|
||||
create_docstore,
|
||||
)
|
||||
@@ -52,14 +53,14 @@ __all__ = [
|
||||
"IndexBuilder",
|
||||
"IndexBuilderConfig",
|
||||
"DocstoreConfig",
|
||||
|
||||
|
||||
# 加载器
|
||||
"DocumentLoader",
|
||||
|
||||
|
||||
# 切分相关
|
||||
"SplitterType",
|
||||
"get_splitter",
|
||||
|
||||
|
||||
# 配置
|
||||
"QDRANT_URL",
|
||||
"QDRANT_API_KEY",
|
||||
@@ -69,11 +70,12 @@ __all__ = [
|
||||
"DOCSTORE_URI",
|
||||
"RAG_OCR_LANGUAGES",
|
||||
"RAG_DOC_LANGUAGES",
|
||||
|
||||
|
||||
# 嵌入与向量存储
|
||||
"LlamaCppEmbedder",
|
||||
"QdrantVectorStore",
|
||||
|
||||
"get_embeddings",
|
||||
"get_embedding_dimension",
|
||||
"QdrantHybridStore",
|
||||
|
||||
# 文档存储
|
||||
"PostgresDocStore",
|
||||
"create_docstore",
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
"""
|
||||
离线 RAG 索引构建核心流水线。
|
||||
|
||||
使用 LangChain 的 ParentDocumentRetriever 实现父子块策略。
|
||||
支持 Qdrant 混合检索(Dense + Sparse)。
|
||||
自定义实现父子块策略,支持 Qdrant 混合检索(Dense + Sparse)。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -12,33 +11,22 @@ from pathlib import Path
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Union, Optional, Any, Dict
|
||||
|
||||
from httpx import RemoteProtocolError
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.stores import BaseStore
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
|
||||
from qdrant_client.http.exceptions import ResponseHandlingException
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.http.models import SparseVectorParams
|
||||
|
||||
from .loaders import DocumentLoader
|
||||
from .splitters import SplitterType, get_splitter
|
||||
|
||||
from backend.rag_core import LlamaCppEmbedder, QdrantVectorStore, create_docstore, create_parent_retriever
|
||||
|
||||
# 尝试导入新的 model_services(如果可用)
|
||||
try:
|
||||
from backend.app.model_services import get_embedding_service
|
||||
HAS_MODEL_SERVICES = True
|
||||
except ImportError:
|
||||
HAS_MODEL_SERVICES = False
|
||||
from backend.rag_core import get_embeddings, QdrantHybridStore, create_docstore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------- 配置数据类 ----------
|
||||
@dataclass
|
||||
class DocstoreConfig:
|
||||
"""文档存储配置(用于父块存储)。"""
|
||||
"""文档存储配置(用于父文档存储)。"""
|
||||
pool_config: Dict[str, Any] | None = None
|
||||
max_concurrency: int | None = None
|
||||
# 若要从外部注入已创建好的 docstore,可直接设置此字段
|
||||
@@ -71,11 +59,10 @@ class IndexBuilderConfig:
|
||||
class IndexBuilder:
|
||||
"""RAG 索引构建主流水线,支持单块切分与父子块切分,支持混合检索。"""
|
||||
|
||||
def __init__(self, config: Optional[IndexBuilderConfig] = None, embeddings: Optional[Embeddings] = None, **kwargs):
|
||||
def __init__(self, config: Optional[IndexBuilderConfig] = None, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
config: 索引构建器配置对象,优先级高于 kwargs
|
||||
embeddings: 可选的外部嵌入模型实例,如果提供则使用它
|
||||
**kwargs: 可直接传入配置参数,会合并到 config 中(为方便使用保留)
|
||||
"""
|
||||
if config is None:
|
||||
@@ -91,29 +78,15 @@ class IndexBuilder:
|
||||
|
||||
# 初始化基础组件
|
||||
self.loader = DocumentLoader()
|
||||
|
||||
# 设置嵌入模型 - 优先使用外部提供的,然后尝试使用新服务,最后回退到原来的方式
|
||||
if embeddings is not None:
|
||||
self.embeddings = embeddings
|
||||
self._embedder = None
|
||||
logger.info("使用外部提供的嵌入模型")
|
||||
elif HAS_MODEL_SERVICES:
|
||||
try:
|
||||
self.embeddings = get_embedding_service()
|
||||
self._embedder = None
|
||||
logger.info("使用 model_services 提供的嵌入服务")
|
||||
except Exception as e:
|
||||
logger.warning(f"获取嵌入服务失败,回退到 LlamaCppEmbedder: {e}")
|
||||
self._embedder = LlamaCppEmbedder()
|
||||
self.embeddings = self._embedder.as_langchain_embeddings()
|
||||
else:
|
||||
self._embedder = LlamaCppEmbedder()
|
||||
self.embeddings = self._embedder.as_langchain_embeddings()
|
||||
|
||||
# 设置嵌入模型 - 完全使用服务内部提供
|
||||
self.embeddings = get_embeddings()
|
||||
logger.info("使用统一嵌入服务")
|
||||
|
||||
# 初始化向量存储(自动支持稠密+稀疏混合检索)
|
||||
self.vector_store = QdrantVectorStore(
|
||||
self.vector_store = QdrantHybridStore(
|
||||
collection_name=config.collection_name,
|
||||
embeddings=self.embeddings if self._embedder is None else None
|
||||
embeddings=self.embeddings,
|
||||
)
|
||||
logger.info("✅ 混合检索向量存储初始化成功(稠密+BM25稀疏)")
|
||||
|
||||
@@ -141,13 +114,13 @@ class IndexBuilder:
|
||||
def _init_parent_child_mode(self) -> None:
|
||||
cfg = self.config
|
||||
|
||||
# 父块切分器(索引构建需要,必须保留)
|
||||
# 父块切分器
|
||||
self.parent_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=cfg.parent_chunk_size,
|
||||
chunk_overlap=cfg.parent_chunk_overlap,
|
||||
)
|
||||
|
||||
# 子块切分器(索引构建需要)
|
||||
# 子块切分器
|
||||
if cfg.child_splitter_type == SplitterType.SEMANTIC:
|
||||
self.child_splitter = get_splitter(
|
||||
SplitterType.SEMANTIC,
|
||||
@@ -163,16 +136,10 @@ class IndexBuilder:
|
||||
# 文档存储
|
||||
self.docstore = self._create_or_use_docstore()
|
||||
|
||||
# 使用工厂函数创建检索器,避免重复代码
|
||||
self.retriever = create_parent_retriever(
|
||||
collection_name=cfg.collection_name,
|
||||
parent_splitter=self.parent_splitter,
|
||||
child_splitter=self.child_splitter,
|
||||
docstore=self.docstore,
|
||||
search_k=cfg.search_k,
|
||||
embeddings=self.embeddings if self._embedder is None else None,
|
||||
)
|
||||
logger.info("ParentDocumentRetriever 初始化完成")
|
||||
# 注意:不再使用 LangChain 的 ParentDocumentRetriever
|
||||
# 改为自定义实现,以支持稀疏向量
|
||||
self.retriever = None
|
||||
logger.info("父子文档模式初始化完成(使用自定义索引逻辑)")
|
||||
|
||||
def _create_or_use_docstore(self) -> BaseStore:
|
||||
"""创建或获取文档存储实例。"""
|
||||
@@ -217,54 +184,71 @@ class IndexBuilder:
|
||||
return await self._index_with_single_splitter(documents)
|
||||
|
||||
async def _index_with_single_splitter(self, documents: List[Document]) -> int:
|
||||
"""单一切分模式:切分后直接写入向量库。"""
|
||||
"""单一切分模式:切分后直接写入向量库(异步)。"""
|
||||
chunks = self.splitter.split_documents(documents)
|
||||
logger.info("已切分为 %d 个块", len(chunks))
|
||||
|
||||
self.vector_store.create_collection()
|
||||
self.vector_store.add_documents(chunks)
|
||||
await self.vector_store.aadd_documents(chunks)
|
||||
return len(chunks)
|
||||
|
||||
async def _index_with_parent_child(self, documents: List[Document]) -> int:
|
||||
"""父子块模式:使用 ParentDocumentRetriever 批量添加。"""
|
||||
"""父子块模式:自定义实现,支持稠密+稀疏双向量。"""
|
||||
self.vector_store.create_collection()
|
||||
assert self.retriever is not None
|
||||
assert self.docstore is not None
|
||||
|
||||
batch_size = 10
|
||||
total = len(documents)
|
||||
processed = 0
|
||||
import uuid
|
||||
total_chunks = 0
|
||||
|
||||
for i in range(0, total, batch_size):
|
||||
batch = documents[i:i+batch_size]
|
||||
await self._add_batch_with_retry(batch, i // batch_size + 1)
|
||||
processed += len(batch)
|
||||
logger.info("批次 %d: 已处理 %d/%d", i // batch_size + 1, processed, total)
|
||||
# 1. 切分父块
|
||||
parent_chunks = self.parent_splitter.split_documents(documents)
|
||||
logger.info("切分出 %d 个父块", len(parent_chunks))
|
||||
|
||||
logger.info("ParentDocumentRetriever 索引完成,共处理 %d 个文档", processed)
|
||||
return processed
|
||||
# 2. 为每个父块生成 UUID 并存储
|
||||
parent_docs_with_ids = []
|
||||
for parent_chunk in parent_chunks:
|
||||
parent_id = str(uuid.uuid4())
|
||||
parent_chunk.metadata["id"] = parent_id
|
||||
parent_chunk.metadata["is_parent"] = True
|
||||
parent_docs_with_ids.append((parent_id, parent_chunk))
|
||||
|
||||
# 3. 父文档批量存入 PostgreSQL
|
||||
await self.docstore.amset(parent_docs_with_ids)
|
||||
logger.info("已存入 %d 个父文档到 PostgreSQL", len(parent_docs_with_ids))
|
||||
|
||||
# 4. 切分子块并添加 parent_id
|
||||
all_child_chunks = []
|
||||
for parent_id, parent_chunk in parent_docs_with_ids:
|
||||
child_chunks = self.child_splitter.split_documents([parent_chunk])
|
||||
for child_chunk in child_chunks:
|
||||
child_chunk.metadata["parent_id"] = parent_id
|
||||
child_chunk.metadata["is_parent"] = False
|
||||
# 继承父文档的重要元数据
|
||||
child_chunk.metadata["source"] = parent_chunk.metadata.get("source")
|
||||
child_chunk.metadata["page"] = parent_chunk.metadata.get("page")
|
||||
child_chunk.metadata["file_path"] = parent_chunk.metadata.get("file_path")
|
||||
all_child_chunks.append(child_chunk)
|
||||
|
||||
total_chunks = len(all_child_chunks)
|
||||
logger.info("切分出 %d 个子块", total_chunks)
|
||||
|
||||
# 5. 子文档分批存入 Qdrant(双向量,异步)
|
||||
batch_size = 100
|
||||
for i in range(0, total_chunks, batch_size):
|
||||
batch = all_child_chunks[i:i+batch_size]
|
||||
await self.vector_store.aadd_documents(batch)
|
||||
logger.info("已向 Qdrant 存入子文档批次 %d/%d",
|
||||
i // batch_size + 1,
|
||||
(total_chunks + batch_size - 1) // batch_size)
|
||||
|
||||
logger.info("父子文档索引完成:%d 父文档,%d 子文档",
|
||||
len(parent_docs_with_ids), total_chunks)
|
||||
return total_chunks
|
||||
|
||||
async def _add_batch_with_retry(self, batch: List[Document], batch_no: int) -> None:
|
||||
"""添加批次,失败时自动重试(处理网络波动)。"""
|
||||
max_retries = 5
|
||||
base_delay = 2
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
await self.retriever.aadd_documents(batch)
|
||||
logger.info("批次 %d 成功添加 %d 个文档", batch_no, len(batch))
|
||||
return
|
||||
except (RemoteProtocolError, ConnectionError, OSError, ResponseHandlingException) as e:
|
||||
if attempt == max_retries - 1:
|
||||
logger.error("批次 %d 重试 %d 次后仍然失败: %s", batch_no, max_retries, e)
|
||||
raise
|
||||
wait_time = base_delay * (2 ** attempt)
|
||||
error_type = type(e).__name__
|
||||
logger.warning(
|
||||
"批次 %d 遇到网络异常 [%s],%d秒后重试 (%d/%d): %s",
|
||||
batch_no, error_type, wait_time, attempt + 1, max_retries, e
|
||||
)
|
||||
self.vector_store.refresh_client()
|
||||
logger.debug("批次 %d 已刷新 Qdrant 客户端连接", batch_no)
|
||||
await asyncio.sleep(wait_time)
|
||||
"""这个方法不再使用,保留只是为了兼容(不再被调用)"""
|
||||
# 这个方法现在不需要了,因为我们重写了 _index_with_parent_child
|
||||
pass
|
||||
|
||||
# ---------- 信息获取方法 ----------
|
||||
def get_collection_info(self) -> Any:
|
||||
|
||||
@@ -7,7 +7,7 @@ import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
from backend.rag_core import QdrantVectorStore
|
||||
from backend.rag_core import QdrantHybridStore
|
||||
from backend.app.model_services import get_embedding_service
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ def check_qdrant_data():
|
||||
print("="*70)
|
||||
|
||||
embeddings = get_embedding_service()
|
||||
vs = QdrantVectorStore(collection_name="rag_documents", embeddings=embeddings)
|
||||
vs = QdrantHybridStore(collection_name="rag_documents", embeddings=embeddings)
|
||||
client = vs.get_qdrant_client()
|
||||
|
||||
# 先获取几个点看看 payload 结构
|
||||
|
||||
@@ -8,7 +8,7 @@ import os
|
||||
import sys
|
||||
|
||||
from qdrant_client import models
|
||||
from backend.rag_core import QdrantVectorStore, get_sparse_embedder
|
||||
from backend.rag_core import QdrantHybridStore, get_sparse_embedder
|
||||
from backend.app.model_services import get_embedding_service
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ def test_dense_retrieval():
|
||||
print("="*70)
|
||||
|
||||
embeddings = get_embedding_service()
|
||||
vs = QdrantVectorStore(collection_name="rag_documents", embeddings=embeddings)
|
||||
vs = QdrantHybridStore(collection_name="rag_documents", embeddings=embeddings)
|
||||
|
||||
query = "黄双银" # 用文档里真正有的名字查询
|
||||
print(f"\n查询: {query}")
|
||||
|
||||
@@ -7,7 +7,7 @@ import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
from backend.rag_core import QdrantVectorStore
|
||||
from backend.rag_core import QdrantHybridStore
|
||||
from backend.app.model_services import get_embedding_service
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ async def delete_and_recreate():
|
||||
print("="*70)
|
||||
|
||||
embeddings = get_embedding_service()
|
||||
vs = QdrantVectorStore(collection_name="rag_documents", embeddings=embeddings)
|
||||
vs = QdrantHybridStore(collection_name="rag_documents", embeddings=embeddings)
|
||||
|
||||
# 删除旧集合
|
||||
try:
|
||||
|
||||
@@ -8,7 +8,7 @@ import os
|
||||
import sys
|
||||
|
||||
from qdrant_client import models
|
||||
from backend.rag_core import QdrantVectorStore, get_sparse_embedder
|
||||
from backend.rag_core import QdrantHybridStore, get_sparse_embedder
|
||||
from backend.app.model_services import get_embedding_service
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ def check_qdrant_content():
|
||||
print("="*70)
|
||||
|
||||
embeddings = get_embedding_service()
|
||||
vs = QdrantVectorStore(collection_name="rag_documents", embeddings=embeddings)
|
||||
vs = QdrantHybridStore(collection_name="rag_documents", embeddings=embeddings)
|
||||
client = vs.get_qdrant_client()
|
||||
|
||||
# 滚动获取前 5 个点
|
||||
@@ -51,7 +51,7 @@ def test_dense_retrieval():
|
||||
print("="*70)
|
||||
|
||||
embeddings = get_embedding_service()
|
||||
vs = QdrantVectorStore(collection_name="rag_documents", embeddings=embeddings)
|
||||
vs = QdrantHybridStore(collection_name="rag_documents", embeddings=embeddings)
|
||||
|
||||
query = "蚂蚁" # 用中文查询
|
||||
print(f"\n查询: {query}")
|
||||
@@ -72,7 +72,7 @@ def test_sparse_retrieval():
|
||||
print("="*70)
|
||||
|
||||
embeddings = get_embedding_service()
|
||||
vs = QdrantVectorStore(collection_name="rag_documents", embeddings=embeddings)
|
||||
vs = QdrantHybridStore(collection_name="rag_documents", embeddings=embeddings)
|
||||
client = vs.get_qdrant_client()
|
||||
sparse_embedder = get_sparse_embedder()
|
||||
|
||||
@@ -109,7 +109,7 @@ def test_hybrid_retrieval():
|
||||
print("="*70)
|
||||
|
||||
embeddings = get_embedding_service()
|
||||
vs = QdrantVectorStore(collection_name="rag_documents", embeddings=embeddings)
|
||||
vs = QdrantHybridStore(collection_name="rag_documents", embeddings=embeddings)
|
||||
client = vs.get_qdrant_client()
|
||||
sparse_embedder = get_sparse_embedder()
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ import os
|
||||
from rag_indexer.index_builder import IndexBuilder
|
||||
from rag_indexer.splitters import SplitterType
|
||||
|
||||
from backend.rag_core import QdrantVectorStore, get_sparse_embedder
|
||||
from backend.rag_core import QdrantHybridStore, get_sparse_embedder
|
||||
from backend.app.model_services import get_embedding_service
|
||||
from qdrant_client import models
|
||||
|
||||
@@ -61,7 +61,7 @@ def test_dense_retrieval():
|
||||
embeddings = get_embedding_service()
|
||||
|
||||
# 创建向量存储
|
||||
vs = QdrantVectorStore(collection_name="rag_documents", embeddings=embeddings)
|
||||
vs = QdrantHybridStore(collection_name="rag_documents", embeddings=embeddings)
|
||||
|
||||
# 测试查询
|
||||
query = "The Ant and the Grasshopper"
|
||||
@@ -87,7 +87,7 @@ def test_sparse_retrieval_simple():
|
||||
|
||||
# 获取嵌入服务和稀疏嵌入器
|
||||
embeddings = get_embedding_service()
|
||||
vs = QdrantVectorStore(collection_name="rag_documents", embeddings=embeddings)
|
||||
vs = QdrantHybridStore(collection_name="rag_documents", embeddings=embeddings)
|
||||
client = vs.get_qdrant_client()
|
||||
sparse_embedder = get_sparse_embedder()
|
||||
|
||||
@@ -133,7 +133,7 @@ def test_hybrid_retrieval_simple():
|
||||
|
||||
# 获取嵌入服务和稀疏嵌入器
|
||||
embeddings = get_embedding_service()
|
||||
vs = QdrantVectorStore(collection_name="rag_documents", embeddings=embeddings)
|
||||
vs = QdrantHybridStore(collection_name="rag_documents", embeddings=embeddings)
|
||||
client = vs.get_qdrant_client()
|
||||
sparse_embedder = get_sparse_embedder()
|
||||
|
||||
@@ -189,7 +189,7 @@ def test_parent_child_retrieval_simple():
|
||||
|
||||
# 获取嵌入服务和稀疏嵌入器
|
||||
embeddings = get_embedding_service()
|
||||
vs = QdrantVectorStore(collection_name="rag_documents", embeddings=embeddings)
|
||||
vs = QdrantHybridStore(collection_name="rag_documents", embeddings=embeddings)
|
||||
client = vs.get_qdrant_client()
|
||||
sparse_embedder = get_sparse_embedder()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user