Files
ailine/rag_indexer/builder.py
2026-04-19 15:01:40 +08:00

393 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
离线 RAG 索引构建核心流水线。
支持 LangChain 的 ParentDocumentRetriever 用于父子块切分。
"""
import asyncio
import logging
from pathlib import Path
from typing import List, Union, Optional, Tuple, Any
from dataclasses import dataclass
from httpx import RemoteProtocolError
from langchain_core.documents import Document
from langchain_classic.retrievers import ParentDocumentRetriever
from langchain_core.stores import BaseStore
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_experimental.text_splitter import SemanticChunker
from .loaders import DocumentLoader
from .splitters import SplitterType, get_splitter, ParentChildSplitter, SemanticChunkerAdapter
from .embedders import LlamaCppEmbedder
from .vector_store import QdrantVectorStore, QDRANT_URL, QDRANT_API_KEY
from .store import create_docstore
logger = logging.getLogger(__name__)
@dataclass
class ParentChildConfig:
"""父子块切分配置。"""
parent_chunk_size: int = 1000
child_chunk_size: int = 200
parent_chunk_overlap: int = 100
child_chunk_overlap: int = 20
search_k: int = 5
docstore_path: Optional[str] = None
docstore_type: str = "local"
docstore_conn_string: Optional[str] = None
class IndexBuilder:
"""RAG 索引构建主流水线。"""
# 类型注解
parent_splitter: "RecursiveCharacterTextSplitter"
child_splitter: Union["RecursiveCharacterTextSplitter", "SemanticChunker", "SemanticChunkerAdapter"]
docstore: Optional["BaseStore"]
_docstore_conn: Optional[str]
retriever: Optional["ParentDocumentRetriever"]
vector_store_obj: Any
def __init__(
self,
collection_name: str = "rag_documents",
splitter_type: SplitterType = SplitterType.PARENT_CHILD,
docstore=None,
**splitter_kwargs,
):
self.collection_name = collection_name
self.splitter_type = splitter_type
self.splitter_kwargs = splitter_kwargs
self.docstore = docstore # 从外部注入
# 组件
self.loader = DocumentLoader()
self.embedder = LlamaCppEmbedder()
self.embeddings = self.embedder.as_langchain_embeddings()
self.vector_store = QdrantVectorStore(
collection_name=collection_name,
embeddings=self.embeddings,
)
# 切分器(父子块单独处理)
if splitter_type != SplitterType.PARENT_CHILD:
if splitter_type == SplitterType.SEMANTIC:
splitter_kwargs["embeddings"] = self.embeddings
self.splitter = get_splitter(splitter_type, **splitter_kwargs)
else:
self.splitter = None
# 为父子块切分初始化 ParentDocumentRetriever
self._init_parent_child_retriever()
def _init_parent_child_retriever(self, **kwargs):
"""
初始化 ParentDocumentRetriever 用于父子块切分。
支持动态语义切分与父子块策略结合:
- 父块使用递归切分(大块,提供上下文)
- 子块可以使用递归切分或语义切分(根据语义动态切分,提高检索精度)
替代自定义的 ParentChildSplitter 逻辑。
"""
# 解析父子块配置参数
parent_size = kwargs.get("parent_chunk_size", 1000)
child_size = kwargs.get("child_chunk_size", 200)
parent_overlap = kwargs.get("parent_chunk_overlap", kwargs.get("chunk_overlap", 100))
child_overlap = kwargs.get("child_chunk_overlap", kwargs.get("chunk_overlap", 20))
# 子块切分器类型,默认为语义切分
child_splitter_type = kwargs.get("child_splitter_type", SplitterType.SEMANTIC)
# 定义父块切分器(始终使用递归切分)
self.parent_splitter = RecursiveCharacterTextSplitter(
chunk_size=parent_size,
chunk_overlap=parent_overlap,
)
# 定义子块切分器(根据类型选择)
if child_splitter_type == SplitterType.SEMANTIC:
self.child_splitter = get_splitter(
SplitterType.SEMANTIC,
embeddings=self.embeddings,
)
logger.info(f"子块使用语义切分器")
else:
# 默认使用递归切分
self.child_splitter = RecursiveCharacterTextSplitter(
chunk_size=child_size,
chunk_overlap=child_overlap,
)
logger.info(f"子块使用递归切分器,块大小: {child_size},重叠: {child_overlap}")
# 向量存储(用于子块)
self.vector_store_obj = self.vector_store.get_langchain_vectorstore()
# 文档存储(用于父块)
if self.docstore is None:
# 如果没有外部注入 docstore则使用 PostgreSQL 创建
docstore_conn = kwargs.get("docstore_conn_string")
pool_config = kwargs.get("pool_config")
max_concurrency = kwargs.get("max_concurrency")
# 使用 create_docstore 创建 PostgreSQL 存储
self.docstore, self._docstore_conn = create_docstore(
connection_string=docstore_conn,
pool_config=pool_config,
max_concurrency=max_concurrency
)
else:
# 使用外部注入的 docstore
self._docstore_conn = None
# 创建检索器
self.retriever = ParentDocumentRetriever(
vectorstore=self.vector_store_obj,
docstore=self.docstore,
child_splitter=self.child_splitter, # type: ignore
parent_splitter=self.parent_splitter,
search_kwargs={"k": kwargs.get("search_k", 5)},
)
logger.info(f"ParentDocumentRetriever 已初始化,父块大小: {parent_size},子块类型: {child_splitter_type}")
async def build_from_file(self, file_path: Union[str, Path]) -> int:
logger.info("加载文件: %s", file_path)
documents = self.loader.load_file(file_path)
logger.info("已加载 %d 个文档", len(documents))
return await self._process_documents(documents)
async def build_from_directory(self, directory_path: Union[str, Path], recursive: bool = True) -> int:
logger.info("加载目录: %s (递归=%s)", directory_path, recursive)
documents = self.loader.load_directory(directory_path, recursive=recursive)
logger.info("已从目录加载 %d 个文档", len(documents))
return await self._process_documents(documents)
async def _process_documents(self, documents: List[Document]) -> int:
if not documents:
logger.warning("没有文档需要处理")
return 0
if self.splitter_type == SplitterType.PARENT_CHILD:
logger.info("使用 LangChain ParentDocumentRetriever")
# 确保集合存在(用于子块)
self.vector_store.create_collection()
# 分批处理,避免单次请求过大
assert self.retriever is not None, "retriever 未初始化"
batch_size = 10 # 每次处理10个文档
total = len(documents)
processed = 0
for i in range(0, total, batch_size):
batch = documents[i:i + batch_size]
max_retries = 3
for attempt in range(max_retries):
try:
await self.retriever.aadd_documents(batch)
processed += len(batch)
logger.info(f"批次 {i//batch_size + 1}: 已处理 {processed}/{total}")
break
except (RemoteProtocolError, ConnectionError, OSError) as e:
if attempt == max_retries - 1:
raise
logger.warning(f"连接断开,重试 ({attempt+1}/{max_retries}): {e}")
self.vector_store.refresh_client()
await asyncio.sleep(1)
logger.info(
"已使用 ParentDocumentRetriever 索引: "
f"{processed} 个父块"
)
return processed
else:
logger.info("使用 %s 切分文档", self.splitter_type)
# 当 splitter_type 不是 PARENT_CHILD 时splitter 一定不为 None
assert self.splitter is not None, "splitter 未初始化"
chunks = self.splitter.split_documents(documents)
logger.info("已切分为 %d 个块", len(chunks))
self.vector_store.create_collection()
self.vector_store.add_documents(chunks)
return len(chunks)
def get_collection_info(self):
return self.vector_store.get_collection_info()
def search(self, query: str, k: int = 5) -> List[Document]:
"""标准搜索 - 返回子块。"""
return self.vector_store.similarity_search(query, k=k)
async def search_with_parent_context(self, query: str, k: int = 5) -> List[Document]:
"""
带父块上下文的搜索 - 返回完整父块。
这是使用父子块切分时的主要检索方法。
"""
if self.splitter_type != SplitterType.PARENT_CHILD:
raise RuntimeError(
"search_with_parent_context 仅在 PARENT_CHILD 切分器下可用。"
"请使用 search() 进行标准检索。"
)
assert self.retriever is not None, "retriever 未初始化"
return await self.retriever.ainvoke(query, config={"k": k}) # type: ignore
async def retrieve(self, query: str, return_parent: bool = True) -> List[Document]:
"""
统一检索接口。
Args:
query: 搜索查询
return_parent: 如果为 True 且使用父子块切分,返回父块
如果为 False始终返回子块
Returns:
相关文档列表
"""
if self.splitter_type == SplitterType.PARENT_CHILD and return_parent:
return await self.search_with_parent_context(query)
else:
return self.search(query)
async def retrieve_with_fusion(self, query: str, llm: Any, num_queries: int = 3, k: int = 5, return_parent: bool = True) -> List[Document]:
"""
使用 RAG-Fusion 进行检索(多路查询改写 + 倒数排名融合)。
核心原理:
1. 多路查询改写: 利用 LLM 将原始查询改写成多个不同表述
2. 倒数排名融合: 对每个改写查询的结果进行 RRF 融合,避免单一检索结果主导
Args:
query: 原始搜索查询
llm: 语言模型实例(用于查询改写)
num_queries: 生成的查询数量
k: 返回的文档数量
return_parent: 如果为 True 且使用父子块切分,返回父块
如果为 False始终返回子块
Returns:
经过融合后的相关文档列表
"""
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain.retrievers import EnsembleRetriever
if self.splitter_type == SplitterType.PARENT_CHILD and return_parent:
# 使用 ParentDocumentRetriever 作为基础检索器
assert self.retriever is not None, "retriever 未初始化"
base_retriever = self.retriever
else:
# 使用向量存储作为基础检索器
base_retriever = self.vector_store.as_langchain_vectorstore().as_retriever(search_kwargs={"k": k * 2})
# 创建多路查询检索器
multi_query_retriever = MultiQueryRetriever.from_llm(
retriever=base_retriever,
llm=llm,
include_original=True
)
# 设置自定义提示词以生成指定数量的查询
from langchain_core.prompts import PromptTemplate
multi_query_retriever.llm_chain.prompt = PromptTemplate.from_template(
"你是一个专业的查询改写助手。你的任务是将用户的问题改写成 {num_queries} 个不同的版本。\n"
"这些版本应该从不同的角度、使用不同的关键词来表达相同或相关的意图。\n\n"
"原始问题: {question}\n\n"
"请生成 {num_queries} 个不同版本的查询,每个版本一行。\n"
"确保每个版本都是独立、完整的查询语句。\n\n"
"生成 {num_queries} 个查询:"
)
# 修改调用参数以包含 num_queries
original_ainvoke = multi_query_retriever.llm_chain.ainvoke
async def new_ainvoke(input_dict):
input_dict["num_queries"] = num_queries
return await original_ainvoke(input_dict)
multi_query_retriever.llm_chain.ainvoke = new_ainvoke
# 执行检索
documents = await multi_query_retriever.ainvoke(query)
# 去重并限制数量
seen_content = set()
unique_documents = []
for doc in documents:
content = doc.page_content
if content not in seen_content:
seen_content.add(content)
unique_documents.append(doc)
if len(unique_documents) >= k:
break
logger.info(f"RAG-Fusion 检索完成: 原始 {len(documents)} 个结果,去重后 {len(unique_documents)} 个结果")
return unique_documents
def get_retriever(self) -> ParentDocumentRetriever:
"""
直接获取 ParentDocumentRetriever 实例。
适用于需要在 IndexBuilder 外部访问检索器的高级用例。
"""
if self.splitter_type != SplitterType.PARENT_CHILD:
raise RuntimeError(
"get_retriever() 仅在 PARENT_CHILD 切分器下可用。"
"请使用 search() 或 search_with_parent_context() 进行标准检索。"
)
assert self.retriever is not None, "retriever 未初始化"
return self.retriever
def get_child_splitter(self) -> Union["RecursiveCharacterTextSplitter", "SemanticChunker", "SemanticChunkerAdapter"]:
"""获取子块切分器以便重新配置。"""
if self.splitter_type != SplitterType.PARENT_CHILD:
return self.splitter # type: ignore
return self.child_splitter
def get_parent_splitter(self) -> "RecursiveCharacterTextSplitter":
"""获取父块切分器以便重新配置。"""
if self.splitter_type != SplitterType.PARENT_CHILD:
raise RuntimeError(
"父块切分器仅在 PARENT_CHILD 切分器下可用。"
)
return self.parent_splitter
def get_docstore(self) -> BaseStore:
"""获取父块的文档存储。"""
if self.splitter_type != SplitterType.PARENT_CHILD:
raise RuntimeError(
"文档存储仅在 PARENT_CHILD 切分器下可用。"
)
assert self.docstore is not None, "docstore 未初始化"
return self.docstore
def get_docstore_path(self) -> Optional[str]:
"""获取文档存储路径(已弃用,仅用于兼容性)。"""
if self.splitter_type != SplitterType.PARENT_CHILD:
raise RuntimeError(
"文档存储路径仅在 PARENT_CHILD 切分器下可用。"
)
# PostgreSQL 存储没有 persist_path返回 None
return None
def close(self):
"""关闭资源。"""
if self.docstore is not None and hasattr(self.docstore, "aclose"):
import asyncio
asyncio.get_event_loop().run_until_complete(self.docstore.aclose()) # type: ignore
logger.info("PostgreSQL 异步连接池已关闭")
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
return False
# 需要导入 RecursiveCharacterTextSplitter
from langchain_text_splitters import RecursiveCharacterTextSplitter
# 示例用法已移除,请参考文档