""" 离线 RAG 索引构建核心流水线。 支持 LangChain 的 ParentDocumentRetriever 用于父子块切分。 """ import asyncio import logging from pathlib import Path from typing import List, Union, Optional, Tuple, Any from dataclasses import dataclass from httpx import RemoteProtocolError from langchain_core.documents import Document from langchain_classic.retrievers import ParentDocumentRetriever from langchain_core.stores import BaseStore from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_experimental.text_splitter import SemanticChunker from .loaders import DocumentLoader from .splitters import SplitterType, get_splitter, ParentChildSplitter, SemanticChunkerAdapter from .embedders import LlamaCppEmbedder from .vector_store import QdrantVectorStore, QDRANT_URL, QDRANT_API_KEY from .store import create_docstore logger = logging.getLogger(__name__) @dataclass class ParentChildConfig: """父子块切分配置。""" parent_chunk_size: int = 1000 child_chunk_size: int = 200 parent_chunk_overlap: int = 100 child_chunk_overlap: int = 20 search_k: int = 5 docstore_path: Optional[str] = None docstore_type: str = "local" docstore_conn_string: Optional[str] = None class IndexBuilder: """RAG 索引构建主流水线。""" # 类型注解 parent_splitter: "RecursiveCharacterTextSplitter" child_splitter: Union["RecursiveCharacterTextSplitter", "SemanticChunker", "SemanticChunkerAdapter"] docstore: Optional["BaseStore"] _docstore_conn: Optional[str] retriever: Optional["ParentDocumentRetriever"] vector_store_obj: Any def __init__( self, collection_name: str = "rag_documents", splitter_type: SplitterType = SplitterType.PARENT_CHILD, docstore=None, **splitter_kwargs, ): self.collection_name = collection_name self.splitter_type = splitter_type self.splitter_kwargs = splitter_kwargs self.docstore = docstore # 从外部注入 # 组件 self.loader = DocumentLoader() self.embedder = LlamaCppEmbedder() self.embeddings = self.embedder.as_langchain_embeddings() self.vector_store = QdrantVectorStore( collection_name=collection_name, embeddings=self.embeddings, ) # 切分器(父子块单独处理) if splitter_type != SplitterType.PARENT_CHILD: if splitter_type == SplitterType.SEMANTIC: splitter_kwargs["embeddings"] = self.embeddings self.splitter = get_splitter(splitter_type, **splitter_kwargs) else: self.splitter = None # 为父子块切分初始化 ParentDocumentRetriever self._init_parent_child_retriever() def _init_parent_child_retriever(self, **kwargs): """ 初始化 ParentDocumentRetriever 用于父子块切分。 支持动态语义切分与父子块策略结合: - 父块使用递归切分(大块,提供上下文) - 子块可以使用递归切分或语义切分(根据语义动态切分,提高检索精度) 替代自定义的 ParentChildSplitter 逻辑。 """ # 解析父子块配置参数 parent_size = kwargs.get("parent_chunk_size", 1000) child_size = kwargs.get("child_chunk_size", 200) parent_overlap = kwargs.get("parent_chunk_overlap", kwargs.get("chunk_overlap", 100)) child_overlap = kwargs.get("child_chunk_overlap", kwargs.get("chunk_overlap", 20)) # 子块切分器类型,默认为语义切分 child_splitter_type = kwargs.get("child_splitter_type", SplitterType.SEMANTIC) # 定义父块切分器(始终使用递归切分) self.parent_splitter = RecursiveCharacterTextSplitter( chunk_size=parent_size, chunk_overlap=parent_overlap, ) # 定义子块切分器(根据类型选择) if child_splitter_type == SplitterType.SEMANTIC: self.child_splitter = get_splitter( SplitterType.SEMANTIC, embeddings=self.embeddings, ) logger.info(f"子块使用语义切分器") else: # 默认使用递归切分 self.child_splitter = RecursiveCharacterTextSplitter( chunk_size=child_size, chunk_overlap=child_overlap, ) logger.info(f"子块使用递归切分器,块大小: {child_size},重叠: {child_overlap}") # 向量存储(用于子块) self.vector_store_obj = self.vector_store.get_langchain_vectorstore() # 文档存储(用于父块) if self.docstore is None: # 如果没有外部注入 docstore,则使用 PostgreSQL 创建 docstore_conn = kwargs.get("docstore_conn_string") pool_config = kwargs.get("pool_config") max_concurrency = kwargs.get("max_concurrency") # 使用 create_docstore 创建 PostgreSQL 存储 self.docstore, self._docstore_conn = create_docstore( connection_string=docstore_conn, pool_config=pool_config, max_concurrency=max_concurrency ) else: # 使用外部注入的 docstore self._docstore_conn = None # 创建检索器 self.retriever = ParentDocumentRetriever( vectorstore=self.vector_store_obj, docstore=self.docstore, child_splitter=self.child_splitter, # type: ignore parent_splitter=self.parent_splitter, search_kwargs={"k": kwargs.get("search_k", 5)}, ) logger.info(f"ParentDocumentRetriever 已初始化,父块大小: {parent_size},子块类型: {child_splitter_type}") async def build_from_file(self, file_path: Union[str, Path]) -> int: logger.info("加载文件: %s", file_path) documents = self.loader.load_file(file_path) logger.info("已加载 %d 个文档", len(documents)) return await self._process_documents(documents) async def build_from_directory(self, directory_path: Union[str, Path], recursive: bool = True) -> int: logger.info("加载目录: %s (递归=%s)", directory_path, recursive) documents = self.loader.load_directory(directory_path, recursive=recursive) logger.info("已从目录加载 %d 个文档", len(documents)) return await self._process_documents(documents) async def _process_documents(self, documents: List[Document]) -> int: if not documents: logger.warning("没有文档需要处理") return 0 if self.splitter_type == SplitterType.PARENT_CHILD: logger.info("使用 LangChain ParentDocumentRetriever") # 确保集合存在(用于子块) self.vector_store.create_collection() # 分批处理,避免单次请求过大 assert self.retriever is not None, "retriever 未初始化" batch_size = 10 # 每次处理10个文档 total = len(documents) processed = 0 for i in range(0, total, batch_size): batch = documents[i:i + batch_size] max_retries = 3 for attempt in range(max_retries): try: await self.retriever.aadd_documents(batch) processed += len(batch) logger.info(f"批次 {i//batch_size + 1}: 已处理 {processed}/{total}") break except (RemoteProtocolError, ConnectionError, OSError) as e: if attempt == max_retries - 1: raise logger.warning(f"连接断开,重试 ({attempt+1}/{max_retries}): {e}") self.vector_store.refresh_client() await asyncio.sleep(1) logger.info( "已使用 ParentDocumentRetriever 索引: " f"共 {processed} 个父块" ) return processed else: logger.info("使用 %s 切分文档", self.splitter_type) # 当 splitter_type 不是 PARENT_CHILD 时,splitter 一定不为 None assert self.splitter is not None, "splitter 未初始化" chunks = self.splitter.split_documents(documents) logger.info("已切分为 %d 个块", len(chunks)) self.vector_store.create_collection() self.vector_store.add_documents(chunks) return len(chunks) def get_collection_info(self): return self.vector_store.get_collection_info() def search(self, query: str, k: int = 5) -> List[Document]: """标准搜索 - 返回子块。""" return self.vector_store.similarity_search(query, k=k) async def search_with_parent_context(self, query: str, k: int = 5) -> List[Document]: """ 带父块上下文的搜索 - 返回完整父块。 这是使用父子块切分时的主要检索方法。 """ if self.splitter_type != SplitterType.PARENT_CHILD: raise RuntimeError( "search_with_parent_context 仅在 PARENT_CHILD 切分器下可用。" "请使用 search() 进行标准检索。" ) assert self.retriever is not None, "retriever 未初始化" return await self.retriever.ainvoke(query, config={"k": k}) # type: ignore async def retrieve(self, query: str, return_parent: bool = True) -> List[Document]: """ 统一检索接口。 Args: query: 搜索查询 return_parent: 如果为 True 且使用父子块切分,返回父块 如果为 False,始终返回子块 Returns: 相关文档列表 """ if self.splitter_type == SplitterType.PARENT_CHILD and return_parent: return await self.search_with_parent_context(query) else: return self.search(query) async def retrieve_with_fusion(self, query: str, llm: Any, num_queries: int = 3, k: int = 5, return_parent: bool = True) -> List[Document]: """ 使用 RAG-Fusion 进行检索(多路查询改写 + 倒数排名融合)。 核心原理: 1. 多路查询改写: 利用 LLM 将原始查询改写成多个不同表述 2. 倒数排名融合: 对每个改写查询的结果进行 RRF 融合,避免单一检索结果主导 Args: query: 原始搜索查询 llm: 语言模型实例(用于查询改写) num_queries: 生成的查询数量 k: 返回的文档数量 return_parent: 如果为 True 且使用父子块切分,返回父块 如果为 False,始终返回子块 Returns: 经过融合后的相关文档列表 """ from langchain.retrievers.multi_query import MultiQueryRetriever from langchain.retrievers import EnsembleRetriever if self.splitter_type == SplitterType.PARENT_CHILD and return_parent: # 使用 ParentDocumentRetriever 作为基础检索器 assert self.retriever is not None, "retriever 未初始化" base_retriever = self.retriever else: # 使用向量存储作为基础检索器 base_retriever = self.vector_store.as_langchain_vectorstore().as_retriever(search_kwargs={"k": k * 2}) # 创建多路查询检索器 multi_query_retriever = MultiQueryRetriever.from_llm( retriever=base_retriever, llm=llm, include_original=True ) # 设置自定义提示词以生成指定数量的查询 from langchain_core.prompts import PromptTemplate multi_query_retriever.llm_chain.prompt = PromptTemplate.from_template( "你是一个专业的查询改写助手。你的任务是将用户的问题改写成 {num_queries} 个不同的版本。\n" "这些版本应该从不同的角度、使用不同的关键词来表达相同或相关的意图。\n\n" "原始问题: {question}\n\n" "请生成 {num_queries} 个不同版本的查询,每个版本一行。\n" "确保每个版本都是独立、完整的查询语句。\n\n" "生成 {num_queries} 个查询:" ) # 修改调用参数以包含 num_queries original_ainvoke = multi_query_retriever.llm_chain.ainvoke async def new_ainvoke(input_dict): input_dict["num_queries"] = num_queries return await original_ainvoke(input_dict) multi_query_retriever.llm_chain.ainvoke = new_ainvoke # 执行检索 documents = await multi_query_retriever.ainvoke(query) # 去重并限制数量 seen_content = set() unique_documents = [] for doc in documents: content = doc.page_content if content not in seen_content: seen_content.add(content) unique_documents.append(doc) if len(unique_documents) >= k: break logger.info(f"RAG-Fusion 检索完成: 原始 {len(documents)} 个结果,去重后 {len(unique_documents)} 个结果") return unique_documents def get_retriever(self) -> ParentDocumentRetriever: """ 直接获取 ParentDocumentRetriever 实例。 适用于需要在 IndexBuilder 外部访问检索器的高级用例。 """ if self.splitter_type != SplitterType.PARENT_CHILD: raise RuntimeError( "get_retriever() 仅在 PARENT_CHILD 切分器下可用。" "请使用 search() 或 search_with_parent_context() 进行标准检索。" ) assert self.retriever is not None, "retriever 未初始化" return self.retriever def get_child_splitter(self) -> Union["RecursiveCharacterTextSplitter", "SemanticChunker", "SemanticChunkerAdapter"]: """获取子块切分器以便重新配置。""" if self.splitter_type != SplitterType.PARENT_CHILD: return self.splitter # type: ignore return self.child_splitter def get_parent_splitter(self) -> "RecursiveCharacterTextSplitter": """获取父块切分器以便重新配置。""" if self.splitter_type != SplitterType.PARENT_CHILD: raise RuntimeError( "父块切分器仅在 PARENT_CHILD 切分器下可用。" ) return self.parent_splitter def get_docstore(self) -> BaseStore: """获取父块的文档存储。""" if self.splitter_type != SplitterType.PARENT_CHILD: raise RuntimeError( "文档存储仅在 PARENT_CHILD 切分器下可用。" ) assert self.docstore is not None, "docstore 未初始化" return self.docstore def get_docstore_path(self) -> Optional[str]: """获取文档存储路径(已弃用,仅用于兼容性)。""" if self.splitter_type != SplitterType.PARENT_CHILD: raise RuntimeError( "文档存储路径仅在 PARENT_CHILD 切分器下可用。" ) # PostgreSQL 存储没有 persist_path,返回 None return None def close(self): """关闭资源。""" if self.docstore is not None and hasattr(self.docstore, "aclose"): import asyncio asyncio.get_event_loop().run_until_complete(self.docstore.aclose()) # type: ignore logger.info("PostgreSQL 异步连接池已关闭") def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.close() return False # 需要导入 RecursiveCharacterTextSplitter from langchain_text_splitters import RecursiveCharacterTextSplitter # 示例用法已移除,请参考文档