""" 离线 RAG 索引构建核心流水线。 使用 LangChain 的 ParentDocumentRetriever 实现父子块策略。 """ import asyncio import logging import sys from pathlib import Path from dataclasses import dataclass, field from typing import List, Union, Optional, Any, Dict # 添加 backend 目录到路径以导入 rag_core sys.path.insert(0, str(Path(__file__).parent.parent / "backend")) from httpx import RemoteProtocolError from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.stores import BaseStore from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter from qdrant_client.http.exceptions import ResponseHandlingException from .loaders import DocumentLoader from .splitters import SplitterType, get_splitter # 从 rag_core 导入 import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.parent / "backend")) from rag_core import LlamaCppEmbedder, QdrantVectorStore, create_docstore, create_parent_retriever # 尝试导入新的 model_services(如果可用) try: from app.model_services import get_embedding_service HAS_MODEL_SERVICES = True except ImportError: HAS_MODEL_SERVICES = False logger = logging.getLogger(__name__) # ---------- 配置数据类 ---------- @dataclass class DocstoreConfig: """文档存储配置(用于父块存储)。""" pool_config: Dict[str, Any] | None = None max_concurrency: int | None = None # 若要从外部注入已创建好的 docstore,可直接设置此字段 instance: BaseStore | None = None @dataclass class IndexBuilderConfig: """索引构建器配置。""" collection_name: str = "rag_documents" splitter_type: SplitterType = SplitterType.PARENT_CHILD # 父块切分参数(仅当 splitter_type 为 PARENT_CHILD 时生效) parent_chunk_size: int = 1000 parent_chunk_overlap: int = 100 # 子块切分参数 child_chunk_size: int = 200 child_chunk_overlap: int = 20 child_splitter_type: SplitterType = SplitterType.SEMANTIC # 子块默认语义切分 # 检索参数 search_k: int = 5 # 文档存储配置(仅父子块模式需要) docstore: DocstoreConfig = field(default_factory=DocstoreConfig) # 其他切分器参数(当 splitter_type 非父子块时使用) extra_splitter_kwargs: Dict[str, Any] = field(default_factory=dict) # ---------- 索引构建器 ---------- class IndexBuilder: """RAG 索引构建主流水线,支持单块切分与父子块切分。""" def __init__(self, config: Optional[IndexBuilderConfig] = None, embeddings: Optional[Embeddings] = None, **kwargs): """ Args: config: 索引构建器配置对象,优先级高于 kwargs embeddings: 可选的外部嵌入模型实例,如果提供则使用它 **kwargs: 可直接传入配置参数,会合并到 config 中(为方便使用保留) """ if config is None: config = IndexBuilderConfig(**kwargs) elif kwargs: # 合并 kwargs 到 config 的字段(仅更新已有字段) for key, value in kwargs.items(): if hasattr(config, key): setattr(config, key, value) self.config = config self._docstore_conn: Optional[str] = None # 用于记录由 create_docstore 创建的连接信息 # 初始化基础组件 self.loader = DocumentLoader() # 设置嵌入模型 - 优先使用外部提供的,然后尝试使用新服务,最后回退到原来的方式 if embeddings is not None: self.embeddings = embeddings self.embedder = None logger.info("使用外部提供的嵌入模型") elif HAS_MODEL_SERVICES: try: self.embeddings = get_embedding_service() self.embedder = None logger.info("使用 model_services 提供的嵌入服务") except Exception as e: logger.warning(f"获取嵌入服务失败,回退到 LlamaCppEmbedder: {e}") self.embedder = LlamaCppEmbedder() self.embeddings = self.embedder.as_langchain_embeddings() else: self.embedder = LlamaCppEmbedder() self.embeddings = self.embedder.as_langchain_embeddings() # 初始化向量存储 self.vector_store = QdrantVectorStore( collection_name=config.collection_name, embeddings=self.embeddings if self.embedder is None else None, ) # 根据切分类型初始化相关组件 self._init_splitters_and_retriever() # ---------- 私有初始化方法 ---------- def _init_splitters_and_retriever(self) -> None: """根据配置初始化切分器和检索器。""" if self.config.splitter_type == SplitterType.PARENT_CHILD: self._init_parent_child_mode() else: self._init_single_splitter_mode() def _init_single_splitter_mode(self) -> None: """单一切分模式(递归或语义)。""" splitter_kwargs = self.config.extra_splitter_kwargs.copy() if self.config.splitter_type == SplitterType.SEMANTIC: splitter_kwargs["embeddings"] = self.embeddings self.splitter = get_splitter(self.config.splitter_type, **splitter_kwargs) self.retriever = None self.docstore = None logger.info("使用单一 %s 切分器", self.config.splitter_type.value) def _init_parent_child_mode(self) -> None: cfg = self.config # 父块切分器(索引构建需要,必须保留) self.parent_splitter = RecursiveCharacterTextSplitter( chunk_size=cfg.parent_chunk_size, chunk_overlap=cfg.parent_chunk_overlap, ) # 子块切分器(索引构建需要) if cfg.child_splitter_type == SplitterType.SEMANTIC: self.child_splitter = get_splitter( SplitterType.SEMANTIC, embeddings=self.embeddings, **cfg.extra_splitter_kwargs ) else: self.child_splitter = RecursiveCharacterTextSplitter( chunk_size=cfg.child_chunk_size, chunk_overlap=cfg.child_chunk_overlap, ) # 文档存储 self.docstore = self._create_or_use_docstore() # 使用工厂函数创建检索器,避免重复代码 self.retriever = create_parent_retriever( collection_name=cfg.collection_name, parent_splitter=self.parent_splitter, child_splitter=self.child_splitter, docstore=self.docstore, search_k=cfg.search_k, embeddings=self.embeddings if self.embedder is None else None, ) logger.info("ParentDocumentRetriever 初始化完成") def _create_or_use_docstore(self) -> BaseStore: """创建或获取文档存储实例。""" cfg = self.config.docstore if cfg.instance is not None: logger.debug("使用外部注入的文档存储") return cfg.instance # 使用 create_docstore 创建 PostgreSQL 存储 docstore, conn_info = create_docstore( pool_config=cfg.pool_config, max_concurrency=cfg.max_concurrency, ) self._docstore_conn = conn_info logger.info("文档存储已创建(PostgreSQL)") return docstore # ---------- 公共构建方法 ---------- async def build_from_file(self, file_path: Union[str, Path]) -> int: """从单个文件构建索引。""" logger.info("加载文件: %s", file_path) documents = self.loader.load_file(file_path) logger.info("已加载 %d 个文档", len(documents)) return await self._process_documents(documents) async def build_from_directory( self, directory_path: Union[str, Path], recursive: bool = True ) -> int: """从目录递归构建索引。""" logger.info("加载目录: %s (递归=%s)", directory_path, recursive) documents = self.loader.load_directory(directory_path, recursive=recursive) logger.info("已从目录加载 %d 个文档", len(documents)) return await self._process_documents(documents) async def _process_documents(self, documents: List[Document]) -> int: """处理文档列表,分发给相应的索引逻辑。""" if not documents: logger.warning("没有文档需要处理") return 0 if self.config.splitter_type == SplitterType.PARENT_CHILD: return await self._index_with_parent_child(documents) else: return await self._index_with_single_splitter(documents) async def _index_with_single_splitter(self, documents: List[Document]) -> int: """单一模式:切分后直接写入向量库。""" chunks = self.splitter.split_documents(documents) # type: ignore[union-attr] logger.info("已切分为 %d 个块", len(chunks)) self.vector_store.create_collection() self.vector_store.add_documents(chunks) return len(chunks) async def _index_with_parent_child(self, documents: List[Document]) -> int: """父子模式:使用 ParentDocumentRetriever 批量添加。""" self.vector_store.create_collection() assert self.retriever is not None batch_size = 10 total = len(documents) processed = 0 for i in range(0, total, batch_size): batch = documents[i:i + batch_size] await self._add_batch_with_retry(batch, i // batch_size + 1) processed += len(batch) logger.info("批次 %d: 已处理 %d/%d", i // batch_size + 1, processed, total) logger.info("ParentDocumentRetriever 索引完成,共处理 %d 个文档", processed) return processed async def _add_batch_with_retry(self, batch: List[Document], batch_no: int) -> None: """添加批次,失败时自动重试(处理网络波动)。""" max_retries = 5 base_delay = 2 for attempt in range(max_retries): try: await self.retriever.aadd_documents(batch) # type: ignore[union-attr] logger.info("批次 %d 成功添加 %d 个文档", batch_no, len(batch)) return except (RemoteProtocolError, ConnectionError, OSError, ResponseHandlingException) as e: if attempt == max_retries - 1: logger.error("批次 %d 重试 %d 次后仍然失败: %s", batch_no, max_retries, e) raise wait_time = base_delay * (2 ** attempt) error_type = type(e).__name__ logger.warning( "批次 %d 遇到网络异常 [%s],%d秒后重试 (%d/%d): %s", batch_no, error_type, wait_time, attempt + 1, max_retries, e ) self.vector_store.refresh_client() logger.debug("批次 %d 已刷新 Qdrant 客户端连接", batch_no) await asyncio.sleep(wait_time) # ---------- 信息获取方法 ---------- def get_collection_info(self) -> Any: """获取向量库集合信息。""" return self.vector_store.get_collection_info() def get_child_splitter(self) -> TextSplitter: """获取当前使用的子块切分器。""" if self.config.splitter_type == SplitterType.PARENT_CHILD: return self.child_splitter # type: ignore[return-value] return self.splitter # type: ignore[return-value] def get_parent_splitter(self) -> RecursiveCharacterTextSplitter: """获取父块切分器(仅父子模式可用)。""" if self.config.splitter_type != SplitterType.PARENT_CHILD: raise RuntimeError("父块切分器仅在父子块模式下可用") return self.parent_splitter # type: ignore[return-value] def get_docstore(self) -> BaseStore: """获取文档存储实例(仅父子模式可用)。""" if self.config.splitter_type != SplitterType.PARENT_CHILD: raise RuntimeError("文档存储仅在父子块模式下可用") assert self.docstore is not None return self.docstore # ---------- 资源管理 ---------- def close(self) -> None: """关闭资源(同步版本,供上下文管理器使用)。""" if self.docstore is not None and hasattr(self.docstore, "aclose"): try: loop = asyncio.get_running_loop() except RuntimeError: # 无运行中的事件循环,创建临时循环 loop = asyncio.new_event_loop() loop.run_until_complete(self.docstore.aclose()) # type: ignore[attr-defined] loop.close() else: # 已有运行中的循环,创建任务(用户自行等待) loop.create_task(self.docstore.aclose()) # type: ignore[attr-defined] logger.info("IndexBuilder 资源已关闭") async def aclose(self) -> None: """异步关闭资源。""" if self.docstore is not None and hasattr(self.docstore, "aclose"): await self.docstore.aclose() # type: ignore[attr-defined] logger.info("IndexBuilder 资源已异步关闭") def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.close() return False async def __aenter__(self): return self async def __aexit__(self, exc_type, exc_val, exc_tb): await self.aclose() return False