Files
ailine/rag_indexer/index_builder.py
root 9841f47432
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m53s
refactor: 重构RAG核心组件,简化代码结构和测试文件
2026-05-04 17:58:10 +08:00

320 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
离线 RAG 索引构建核心流水线。
自定义实现父子块策略,支持 Qdrant 混合检索Dense + Sparse
"""
import asyncio
import logging
import sys
from pathlib import Path
from dataclasses import dataclass, field
from typing import List, Union, Optional, Any, Dict
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.stores import BaseStore
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
from .loaders import DocumentLoader
from .splitters import SplitterType, get_splitter
from backend.rag_core import get_embeddings, QdrantHybridStore, create_docstore
logger = logging.getLogger(__name__)
# ---------- 配置数据类 ----------
@dataclass
class DocstoreConfig:
"""文档存储配置(用于父文档存储)。"""
pool_config: Dict[str, Any] | None = None
max_concurrency: int | None = None
# 若要从外部注入已创建好的 docstore可直接设置此字段
instance: BaseStore | None = None
@dataclass
class IndexBuilderConfig:
"""索引构建器配置。"""
collection_name: str = "rag_documents"
splitter_type: SplitterType = SplitterType.PARENT_CHILD
# 父块切分参数(仅当 splitter_type 为 PARENT_CHILD 时生效)
parent_chunk_size: int = 1000
parent_chunk_overlap: int = 100
# 子块切分参数
child_chunk_size: int = 200
child_chunk_overlap: int = 20
child_splitter_type: SplitterType = SplitterType.SEMANTIC # 子块默认语义切分
# 子块语义切分参数
child_buffer_size: int = 1
child_breakpoint_threshold_type: str = "percentile"
child_breakpoint_threshold_amount: float = 90 # 降低阈值,让切分更激进
child_min_chunk_size: int = 50 # 降低最小块大小
# 检索参数
search_k: int = 5
# 文档存储配置(仅父子块模式需要)
docstore: DocstoreConfig = field(default_factory=DocstoreConfig)
# 其他切分器参数(当 splitter_type 非父子块时使用)
extra_splitter_kwargs: Dict[str, Any] = field(default_factory=dict)
# ---------- 索引构建器 ----------
class IndexBuilder:
"""RAG 索引构建主流水线,支持单块切分与父子块切分,支持混合检索。"""
def __init__(self, config: Optional[IndexBuilderConfig] = None, **kwargs):
"""
Args:
config: 索引构建器配置对象,优先级高于 kwargs
**kwargs: 可直接传入配置参数,会合并到 config 中(为方便使用保留)
"""
if config is None:
config = IndexBuilderConfig(**kwargs)
elif kwargs:
# 合并 kwargs 到 config 的字段(仅更新已有字段)
for key, value in kwargs.items():
if hasattr(config, key):
setattr(config, key, value)
self.config = config
self._docstore_conn: Optional[str] = None # 用于记录由 create_docstore 创建的连接信息
# 初始化基础组件
self.loader = DocumentLoader()
# 设置嵌入模型 - 完全使用服务内部提供
self.embeddings = get_embeddings()
logger.info("使用统一嵌入服务")
# 初始化向量存储(自动支持稠密+稀疏混合检索)
self.vector_store = QdrantHybridStore(
collection_name=config.collection_name,
)
logger.info("✅ 混合检索向量存储初始化成功(稠密+BM25稀疏")
# 根据切分类型初始化相关组件
self._init_splitters_and_retriever()
# ---------- 私有初始化方法 ----------
def _init_splitters_and_retriever(self) -> None:
"""根据配置初始化切分器和检索器。"""
if self.config.splitter_type == SplitterType.PARENT_CHILD:
self._init_parent_child_mode()
else:
self._init_single_splitter_mode()
def _init_single_splitter_mode(self) -> None:
"""单一切分模式(递归或语义)。"""
splitter_kwargs = self.config.extra_splitter_kwargs.copy()
if self.config.splitter_type == SplitterType.SEMANTIC:
splitter_kwargs["embeddings"] = self.embeddings
self.splitter = get_splitter(self.config.splitter_type, **splitter_kwargs)
self.retriever = None
self.docstore = None
logger.info("使用单一 %s 切分器", self.config.splitter_type.value)
def _init_parent_child_mode(self) -> None:
cfg = self.config
# 父块切分器
self.parent_splitter = RecursiveCharacterTextSplitter(
chunk_size=cfg.parent_chunk_size,
chunk_overlap=cfg.parent_chunk_overlap,
)
# 子块切分器
if cfg.child_splitter_type == SplitterType.SEMANTIC:
self.child_splitter = get_splitter(
SplitterType.SEMANTIC,
embeddings=self.embeddings,
buffer_size=cfg.child_buffer_size,
breakpoint_threshold_type=cfg.child_breakpoint_threshold_type,
breakpoint_threshold_amount=cfg.child_breakpoint_threshold_amount,
min_chunk_size=cfg.child_min_chunk_size,
**cfg.extra_splitter_kwargs
)
else:
self.child_splitter = RecursiveCharacterTextSplitter(
chunk_size=cfg.child_chunk_size,
chunk_overlap=cfg.child_chunk_overlap,
)
# 文档存储
self.docstore = self._create_or_use_docstore()
# 注意:不再使用 LangChain 的 ParentDocumentRetriever
# 改为自定义实现,以支持稀疏向量
self.retriever = None
logger.info("父子文档模式初始化完成(使用自定义索引逻辑)")
def _create_or_use_docstore(self) -> BaseStore:
"""创建或获取文档存储实例。"""
cfg = self.config.docstore
if cfg.instance is not None:
logger.debug("使用外部注入的文档存储")
return cfg.instance
# 使用 create_docstore 创建 PostgreSQL 存储
docstore, conn_info = create_docstore(
pool_config=cfg.pool_config,
max_concurrency=cfg.max_concurrency,
)
self._docstore_conn = conn_info
logger.info("文档存储已创建PostgreSQL")
return docstore
# ---------- 公共构建方法 ----------
async def build_from_file(self, file_path: Union[str, Path]) -> int:
"""从单个文件构建索引。"""
logger.info("加载文件: %s", file_path)
documents = self.loader.load_file(file_path)
logger.info("已加载 %d 个文档", len(documents))
return await self._process_documents(documents)
async def build_from_directory(self, directory_path: Union[str, Path], recursive: bool = True) -> int:
"""从目录递归构建索引。"""
logger.info("加载目录: %s (递归=%s)", directory_path, recursive)
documents = self.loader.load_directory(directory_path, recursive=recursive)
logger.info("已从目录加载 %d 个文档", len(documents))
return await self._process_documents(documents)
async def _process_documents(self, documents: List[Document]) -> int:
"""处理文档列表,分发给相应的索引逻辑。"""
if not documents:
logger.warning("没有文档需要处理")
return 0
if self.config.splitter_type == SplitterType.PARENT_CHILD:
return await self._index_with_parent_child(documents)
else:
return await self._index_with_single_splitter(documents)
async def _index_with_single_splitter(self, documents: List[Document]) -> int:
"""单一切分模式:切分后直接写入向量库(异步)。"""
chunks = self.splitter.split_documents(documents)
logger.info("已切分为 %d 个块", len(chunks))
self.vector_store.create_collection()
await self.vector_store.aadd_documents(chunks)
return len(chunks)
async def _index_with_parent_child(self, documents: List[Document]) -> int:
"""父子块模式:自定义实现,支持稠密+稀疏双向量。"""
self.vector_store.create_collection()
assert self.docstore is not None
import uuid
total_chunks = 0
# 1. 切分父块
parent_chunks = self.parent_splitter.split_documents(documents)
logger.info("切分出 %d 个父块", len(parent_chunks))
# 2. 为每个父块生成 UUID 并存储
parent_docs_with_ids = []
for parent_chunk in parent_chunks:
parent_id = str(uuid.uuid4())
parent_chunk.metadata["id"] = parent_id
parent_chunk.metadata["is_parent"] = True
parent_docs_with_ids.append((parent_id, parent_chunk))
# 3. 父文档批量存入 PostgreSQL
await self.docstore.amset(parent_docs_with_ids)
logger.info("已存入 %d 个父文档到 PostgreSQL", len(parent_docs_with_ids))
# 4. 切分子块并添加 parent_id
all_child_chunks = []
for parent_id, parent_chunk in parent_docs_with_ids:
child_chunks = self.child_splitter.split_documents([parent_chunk])
for child_chunk in child_chunks:
child_chunk.metadata["parent_id"] = parent_id
child_chunk.metadata["is_parent"] = False
# 继承父文档的重要元数据
child_chunk.metadata["source"] = parent_chunk.metadata.get("source")
child_chunk.metadata["page"] = parent_chunk.metadata.get("page")
child_chunk.metadata["file_path"] = parent_chunk.metadata.get("file_path")
all_child_chunks.append(child_chunk)
total_chunks = len(all_child_chunks)
logger.info("切分出 %d 个子块", total_chunks)
# 5. 子文档分批存入 Qdrant双向量异步
batch_size = 100
for i in range(0, total_chunks, batch_size):
batch = all_child_chunks[i:i+batch_size]
await self.vector_store.aadd_documents(batch)
logger.info("已向 Qdrant 存入子文档批次 %d/%d",
i // batch_size + 1,
(total_chunks + batch_size - 1) // batch_size)
logger.info("父子文档索引完成:%d 父文档,%d 子文档",
len(parent_docs_with_ids), total_chunks)
return total_chunks
async def _add_batch_with_retry(self, batch: List[Document], batch_no: int) -> None:
"""这个方法不再使用,保留只是为了兼容(不再被调用)"""
# 这个方法现在不需要了,因为我们重写了 _index_with_parent_child
pass
# ---------- 信息获取方法 ----------
def get_collection_info(self) -> Any:
"""获取向量库集合信息。"""
return self.vector_store.get_collection_info()
def get_child_splitter(self) -> TextSplitter:
"""获取当前使用的子块切分器。"""
if self.config.splitter_type == SplitterType.PARENT_CHILD:
return self.child_splitter
return self.splitter
def get_parent_splitter(self) -> RecursiveCharacterTextSplitter:
"""获取父块切分器(仅父子块模式可用)。"""
if self.config.splitter_type != SplitterType.PARENT_CHILD:
raise RuntimeError("父块切分器仅在父子块模式下可用")
return self.parent_splitter
def get_docstore(self) -> BaseStore:
"""获取文档存储实例(仅父子块模式可用)。"""
if self.config.splitter_type != SplitterType.PARENT_CHILD:
raise RuntimeError("文档存储仅在父子块模式下可用")
assert self.docstore is not None
return self.docstore
# ---------- 资源管理 ----------
def close(self) -> None:
"""关闭资源(同步版本,供上下文管理器使用)。"""
if self.docstore is not None and hasattr(self.docstore, "aclose"):
try:
loop = asyncio.get_running_loop()
except RuntimeError:
# 无运行中的事件循环,创建临时循环
loop = asyncio.new_event_loop()
loop.run_until_complete(self.docstore.aclose())
loop.close()
else:
# 已有运行中的循环,创建任务(用户自行等待)
loop.create_task(self.docstore.aclose())
logger.info("IndexBuilder 资源已关闭")
async def aclose(self) -> None:
"""异步关闭资源。"""
if self.docstore is not None and hasattr(self.docstore, "aclose"):
await self.docstore.aclose()
logger.info("IndexBuilder 资源已异步关闭")
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
return False
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.aclose()
return False