Files
ailine/rag_indexer/index_builder.py

301 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
离线 RAG 索引构建核心流水线。
使用 LangChain 的 ParentDocumentRetriever 实现父子块策略。
"""
import asyncio
import logging
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Union, Optional, Any, Dict
from httpx import RemoteProtocolError
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.stores import BaseStore
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
from qdrant_client.http.exceptions import ResponseHandlingException
from rag_indexer.loaders import DocumentLoader
from rag_indexer.splitters import SplitterType, get_splitter
from rag_core import LlamaCppEmbedder, QdrantVectorStore, create_docstore, create_parent_retriever
logger = logging.getLogger(__name__)
# ---------- 配置数据类 ----------
@dataclass
class DocstoreConfig:
"""文档存储配置(用于父块存储)。"""
connection_string: Optional[str] = None
pool_config: Optional[Dict[str, Any]] = None
max_concurrency: Optional[int] = None
# 若要从外部注入已创建好的 docstore可直接设置此字段
instance: Optional[BaseStore] = None
@dataclass
class IndexBuilderConfig:
"""索引构建器配置。"""
collection_name: str = "rag_documents"
splitter_type: SplitterType = SplitterType.PARENT_CHILD
# 父块切分参数(仅当 splitter_type 为 PARENT_CHILD 时生效)
parent_chunk_size: int = 1000
parent_chunk_overlap: int = 100
# 子块切分参数
child_chunk_size: int = 200
child_chunk_overlap: int = 20
child_splitter_type: SplitterType = SplitterType.SEMANTIC # 子块默认语义切分
# 检索参数
search_k: int = 5
# 文档存储配置(仅父子块模式需要)
docstore: DocstoreConfig = field(default_factory=DocstoreConfig)
# 其他切分器参数(当 splitter_type 非父子块时使用)
extra_splitter_kwargs: Dict[str, Any] = field(default_factory=dict)
# ---------- 索引构建器 ----------
class IndexBuilder:
"""RAG 索引构建主流水线,支持单块切分与父子块切分。"""
def __init__(self, config: Optional[IndexBuilderConfig] = None, **kwargs):
"""
Args:
config: 索引构建器配置对象,优先级高于 kwargs
**kwargs: 可直接传入配置参数,会合并到 config 中(为方便使用保留)
"""
if config is None:
config = IndexBuilderConfig(**kwargs)
elif kwargs:
# 合并 kwargs 到 config 的字段(仅更新已有字段)
for key, value in kwargs.items():
if hasattr(config, key):
setattr(config, key, value)
self.config = config
self._docstore_conn: Optional[str] = None # 用于记录由 create_docstore 创建的连接信息
# 初始化基础组件
self.loader = DocumentLoader()
self.embedder = LlamaCppEmbedder()
self.embeddings: Embeddings = self.embedder.as_langchain_embeddings()
# 初始化向量存储
self.vector_store = QdrantVectorStore(
collection_name=config.collection_name,
embeddings=self.embeddings,
)
# 根据切分类型初始化相关组件
self._init_splitters_and_retriever()
# ---------- 私有初始化方法 ----------
def _init_splitters_and_retriever(self) -> None:
"""根据配置初始化切分器和检索器。"""
if self.config.splitter_type == SplitterType.PARENT_CHILD:
self._init_parent_child_mode()
else:
self._init_single_splitter_mode()
def _init_single_splitter_mode(self) -> None:
"""单一切分模式(递归或语义)。"""
splitter_kwargs = self.config.extra_splitter_kwargs.copy()
if self.config.splitter_type == SplitterType.SEMANTIC:
splitter_kwargs["embeddings"] = self.embeddings
self.splitter = get_splitter(self.config.splitter_type, **splitter_kwargs)
self.retriever = None
self.docstore = None
logger.info("使用单一 %s 切分器", self.config.splitter_type.value)
def _init_parent_child_mode(self) -> None:
cfg = self.config
# 父块切分器(索引构建需要,必须保留)
self.parent_splitter = RecursiveCharacterTextSplitter(
chunk_size=cfg.parent_chunk_size,
chunk_overlap=cfg.parent_chunk_overlap,
)
# 子块切分器(索引构建需要)
if cfg.child_splitter_type == SplitterType.SEMANTIC:
self.child_splitter = get_splitter(
SplitterType.SEMANTIC,
embeddings=self.embeddings,
**cfg.extra_splitter_kwargs
)
else:
self.child_splitter = RecursiveCharacterTextSplitter(
chunk_size=cfg.child_chunk_size,
chunk_overlap=cfg.child_chunk_overlap,
)
# 文档存储
self.docstore = self._create_or_use_docstore()
# 使用工厂函数创建检索器,避免重复代码
self.retriever = create_parent_retriever(
collection_name=cfg.collection_name,
embeddings=self.embeddings,
parent_splitter=self.parent_splitter,
child_splitter=self.child_splitter,
docstore=self.docstore,
search_k=cfg.search_k,
)
logger.info("ParentDocumentRetriever 初始化完成")
def _create_or_use_docstore(self) -> BaseStore:
"""创建或获取文档存储实例。"""
cfg = self.config.docstore
if cfg.instance is not None:
logger.debug("使用外部注入的文档存储")
return cfg.instance
# 使用 create_docstore 创建 PostgreSQL 存储
docstore, conn_info = create_docstore(
connection_string=cfg.connection_string,
pool_config=cfg.pool_config,
max_concurrency=cfg.max_concurrency,
)
self._docstore_conn = conn_info
logger.info("文档存储已创建PostgreSQL")
return docstore
# ---------- 公共构建方法 ----------
async def build_from_file(self, file_path: Union[str, Path]) -> int:
"""从单个文件构建索引。"""
logger.info("加载文件: %s", file_path)
documents = self.loader.load_file(file_path)
logger.info("已加载 %d 个文档", len(documents))
return await self._process_documents(documents)
async def build_from_directory(
self, directory_path: Union[str, Path], recursive: bool = True
) -> int:
"""从目录递归构建索引。"""
logger.info("加载目录: %s (递归=%s)", directory_path, recursive)
documents = self.loader.load_directory(directory_path, recursive=recursive)
logger.info("已从目录加载 %d 个文档", len(documents))
return await self._process_documents(documents)
async def _process_documents(self, documents: List[Document]) -> int:
"""处理文档列表,分发给相应的索引逻辑。"""
if not documents:
logger.warning("没有文档需要处理")
return 0
if self.config.splitter_type == SplitterType.PARENT_CHILD:
return await self._index_with_parent_child(documents)
else:
return await self._index_with_single_splitter(documents)
async def _index_with_single_splitter(self, documents: List[Document]) -> int:
"""单一模式:切分后直接写入向量库。"""
chunks = self.splitter.split_documents(documents) # type: ignore[union-attr]
logger.info("已切分为 %d 个块", len(chunks))
self.vector_store.create_collection()
self.vector_store.add_documents(chunks)
return len(chunks)
async def _index_with_parent_child(self, documents: List[Document]) -> int:
"""父子模式:使用 ParentDocumentRetriever 批量添加。"""
self.vector_store.create_collection()
assert self.retriever is not None
batch_size = 10
total = len(documents)
processed = 0
for i in range(0, total, batch_size):
batch = documents[i:i + batch_size]
await self._add_batch_with_retry(batch, i // batch_size + 1)
processed += len(batch)
logger.info("批次 %d: 已处理 %d/%d", i // batch_size + 1, processed, total)
logger.info("ParentDocumentRetriever 索引完成,共处理 %d 个文档", processed)
return processed
async def _add_batch_with_retry(self, batch: List[Document], batch_no: int) -> None:
"""添加批次,失败时自动重试(处理网络波动)。"""
max_retries = 5
base_delay = 2
for attempt in range(max_retries):
try:
await self.retriever.aadd_documents(batch) # type: ignore[union-attr]
logger.info("批次 %d 成功添加 %d 个文档", batch_no, len(batch))
return
except (RemoteProtocolError, ConnectionError, OSError, ResponseHandlingException) as e:
if attempt == max_retries - 1:
logger.error("批次 %d 重试 %d 次后仍然失败: %s", batch_no, max_retries, e)
raise
wait_time = base_delay * (2 ** attempt)
error_type = type(e).__name__
logger.warning(
"批次 %d 遇到网络异常 [%s]%d秒后重试 (%d/%d): %s",
batch_no, error_type, wait_time, attempt + 1, max_retries, e
)
self.vector_store.refresh_client()
logger.debug("批次 %d 已刷新 Qdrant 客户端连接", batch_no)
await asyncio.sleep(wait_time)
# ---------- 信息获取方法 ----------
def get_collection_info(self) -> Any:
"""获取向量库集合信息。"""
return self.vector_store.get_collection_info()
def get_child_splitter(self) -> TextSplitter:
"""获取当前使用的子块切分器。"""
if self.config.splitter_type == SplitterType.PARENT_CHILD:
return self.child_splitter # type: ignore[return-value]
return self.splitter # type: ignore[return-value]
def get_parent_splitter(self) -> RecursiveCharacterTextSplitter:
"""获取父块切分器(仅父子模式可用)。"""
if self.config.splitter_type != SplitterType.PARENT_CHILD:
raise RuntimeError("父块切分器仅在父子块模式下可用")
return self.parent_splitter # type: ignore[return-value]
def get_docstore(self) -> BaseStore:
"""获取文档存储实例(仅父子模式可用)。"""
if self.config.splitter_type != SplitterType.PARENT_CHILD:
raise RuntimeError("文档存储仅在父子块模式下可用")
assert self.docstore is not None
return self.docstore
# ---------- 资源管理 ----------
def close(self) -> None:
"""关闭资源(同步版本,供上下文管理器使用)。"""
if self.docstore is not None and hasattr(self.docstore, "aclose"):
try:
loop = asyncio.get_running_loop()
except RuntimeError:
# 无运行中的事件循环,创建临时循环
loop = asyncio.new_event_loop()
loop.run_until_complete(self.docstore.aclose()) # type: ignore[attr-defined]
loop.close()
else:
# 已有运行中的循环,创建任务(用户自行等待)
loop.create_task(self.docstore.aclose()) # type: ignore[attr-defined]
logger.info("IndexBuilder 资源已关闭")
async def aclose(self) -> None:
"""异步关闭资源。"""
if self.docstore is not None and hasattr(self.docstore, "aclose"):
await self.docstore.aclose() # type: ignore[attr-defined]
logger.info("IndexBuilder 资源已异步关闭")
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
return False
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.aclose()
return False