393 lines
16 KiB
Python
393 lines
16 KiB
Python
"""
|
||
离线 RAG 索引构建核心流水线。
|
||
|
||
支持 LangChain 的 ParentDocumentRetriever 用于父子块切分。
|
||
"""
|
||
|
||
import asyncio
|
||
import logging
|
||
from pathlib import Path
|
||
from typing import List, Union, Optional, Tuple, Any
|
||
from dataclasses import dataclass
|
||
|
||
from httpx import RemoteProtocolError
|
||
from langchain_core.documents import Document
|
||
from langchain_classic.retrievers import ParentDocumentRetriever
|
||
from langchain_core.stores import BaseStore
|
||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||
from langchain_experimental.text_splitter import SemanticChunker
|
||
|
||
from .loaders import DocumentLoader
|
||
from .splitters import SplitterType, get_splitter, ParentChildSplitter, SemanticChunkerAdapter
|
||
from .embedders import LlamaCppEmbedder
|
||
from .vector_store import QdrantVectorStore, QDRANT_URL, QDRANT_API_KEY
|
||
from .store import create_docstore
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
@dataclass
|
||
class ParentChildConfig:
|
||
"""父子块切分配置。"""
|
||
parent_chunk_size: int = 1000
|
||
child_chunk_size: int = 200
|
||
parent_chunk_overlap: int = 100
|
||
child_chunk_overlap: int = 20
|
||
search_k: int = 5
|
||
docstore_path: Optional[str] = None
|
||
docstore_type: str = "local"
|
||
docstore_conn_string: Optional[str] = None
|
||
|
||
|
||
class IndexBuilder:
|
||
"""RAG 索引构建主流水线。"""
|
||
|
||
# 类型注解
|
||
parent_splitter: "RecursiveCharacterTextSplitter"
|
||
child_splitter: Union["RecursiveCharacterTextSplitter", "SemanticChunker", "SemanticChunkerAdapter"]
|
||
docstore: Optional["BaseStore"]
|
||
_docstore_conn: Optional[str]
|
||
retriever: Optional["ParentDocumentRetriever"]
|
||
vector_store_obj: Any
|
||
|
||
def __init__(
|
||
self,
|
||
collection_name: str = "rag_documents",
|
||
splitter_type: SplitterType = SplitterType.PARENT_CHILD,
|
||
docstore=None,
|
||
**splitter_kwargs,
|
||
):
|
||
self.collection_name = collection_name
|
||
self.splitter_type = splitter_type
|
||
self.splitter_kwargs = splitter_kwargs
|
||
self.docstore = docstore # 从外部注入
|
||
|
||
# 组件
|
||
self.loader = DocumentLoader()
|
||
self.embedder = LlamaCppEmbedder()
|
||
self.embeddings = self.embedder.as_langchain_embeddings()
|
||
|
||
self.vector_store = QdrantVectorStore(
|
||
collection_name=collection_name,
|
||
embeddings=self.embeddings,
|
||
)
|
||
|
||
# 切分器(父子块单独处理)
|
||
if splitter_type != SplitterType.PARENT_CHILD:
|
||
if splitter_type == SplitterType.SEMANTIC:
|
||
splitter_kwargs["embeddings"] = self.embeddings
|
||
self.splitter = get_splitter(splitter_type, **splitter_kwargs)
|
||
else:
|
||
self.splitter = None
|
||
# 为父子块切分初始化 ParentDocumentRetriever
|
||
self._init_parent_child_retriever()
|
||
|
||
def _init_parent_child_retriever(self, **kwargs):
|
||
"""
|
||
初始化 ParentDocumentRetriever 用于父子块切分。
|
||
|
||
支持动态语义切分与父子块策略结合:
|
||
- 父块使用递归切分(大块,提供上下文)
|
||
- 子块可以使用递归切分或语义切分(根据语义动态切分,提高检索精度)
|
||
|
||
替代自定义的 ParentChildSplitter 逻辑。
|
||
"""
|
||
# 解析父子块配置参数
|
||
parent_size = kwargs.get("parent_chunk_size", 1000)
|
||
child_size = kwargs.get("child_chunk_size", 200)
|
||
parent_overlap = kwargs.get("parent_chunk_overlap", kwargs.get("chunk_overlap", 100))
|
||
child_overlap = kwargs.get("child_chunk_overlap", kwargs.get("chunk_overlap", 20))
|
||
|
||
# 子块切分器类型,默认为语义切分
|
||
child_splitter_type = kwargs.get("child_splitter_type", SplitterType.SEMANTIC)
|
||
|
||
# 定义父块切分器(始终使用递归切分)
|
||
self.parent_splitter = RecursiveCharacterTextSplitter(
|
||
chunk_size=parent_size,
|
||
chunk_overlap=parent_overlap,
|
||
)
|
||
|
||
# 定义子块切分器(根据类型选择)
|
||
if child_splitter_type == SplitterType.SEMANTIC:
|
||
self.child_splitter = get_splitter(
|
||
SplitterType.SEMANTIC,
|
||
embeddings=self.embeddings,
|
||
)
|
||
logger.info(f"子块使用语义切分器")
|
||
else:
|
||
# 默认使用递归切分
|
||
self.child_splitter = RecursiveCharacterTextSplitter(
|
||
chunk_size=child_size,
|
||
chunk_overlap=child_overlap,
|
||
)
|
||
logger.info(f"子块使用递归切分器,块大小: {child_size},重叠: {child_overlap}")
|
||
|
||
# 向量存储(用于子块)
|
||
self.vector_store_obj = self.vector_store.get_langchain_vectorstore()
|
||
|
||
# 文档存储(用于父块)
|
||
if self.docstore is None:
|
||
# 如果没有外部注入 docstore,则使用 PostgreSQL 创建
|
||
docstore_conn = kwargs.get("docstore_conn_string")
|
||
pool_config = kwargs.get("pool_config")
|
||
max_concurrency = kwargs.get("max_concurrency")
|
||
|
||
# 使用 create_docstore 创建 PostgreSQL 存储
|
||
self.docstore, self._docstore_conn = create_docstore(
|
||
connection_string=docstore_conn,
|
||
pool_config=pool_config,
|
||
max_concurrency=max_concurrency
|
||
)
|
||
else:
|
||
# 使用外部注入的 docstore
|
||
self._docstore_conn = None
|
||
|
||
# 创建检索器
|
||
self.retriever = ParentDocumentRetriever(
|
||
vectorstore=self.vector_store_obj,
|
||
docstore=self.docstore,
|
||
child_splitter=self.child_splitter, # type: ignore
|
||
parent_splitter=self.parent_splitter,
|
||
search_kwargs={"k": kwargs.get("search_k", 5)},
|
||
)
|
||
logger.info(f"ParentDocumentRetriever 已初始化,父块大小: {parent_size},子块类型: {child_splitter_type}")
|
||
|
||
async def build_from_file(self, file_path: Union[str, Path]) -> int:
|
||
logger.info("加载文件: %s", file_path)
|
||
documents = self.loader.load_file(file_path)
|
||
logger.info("已加载 %d 个文档", len(documents))
|
||
return await self._process_documents(documents)
|
||
|
||
async def build_from_directory(self, directory_path: Union[str, Path], recursive: bool = True) -> int:
|
||
logger.info("加载目录: %s (递归=%s)", directory_path, recursive)
|
||
documents = self.loader.load_directory(directory_path, recursive=recursive)
|
||
logger.info("已从目录加载 %d 个文档", len(documents))
|
||
return await self._process_documents(documents)
|
||
|
||
async def _process_documents(self, documents: List[Document]) -> int:
|
||
if not documents:
|
||
logger.warning("没有文档需要处理")
|
||
return 0
|
||
|
||
if self.splitter_type == SplitterType.PARENT_CHILD:
|
||
logger.info("使用 LangChain ParentDocumentRetriever")
|
||
|
||
# 确保集合存在(用于子块)
|
||
self.vector_store.create_collection()
|
||
|
||
# 分批处理,避免单次请求过大
|
||
assert self.retriever is not None, "retriever 未初始化"
|
||
batch_size = 10 # 每次处理10个文档
|
||
total = len(documents)
|
||
processed = 0
|
||
|
||
for i in range(0, total, batch_size):
|
||
batch = documents[i:i + batch_size]
|
||
max_retries = 3
|
||
for attempt in range(max_retries):
|
||
try:
|
||
await self.retriever.aadd_documents(batch)
|
||
processed += len(batch)
|
||
logger.info(f"批次 {i//batch_size + 1}: 已处理 {processed}/{total}")
|
||
break
|
||
except (RemoteProtocolError, ConnectionError, OSError) as e:
|
||
if attempt == max_retries - 1:
|
||
raise
|
||
logger.warning(f"连接断开,重试 ({attempt+1}/{max_retries}): {e}")
|
||
self.vector_store.refresh_client()
|
||
await asyncio.sleep(1)
|
||
|
||
logger.info(
|
||
"已使用 ParentDocumentRetriever 索引: "
|
||
f"共 {processed} 个父块"
|
||
)
|
||
return processed
|
||
|
||
else:
|
||
logger.info("使用 %s 切分文档", self.splitter_type)
|
||
# 当 splitter_type 不是 PARENT_CHILD 时,splitter 一定不为 None
|
||
assert self.splitter is not None, "splitter 未初始化"
|
||
chunks = self.splitter.split_documents(documents)
|
||
logger.info("已切分为 %d 个块", len(chunks))
|
||
|
||
self.vector_store.create_collection()
|
||
self.vector_store.add_documents(chunks)
|
||
return len(chunks)
|
||
|
||
def get_collection_info(self):
|
||
return self.vector_store.get_collection_info()
|
||
|
||
def search(self, query: str, k: int = 5) -> List[Document]:
|
||
"""标准搜索 - 返回子块。"""
|
||
return self.vector_store.similarity_search(query, k=k)
|
||
|
||
async def search_with_parent_context(self, query: str, k: int = 5) -> List[Document]:
|
||
"""
|
||
带父块上下文的搜索 - 返回完整父块。
|
||
|
||
这是使用父子块切分时的主要检索方法。
|
||
"""
|
||
if self.splitter_type != SplitterType.PARENT_CHILD:
|
||
raise RuntimeError(
|
||
"search_with_parent_context 仅在 PARENT_CHILD 切分器下可用。"
|
||
"请使用 search() 进行标准检索。"
|
||
)
|
||
assert self.retriever is not None, "retriever 未初始化"
|
||
return await self.retriever.ainvoke(query, config={"k": k}) # type: ignore
|
||
|
||
async def retrieve(self, query: str, return_parent: bool = True) -> List[Document]:
|
||
"""
|
||
统一检索接口。
|
||
|
||
Args:
|
||
query: 搜索查询
|
||
return_parent: 如果为 True 且使用父子块切分,返回父块
|
||
如果为 False,始终返回子块
|
||
|
||
Returns:
|
||
相关文档列表
|
||
"""
|
||
if self.splitter_type == SplitterType.PARENT_CHILD and return_parent:
|
||
return await self.search_with_parent_context(query)
|
||
else:
|
||
return self.search(query)
|
||
|
||
async def retrieve_with_fusion(self, query: str, llm: Any, num_queries: int = 3, k: int = 5, return_parent: bool = True) -> List[Document]:
|
||
"""
|
||
使用 RAG-Fusion 进行检索(多路查询改写 + 倒数排名融合)。
|
||
|
||
核心原理:
|
||
1. 多路查询改写: 利用 LLM 将原始查询改写成多个不同表述
|
||
2. 倒数排名融合: 对每个改写查询的结果进行 RRF 融合,避免单一检索结果主导
|
||
|
||
Args:
|
||
query: 原始搜索查询
|
||
llm: 语言模型实例(用于查询改写)
|
||
num_queries: 生成的查询数量
|
||
k: 返回的文档数量
|
||
return_parent: 如果为 True 且使用父子块切分,返回父块
|
||
如果为 False,始终返回子块
|
||
|
||
Returns:
|
||
经过融合后的相关文档列表
|
||
"""
|
||
from langchain.retrievers.multi_query import MultiQueryRetriever
|
||
from langchain.retrievers import EnsembleRetriever
|
||
|
||
if self.splitter_type == SplitterType.PARENT_CHILD and return_parent:
|
||
# 使用 ParentDocumentRetriever 作为基础检索器
|
||
assert self.retriever is not None, "retriever 未初始化"
|
||
base_retriever = self.retriever
|
||
else:
|
||
# 使用向量存储作为基础检索器
|
||
base_retriever = self.vector_store.as_langchain_vectorstore().as_retriever(search_kwargs={"k": k * 2})
|
||
|
||
# 创建多路查询检索器
|
||
multi_query_retriever = MultiQueryRetriever.from_llm(
|
||
retriever=base_retriever,
|
||
llm=llm,
|
||
include_original=True
|
||
)
|
||
|
||
# 设置自定义提示词以生成指定数量的查询
|
||
from langchain_core.prompts import PromptTemplate
|
||
multi_query_retriever.llm_chain.prompt = PromptTemplate.from_template(
|
||
"你是一个专业的查询改写助手。你的任务是将用户的问题改写成 {num_queries} 个不同的版本。\n"
|
||
"这些版本应该从不同的角度、使用不同的关键词来表达相同或相关的意图。\n\n"
|
||
"原始问题: {question}\n\n"
|
||
"请生成 {num_queries} 个不同版本的查询,每个版本一行。\n"
|
||
"确保每个版本都是独立、完整的查询语句。\n\n"
|
||
"生成 {num_queries} 个查询:"
|
||
)
|
||
|
||
# 修改调用参数以包含 num_queries
|
||
original_ainvoke = multi_query_retriever.llm_chain.ainvoke
|
||
async def new_ainvoke(input_dict):
|
||
input_dict["num_queries"] = num_queries
|
||
return await original_ainvoke(input_dict)
|
||
multi_query_retriever.llm_chain.ainvoke = new_ainvoke
|
||
|
||
# 执行检索
|
||
documents = await multi_query_retriever.ainvoke(query)
|
||
|
||
# 去重并限制数量
|
||
seen_content = set()
|
||
unique_documents = []
|
||
for doc in documents:
|
||
content = doc.page_content
|
||
if content not in seen_content:
|
||
seen_content.add(content)
|
||
unique_documents.append(doc)
|
||
if len(unique_documents) >= k:
|
||
break
|
||
|
||
logger.info(f"RAG-Fusion 检索完成: 原始 {len(documents)} 个结果,去重后 {len(unique_documents)} 个结果")
|
||
return unique_documents
|
||
|
||
def get_retriever(self) -> ParentDocumentRetriever:
|
||
"""
|
||
直接获取 ParentDocumentRetriever 实例。
|
||
|
||
适用于需要在 IndexBuilder 外部访问检索器的高级用例。
|
||
"""
|
||
if self.splitter_type != SplitterType.PARENT_CHILD:
|
||
raise RuntimeError(
|
||
"get_retriever() 仅在 PARENT_CHILD 切分器下可用。"
|
||
"请使用 search() 或 search_with_parent_context() 进行标准检索。"
|
||
)
|
||
assert self.retriever is not None, "retriever 未初始化"
|
||
return self.retriever
|
||
|
||
def get_child_splitter(self) -> Union["RecursiveCharacterTextSplitter", "SemanticChunker", "SemanticChunkerAdapter"]:
|
||
"""获取子块切分器以便重新配置。"""
|
||
if self.splitter_type != SplitterType.PARENT_CHILD:
|
||
return self.splitter # type: ignore
|
||
return self.child_splitter
|
||
|
||
def get_parent_splitter(self) -> "RecursiveCharacterTextSplitter":
|
||
"""获取父块切分器以便重新配置。"""
|
||
if self.splitter_type != SplitterType.PARENT_CHILD:
|
||
raise RuntimeError(
|
||
"父块切分器仅在 PARENT_CHILD 切分器下可用。"
|
||
)
|
||
return self.parent_splitter
|
||
|
||
def get_docstore(self) -> BaseStore:
|
||
"""获取父块的文档存储。"""
|
||
if self.splitter_type != SplitterType.PARENT_CHILD:
|
||
raise RuntimeError(
|
||
"文档存储仅在 PARENT_CHILD 切分器下可用。"
|
||
)
|
||
assert self.docstore is not None, "docstore 未初始化"
|
||
return self.docstore
|
||
|
||
def get_docstore_path(self) -> Optional[str]:
|
||
"""获取文档存储路径(已弃用,仅用于兼容性)。"""
|
||
if self.splitter_type != SplitterType.PARENT_CHILD:
|
||
raise RuntimeError(
|
||
"文档存储路径仅在 PARENT_CHILD 切分器下可用。"
|
||
)
|
||
# PostgreSQL 存储没有 persist_path,返回 None
|
||
return None
|
||
|
||
def close(self):
|
||
"""关闭资源。"""
|
||
if self.docstore is not None and hasattr(self.docstore, "aclose"):
|
||
import asyncio
|
||
asyncio.get_event_loop().run_until_complete(self.docstore.aclose()) # type: ignore
|
||
logger.info("PostgreSQL 异步连接池已关闭")
|
||
|
||
def __enter__(self):
|
||
return self
|
||
|
||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||
self.close()
|
||
return False
|
||
|
||
|
||
# 需要导入 RecursiveCharacterTextSplitter
|
||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||
|
||
|
||
# 示例用法已移除,请参考文档
|