Files
ailine/rag_indexer/index_builder.py

307 lines
12 KiB
Python
Raw Normal View History

2026-04-19 22:01:55 +08:00
"""
离线 RAG 索引构建核心流水线
使用 LangChain ParentDocumentRetriever 实现父子块策略
"""
import asyncio
import logging
2026-04-21 10:26:37 +08:00
import sys
2026-04-19 22:01:55 +08:00
from pathlib import Path
2026-04-21 10:26:37 +08:00
from dataclasses import dataclass, field
from typing import List, Union, Optional, Any, Dict
2026-04-19 22:01:55 +08:00
2026-04-21 10:26:37 +08:00
# 添加 backend 目录到路径以导入 rag_core
sys.path.insert(0, str(Path(__file__).parent.parent / "backend"))
2026-04-19 22:01:55 +08:00
from httpx import RemoteProtocolError
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.stores import BaseStore
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
2026-04-20 14:05:57 +08:00
from qdrant_client.http.exceptions import ResponseHandlingException
2026-04-19 22:01:55 +08:00
2026-04-21 10:26:37 +08:00
from .loaders import DocumentLoader
from .splitters import SplitterType, get_splitter
2026-04-21 18:41:14 +08:00
# 从 rag_core 导入
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent / "backend"))
2026-04-20 01:10:18 +08:00
from rag_core import LlamaCppEmbedder, QdrantVectorStore, create_docstore, create_parent_retriever
2026-04-19 22:01:55 +08:00
logger = logging.getLogger(__name__)
# ---------- 配置数据类 ----------
@dataclass
class DocstoreConfig:
"""文档存储配置(用于父块存储)。"""
2026-04-21 19:06:34 +08:00
pool_config: Dict[str, Any] | None = None
max_concurrency: int | None = None
2026-04-19 22:01:55 +08:00
# 若要从外部注入已创建好的 docstore可直接设置此字段
2026-04-21 19:06:34 +08:00
instance: BaseStore | None = None
2026-04-19 22:01:55 +08:00
@dataclass
class IndexBuilderConfig:
"""索引构建器配置。"""
collection_name: str = "rag_documents"
splitter_type: SplitterType = SplitterType.PARENT_CHILD
# 父块切分参数(仅当 splitter_type 为 PARENT_CHILD 时生效)
parent_chunk_size: int = 1000
parent_chunk_overlap: int = 100
# 子块切分参数
child_chunk_size: int = 200
child_chunk_overlap: int = 20
child_splitter_type: SplitterType = SplitterType.SEMANTIC # 子块默认语义切分
# 检索参数
search_k: int = 5
# 文档存储配置(仅父子块模式需要)
docstore: DocstoreConfig = field(default_factory=DocstoreConfig)
# 其他切分器参数(当 splitter_type 非父子块时使用)
extra_splitter_kwargs: Dict[str, Any] = field(default_factory=dict)
# ---------- 索引构建器 ----------
class IndexBuilder:
"""RAG 索引构建主流水线,支持单块切分与父子块切分。"""
def __init__(self, config: Optional[IndexBuilderConfig] = None, **kwargs):
"""
Args:
config: 索引构建器配置对象优先级高于 kwargs
**kwargs: 可直接传入配置参数会合并到 config 为方便使用保留
"""
if config is None:
config = IndexBuilderConfig(**kwargs)
elif kwargs:
# 合并 kwargs 到 config 的字段(仅更新已有字段)
for key, value in kwargs.items():
if hasattr(config, key):
setattr(config, key, value)
self.config = config
self._docstore_conn: Optional[str] = None # 用于记录由 create_docstore 创建的连接信息
# 初始化基础组件
self.loader = DocumentLoader()
self.embedder = LlamaCppEmbedder()
self.embeddings: Embeddings = self.embedder.as_langchain_embeddings()
# 初始化向量存储
self.vector_store = QdrantVectorStore(
collection_name=config.collection_name,
)
# 根据切分类型初始化相关组件
self._init_splitters_and_retriever()
# ---------- 私有初始化方法 ----------
def _init_splitters_and_retriever(self) -> None:
"""根据配置初始化切分器和检索器。"""
if self.config.splitter_type == SplitterType.PARENT_CHILD:
self._init_parent_child_mode()
else:
self._init_single_splitter_mode()
def _init_single_splitter_mode(self) -> None:
"""单一切分模式(递归或语义)。"""
splitter_kwargs = self.config.extra_splitter_kwargs.copy()
if self.config.splitter_type == SplitterType.SEMANTIC:
splitter_kwargs["embeddings"] = self.embeddings
self.splitter = get_splitter(self.config.splitter_type, **splitter_kwargs)
self.retriever = None
self.docstore = None
logger.info("使用单一 %s 切分器", self.config.splitter_type.value)
def _init_parent_child_mode(self) -> None:
cfg = self.config
2026-04-20 01:10:18 +08:00
# 父块切分器(索引构建需要,必须保留)
2026-04-19 22:01:55 +08:00
self.parent_splitter = RecursiveCharacterTextSplitter(
chunk_size=cfg.parent_chunk_size,
chunk_overlap=cfg.parent_chunk_overlap,
)
2026-04-20 01:10:18 +08:00
# 子块切分器(索引构建需要)
2026-04-19 22:01:55 +08:00
if cfg.child_splitter_type == SplitterType.SEMANTIC:
self.child_splitter = get_splitter(
SplitterType.SEMANTIC,
embeddings=self.embeddings,
**cfg.extra_splitter_kwargs
)
else:
self.child_splitter = RecursiveCharacterTextSplitter(
chunk_size=cfg.child_chunk_size,
chunk_overlap=cfg.child_chunk_overlap,
)
2026-04-20 01:10:18 +08:00
# 文档存储
2026-04-19 22:01:55 +08:00
self.docstore = self._create_or_use_docstore()
2026-04-20 01:10:18 +08:00
# 使用工厂函数创建检索器,避免重复代码
self.retriever = create_parent_retriever(
collection_name=cfg.collection_name,
2026-04-19 22:01:55 +08:00
parent_splitter=self.parent_splitter,
2026-04-20 01:10:18 +08:00
child_splitter=self.child_splitter,
docstore=self.docstore,
search_k=cfg.search_k,
2026-04-19 22:01:55 +08:00
)
2026-04-20 01:10:18 +08:00
logger.info("ParentDocumentRetriever 初始化完成")
2026-04-19 22:01:55 +08:00
def _create_or_use_docstore(self) -> BaseStore:
"""创建或获取文档存储实例。"""
cfg = self.config.docstore
if cfg.instance is not None:
logger.debug("使用外部注入的文档存储")
return cfg.instance
# 使用 create_docstore 创建 PostgreSQL 存储
docstore, conn_info = create_docstore(
pool_config=cfg.pool_config,
max_concurrency=cfg.max_concurrency,
)
self._docstore_conn = conn_info
logger.info("文档存储已创建PostgreSQL")
return docstore
# ---------- 公共构建方法 ----------
async def build_from_file(self, file_path: Union[str, Path]) -> int:
"""从单个文件构建索引。"""
logger.info("加载文件: %s", file_path)
documents = self.loader.load_file(file_path)
logger.info("已加载 %d 个文档", len(documents))
return await self._process_documents(documents)
async def build_from_directory(
self, directory_path: Union[str, Path], recursive: bool = True
) -> int:
"""从目录递归构建索引。"""
logger.info("加载目录: %s (递归=%s)", directory_path, recursive)
documents = self.loader.load_directory(directory_path, recursive=recursive)
logger.info("已从目录加载 %d 个文档", len(documents))
return await self._process_documents(documents)
async def _process_documents(self, documents: List[Document]) -> int:
"""处理文档列表,分发给相应的索引逻辑。"""
if not documents:
logger.warning("没有文档需要处理")
return 0
if self.config.splitter_type == SplitterType.PARENT_CHILD:
return await self._index_with_parent_child(documents)
else:
return await self._index_with_single_splitter(documents)
async def _index_with_single_splitter(self, documents: List[Document]) -> int:
"""单一模式:切分后直接写入向量库。"""
chunks = self.splitter.split_documents(documents) # type: ignore[union-attr]
logger.info("已切分为 %d 个块", len(chunks))
self.vector_store.create_collection()
self.vector_store.add_documents(chunks)
return len(chunks)
async def _index_with_parent_child(self, documents: List[Document]) -> int:
"""父子模式:使用 ParentDocumentRetriever 批量添加。"""
self.vector_store.create_collection()
assert self.retriever is not None
batch_size = 10
total = len(documents)
processed = 0
for i in range(0, total, batch_size):
batch = documents[i:i + batch_size]
await self._add_batch_with_retry(batch, i // batch_size + 1)
processed += len(batch)
logger.info("批次 %d: 已处理 %d/%d", i // batch_size + 1, processed, total)
logger.info("ParentDocumentRetriever 索引完成,共处理 %d 个文档", processed)
return processed
async def _add_batch_with_retry(self, batch: List[Document], batch_no: int) -> None:
"""添加批次,失败时自动重试(处理网络波动)。"""
2026-04-20 14:05:57 +08:00
max_retries = 5
base_delay = 2
2026-04-19 22:01:55 +08:00
for attempt in range(max_retries):
try:
await self.retriever.aadd_documents(batch) # type: ignore[union-attr]
2026-04-20 14:05:57 +08:00
logger.info("批次 %d 成功添加 %d 个文档", batch_no, len(batch))
2026-04-19 22:01:55 +08:00
return
2026-04-20 14:05:57 +08:00
except (RemoteProtocolError, ConnectionError, OSError, ResponseHandlingException) as e:
2026-04-19 22:01:55 +08:00
if attempt == max_retries - 1:
2026-04-20 14:05:57 +08:00
logger.error("批次 %d 重试 %d 次后仍然失败: %s", batch_no, max_retries, e)
2026-04-19 22:01:55 +08:00
raise
2026-04-20 14:05:57 +08:00
wait_time = base_delay * (2 ** attempt)
error_type = type(e).__name__
logger.warning(
"批次 %d 遇到网络异常 [%s]%d秒后重试 (%d/%d): %s",
batch_no, error_type, wait_time, attempt + 1, max_retries, e
)
2026-04-19 22:01:55 +08:00
self.vector_store.refresh_client()
2026-04-20 14:05:57 +08:00
logger.debug("批次 %d 已刷新 Qdrant 客户端连接", batch_no)
await asyncio.sleep(wait_time)
2026-04-19 22:01:55 +08:00
# ---------- 信息获取方法 ----------
def get_collection_info(self) -> Any:
"""获取向量库集合信息。"""
return self.vector_store.get_collection_info()
def get_child_splitter(self) -> TextSplitter:
"""获取当前使用的子块切分器。"""
if self.config.splitter_type == SplitterType.PARENT_CHILD:
return self.child_splitter # type: ignore[return-value]
return self.splitter # type: ignore[return-value]
def get_parent_splitter(self) -> RecursiveCharacterTextSplitter:
"""获取父块切分器(仅父子模式可用)。"""
if self.config.splitter_type != SplitterType.PARENT_CHILD:
raise RuntimeError("父块切分器仅在父子块模式下可用")
return self.parent_splitter # type: ignore[return-value]
def get_docstore(self) -> BaseStore:
"""获取文档存储实例(仅父子模式可用)。"""
if self.config.splitter_type != SplitterType.PARENT_CHILD:
raise RuntimeError("文档存储仅在父子块模式下可用")
assert self.docstore is not None
return self.docstore
# ---------- 资源管理 ----------
def close(self) -> None:
"""关闭资源(同步版本,供上下文管理器使用)。"""
if self.docstore is not None and hasattr(self.docstore, "aclose"):
try:
loop = asyncio.get_running_loop()
except RuntimeError:
# 无运行中的事件循环,创建临时循环
loop = asyncio.new_event_loop()
loop.run_until_complete(self.docstore.aclose()) # type: ignore[attr-defined]
loop.close()
else:
# 已有运行中的循环,创建任务(用户自行等待)
loop.create_task(self.docstore.aclose()) # type: ignore[attr-defined]
logger.info("IndexBuilder 资源已关闭")
async def aclose(self) -> None:
"""异步关闭资源。"""
if self.docstore is not None and hasattr(self.docstore, "aclose"):
await self.docstore.aclose() # type: ignore[attr-defined]
logger.info("IndexBuilder 资源已异步关闭")
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
return False
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.aclose()
return False