Files
ailine/rag_indexer/IndexBuilder.py
root 933d418d77
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 17m12s
检索器重构
2026-04-19 22:01:55 +08:00

299 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
离线 RAG 索引构建核心流水线。
使用 LangChain 的 ParentDocumentRetriever 实现父子块策略。
"""
import asyncio
import logging
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Union, Optional, Any, Dict, Tuple
from httpx import RemoteProtocolError
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.stores import BaseStore
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
from langchain_classic.retrievers import ParentDocumentRetriever
from .loaders import DocumentLoader
from .splitters import SplitterType, get_splitter, SemanticChunkerAdapter
from rag_core import LlamaCppEmbedder, QdrantVectorStore, create_docstore
logger = logging.getLogger(__name__)
# ---------- 配置数据类 ----------
@dataclass
class DocstoreConfig:
"""文档存储配置(用于父块存储)。"""
connection_string: Optional[str] = None
pool_config: Optional[Dict[str, Any]] = None
max_concurrency: Optional[int] = None
# 若要从外部注入已创建好的 docstore可直接设置此字段
instance: Optional[BaseStore] = None
@dataclass
class IndexBuilderConfig:
"""索引构建器配置。"""
collection_name: str = "rag_documents"
splitter_type: SplitterType = SplitterType.PARENT_CHILD
# 父块切分参数(仅当 splitter_type 为 PARENT_CHILD 时生效)
parent_chunk_size: int = 1000
parent_chunk_overlap: int = 100
# 子块切分参数
child_chunk_size: int = 200
child_chunk_overlap: int = 20
child_splitter_type: SplitterType = SplitterType.SEMANTIC # 子块默认语义切分
# 检索参数
search_k: int = 5
# 文档存储配置(仅父子块模式需要)
docstore: DocstoreConfig = field(default_factory=DocstoreConfig)
# 其他切分器参数(当 splitter_type 非父子块时使用)
extra_splitter_kwargs: Dict[str, Any] = field(default_factory=dict)
# ---------- 索引构建器 ----------
class IndexBuilder:
"""RAG 索引构建主流水线,支持单块切分与父子块切分。"""
def __init__(self, config: Optional[IndexBuilderConfig] = None, **kwargs):
"""
Args:
config: 索引构建器配置对象,优先级高于 kwargs
**kwargs: 可直接传入配置参数,会合并到 config 中(为方便使用保留)
"""
if config is None:
config = IndexBuilderConfig(**kwargs)
elif kwargs:
# 合并 kwargs 到 config 的字段(仅更新已有字段)
for key, value in kwargs.items():
if hasattr(config, key):
setattr(config, key, value)
self.config = config
self._docstore_conn: Optional[str] = None # 用于记录由 create_docstore 创建的连接信息
# 初始化基础组件
self.loader = DocumentLoader()
self.embedder = LlamaCppEmbedder()
self.embeddings: Embeddings = self.embedder.as_langchain_embeddings()
# 初始化向量存储
self.vector_store = QdrantVectorStore(
collection_name=config.collection_name,
embeddings=self.embeddings,
)
# 根据切分类型初始化相关组件
self._init_splitters_and_retriever()
# ---------- 私有初始化方法 ----------
def _init_splitters_and_retriever(self) -> None:
"""根据配置初始化切分器和检索器。"""
if self.config.splitter_type == SplitterType.PARENT_CHILD:
self._init_parent_child_mode()
else:
self._init_single_splitter_mode()
def _init_single_splitter_mode(self) -> None:
"""单一切分模式(递归或语义)。"""
splitter_kwargs = self.config.extra_splitter_kwargs.copy()
if self.config.splitter_type == SplitterType.SEMANTIC:
splitter_kwargs["embeddings"] = self.embeddings
self.splitter = get_splitter(self.config.splitter_type, **splitter_kwargs)
self.retriever = None
self.docstore = None
logger.info("使用单一 %s 切分器", self.config.splitter_type.value)
def _init_parent_child_mode(self) -> None:
"""父子块切分模式,初始化父块/子块切分器、文档存储和检索器。"""
cfg = self.config
# 父块切分器(始终使用递归切分)
self.parent_splitter = RecursiveCharacterTextSplitter(
chunk_size=cfg.parent_chunk_size,
chunk_overlap=cfg.parent_chunk_overlap,
)
# 子块切分器
if cfg.child_splitter_type == SplitterType.SEMANTIC:
self.child_splitter = get_splitter(
SplitterType.SEMANTIC,
embeddings=self.embeddings,
**cfg.extra_splitter_kwargs
)
logger.info("子块使用语义切分器")
else:
self.child_splitter = RecursiveCharacterTextSplitter(
chunk_size=cfg.child_chunk_size,
chunk_overlap=cfg.child_chunk_overlap,
)
logger.info("子块使用递归切分器,块大小=%d,重叠=%d",
cfg.child_chunk_size, cfg.child_chunk_overlap)
# 初始化文档存储(用于父块)
self.docstore = self._create_or_use_docstore()
# 创建检索器
self.retriever = ParentDocumentRetriever(
vectorstore=self.vector_store.get_langchain_vectorstore(),
docstore=self.docstore,
child_splitter=self.child_splitter, # type: ignore[arg-type]
parent_splitter=self.parent_splitter,
search_kwargs={"k": cfg.search_k},
)
logger.info("ParentDocumentRetriever 初始化完成,父块大小=%d", cfg.parent_chunk_size)
def _create_or_use_docstore(self) -> BaseStore:
"""创建或获取文档存储实例。"""
cfg = self.config.docstore
if cfg.instance is not None:
logger.debug("使用外部注入的文档存储")
return cfg.instance
# 使用 create_docstore 创建 PostgreSQL 存储
docstore, conn_info = create_docstore(
connection_string=cfg.connection_string,
pool_config=cfg.pool_config,
max_concurrency=cfg.max_concurrency,
)
self._docstore_conn = conn_info
logger.info("文档存储已创建PostgreSQL")
return docstore
# ---------- 公共构建方法 ----------
async def build_from_file(self, file_path: Union[str, Path]) -> int:
"""从单个文件构建索引。"""
logger.info("加载文件: %s", file_path)
documents = self.loader.load_file(file_path)
logger.info("已加载 %d 个文档", len(documents))
return await self._process_documents(documents)
async def build_from_directory(
self, directory_path: Union[str, Path], recursive: bool = True
) -> int:
"""从目录递归构建索引。"""
logger.info("加载目录: %s (递归=%s)", directory_path, recursive)
documents = self.loader.load_directory(directory_path, recursive=recursive)
logger.info("已从目录加载 %d 个文档", len(documents))
return await self._process_documents(documents)
async def _process_documents(self, documents: List[Document]) -> int:
"""处理文档列表,分发给相应的索引逻辑。"""
if not documents:
logger.warning("没有文档需要处理")
return 0
if self.config.splitter_type == SplitterType.PARENT_CHILD:
return await self._index_with_parent_child(documents)
else:
return await self._index_with_single_splitter(documents)
async def _index_with_single_splitter(self, documents: List[Document]) -> int:
"""单一模式:切分后直接写入向量库。"""
chunks = self.splitter.split_documents(documents) # type: ignore[union-attr]
logger.info("已切分为 %d 个块", len(chunks))
self.vector_store.create_collection()
self.vector_store.add_documents(chunks)
return len(chunks)
async def _index_with_parent_child(self, documents: List[Document]) -> int:
"""父子模式:使用 ParentDocumentRetriever 批量添加。"""
self.vector_store.create_collection()
assert self.retriever is not None
batch_size = 10
total = len(documents)
processed = 0
for i in range(0, total, batch_size):
batch = documents[i:i + batch_size]
await self._add_batch_with_retry(batch, i // batch_size + 1)
processed += len(batch)
logger.info("批次 %d: 已处理 %d/%d", i // batch_size + 1, processed, total)
logger.info("ParentDocumentRetriever 索引完成,共处理 %d 个文档", processed)
return processed
async def _add_batch_with_retry(self, batch: List[Document], batch_no: int) -> None:
"""添加批次,失败时自动重试(处理网络波动)。"""
max_retries = 3
for attempt in range(max_retries):
try:
await self.retriever.aadd_documents(batch) # type: ignore[union-attr]
return
except (RemoteProtocolError, ConnectionError, OSError) as e:
if attempt == max_retries - 1:
raise
logger.warning("批次 %d 连接断开,重试 (%d/%d): %s",
batch_no, attempt + 1, max_retries, e)
self.vector_store.refresh_client()
await asyncio.sleep(1)
# ---------- 信息获取方法 ----------
def get_collection_info(self) -> Any:
"""获取向量库集合信息。"""
return self.vector_store.get_collection_info()
def get_child_splitter(self) -> TextSplitter:
"""获取当前使用的子块切分器。"""
if self.config.splitter_type == SplitterType.PARENT_CHILD:
return self.child_splitter # type: ignore[return-value]
return self.splitter # type: ignore[return-value]
def get_parent_splitter(self) -> RecursiveCharacterTextSplitter:
"""获取父块切分器(仅父子模式可用)。"""
if self.config.splitter_type != SplitterType.PARENT_CHILD:
raise RuntimeError("父块切分器仅在父子块模式下可用")
return self.parent_splitter # type: ignore[return-value]
def get_docstore(self) -> BaseStore:
"""获取文档存储实例(仅父子模式可用)。"""
if self.config.splitter_type != SplitterType.PARENT_CHILD:
raise RuntimeError("文档存储仅在父子块模式下可用")
assert self.docstore is not None
return self.docstore
# ---------- 资源管理 ----------
def close(self) -> None:
"""关闭资源(同步版本,供上下文管理器使用)。"""
if self.docstore is not None and hasattr(self.docstore, "aclose"):
try:
loop = asyncio.get_running_loop()
except RuntimeError:
# 无运行中的事件循环,创建临时循环
loop = asyncio.new_event_loop()
loop.run_until_complete(self.docstore.aclose()) # type: ignore[attr-defined]
loop.close()
else:
# 已有运行中的循环,创建任务(用户自行等待)
loop.create_task(self.docstore.aclose()) # type: ignore[attr-defined]
logger.info("IndexBuilder 资源已关闭")
async def aclose(self) -> None:
"""异步关闭资源。"""
if self.docstore is not None and hasattr(self.docstore, "aclose"):
await self.docstore.aclose() # type: ignore[attr-defined]
logger.info("IndexBuilder 资源已异步关闭")
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
return False
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.aclose()
return False