RAG数据库生成

This commit is contained in:
2026-04-19 15:01:40 +08:00
parent c18e8a9860
commit cc8ef41ef9
17 changed files with 1089 additions and 577 deletions

View File

@@ -1,56 +1,68 @@
"""
Core pipeline builder for offline RAG index construction.
离线 RAG 索引构建核心流水线。
Now supports LangChain's ParentDocumentRetriever for parent-child chunking.
支持 LangChain ParentDocumentRetriever 用于父子块切分。
"""
import asyncio
import logging
from pathlib import Path
from typing import List, Union, Optional, Tuple
from typing import List, Union, Optional, Tuple, Any
from dataclasses import dataclass
from httpx import RemoteProtocolError
from langchain_core.documents import Document
from langchain.retrievers import ParentDocumentRetriever
from langchain.storage import LocalFileStore, BaseStore
from langchain_classic.retrievers import ParentDocumentRetriever
from langchain_core.stores import BaseStore
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_experimental.text_splitter import SemanticChunker
from .loaders import DocumentLoader
from .splitters import SplitterType, get_splitter, ParentChildSplitter
from .splitters import SplitterType, get_splitter, ParentChildSplitter, SemanticChunkerAdapter
from .embedders import LlamaCppEmbedder
from .vector_store import QdrantVectorStore
from .docstore_manager import get_docstore, PostgresDocStore, create_docstore
from .vector_store import QdrantVectorStore, QDRANT_URL, QDRANT_API_KEY
from .store import create_docstore
logger = logging.getLogger(__name__)
@dataclass
class ParentChildConfig:
"""Configuration for parent-child splitting."""
"""父子块切分配置。"""
parent_chunk_size: int = 1000
child_chunk_size: int = 200
parent_chunk_overlap: int = 100
child_chunk_overlap: int = 20
search_k: int = 5
docstore_path: str = None
docstore_path: Optional[str] = None
docstore_type: str = "local"
docstore_conn_string: str = None
docstore_conn_string: Optional[str] = None
class IndexBuilder:
"""Main pipeline for RAG index construction."""
"""RAG 索引构建主流水线。"""
# 类型注解
parent_splitter: "RecursiveCharacterTextSplitter"
child_splitter: Union["RecursiveCharacterTextSplitter", "SemanticChunker", "SemanticChunkerAdapter"]
docstore: Optional["BaseStore"]
_docstore_conn: Optional[str]
retriever: Optional["ParentDocumentRetriever"]
vector_store_obj: Any
def __init__(
self,
collection_name: str = "rag_documents",
qdrant_url: str = None,
splitter_type: SplitterType = SplitterType.RECURSIVE,
splitter_type: SplitterType = SplitterType.PARENT_CHILD,
docstore=None,
**splitter_kwargs,
):
self.collection_name = collection_name
self.qdrant_url = qdrant_url
self.splitter_type = splitter_type
self.splitter_kwargs = splitter_kwargs
self.docstore = docstore # 从外部注入
# Components
# 组件
self.loader = DocumentLoader()
self.embedder = LlamaCppEmbedder()
self.embeddings = self.embedder.as_langchain_embeddings()
@@ -58,104 +70,145 @@ class IndexBuilder:
self.vector_store = QdrantVectorStore(
collection_name=collection_name,
embeddings=self.embeddings,
qdrant_url=qdrant_url,
)
# Splitter (except parent-child which is handled separately)
# 切分器(父子块单独处理)
if splitter_type != SplitterType.PARENT_CHILD:
if splitter_type == SplitterType.SEMANTIC:
splitter_kwargs["embeddings"] = self.embeddings
self.splitter = get_splitter(splitter_type, **splitter_kwargs)
else:
self.splitter = None
# Initialize ParentDocumentRetriever for parent-child splitting
# 为父子块切分初始化 ParentDocumentRetriever
self._init_parent_child_retriever()
def _init_parent_child_retriever(self, **kwargs):
"""
Initialize ParentDocumentRetriever for parent-child chunking.
初始化 ParentDocumentRetriever 用于父子块切分。
This replaces the custom ParentChildSplitter logic.
支持动态语义切分与父子块策略结合:
- 父块使用递归切分(大块,提供上下文)
- 子块可以使用递归切分或语义切分(根据语义动态切分,提高检索精度)
替代自定义的 ParentChildSplitter 逻辑。
"""
# Parse kwargs for parent-child config
# 解析父子块配置参数
parent_size = kwargs.get("parent_chunk_size", 1000)
child_size = kwargs.get("child_chunk_size", 200)
parent_overlap = kwargs.get("parent_chunk_overlap", kwargs.get("chunk_overlap", 100))
child_overlap = kwargs.get("child_chunk_overlap", kwargs.get("chunk_overlap", 20))
# 子块切分器类型,默认为语义切分
child_splitter_type = kwargs.get("child_splitter_type", SplitterType.SEMANTIC)
# Define splitters
# 定义父块切分器(始终使用递归切分)
self.parent_splitter = RecursiveCharacterTextSplitter(
chunk_size=parent_size,
chunk_overlap=parent_overlap,
)
self.child_splitter = RecursiveCharacterTextSplitter(
chunk_size=child_size,
chunk_overlap=child_overlap,
)
# 定义子块切分器(根据类型选择)
if child_splitter_type == SplitterType.SEMANTIC:
self.child_splitter = get_splitter(
SplitterType.SEMANTIC,
embeddings=self.embeddings,
)
logger.info(f"子块使用语义切分器")
else:
# 默认使用递归切分
self.child_splitter = RecursiveCharacterTextSplitter(
chunk_size=child_size,
chunk_overlap=child_overlap,
)
logger.info(f"子块使用递归切分器,块大小: {child_size},重叠: {child_overlap}")
# Vector store (for child chunks)
# 向量存储(用于子块)
self.vector_store_obj = self.vector_store.get_langchain_vectorstore()
# Document store (for parent chunks)
docstore_path = kwargs.get("docstore_path")
docstore_type = kwargs.get("docstore_type", "local")
docstore_conn = kwargs.get("docstore_conn_string")
if docstore_type == "postgres" and docstore_conn:
self.docstore = PostgresDocStore(docstore_conn)
self._docstore_conn = docstore_conn
# 文档存储(用于父块)
if self.docstore is None:
# 如果没有外部注入 docstore则使用 PostgreSQL 创建
docstore_conn = kwargs.get("docstore_conn_string")
pool_config = kwargs.get("pool_config")
max_concurrency = kwargs.get("max_concurrency")
# 使用 create_docstore 创建 PostgreSQL 存储
self.docstore, self._docstore_conn = create_docstore(
connection_string=docstore_conn,
pool_config=pool_config,
max_concurrency=max_concurrency
)
else:
self.docstore = get_docstore(docstore_path)
# 使用外部注入的 docstore
self._docstore_conn = None
# Create retriever
# 创建检索器
self.retriever = ParentDocumentRetriever(
vectorstore=self.vector_store_obj,
docstore=self.docstore,
child_splitter=self.child_splitter,
child_splitter=self.child_splitter, # type: ignore
parent_splitter=self.parent_splitter,
search_kwargs={"k": kwargs.get("search_k", 5)},
)
logger.info(f"ParentDocumentRetriever 已初始化,父块大小: {parent_size},子块类型: {child_splitter_type}")
def build_from_file(self, file_path: Union[str, Path]) -> int:
logger.info("Loading file: %s", file_path)
async def build_from_file(self, file_path: Union[str, Path]) -> int:
logger.info("加载文件: %s", file_path)
documents = self.loader.load_file(file_path)
logger.info("Loaded %d documents", len(documents))
return self._process_documents(documents)
logger.info("已加载 %d 个文档", len(documents))
return await self._process_documents(documents)
def build_from_directory(self, directory_path: Union[str, Path], recursive: bool = True) -> int:
logger.info("Loading directory: %s (recursive=%s)", directory_path, recursive)
async def build_from_directory(self, directory_path: Union[str, Path], recursive: bool = True) -> int:
logger.info("加载目录: %s (递归=%s)", directory_path, recursive)
documents = self.loader.load_directory(directory_path, recursive=recursive)
logger.info("Loaded %d documents from directory", len(documents))
return self._process_documents(documents)
logger.info("已从目录加载 %d 个文档", len(documents))
return await self._process_documents(documents)
def _process_documents(self, documents: List[Document]) -> int:
async def _process_documents(self, documents: List[Document]) -> int:
if not documents:
logger.warning("No documents to process")
logger.warning("没有文档需要处理")
return 0
if self.splitter_type == SplitterType.PARENT_CHILD:
logger.info("Using LangChain ParentDocumentRetriever")
logger.info("使用 LangChain ParentDocumentRetriever")
# Ensure collection exists for child chunks
# 确保集合存在(用于子块)
self.vector_store.create_collection()
# Use ParentDocumentRetriever to add documents
# This automatically handles parent-child splitting, mapping, and retrieval
self.retriever.add_documents(documents)
# 分批处理,避免单次请求过大
assert self.retriever is not None, "retriever 未初始化"
batch_size = 10 # 每次处理10个文档
total = len(documents)
processed = 0
for i in range(0, total, batch_size):
batch = documents[i:i + batch_size]
max_retries = 3
for attempt in range(max_retries):
try:
await self.retriever.aadd_documents(batch)
processed += len(batch)
logger.info(f"批次 {i//batch_size + 1}: 已处理 {processed}/{total}")
break
except (RemoteProtocolError, ConnectionError, OSError) as e:
if attempt == max_retries - 1:
raise
logger.warning(f"连接断开,重试 ({attempt+1}/{max_retries}): {e}")
self.vector_store.refresh_client()
await asyncio.sleep(1)
# Log estimated chunk counts
estimated_parent_chunks = len(documents) * (self.parent_splitter._chunk_size // self.child_splitter._chunk_size)
logger.info(
"Indexed with ParentDocumentRetriever: "
f"~{len(documents)} parent chunks, ~{estimated_parent_chunks} child chunks"
"已使用 ParentDocumentRetriever 索引: "
f"{processed} 个父块"
)
return len(documents)
return processed
else:
logger.info("Splitting documents using %s", self.splitter_type)
logger.info("使用 %s 切分文档", self.splitter_type)
# 当 splitter_type 不是 PARENT_CHILD 时splitter 一定不为 None
assert self.splitter is not None, "splitter 未初始化"
chunks = self.splitter.split_documents(documents)
logger.info("Split into %d chunks", len(chunks))
logger.info("已切分为 %d 个块", len(chunks))
self.vector_store.create_collection()
self.vector_store.add_documents(chunks)
@@ -165,90 +218,164 @@ class IndexBuilder:
return self.vector_store.get_collection_info()
def search(self, query: str, k: int = 5) -> List[Document]:
"""Standard search - returns child chunks."""
"""标准搜索 - 返回子块。"""
return self.vector_store.similarity_search(query, k=k)
def search_with_parent_context(self, query: str, k: int = 5) -> List[Document]:
async def search_with_parent_context(self, query: str, k: int = 5) -> List[Document]:
"""
Search with parent context - returns full parent chunks.
带父块上下文的搜索 - 返回完整父块。
This is the main retrieval method when using parent-child splitting.
这是使用父子块切分时的主要检索方法。
"""
if self.splitter_type != SplitterType.PARENT_CHILD:
raise RuntimeError(
"search_with_parent_context only available with PARENT_CHILD splitter. "
"Use search() for standard retrieval."
"search_with_parent_context 仅在 PARENT_CHILD 切分器下可用。"
"请使用 search() 进行标准检索。"
)
return self.retriever.get_relevant_documents(query, k=k)
assert self.retriever is not None, "retriever 未初始化"
return await self.retriever.ainvoke(query, config={"k": k}) # type: ignore
def retrieve(self, query: str, return_parent: bool = True) -> List[Document]:
async def retrieve(self, query: str, return_parent: bool = True) -> List[Document]:
"""
Unified retrieval interface.
统一检索接口。
Args:
query: Search query
return_parent: If True and using parent-child splitter, return parent chunks
If False, always return child chunks
query: 搜索查询
return_parent: 如果为 True 且使用父子块切分,返回父块
如果为 False,始终返回子块
Returns:
List of relevant documents
相关文档列表
"""
if self.splitter_type == SplitterType.PARENT_CHILD and return_parent:
return self.search_with_parent_context(query)
return await self.search_with_parent_context(query)
else:
return self.search(query)
async def retrieve_with_fusion(self, query: str, llm: Any, num_queries: int = 3, k: int = 5, return_parent: bool = True) -> List[Document]:
"""
使用 RAG-Fusion 进行检索(多路查询改写 + 倒数排名融合)。
核心原理:
1. 多路查询改写: 利用 LLM 将原始查询改写成多个不同表述
2. 倒数排名融合: 对每个改写查询的结果进行 RRF 融合,避免单一检索结果主导
Args:
query: 原始搜索查询
llm: 语言模型实例(用于查询改写)
num_queries: 生成的查询数量
k: 返回的文档数量
return_parent: 如果为 True 且使用父子块切分,返回父块
如果为 False始终返回子块
Returns:
经过融合后的相关文档列表
"""
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain.retrievers import EnsembleRetriever
if self.splitter_type == SplitterType.PARENT_CHILD and return_parent:
# 使用 ParentDocumentRetriever 作为基础检索器
assert self.retriever is not None, "retriever 未初始化"
base_retriever = self.retriever
else:
# 使用向量存储作为基础检索器
base_retriever = self.vector_store.as_langchain_vectorstore().as_retriever(search_kwargs={"k": k * 2})
# 创建多路查询检索器
multi_query_retriever = MultiQueryRetriever.from_llm(
retriever=base_retriever,
llm=llm,
include_original=True
)
# 设置自定义提示词以生成指定数量的查询
from langchain_core.prompts import PromptTemplate
multi_query_retriever.llm_chain.prompt = PromptTemplate.from_template(
"你是一个专业的查询改写助手。你的任务是将用户的问题改写成 {num_queries} 个不同的版本。\n"
"这些版本应该从不同的角度、使用不同的关键词来表达相同或相关的意图。\n\n"
"原始问题: {question}\n\n"
"请生成 {num_queries} 个不同版本的查询,每个版本一行。\n"
"确保每个版本都是独立、完整的查询语句。\n\n"
"生成 {num_queries} 个查询:"
)
# 修改调用参数以包含 num_queries
original_ainvoke = multi_query_retriever.llm_chain.ainvoke
async def new_ainvoke(input_dict):
input_dict["num_queries"] = num_queries
return await original_ainvoke(input_dict)
multi_query_retriever.llm_chain.ainvoke = new_ainvoke
# 执行检索
documents = await multi_query_retriever.ainvoke(query)
# 去重并限制数量
seen_content = set()
unique_documents = []
for doc in documents:
content = doc.page_content
if content not in seen_content:
seen_content.add(content)
unique_documents.append(doc)
if len(unique_documents) >= k:
break
logger.info(f"RAG-Fusion 检索完成: 原始 {len(documents)} 个结果,去重后 {len(unique_documents)} 个结果")
return unique_documents
def get_retriever(self) -> ParentDocumentRetriever:
"""
Get the ParentDocumentRetriever instance directly.
直接获取 ParentDocumentRetriever 实例。
Useful for advanced use cases where you want to access the retriever
outside of IndexBuilder.
适用于需要在 IndexBuilder 外部访问检索器的高级用例。
"""
if self.splitter_type != SplitterType.PARENT_CHILD:
raise RuntimeError(
"get_retriever() only available with PARENT_CHILD splitter. "
"Use search() or search_with_parent_context() for standard retrieval."
"get_retriever() 仅在 PARENT_CHILD 切分器下可用。"
"请使用 search() search_with_parent_context() 进行标准检索。"
)
assert self.retriever is not None, "retriever 未初始化"
return self.retriever
def get_child_splitter(self) -> "RecursiveCharacterTextSplitter":
"""Get the child splitter for reconfiguration."""
def get_child_splitter(self) -> Union["RecursiveCharacterTextSplitter", "SemanticChunker", "SemanticChunkerAdapter"]:
"""获取子块切分器以便重新配置。"""
if self.splitter_type != SplitterType.PARENT_CHILD:
return self.splitter
return self.splitter # type: ignore
return self.child_splitter
def get_parent_splitter(self) -> "RecursiveCharacterTextSplitter":
"""Get the parent splitter for reconfiguration."""
"""获取父块切分器以便重新配置。"""
if self.splitter_type != SplitterType.PARENT_CHILD:
raise RuntimeError(
"Parent splitter only available with PARENT_CHILD splitter."
"父块切分器仅在 PARENT_CHILD 切分器下可用。"
)
return self.parent_splitter
def get_docstore(self) -> BaseStore:
"""Get the document store for parent chunks."""
"""获取父块的文档存储。"""
if self.splitter_type != SplitterType.PARENT_CHILD:
raise RuntimeError(
"Docstore only available with PARENT_CHILD splitter."
"文档存储仅在 PARENT_CHILD 切分器下可用。"
)
assert self.docstore is not None, "docstore 未初始化"
return self.docstore
def get_docstore_path(self) -> str:
"""Get the document store path."""
def get_docstore_path(self) -> Optional[str]:
"""获取文档存储路径(已弃用,仅用于兼容性)。"""
if self.splitter_type != SplitterType.PARENT_CHILD:
raise RuntimeError(
"Docstore path only available with PARENT_CHILD splitter."
"文档存储路径仅在 PARENT_CHILD 切分器下可用。"
)
return self.docstore.persist_path
# PostgreSQL 存储没有 persist_path,返回 None
return None
def close(self):
"""Close resources."""
if hasattr(self, "_docstore_conn") and self._docstore_conn:
import psycopg2
conn = psycopg2.connect(self._docstore_conn)
conn.close()
logger.info("Closed PostgreSQL connection")
"""关闭资源。"""
if self.docstore is not None and hasattr(self.docstore, "aclose"):
import asyncio
asyncio.get_event_loop().run_until_complete(self.docstore.aclose()) # type: ignore
logger.info("PostgreSQL 异步连接池已关闭")
def __enter__(self):
return self
@@ -258,20 +385,8 @@ class IndexBuilder:
return False
# RecursiveCharacterTextSplitter needs to be imported
# 需要导入 RecursiveCharacterTextSplitter
from langchain_text_splitters import RecursiveCharacterTextSplitter
if __name__ == "__main__":
# Example usage
builder = IndexBuilder(
splitter_type=SplitterType.PARENT_CHILD,
parent_chunk_size=1000,
child_chunk_size=200,
docstore_path="./my_parent_docs",
)
print("Parent splitter:", builder.get_parent_splitter().chunk_size)
print("Child splitter:", builder.get_child_splitter().chunk_size)
print("Docstore path:", builder.get_docstore_path())
print("Retriever:", builder.get_retriever())
# 示例用法已移除,请参考文档