RAG数据库生成
This commit is contained in:
@@ -1,56 +1,68 @@
|
||||
"""
|
||||
Core pipeline builder for offline RAG index construction.
|
||||
离线 RAG 索引构建核心流水线。
|
||||
|
||||
Now supports LangChain's ParentDocumentRetriever for parent-child chunking.
|
||||
支持 LangChain 的 ParentDocumentRetriever 用于父子块切分。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Union, Optional, Tuple
|
||||
from typing import List, Union, Optional, Tuple, Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
from httpx import RemoteProtocolError
|
||||
from langchain_core.documents import Document
|
||||
from langchain.retrievers import ParentDocumentRetriever
|
||||
from langchain.storage import LocalFileStore, BaseStore
|
||||
from langchain_classic.retrievers import ParentDocumentRetriever
|
||||
from langchain_core.stores import BaseStore
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
from langchain_experimental.text_splitter import SemanticChunker
|
||||
|
||||
from .loaders import DocumentLoader
|
||||
from .splitters import SplitterType, get_splitter, ParentChildSplitter
|
||||
from .splitters import SplitterType, get_splitter, ParentChildSplitter, SemanticChunkerAdapter
|
||||
from .embedders import LlamaCppEmbedder
|
||||
from .vector_store import QdrantVectorStore
|
||||
from .docstore_manager import get_docstore, PostgresDocStore, create_docstore
|
||||
from .vector_store import QdrantVectorStore, QDRANT_URL, QDRANT_API_KEY
|
||||
from .store import create_docstore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParentChildConfig:
|
||||
"""Configuration for parent-child splitting."""
|
||||
"""父子块切分配置。"""
|
||||
parent_chunk_size: int = 1000
|
||||
child_chunk_size: int = 200
|
||||
parent_chunk_overlap: int = 100
|
||||
child_chunk_overlap: int = 20
|
||||
search_k: int = 5
|
||||
docstore_path: str = None
|
||||
docstore_path: Optional[str] = None
|
||||
docstore_type: str = "local"
|
||||
docstore_conn_string: str = None
|
||||
docstore_conn_string: Optional[str] = None
|
||||
|
||||
|
||||
class IndexBuilder:
|
||||
"""Main pipeline for RAG index construction."""
|
||||
"""RAG 索引构建主流水线。"""
|
||||
|
||||
# 类型注解
|
||||
parent_splitter: "RecursiveCharacterTextSplitter"
|
||||
child_splitter: Union["RecursiveCharacterTextSplitter", "SemanticChunker", "SemanticChunkerAdapter"]
|
||||
docstore: Optional["BaseStore"]
|
||||
_docstore_conn: Optional[str]
|
||||
retriever: Optional["ParentDocumentRetriever"]
|
||||
vector_store_obj: Any
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str = "rag_documents",
|
||||
qdrant_url: str = None,
|
||||
splitter_type: SplitterType = SplitterType.RECURSIVE,
|
||||
splitter_type: SplitterType = SplitterType.PARENT_CHILD,
|
||||
docstore=None,
|
||||
**splitter_kwargs,
|
||||
):
|
||||
self.collection_name = collection_name
|
||||
self.qdrant_url = qdrant_url
|
||||
self.splitter_type = splitter_type
|
||||
self.splitter_kwargs = splitter_kwargs
|
||||
self.docstore = docstore # 从外部注入
|
||||
|
||||
# Components
|
||||
# 组件
|
||||
self.loader = DocumentLoader()
|
||||
self.embedder = LlamaCppEmbedder()
|
||||
self.embeddings = self.embedder.as_langchain_embeddings()
|
||||
@@ -58,104 +70,145 @@ class IndexBuilder:
|
||||
self.vector_store = QdrantVectorStore(
|
||||
collection_name=collection_name,
|
||||
embeddings=self.embeddings,
|
||||
qdrant_url=qdrant_url,
|
||||
)
|
||||
|
||||
# Splitter (except parent-child which is handled separately)
|
||||
# 切分器(父子块单独处理)
|
||||
if splitter_type != SplitterType.PARENT_CHILD:
|
||||
if splitter_type == SplitterType.SEMANTIC:
|
||||
splitter_kwargs["embeddings"] = self.embeddings
|
||||
self.splitter = get_splitter(splitter_type, **splitter_kwargs)
|
||||
else:
|
||||
self.splitter = None
|
||||
# Initialize ParentDocumentRetriever for parent-child splitting
|
||||
# 为父子块切分初始化 ParentDocumentRetriever
|
||||
self._init_parent_child_retriever()
|
||||
|
||||
def _init_parent_child_retriever(self, **kwargs):
|
||||
"""
|
||||
Initialize ParentDocumentRetriever for parent-child chunking.
|
||||
初始化 ParentDocumentRetriever 用于父子块切分。
|
||||
|
||||
This replaces the custom ParentChildSplitter logic.
|
||||
支持动态语义切分与父子块策略结合:
|
||||
- 父块使用递归切分(大块,提供上下文)
|
||||
- 子块可以使用递归切分或语义切分(根据语义动态切分,提高检索精度)
|
||||
|
||||
替代自定义的 ParentChildSplitter 逻辑。
|
||||
"""
|
||||
# Parse kwargs for parent-child config
|
||||
# 解析父子块配置参数
|
||||
parent_size = kwargs.get("parent_chunk_size", 1000)
|
||||
child_size = kwargs.get("child_chunk_size", 200)
|
||||
parent_overlap = kwargs.get("parent_chunk_overlap", kwargs.get("chunk_overlap", 100))
|
||||
child_overlap = kwargs.get("child_chunk_overlap", kwargs.get("chunk_overlap", 20))
|
||||
|
||||
# 子块切分器类型,默认为语义切分
|
||||
child_splitter_type = kwargs.get("child_splitter_type", SplitterType.SEMANTIC)
|
||||
|
||||
# Define splitters
|
||||
# 定义父块切分器(始终使用递归切分)
|
||||
self.parent_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=parent_size,
|
||||
chunk_overlap=parent_overlap,
|
||||
)
|
||||
self.child_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=child_size,
|
||||
chunk_overlap=child_overlap,
|
||||
)
|
||||
|
||||
# 定义子块切分器(根据类型选择)
|
||||
if child_splitter_type == SplitterType.SEMANTIC:
|
||||
self.child_splitter = get_splitter(
|
||||
SplitterType.SEMANTIC,
|
||||
embeddings=self.embeddings,
|
||||
)
|
||||
logger.info(f"子块使用语义切分器")
|
||||
else:
|
||||
# 默认使用递归切分
|
||||
self.child_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=child_size,
|
||||
chunk_overlap=child_overlap,
|
||||
)
|
||||
logger.info(f"子块使用递归切分器,块大小: {child_size},重叠: {child_overlap}")
|
||||
|
||||
# Vector store (for child chunks)
|
||||
# 向量存储(用于子块)
|
||||
self.vector_store_obj = self.vector_store.get_langchain_vectorstore()
|
||||
|
||||
# Document store (for parent chunks)
|
||||
docstore_path = kwargs.get("docstore_path")
|
||||
docstore_type = kwargs.get("docstore_type", "local")
|
||||
docstore_conn = kwargs.get("docstore_conn_string")
|
||||
|
||||
if docstore_type == "postgres" and docstore_conn:
|
||||
self.docstore = PostgresDocStore(docstore_conn)
|
||||
self._docstore_conn = docstore_conn
|
||||
# 文档存储(用于父块)
|
||||
if self.docstore is None:
|
||||
# 如果没有外部注入 docstore,则使用 PostgreSQL 创建
|
||||
docstore_conn = kwargs.get("docstore_conn_string")
|
||||
pool_config = kwargs.get("pool_config")
|
||||
max_concurrency = kwargs.get("max_concurrency")
|
||||
|
||||
# 使用 create_docstore 创建 PostgreSQL 存储
|
||||
self.docstore, self._docstore_conn = create_docstore(
|
||||
connection_string=docstore_conn,
|
||||
pool_config=pool_config,
|
||||
max_concurrency=max_concurrency
|
||||
)
|
||||
else:
|
||||
self.docstore = get_docstore(docstore_path)
|
||||
# 使用外部注入的 docstore
|
||||
self._docstore_conn = None
|
||||
|
||||
# Create retriever
|
||||
# 创建检索器
|
||||
self.retriever = ParentDocumentRetriever(
|
||||
vectorstore=self.vector_store_obj,
|
||||
docstore=self.docstore,
|
||||
child_splitter=self.child_splitter,
|
||||
child_splitter=self.child_splitter, # type: ignore
|
||||
parent_splitter=self.parent_splitter,
|
||||
search_kwargs={"k": kwargs.get("search_k", 5)},
|
||||
)
|
||||
logger.info(f"ParentDocumentRetriever 已初始化,父块大小: {parent_size},子块类型: {child_splitter_type}")
|
||||
|
||||
def build_from_file(self, file_path: Union[str, Path]) -> int:
|
||||
logger.info("Loading file: %s", file_path)
|
||||
async def build_from_file(self, file_path: Union[str, Path]) -> int:
|
||||
logger.info("加载文件: %s", file_path)
|
||||
documents = self.loader.load_file(file_path)
|
||||
logger.info("Loaded %d documents", len(documents))
|
||||
return self._process_documents(documents)
|
||||
logger.info("已加载 %d 个文档", len(documents))
|
||||
return await self._process_documents(documents)
|
||||
|
||||
def build_from_directory(self, directory_path: Union[str, Path], recursive: bool = True) -> int:
|
||||
logger.info("Loading directory: %s (recursive=%s)", directory_path, recursive)
|
||||
async def build_from_directory(self, directory_path: Union[str, Path], recursive: bool = True) -> int:
|
||||
logger.info("加载目录: %s (递归=%s)", directory_path, recursive)
|
||||
documents = self.loader.load_directory(directory_path, recursive=recursive)
|
||||
logger.info("Loaded %d documents from directory", len(documents))
|
||||
return self._process_documents(documents)
|
||||
logger.info("已从目录加载 %d 个文档", len(documents))
|
||||
return await self._process_documents(documents)
|
||||
|
||||
def _process_documents(self, documents: List[Document]) -> int:
|
||||
async def _process_documents(self, documents: List[Document]) -> int:
|
||||
if not documents:
|
||||
logger.warning("No documents to process")
|
||||
logger.warning("没有文档需要处理")
|
||||
return 0
|
||||
|
||||
if self.splitter_type == SplitterType.PARENT_CHILD:
|
||||
logger.info("Using LangChain ParentDocumentRetriever")
|
||||
logger.info("使用 LangChain ParentDocumentRetriever")
|
||||
|
||||
# Ensure collection exists for child chunks
|
||||
# 确保集合存在(用于子块)
|
||||
self.vector_store.create_collection()
|
||||
|
||||
# Use ParentDocumentRetriever to add documents
|
||||
# This automatically handles parent-child splitting, mapping, and retrieval
|
||||
self.retriever.add_documents(documents)
|
||||
# 分批处理,避免单次请求过大
|
||||
assert self.retriever is not None, "retriever 未初始化"
|
||||
batch_size = 10 # 每次处理10个文档
|
||||
total = len(documents)
|
||||
processed = 0
|
||||
|
||||
for i in range(0, total, batch_size):
|
||||
batch = documents[i:i + batch_size]
|
||||
max_retries = 3
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
await self.retriever.aadd_documents(batch)
|
||||
processed += len(batch)
|
||||
logger.info(f"批次 {i//batch_size + 1}: 已处理 {processed}/{total}")
|
||||
break
|
||||
except (RemoteProtocolError, ConnectionError, OSError) as e:
|
||||
if attempt == max_retries - 1:
|
||||
raise
|
||||
logger.warning(f"连接断开,重试 ({attempt+1}/{max_retries}): {e}")
|
||||
self.vector_store.refresh_client()
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Log estimated chunk counts
|
||||
estimated_parent_chunks = len(documents) * (self.parent_splitter._chunk_size // self.child_splitter._chunk_size)
|
||||
logger.info(
|
||||
"Indexed with ParentDocumentRetriever: "
|
||||
f"~{len(documents)} parent chunks, ~{estimated_parent_chunks} child chunks"
|
||||
"已使用 ParentDocumentRetriever 索引: "
|
||||
f"共 {processed} 个父块"
|
||||
)
|
||||
return len(documents)
|
||||
return processed
|
||||
|
||||
else:
|
||||
logger.info("Splitting documents using %s", self.splitter_type)
|
||||
logger.info("使用 %s 切分文档", self.splitter_type)
|
||||
# 当 splitter_type 不是 PARENT_CHILD 时,splitter 一定不为 None
|
||||
assert self.splitter is not None, "splitter 未初始化"
|
||||
chunks = self.splitter.split_documents(documents)
|
||||
logger.info("Split into %d chunks", len(chunks))
|
||||
logger.info("已切分为 %d 个块", len(chunks))
|
||||
|
||||
self.vector_store.create_collection()
|
||||
self.vector_store.add_documents(chunks)
|
||||
@@ -165,90 +218,164 @@ class IndexBuilder:
|
||||
return self.vector_store.get_collection_info()
|
||||
|
||||
def search(self, query: str, k: int = 5) -> List[Document]:
|
||||
"""Standard search - returns child chunks."""
|
||||
"""标准搜索 - 返回子块。"""
|
||||
return self.vector_store.similarity_search(query, k=k)
|
||||
|
||||
def search_with_parent_context(self, query: str, k: int = 5) -> List[Document]:
|
||||
async def search_with_parent_context(self, query: str, k: int = 5) -> List[Document]:
|
||||
"""
|
||||
Search with parent context - returns full parent chunks.
|
||||
带父块上下文的搜索 - 返回完整父块。
|
||||
|
||||
This is the main retrieval method when using parent-child splitting.
|
||||
这是使用父子块切分时的主要检索方法。
|
||||
"""
|
||||
if self.splitter_type != SplitterType.PARENT_CHILD:
|
||||
raise RuntimeError(
|
||||
"search_with_parent_context only available with PARENT_CHILD splitter. "
|
||||
"Use search() for standard retrieval."
|
||||
"search_with_parent_context 仅在 PARENT_CHILD 切分器下可用。"
|
||||
"请使用 search() 进行标准检索。"
|
||||
)
|
||||
return self.retriever.get_relevant_documents(query, k=k)
|
||||
assert self.retriever is not None, "retriever 未初始化"
|
||||
return await self.retriever.ainvoke(query, config={"k": k}) # type: ignore
|
||||
|
||||
def retrieve(self, query: str, return_parent: bool = True) -> List[Document]:
|
||||
async def retrieve(self, query: str, return_parent: bool = True) -> List[Document]:
|
||||
"""
|
||||
Unified retrieval interface.
|
||||
统一检索接口。
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
return_parent: If True and using parent-child splitter, return parent chunks
|
||||
If False, always return child chunks
|
||||
query: 搜索查询
|
||||
return_parent: 如果为 True 且使用父子块切分,返回父块
|
||||
如果为 False,始终返回子块
|
||||
|
||||
Returns:
|
||||
List of relevant documents
|
||||
相关文档列表
|
||||
"""
|
||||
if self.splitter_type == SplitterType.PARENT_CHILD and return_parent:
|
||||
return self.search_with_parent_context(query)
|
||||
return await self.search_with_parent_context(query)
|
||||
else:
|
||||
return self.search(query)
|
||||
|
||||
async def retrieve_with_fusion(self, query: str, llm: Any, num_queries: int = 3, k: int = 5, return_parent: bool = True) -> List[Document]:
|
||||
"""
|
||||
使用 RAG-Fusion 进行检索(多路查询改写 + 倒数排名融合)。
|
||||
|
||||
核心原理:
|
||||
1. 多路查询改写: 利用 LLM 将原始查询改写成多个不同表述
|
||||
2. 倒数排名融合: 对每个改写查询的结果进行 RRF 融合,避免单一检索结果主导
|
||||
|
||||
Args:
|
||||
query: 原始搜索查询
|
||||
llm: 语言模型实例(用于查询改写)
|
||||
num_queries: 生成的查询数量
|
||||
k: 返回的文档数量
|
||||
return_parent: 如果为 True 且使用父子块切分,返回父块
|
||||
如果为 False,始终返回子块
|
||||
|
||||
Returns:
|
||||
经过融合后的相关文档列表
|
||||
"""
|
||||
from langchain.retrievers.multi_query import MultiQueryRetriever
|
||||
from langchain.retrievers import EnsembleRetriever
|
||||
|
||||
if self.splitter_type == SplitterType.PARENT_CHILD and return_parent:
|
||||
# 使用 ParentDocumentRetriever 作为基础检索器
|
||||
assert self.retriever is not None, "retriever 未初始化"
|
||||
base_retriever = self.retriever
|
||||
else:
|
||||
# 使用向量存储作为基础检索器
|
||||
base_retriever = self.vector_store.as_langchain_vectorstore().as_retriever(search_kwargs={"k": k * 2})
|
||||
|
||||
# 创建多路查询检索器
|
||||
multi_query_retriever = MultiQueryRetriever.from_llm(
|
||||
retriever=base_retriever,
|
||||
llm=llm,
|
||||
include_original=True
|
||||
)
|
||||
|
||||
# 设置自定义提示词以生成指定数量的查询
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
multi_query_retriever.llm_chain.prompt = PromptTemplate.from_template(
|
||||
"你是一个专业的查询改写助手。你的任务是将用户的问题改写成 {num_queries} 个不同的版本。\n"
|
||||
"这些版本应该从不同的角度、使用不同的关键词来表达相同或相关的意图。\n\n"
|
||||
"原始问题: {question}\n\n"
|
||||
"请生成 {num_queries} 个不同版本的查询,每个版本一行。\n"
|
||||
"确保每个版本都是独立、完整的查询语句。\n\n"
|
||||
"生成 {num_queries} 个查询:"
|
||||
)
|
||||
|
||||
# 修改调用参数以包含 num_queries
|
||||
original_ainvoke = multi_query_retriever.llm_chain.ainvoke
|
||||
async def new_ainvoke(input_dict):
|
||||
input_dict["num_queries"] = num_queries
|
||||
return await original_ainvoke(input_dict)
|
||||
multi_query_retriever.llm_chain.ainvoke = new_ainvoke
|
||||
|
||||
# 执行检索
|
||||
documents = await multi_query_retriever.ainvoke(query)
|
||||
|
||||
# 去重并限制数量
|
||||
seen_content = set()
|
||||
unique_documents = []
|
||||
for doc in documents:
|
||||
content = doc.page_content
|
||||
if content not in seen_content:
|
||||
seen_content.add(content)
|
||||
unique_documents.append(doc)
|
||||
if len(unique_documents) >= k:
|
||||
break
|
||||
|
||||
logger.info(f"RAG-Fusion 检索完成: 原始 {len(documents)} 个结果,去重后 {len(unique_documents)} 个结果")
|
||||
return unique_documents
|
||||
|
||||
def get_retriever(self) -> ParentDocumentRetriever:
|
||||
"""
|
||||
Get the ParentDocumentRetriever instance directly.
|
||||
直接获取 ParentDocumentRetriever 实例。
|
||||
|
||||
Useful for advanced use cases where you want to access the retriever
|
||||
outside of IndexBuilder.
|
||||
适用于需要在 IndexBuilder 外部访问检索器的高级用例。
|
||||
"""
|
||||
if self.splitter_type != SplitterType.PARENT_CHILD:
|
||||
raise RuntimeError(
|
||||
"get_retriever() only available with PARENT_CHILD splitter. "
|
||||
"Use search() or search_with_parent_context() for standard retrieval."
|
||||
"get_retriever() 仅在 PARENT_CHILD 切分器下可用。"
|
||||
"请使用 search() 或 search_with_parent_context() 进行标准检索。"
|
||||
)
|
||||
assert self.retriever is not None, "retriever 未初始化"
|
||||
return self.retriever
|
||||
|
||||
def get_child_splitter(self) -> "RecursiveCharacterTextSplitter":
|
||||
"""Get the child splitter for reconfiguration."""
|
||||
def get_child_splitter(self) -> Union["RecursiveCharacterTextSplitter", "SemanticChunker", "SemanticChunkerAdapter"]:
|
||||
"""获取子块切分器以便重新配置。"""
|
||||
if self.splitter_type != SplitterType.PARENT_CHILD:
|
||||
return self.splitter
|
||||
return self.splitter # type: ignore
|
||||
return self.child_splitter
|
||||
|
||||
def get_parent_splitter(self) -> "RecursiveCharacterTextSplitter":
|
||||
"""Get the parent splitter for reconfiguration."""
|
||||
"""获取父块切分器以便重新配置。"""
|
||||
if self.splitter_type != SplitterType.PARENT_CHILD:
|
||||
raise RuntimeError(
|
||||
"Parent splitter only available with PARENT_CHILD splitter."
|
||||
"父块切分器仅在 PARENT_CHILD 切分器下可用。"
|
||||
)
|
||||
return self.parent_splitter
|
||||
|
||||
def get_docstore(self) -> BaseStore:
|
||||
"""Get the document store for parent chunks."""
|
||||
"""获取父块的文档存储。"""
|
||||
if self.splitter_type != SplitterType.PARENT_CHILD:
|
||||
raise RuntimeError(
|
||||
"Docstore only available with PARENT_CHILD splitter."
|
||||
"文档存储仅在 PARENT_CHILD 切分器下可用。"
|
||||
)
|
||||
assert self.docstore is not None, "docstore 未初始化"
|
||||
return self.docstore
|
||||
|
||||
def get_docstore_path(self) -> str:
|
||||
"""Get the document store path."""
|
||||
def get_docstore_path(self) -> Optional[str]:
|
||||
"""获取文档存储路径(已弃用,仅用于兼容性)。"""
|
||||
if self.splitter_type != SplitterType.PARENT_CHILD:
|
||||
raise RuntimeError(
|
||||
"Docstore path only available with PARENT_CHILD splitter."
|
||||
"文档存储路径仅在 PARENT_CHILD 切分器下可用。"
|
||||
)
|
||||
return self.docstore.persist_path
|
||||
# PostgreSQL 存储没有 persist_path,返回 None
|
||||
return None
|
||||
|
||||
def close(self):
|
||||
"""Close resources."""
|
||||
if hasattr(self, "_docstore_conn") and self._docstore_conn:
|
||||
import psycopg2
|
||||
conn = psycopg2.connect(self._docstore_conn)
|
||||
conn.close()
|
||||
logger.info("Closed PostgreSQL connection")
|
||||
"""关闭资源。"""
|
||||
if self.docstore is not None and hasattr(self.docstore, "aclose"):
|
||||
import asyncio
|
||||
asyncio.get_event_loop().run_until_complete(self.docstore.aclose()) # type: ignore
|
||||
logger.info("PostgreSQL 异步连接池已关闭")
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
@@ -258,20 +385,8 @@ class IndexBuilder:
|
||||
return False
|
||||
|
||||
|
||||
# RecursiveCharacterTextSplitter needs to be imported
|
||||
# 需要导入 RecursiveCharacterTextSplitter
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example usage
|
||||
builder = IndexBuilder(
|
||||
splitter_type=SplitterType.PARENT_CHILD,
|
||||
parent_chunk_size=1000,
|
||||
child_chunk_size=200,
|
||||
docstore_path="./my_parent_docs",
|
||||
)
|
||||
|
||||
print("Parent splitter:", builder.get_parent_splitter().chunk_size)
|
||||
print("Child splitter:", builder.get_child_splitter().chunk_size)
|
||||
print("Docstore path:", builder.get_docstore_path())
|
||||
print("Retriever:", builder.get_retriever())
|
||||
# 示例用法已移除,请参考文档
|
||||
|
||||
Reference in New Issue
Block a user