Files
ailine/rag_indexer/builder.py

393 lines
16 KiB
Python
Raw Normal View History

2026-04-18 16:56:23 +08:00
"""
2026-04-19 15:01:40 +08:00
离线 RAG 索引构建核心流水线
2026-04-18 16:56:23 +08:00
2026-04-19 15:01:40 +08:00
支持 LangChain ParentDocumentRetriever 用于父子块切分
2026-04-18 16:56:23 +08:00
"""
2026-04-19 15:01:40 +08:00
import asyncio
2026-04-18 16:56:23 +08:00
import logging
from pathlib import Path
2026-04-19 15:01:40 +08:00
from typing import List, Union, Optional, Tuple, Any
2026-04-18 16:56:23 +08:00
from dataclasses import dataclass
2026-04-19 15:01:40 +08:00
from httpx import RemoteProtocolError
2026-04-18 16:56:23 +08:00
from langchain_core.documents import Document
2026-04-19 15:01:40 +08:00
from langchain_classic.retrievers import ParentDocumentRetriever
from langchain_core.stores import BaseStore
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_experimental.text_splitter import SemanticChunker
2026-04-18 16:56:23 +08:00
from .loaders import DocumentLoader
2026-04-19 15:01:40 +08:00
from .splitters import SplitterType, get_splitter, ParentChildSplitter, SemanticChunkerAdapter
2026-04-18 16:56:23 +08:00
from .embedders import LlamaCppEmbedder
2026-04-19 15:01:40 +08:00
from .vector_store import QdrantVectorStore, QDRANT_URL, QDRANT_API_KEY
from .store import create_docstore
2026-04-18 16:56:23 +08:00
logger = logging.getLogger(__name__)
@dataclass
class ParentChildConfig:
2026-04-19 15:01:40 +08:00
"""父子块切分配置。"""
2026-04-18 16:56:23 +08:00
parent_chunk_size: int = 1000
child_chunk_size: int = 200
parent_chunk_overlap: int = 100
child_chunk_overlap: int = 20
search_k: int = 5
2026-04-19 15:01:40 +08:00
docstore_path: Optional[str] = None
2026-04-18 16:56:23 +08:00
docstore_type: str = "local"
2026-04-19 15:01:40 +08:00
docstore_conn_string: Optional[str] = None
2026-04-18 16:56:23 +08:00
class IndexBuilder:
2026-04-19 15:01:40 +08:00
"""RAG 索引构建主流水线。"""
# 类型注解
parent_splitter: "RecursiveCharacterTextSplitter"
child_splitter: Union["RecursiveCharacterTextSplitter", "SemanticChunker", "SemanticChunkerAdapter"]
docstore: Optional["BaseStore"]
_docstore_conn: Optional[str]
retriever: Optional["ParentDocumentRetriever"]
vector_store_obj: Any
2026-04-18 16:56:23 +08:00
def __init__(
self,
collection_name: str = "rag_documents",
2026-04-19 15:01:40 +08:00
splitter_type: SplitterType = SplitterType.PARENT_CHILD,
docstore=None,
2026-04-18 16:56:23 +08:00
**splitter_kwargs,
):
self.collection_name = collection_name
self.splitter_type = splitter_type
self.splitter_kwargs = splitter_kwargs
2026-04-19 15:01:40 +08:00
self.docstore = docstore # 从外部注入
2026-04-18 16:56:23 +08:00
2026-04-19 15:01:40 +08:00
# 组件
2026-04-18 16:56:23 +08:00
self.loader = DocumentLoader()
self.embedder = LlamaCppEmbedder()
self.embeddings = self.embedder.as_langchain_embeddings()
self.vector_store = QdrantVectorStore(
collection_name=collection_name,
embeddings=self.embeddings,
)
2026-04-19 15:01:40 +08:00
# 切分器(父子块单独处理)
2026-04-18 16:56:23 +08:00
if splitter_type != SplitterType.PARENT_CHILD:
if splitter_type == SplitterType.SEMANTIC:
splitter_kwargs["embeddings"] = self.embeddings
self.splitter = get_splitter(splitter_type, **splitter_kwargs)
else:
self.splitter = None
2026-04-19 15:01:40 +08:00
# 为父子块切分初始化 ParentDocumentRetriever
2026-04-18 16:56:23 +08:00
self._init_parent_child_retriever()
def _init_parent_child_retriever(self, **kwargs):
"""
2026-04-19 15:01:40 +08:00
初始化 ParentDocumentRetriever 用于父子块切分
支持动态语义切分与父子块策略结合
- 父块使用递归切分大块提供上下文
- 子块可以使用递归切分或语义切分根据语义动态切分提高检索精度
2026-04-18 16:56:23 +08:00
2026-04-19 15:01:40 +08:00
替代自定义的 ParentChildSplitter 逻辑
2026-04-18 16:56:23 +08:00
"""
2026-04-19 15:01:40 +08:00
# 解析父子块配置参数
2026-04-18 16:56:23 +08:00
parent_size = kwargs.get("parent_chunk_size", 1000)
child_size = kwargs.get("child_chunk_size", 200)
parent_overlap = kwargs.get("parent_chunk_overlap", kwargs.get("chunk_overlap", 100))
child_overlap = kwargs.get("child_chunk_overlap", kwargs.get("chunk_overlap", 20))
2026-04-19 15:01:40 +08:00
# 子块切分器类型,默认为语义切分
child_splitter_type = kwargs.get("child_splitter_type", SplitterType.SEMANTIC)
2026-04-18 16:56:23 +08:00
2026-04-19 15:01:40 +08:00
# 定义父块切分器(始终使用递归切分)
2026-04-18 16:56:23 +08:00
self.parent_splitter = RecursiveCharacterTextSplitter(
chunk_size=parent_size,
chunk_overlap=parent_overlap,
)
2026-04-19 15:01:40 +08:00
# 定义子块切分器(根据类型选择)
if child_splitter_type == SplitterType.SEMANTIC:
self.child_splitter = get_splitter(
SplitterType.SEMANTIC,
embeddings=self.embeddings,
)
logger.info(f"子块使用语义切分器")
else:
# 默认使用递归切分
self.child_splitter = RecursiveCharacterTextSplitter(
chunk_size=child_size,
chunk_overlap=child_overlap,
)
logger.info(f"子块使用递归切分器,块大小: {child_size},重叠: {child_overlap}")
2026-04-18 16:56:23 +08:00
2026-04-19 15:01:40 +08:00
# 向量存储(用于子块)
2026-04-18 16:56:23 +08:00
self.vector_store_obj = self.vector_store.get_langchain_vectorstore()
2026-04-19 15:01:40 +08:00
# 文档存储(用于父块)
if self.docstore is None:
# 如果没有外部注入 docstore则使用 PostgreSQL 创建
docstore_conn = kwargs.get("docstore_conn_string")
pool_config = kwargs.get("pool_config")
max_concurrency = kwargs.get("max_concurrency")
# 使用 create_docstore 创建 PostgreSQL 存储
self.docstore, self._docstore_conn = create_docstore(
connection_string=docstore_conn,
pool_config=pool_config,
max_concurrency=max_concurrency
)
2026-04-18 16:56:23 +08:00
else:
2026-04-19 15:01:40 +08:00
# 使用外部注入的 docstore
2026-04-18 16:56:23 +08:00
self._docstore_conn = None
2026-04-19 15:01:40 +08:00
# 创建检索器
2026-04-18 16:56:23 +08:00
self.retriever = ParentDocumentRetriever(
vectorstore=self.vector_store_obj,
docstore=self.docstore,
2026-04-19 15:01:40 +08:00
child_splitter=self.child_splitter, # type: ignore
2026-04-18 16:56:23 +08:00
parent_splitter=self.parent_splitter,
search_kwargs={"k": kwargs.get("search_k", 5)},
)
2026-04-19 15:01:40 +08:00
logger.info(f"ParentDocumentRetriever 已初始化,父块大小: {parent_size},子块类型: {child_splitter_type}")
2026-04-18 16:56:23 +08:00
2026-04-19 15:01:40 +08:00
async def build_from_file(self, file_path: Union[str, Path]) -> int:
logger.info("加载文件: %s", file_path)
2026-04-18 16:56:23 +08:00
documents = self.loader.load_file(file_path)
2026-04-19 15:01:40 +08:00
logger.info("已加载 %d 个文档", len(documents))
return await self._process_documents(documents)
2026-04-18 16:56:23 +08:00
2026-04-19 15:01:40 +08:00
async def build_from_directory(self, directory_path: Union[str, Path], recursive: bool = True) -> int:
logger.info("加载目录: %s (递归=%s)", directory_path, recursive)
2026-04-18 16:56:23 +08:00
documents = self.loader.load_directory(directory_path, recursive=recursive)
2026-04-19 15:01:40 +08:00
logger.info("已从目录加载 %d 个文档", len(documents))
return await self._process_documents(documents)
2026-04-18 16:56:23 +08:00
2026-04-19 15:01:40 +08:00
async def _process_documents(self, documents: List[Document]) -> int:
2026-04-18 16:56:23 +08:00
if not documents:
2026-04-19 15:01:40 +08:00
logger.warning("没有文档需要处理")
2026-04-18 16:56:23 +08:00
return 0
if self.splitter_type == SplitterType.PARENT_CHILD:
2026-04-19 15:01:40 +08:00
logger.info("使用 LangChain ParentDocumentRetriever")
2026-04-18 16:56:23 +08:00
2026-04-19 15:01:40 +08:00
# 确保集合存在(用于子块)
2026-04-18 16:56:23 +08:00
self.vector_store.create_collection()
2026-04-19 15:01:40 +08:00
# 分批处理,避免单次请求过大
assert self.retriever is not None, "retriever 未初始化"
batch_size = 10 # 每次处理10个文档
total = len(documents)
processed = 0
for i in range(0, total, batch_size):
batch = documents[i:i + batch_size]
max_retries = 3
for attempt in range(max_retries):
try:
await self.retriever.aadd_documents(batch)
processed += len(batch)
logger.info(f"批次 {i//batch_size + 1}: 已处理 {processed}/{total}")
break
except (RemoteProtocolError, ConnectionError, OSError) as e:
if attempt == max_retries - 1:
raise
logger.warning(f"连接断开,重试 ({attempt+1}/{max_retries}): {e}")
self.vector_store.refresh_client()
await asyncio.sleep(1)
2026-04-18 16:56:23 +08:00
logger.info(
2026-04-19 15:01:40 +08:00
"已使用 ParentDocumentRetriever 索引: "
f"{processed} 个父块"
2026-04-18 16:56:23 +08:00
)
2026-04-19 15:01:40 +08:00
return processed
2026-04-18 16:56:23 +08:00
else:
2026-04-19 15:01:40 +08:00
logger.info("使用 %s 切分文档", self.splitter_type)
# 当 splitter_type 不是 PARENT_CHILD 时splitter 一定不为 None
assert self.splitter is not None, "splitter 未初始化"
2026-04-18 16:56:23 +08:00
chunks = self.splitter.split_documents(documents)
2026-04-19 15:01:40 +08:00
logger.info("已切分为 %d 个块", len(chunks))
2026-04-18 16:56:23 +08:00
self.vector_store.create_collection()
self.vector_store.add_documents(chunks)
return len(chunks)
def get_collection_info(self):
return self.vector_store.get_collection_info()
def search(self, query: str, k: int = 5) -> List[Document]:
2026-04-19 15:01:40 +08:00
"""标准搜索 - 返回子块。"""
2026-04-18 16:56:23 +08:00
return self.vector_store.similarity_search(query, k=k)
2026-04-19 15:01:40 +08:00
async def search_with_parent_context(self, query: str, k: int = 5) -> List[Document]:
2026-04-18 16:56:23 +08:00
"""
2026-04-19 15:01:40 +08:00
带父块上下文的搜索 - 返回完整父块
2026-04-18 16:56:23 +08:00
2026-04-19 15:01:40 +08:00
这是使用父子块切分时的主要检索方法
2026-04-18 16:56:23 +08:00
"""
if self.splitter_type != SplitterType.PARENT_CHILD:
raise RuntimeError(
2026-04-19 15:01:40 +08:00
"search_with_parent_context 仅在 PARENT_CHILD 切分器下可用。"
"请使用 search() 进行标准检索。"
2026-04-18 16:56:23 +08:00
)
2026-04-19 15:01:40 +08:00
assert self.retriever is not None, "retriever 未初始化"
return await self.retriever.ainvoke(query, config={"k": k}) # type: ignore
2026-04-18 16:56:23 +08:00
2026-04-19 15:01:40 +08:00
async def retrieve(self, query: str, return_parent: bool = True) -> List[Document]:
2026-04-18 16:56:23 +08:00
"""
2026-04-19 15:01:40 +08:00
统一检索接口
2026-04-18 16:56:23 +08:00
Args:
2026-04-19 15:01:40 +08:00
query: 搜索查询
return_parent: 如果为 True 且使用父子块切分返回父块
如果为 False始终返回子块
2026-04-18 16:56:23 +08:00
Returns:
2026-04-19 15:01:40 +08:00
相关文档列表
2026-04-18 16:56:23 +08:00
"""
if self.splitter_type == SplitterType.PARENT_CHILD and return_parent:
2026-04-19 15:01:40 +08:00
return await self.search_with_parent_context(query)
2026-04-18 16:56:23 +08:00
else:
return self.search(query)
2026-04-19 15:01:40 +08:00
async def retrieve_with_fusion(self, query: str, llm: Any, num_queries: int = 3, k: int = 5, return_parent: bool = True) -> List[Document]:
"""
使用 RAG-Fusion 进行检索多路查询改写 + 倒数排名融合
核心原理:
1. 多路查询改写: 利用 LLM 将原始查询改写成多个不同表述
2. 倒数排名融合: 对每个改写查询的结果进行 RRF 融合避免单一检索结果主导
Args:
query: 原始搜索查询
llm: 语言模型实例用于查询改写
num_queries: 生成的查询数量
k: 返回的文档数量
return_parent: 如果为 True 且使用父子块切分返回父块
如果为 False始终返回子块
Returns:
经过融合后的相关文档列表
"""
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain.retrievers import EnsembleRetriever
if self.splitter_type == SplitterType.PARENT_CHILD and return_parent:
# 使用 ParentDocumentRetriever 作为基础检索器
assert self.retriever is not None, "retriever 未初始化"
base_retriever = self.retriever
else:
# 使用向量存储作为基础检索器
base_retriever = self.vector_store.as_langchain_vectorstore().as_retriever(search_kwargs={"k": k * 2})
# 创建多路查询检索器
multi_query_retriever = MultiQueryRetriever.from_llm(
retriever=base_retriever,
llm=llm,
include_original=True
)
# 设置自定义提示词以生成指定数量的查询
from langchain_core.prompts import PromptTemplate
multi_query_retriever.llm_chain.prompt = PromptTemplate.from_template(
"你是一个专业的查询改写助手。你的任务是将用户的问题改写成 {num_queries} 个不同的版本。\n"
"这些版本应该从不同的角度、使用不同的关键词来表达相同或相关的意图。\n\n"
"原始问题: {question}\n\n"
"请生成 {num_queries} 个不同版本的查询,每个版本一行。\n"
"确保每个版本都是独立、完整的查询语句。\n\n"
"生成 {num_queries} 个查询:"
)
# 修改调用参数以包含 num_queries
original_ainvoke = multi_query_retriever.llm_chain.ainvoke
async def new_ainvoke(input_dict):
input_dict["num_queries"] = num_queries
return await original_ainvoke(input_dict)
multi_query_retriever.llm_chain.ainvoke = new_ainvoke
# 执行检索
documents = await multi_query_retriever.ainvoke(query)
# 去重并限制数量
seen_content = set()
unique_documents = []
for doc in documents:
content = doc.page_content
if content not in seen_content:
seen_content.add(content)
unique_documents.append(doc)
if len(unique_documents) >= k:
break
logger.info(f"RAG-Fusion 检索完成: 原始 {len(documents)} 个结果,去重后 {len(unique_documents)} 个结果")
return unique_documents
2026-04-18 16:56:23 +08:00
def get_retriever(self) -> ParentDocumentRetriever:
"""
2026-04-19 15:01:40 +08:00
直接获取 ParentDocumentRetriever 实例
2026-04-18 16:56:23 +08:00
2026-04-19 15:01:40 +08:00
适用于需要在 IndexBuilder 外部访问检索器的高级用例
2026-04-18 16:56:23 +08:00
"""
if self.splitter_type != SplitterType.PARENT_CHILD:
raise RuntimeError(
2026-04-19 15:01:40 +08:00
"get_retriever() 仅在 PARENT_CHILD 切分器下可用。"
"请使用 search() 或 search_with_parent_context() 进行标准检索。"
2026-04-18 16:56:23 +08:00
)
2026-04-19 15:01:40 +08:00
assert self.retriever is not None, "retriever 未初始化"
2026-04-18 16:56:23 +08:00
return self.retriever
2026-04-19 15:01:40 +08:00
def get_child_splitter(self) -> Union["RecursiveCharacterTextSplitter", "SemanticChunker", "SemanticChunkerAdapter"]:
"""获取子块切分器以便重新配置。"""
2026-04-18 16:56:23 +08:00
if self.splitter_type != SplitterType.PARENT_CHILD:
2026-04-19 15:01:40 +08:00
return self.splitter # type: ignore
2026-04-18 16:56:23 +08:00
return self.child_splitter
def get_parent_splitter(self) -> "RecursiveCharacterTextSplitter":
2026-04-19 15:01:40 +08:00
"""获取父块切分器以便重新配置。"""
2026-04-18 16:56:23 +08:00
if self.splitter_type != SplitterType.PARENT_CHILD:
raise RuntimeError(
2026-04-19 15:01:40 +08:00
"父块切分器仅在 PARENT_CHILD 切分器下可用。"
2026-04-18 16:56:23 +08:00
)
return self.parent_splitter
def get_docstore(self) -> BaseStore:
2026-04-19 15:01:40 +08:00
"""获取父块的文档存储。"""
2026-04-18 16:56:23 +08:00
if self.splitter_type != SplitterType.PARENT_CHILD:
raise RuntimeError(
2026-04-19 15:01:40 +08:00
"文档存储仅在 PARENT_CHILD 切分器下可用。"
2026-04-18 16:56:23 +08:00
)
2026-04-19 15:01:40 +08:00
assert self.docstore is not None, "docstore 未初始化"
2026-04-18 16:56:23 +08:00
return self.docstore
2026-04-19 15:01:40 +08:00
def get_docstore_path(self) -> Optional[str]:
"""获取文档存储路径(已弃用,仅用于兼容性)。"""
2026-04-18 16:56:23 +08:00
if self.splitter_type != SplitterType.PARENT_CHILD:
raise RuntimeError(
2026-04-19 15:01:40 +08:00
"文档存储路径仅在 PARENT_CHILD 切分器下可用。"
2026-04-18 16:56:23 +08:00
)
2026-04-19 15:01:40 +08:00
# PostgreSQL 存储没有 persist_path返回 None
return None
2026-04-18 16:56:23 +08:00
def close(self):
2026-04-19 15:01:40 +08:00
"""关闭资源。"""
if self.docstore is not None and hasattr(self.docstore, "aclose"):
import asyncio
asyncio.get_event_loop().run_until_complete(self.docstore.aclose()) # type: ignore
logger.info("PostgreSQL 异步连接池已关闭")
2026-04-18 16:56:23 +08:00
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
return False
2026-04-19 15:01:40 +08:00
# 需要导入 RecursiveCharacterTextSplitter
2026-04-18 16:56:23 +08:00
from langchain_text_splitters import RecursiveCharacterTextSplitter
2026-04-19 15:01:40 +08:00
# 示例用法已移除,请参考文档