This commit is contained in:
299
rag_indexer/IndexBuilder.py
Normal file
299
rag_indexer/IndexBuilder.py
Normal file
@@ -0,0 +1,299 @@
|
||||
"""
|
||||
离线 RAG 索引构建核心流水线。
|
||||
|
||||
使用 LangChain 的 ParentDocumentRetriever 实现父子块策略。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import List, Union, Optional, Any, Dict, Tuple
|
||||
|
||||
from httpx import RemoteProtocolError
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.stores import BaseStore
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
|
||||
from langchain_classic.retrievers import ParentDocumentRetriever
|
||||
|
||||
from .loaders import DocumentLoader
|
||||
from .splitters import SplitterType, get_splitter, SemanticChunkerAdapter
|
||||
from rag_core import LlamaCppEmbedder, QdrantVectorStore, create_docstore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------- 配置数据类 ----------
|
||||
@dataclass
|
||||
class DocstoreConfig:
|
||||
"""文档存储配置(用于父块存储)。"""
|
||||
connection_string: Optional[str] = None
|
||||
pool_config: Optional[Dict[str, Any]] = None
|
||||
max_concurrency: Optional[int] = None
|
||||
# 若要从外部注入已创建好的 docstore,可直接设置此字段
|
||||
instance: Optional[BaseStore] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexBuilderConfig:
|
||||
"""索引构建器配置。"""
|
||||
collection_name: str = "rag_documents"
|
||||
splitter_type: SplitterType = SplitterType.PARENT_CHILD
|
||||
|
||||
# 父块切分参数(仅当 splitter_type 为 PARENT_CHILD 时生效)
|
||||
parent_chunk_size: int = 1000
|
||||
parent_chunk_overlap: int = 100
|
||||
# 子块切分参数
|
||||
child_chunk_size: int = 200
|
||||
child_chunk_overlap: int = 20
|
||||
child_splitter_type: SplitterType = SplitterType.SEMANTIC # 子块默认语义切分
|
||||
|
||||
# 检索参数
|
||||
search_k: int = 5
|
||||
|
||||
# 文档存储配置(仅父子块模式需要)
|
||||
docstore: DocstoreConfig = field(default_factory=DocstoreConfig)
|
||||
|
||||
# 其他切分器参数(当 splitter_type 非父子块时使用)
|
||||
extra_splitter_kwargs: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
# ---------- 索引构建器 ----------
|
||||
class IndexBuilder:
|
||||
"""RAG 索引构建主流水线,支持单块切分与父子块切分。"""
|
||||
|
||||
def __init__(self, config: Optional[IndexBuilderConfig] = None, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
config: 索引构建器配置对象,优先级高于 kwargs
|
||||
**kwargs: 可直接传入配置参数,会合并到 config 中(为方便使用保留)
|
||||
"""
|
||||
if config is None:
|
||||
config = IndexBuilderConfig(**kwargs)
|
||||
elif kwargs:
|
||||
# 合并 kwargs 到 config 的字段(仅更新已有字段)
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(config, key):
|
||||
setattr(config, key, value)
|
||||
|
||||
self.config = config
|
||||
self._docstore_conn: Optional[str] = None # 用于记录由 create_docstore 创建的连接信息
|
||||
|
||||
# 初始化基础组件
|
||||
self.loader = DocumentLoader()
|
||||
self.embedder = LlamaCppEmbedder()
|
||||
self.embeddings: Embeddings = self.embedder.as_langchain_embeddings()
|
||||
|
||||
# 初始化向量存储
|
||||
self.vector_store = QdrantVectorStore(
|
||||
collection_name=config.collection_name,
|
||||
embeddings=self.embeddings,
|
||||
)
|
||||
|
||||
# 根据切分类型初始化相关组件
|
||||
self._init_splitters_and_retriever()
|
||||
|
||||
# ---------- 私有初始化方法 ----------
|
||||
def _init_splitters_and_retriever(self) -> None:
|
||||
"""根据配置初始化切分器和检索器。"""
|
||||
if self.config.splitter_type == SplitterType.PARENT_CHILD:
|
||||
self._init_parent_child_mode()
|
||||
else:
|
||||
self._init_single_splitter_mode()
|
||||
|
||||
def _init_single_splitter_mode(self) -> None:
|
||||
"""单一切分模式(递归或语义)。"""
|
||||
splitter_kwargs = self.config.extra_splitter_kwargs.copy()
|
||||
if self.config.splitter_type == SplitterType.SEMANTIC:
|
||||
splitter_kwargs["embeddings"] = self.embeddings
|
||||
self.splitter = get_splitter(self.config.splitter_type, **splitter_kwargs)
|
||||
self.retriever = None
|
||||
self.docstore = None
|
||||
logger.info("使用单一 %s 切分器", self.config.splitter_type.value)
|
||||
|
||||
def _init_parent_child_mode(self) -> None:
|
||||
"""父子块切分模式,初始化父块/子块切分器、文档存储和检索器。"""
|
||||
cfg = self.config
|
||||
|
||||
# 父块切分器(始终使用递归切分)
|
||||
self.parent_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=cfg.parent_chunk_size,
|
||||
chunk_overlap=cfg.parent_chunk_overlap,
|
||||
)
|
||||
|
||||
# 子块切分器
|
||||
if cfg.child_splitter_type == SplitterType.SEMANTIC:
|
||||
self.child_splitter = get_splitter(
|
||||
SplitterType.SEMANTIC,
|
||||
embeddings=self.embeddings,
|
||||
**cfg.extra_splitter_kwargs
|
||||
)
|
||||
logger.info("子块使用语义切分器")
|
||||
else:
|
||||
self.child_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=cfg.child_chunk_size,
|
||||
chunk_overlap=cfg.child_chunk_overlap,
|
||||
)
|
||||
logger.info("子块使用递归切分器,块大小=%d,重叠=%d",
|
||||
cfg.child_chunk_size, cfg.child_chunk_overlap)
|
||||
|
||||
# 初始化文档存储(用于父块)
|
||||
self.docstore = self._create_or_use_docstore()
|
||||
|
||||
# 创建检索器
|
||||
self.retriever = ParentDocumentRetriever(
|
||||
vectorstore=self.vector_store.get_langchain_vectorstore(),
|
||||
docstore=self.docstore,
|
||||
child_splitter=self.child_splitter, # type: ignore[arg-type]
|
||||
parent_splitter=self.parent_splitter,
|
||||
search_kwargs={"k": cfg.search_k},
|
||||
)
|
||||
logger.info("ParentDocumentRetriever 初始化完成,父块大小=%d", cfg.parent_chunk_size)
|
||||
|
||||
def _create_or_use_docstore(self) -> BaseStore:
|
||||
"""创建或获取文档存储实例。"""
|
||||
cfg = self.config.docstore
|
||||
if cfg.instance is not None:
|
||||
logger.debug("使用外部注入的文档存储")
|
||||
return cfg.instance
|
||||
|
||||
# 使用 create_docstore 创建 PostgreSQL 存储
|
||||
docstore, conn_info = create_docstore(
|
||||
connection_string=cfg.connection_string,
|
||||
pool_config=cfg.pool_config,
|
||||
max_concurrency=cfg.max_concurrency,
|
||||
)
|
||||
self._docstore_conn = conn_info
|
||||
logger.info("文档存储已创建(PostgreSQL)")
|
||||
return docstore
|
||||
|
||||
# ---------- 公共构建方法 ----------
|
||||
async def build_from_file(self, file_path: Union[str, Path]) -> int:
|
||||
"""从单个文件构建索引。"""
|
||||
logger.info("加载文件: %s", file_path)
|
||||
documents = self.loader.load_file(file_path)
|
||||
logger.info("已加载 %d 个文档", len(documents))
|
||||
return await self._process_documents(documents)
|
||||
|
||||
async def build_from_directory(
|
||||
self, directory_path: Union[str, Path], recursive: bool = True
|
||||
) -> int:
|
||||
"""从目录递归构建索引。"""
|
||||
logger.info("加载目录: %s (递归=%s)", directory_path, recursive)
|
||||
documents = self.loader.load_directory(directory_path, recursive=recursive)
|
||||
logger.info("已从目录加载 %d 个文档", len(documents))
|
||||
return await self._process_documents(documents)
|
||||
|
||||
async def _process_documents(self, documents: List[Document]) -> int:
|
||||
"""处理文档列表,分发给相应的索引逻辑。"""
|
||||
if not documents:
|
||||
logger.warning("没有文档需要处理")
|
||||
return 0
|
||||
|
||||
if self.config.splitter_type == SplitterType.PARENT_CHILD:
|
||||
return await self._index_with_parent_child(documents)
|
||||
else:
|
||||
return await self._index_with_single_splitter(documents)
|
||||
|
||||
async def _index_with_single_splitter(self, documents: List[Document]) -> int:
|
||||
"""单一模式:切分后直接写入向量库。"""
|
||||
chunks = self.splitter.split_documents(documents) # type: ignore[union-attr]
|
||||
logger.info("已切分为 %d 个块", len(chunks))
|
||||
|
||||
self.vector_store.create_collection()
|
||||
self.vector_store.add_documents(chunks)
|
||||
return len(chunks)
|
||||
|
||||
async def _index_with_parent_child(self, documents: List[Document]) -> int:
|
||||
"""父子模式:使用 ParentDocumentRetriever 批量添加。"""
|
||||
self.vector_store.create_collection()
|
||||
assert self.retriever is not None
|
||||
|
||||
batch_size = 10
|
||||
total = len(documents)
|
||||
processed = 0
|
||||
|
||||
for i in range(0, total, batch_size):
|
||||
batch = documents[i:i + batch_size]
|
||||
await self._add_batch_with_retry(batch, i // batch_size + 1)
|
||||
processed += len(batch)
|
||||
logger.info("批次 %d: 已处理 %d/%d", i // batch_size + 1, processed, total)
|
||||
|
||||
logger.info("ParentDocumentRetriever 索引完成,共处理 %d 个文档", processed)
|
||||
return processed
|
||||
|
||||
async def _add_batch_with_retry(self, batch: List[Document], batch_no: int) -> None:
|
||||
"""添加批次,失败时自动重试(处理网络波动)。"""
|
||||
max_retries = 3
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
await self.retriever.aadd_documents(batch) # type: ignore[union-attr]
|
||||
return
|
||||
except (RemoteProtocolError, ConnectionError, OSError) as e:
|
||||
if attempt == max_retries - 1:
|
||||
raise
|
||||
logger.warning("批次 %d 连接断开,重试 (%d/%d): %s",
|
||||
batch_no, attempt + 1, max_retries, e)
|
||||
self.vector_store.refresh_client()
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# ---------- 信息获取方法 ----------
|
||||
def get_collection_info(self) -> Any:
|
||||
"""获取向量库集合信息。"""
|
||||
return self.vector_store.get_collection_info()
|
||||
|
||||
def get_child_splitter(self) -> TextSplitter:
|
||||
"""获取当前使用的子块切分器。"""
|
||||
if self.config.splitter_type == SplitterType.PARENT_CHILD:
|
||||
return self.child_splitter # type: ignore[return-value]
|
||||
return self.splitter # type: ignore[return-value]
|
||||
|
||||
def get_parent_splitter(self) -> RecursiveCharacterTextSplitter:
|
||||
"""获取父块切分器(仅父子模式可用)。"""
|
||||
if self.config.splitter_type != SplitterType.PARENT_CHILD:
|
||||
raise RuntimeError("父块切分器仅在父子块模式下可用")
|
||||
return self.parent_splitter # type: ignore[return-value]
|
||||
|
||||
def get_docstore(self) -> BaseStore:
|
||||
"""获取文档存储实例(仅父子模式可用)。"""
|
||||
if self.config.splitter_type != SplitterType.PARENT_CHILD:
|
||||
raise RuntimeError("文档存储仅在父子块模式下可用")
|
||||
assert self.docstore is not None
|
||||
return self.docstore
|
||||
|
||||
# ---------- 资源管理 ----------
|
||||
def close(self) -> None:
|
||||
"""关闭资源(同步版本,供上下文管理器使用)。"""
|
||||
if self.docstore is not None and hasattr(self.docstore, "aclose"):
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
# 无运行中的事件循环,创建临时循环
|
||||
loop = asyncio.new_event_loop()
|
||||
loop.run_until_complete(self.docstore.aclose()) # type: ignore[attr-defined]
|
||||
loop.close()
|
||||
else:
|
||||
# 已有运行中的循环,创建任务(用户自行等待)
|
||||
loop.create_task(self.docstore.aclose()) # type: ignore[attr-defined]
|
||||
logger.info("IndexBuilder 资源已关闭")
|
||||
|
||||
async def aclose(self) -> None:
|
||||
"""异步关闭资源。"""
|
||||
if self.docstore is not None and hasattr(self.docstore, "aclose"):
|
||||
await self.docstore.aclose() # type: ignore[attr-defined]
|
||||
logger.info("IndexBuilder 资源已异步关闭")
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
return False
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.aclose()
|
||||
return False
|
||||
@@ -2,35 +2,13 @@
|
||||
|
||||
该模块负责 RAG 系统的阶段一:**离线索引构建**。它将外部的非结构化数据(如文档、PDF、网页等)清洗、切分并转化为向量,最终存入向量数据库中。
|
||||
|
||||
## 📊 系统工作流示意图
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
A[原始文档集合 <br> PDF / Word / Markdown] --> B(文档加载器 DocumentLoader)
|
||||
B --> C{文本切分策略 Splitter}
|
||||
|
||||
C -->|基础策略| D1[固定字符长度切分 <br> Recursive Split]
|
||||
C -->|进阶策略| D2[语义边界切分 <br> Semantic Chunking]
|
||||
C -->|高级策略| D3[父子文档切分 <br> Parent-Child / Auto-merging]
|
||||
|
||||
D1 & D2 & D3 --> E[向量化 Embedder <br> llama.cpp: embeddinggemma]
|
||||
|
||||
E --> F[(Qdrant 向量数据库)]
|
||||
|
||||
subgraph "元数据管理"
|
||||
G[提取作者、日期、页码等元数据 Metadata] -.附加.-> E
|
||||
end
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🎯 演进路线与核心算法 (Roadmap)
|
||||
|
||||
### Level 1: 基础暴力切分 (Basic Recursive Splitting)
|
||||
- **核心算法**: 递归字符切分。它按照预定义的分隔符列表(如 `["\n\n", "\n", " ", ""]`)从大到小尝试切分文本,直到每块的大小满足最大长度限制。
|
||||
- **核心算法**: 递归字符切分。它按照预定义的分隔符列表(如 `["\n\n", "\n", "。", "!", "?", " ", ""]`)从大到小尝试切分文本,直到每块的大小满足最大长度限制。
|
||||
- **优缺点**: 实现极简单,速度快。但非常容易将一句话拦腰截断,导致上下文语义丢失。
|
||||
- **实现指南**:
|
||||
- 从 `langchain.text_splitter` 导入 `RecursiveCharacterTextSplitter`。
|
||||
- **实现指南**:
|
||||
- 从 `langchain_text_splitters` 导入 `RecursiveCharacterTextSplitter`。
|
||||
- 实例化时设置 `chunk_size`(如 500)和 `chunk_overlap`(如 50),直接调用 `.split_documents(raw_docs)` 方法。
|
||||
|
||||
### Level 2: 语义动态切分 (Semantic Chunking)
|
||||
@@ -38,58 +16,52 @@ graph TD
|
||||
1. 将文章按标点符号按句子拆分。
|
||||
2. 使用轻量级 Embedding 模型将每一句向量化。
|
||||
3. 计算相邻两句之间的余弦相似度 (Cosine Similarity)。
|
||||
4. 当相似度低于设定阈值时(说明两句话讲的不是同一件事,语义发生了转折),在此处“切断”形成一个新的块。
|
||||
4. 当相似度低于设定阈值时(说明两句话讲的不是同一件事,语义发生了转折),在此处"切断"形成一个新的块。
|
||||
- **优缺点**: 极大程度保留了段落内语义的连贯性,对 LLM 回答非常友好。但由于在切分阶段就需要调用向量模型,耗时略长。
|
||||
- **实现指南**:
|
||||
- **实现指南**:
|
||||
- 从 `langchain_text_splitters` 导入 `TextSplitter` 作为基类。
|
||||
- 从 `langchain_experimental.text_splitter` 导入 `SemanticChunker`。
|
||||
- 实例化时需要传入你已经配置好的 Embedding 模型实例(如基于 `OpenAIEmbeddings` 封装的 llama.cpp 本地模型),并设置 `breakpoint_threshold_type="percentile"` 等阈值参数。
|
||||
- 实现 `SemanticChunkerAdapter` 继承 `TextSplitter`,解决类型不兼容问题。
|
||||
- 实例化时需要传入你已经配置好的 Embedding 模型实例(如基于 `LlamaCppEmbedder` 封装的本地模型)。
|
||||
|
||||
### Level 3: 高级父子块策略 (Parent-Child / Auto-merging)
|
||||
- **核心算法**: 层次化双重存储与映射。
|
||||
- **切分机制**: 首先将文档粗切为较大的“父块 (Parent Chunk, 约 1000 词)”,随后将父块细切为较小的“子块 (Child Chunk, 约 200 词)”。
|
||||
- **存储机制**: 仅仅将**子块**的向量存入 Qdrant 用于精准计算距离;将**父块**的原始内容存在内存或 Document Store (如 KV 数据库) 中,通过 UUID 相互映射。
|
||||
- **切分机制**: 首先将文档粗切为较大的"父块 (Parent Chunk, 约 1000 字符)",随后将父块细切为较小的"子块 (Child Chunk, 约 200 字符)"。
|
||||
- **存储机制**: 仅仅将**子块**的向量存入 Qdrant 用于精准计算距离;将**父块**的原始内容存在 PostgreSQL DocStore 中,通过 UUID 相互映射。
|
||||
- **核心思路**: 解决 RAG 领域经典的矛盾——检索时块越小越容易精确命中(去除噪声);但生成回答时,块越大越能给大模型提供充足的上下文背景。
|
||||
- **实现指南**:
|
||||
- 使用 `langchain.retrievers` 中的 `ParentDocumentRetriever` 模块。
|
||||
- **实现指南**:
|
||||
- 使用 `langchain_classic.retrievers` 中的 `ParentDocumentRetriever` 模块。
|
||||
- 在写入时,你需要同时准备一个底层的 `VectorStore` (即 Qdrant) 和一个 `BaseStore`。
|
||||
- **推荐方案**: 使用 `LocalFileStore` (默认) 或 `PostgresDocStore` 作为 docstore。
|
||||
- **推荐方案**: 使用 `PostgresDocStore` 作为 docstore,支持持久化存储。
|
||||
- 将两种不同的 `TextSplitter` 分别赋值给检索器的 `child_splitter` 和 `parent_splitter`,然后调用 `.add_documents()` 即可让系统自动完成映射。
|
||||
|
||||
### Level 3.1: PostgreSQL DocStore 集成
|
||||
- **核心优势**: 利用 PostgreSQL 作为持久化存储,适合生产环境。使用同步连接池,避免异步复杂度。
|
||||
- **核心优势**: 利用 PostgreSQL 作为持久化存储,适合生产环境。使用异步连接池,支持高并发。
|
||||
- **实现步骤**:
|
||||
1. **安装依赖**: `pip install psycopg2-binary`
|
||||
2. **配置连接**: 设置 `DB_URI` 环境变量或直接在代码中指定 PostgreSQL 连接字符串
|
||||
3. **创建 docstore**: 使用 `PostgresDocStore` 类直接创建
|
||||
4. **注入到 IndexBuilder**: 在创建 `IndexBuilder` 时通过 `docstore` 参数注入
|
||||
1. **配置连接**: 设置 `DB_URI` 环境变量或通过 `docstore_conn_string` 参数指定
|
||||
2. **创建 docstore**: 使用 `rag_indexer.store.create_docstore()` 工厂函数
|
||||
3. **注入到 IndexBuilder**: 通过构造函数参数注入
|
||||
|
||||
- **使用示例**:
|
||||
```python
|
||||
from rag_indexer.docstore_manager import PostgresDocStore
|
||||
from rag_indexer.builder import IndexBuilder, SplitterType
|
||||
|
||||
# 创建 PostgreSQL docstore
|
||||
docstore = PostgresDocStore(
|
||||
connection_string="postgresql://user:pass@host:5432/db",
|
||||
table_name="parent_documents"
|
||||
)
|
||||
|
||||
# 创建 IndexBuilder 并注入 docstore
|
||||
# 创建 IndexBuilder
|
||||
builder = IndexBuilder(
|
||||
collection_name="rag_documents",
|
||||
splitter_type=SplitterType.PARENT_CHILD,
|
||||
docstore=docstore,
|
||||
parent_chunk_size=1000,
|
||||
child_chunk_size=200,
|
||||
docstore_conn_string="postgresql://user:pass@host:5432/db",
|
||||
)
|
||||
```
|
||||
|
||||
### Level 3.2: 语义切分与父子块策略结合
|
||||
- **核心优势**: 结合语义切分的连贯性和父子块策略的层次化存储优势,实现更精准的检索和更丰富的上下文。
|
||||
- **实现原理**:
|
||||
- **父块切分**: 使用递归字符切分创建大块(约1000词),提供完整的上下文背景
|
||||
- **子块切分**: 使用语义动态切分创建小块(约200词),根据语义连贯性动态切分,提高检索精度
|
||||
- **存储机制**: 子块向量存入Qdrant用于精准检索,父块内容存入PostgreSQL提供完整上下文
|
||||
- **父块切分**: 使用 `RecursiveCharacterTextSplitter` 创建大块(约1000字符),提供完整的上下文背景
|
||||
- **子块切分**: 使用 `SemanticChunkerAdapter` 创建小块,根据语义连贯性动态切分,提高检索精度
|
||||
- **存储机制**: 子块向量存入 Qdrant 用于精准检索,父块内容存入 PostgreSQL 提供完整上下文
|
||||
- **使用示例**:
|
||||
```python
|
||||
from rag_indexer.builder import IndexBuilder, SplitterType
|
||||
@@ -109,97 +81,55 @@ graph TD
|
||||
```
|
||||
- **配置参数**:
|
||||
- `child_splitter_type`: 子块切分器类型,可选 `SplitterType.RECURSIVE`(默认)或 `SplitterType.SEMANTIC`
|
||||
- 当使用语义切分时,系统会自动使用已配置的Embedding模型进行句子级相似度计算
|
||||
- 当使用语义切分时,系统会自动使用已配置的 Embedding 模型进行句子级相似度计算
|
||||
|
||||
### Level 4: RAG-Fusion (多路改写与倒数排名融合)
|
||||
- **核心优势**: 通过大模型发散思维,将单一问题改写为多个相似问题,扩大搜索面,再利用数学统计算法合并结果,提高检索的全面性和准确性。
|
||||
- **实现原理**:
|
||||
1. **多路查询改写**: 利用LLM将原始查询改写成3-5个不同表述的查询,从不同角度表达相同意图
|
||||
2. **倒数排名融合 (RRF)**: 对每个改写查询的结果进行RRF融合,公式为 $RRF\_score(d) = \sum_{q \in Q} \frac{1}{k + rank_q(d)}$,避免单一检索结果主导
|
||||
3. **结果去重**: 对融合后的结果进行去重,确保返回的文档唯一
|
||||
- **使用示例**:
|
||||
```python
|
||||
from rag_indexer.builder import IndexBuilder, SplitterType
|
||||
from langchain_openai import OpenAI
|
||||
|
||||
# 创建 IndexBuilder
|
||||
builder = IndexBuilder(
|
||||
collection_name="rag_documents",
|
||||
splitter_type=SplitterType.PARENT_CHILD,
|
||||
parent_chunk_size=1000,
|
||||
child_chunk_size=200,
|
||||
docstore_conn_string="postgresql://user:pass@host:5432/db",
|
||||
)
|
||||
|
||||
# 创建语言模型用于查询改写
|
||||
llm = OpenAI(
|
||||
openai_api_base="http://localhost:8000/v1",
|
||||
openai_api_key="no-key-needed",
|
||||
model_name="Qwen2.5-7B-Instruct",
|
||||
temperature=0.3,
|
||||
)
|
||||
|
||||
# 使用 RAG-Fusion 检索
|
||||
query = "如何申请项目资金?"
|
||||
results = builder.retrieve_with_fusion(
|
||||
query=query,
|
||||
llm=llm,
|
||||
num_queries=3,
|
||||
k=5,
|
||||
return_parent=True
|
||||
)
|
||||
```
|
||||
- **配置参数**:
|
||||
- `llm`: 语言模型实例,用于查询改写
|
||||
- `num_queries`: 生成的查询数量,建议3-5个
|
||||
- `k`: 返回的文档数量
|
||||
- `return_parent`: 是否返回父块上下文
|
||||
|
||||
### Level 5: GraphRAG 与 多模态 (Graph & Multi-modal)
|
||||
### Level 4: GraphRAG(基于图和关系的 RAG)
|
||||
- **核心算法**: LLM 实体关系抽取 (NER & Relation Extraction)。
|
||||
- **核心思路**: 解决传统纯向量检索难以处理“跨文档复杂关系推理”的痛点(如:A公司的CEO是谁?他名下的B公司主要业务是什么?这种需要横跨多页 PDF 的跳跃性问题)。
|
||||
- **实现指南**:
|
||||
- 使用本地的大模型(如 `Gemma-4-E2B`)配合 `langchain_community.graphs` 模块。
|
||||
- 利用 `LLMGraphTransformer` 组件,在读取文档时,通过预设的 Prompt 强制大模型提取出实体(Node)和关系(Edge),直接写入诸如 Neo4j 这样的图数据库中,而非传统的 Qdrant 向量库。
|
||||
- **核心思路**: 解决传统纯向量检索难以处理"跨文档复杂关系推理"的痛点(如:A公司的CEO是谁?他名下的B公司主要业务是什么?这种需要横跨多页 PDF 的跳跃性问题)。
|
||||
- **实现原理**:
|
||||
1. **实体提取**: 利用 LLM 从文档中提取实体(如人物、组织、地点、事件等)
|
||||
2. **关系抽取**: 识别实体之间的关系(如"CEO of"、"founded by"、"located in"等)
|
||||
3. **图构建**: 将实体作为节点,关系作为边,构建知识图谱
|
||||
4. **混合检索**: 结合向量检索和图查询,同时利用语义相似性和结构关系
|
||||
- **技术栈**:
|
||||
- **图数据库**: Neo4j 或 RedisGraph
|
||||
- **LLM 工具**: `LLMGraphTransformer` 或自定义 Prompt
|
||||
- **集成方式**: 与向量存储并行,形成混合检索系统
|
||||
- **实现指南**:
|
||||
- 使用 `langchain_community.graphs` 模块
|
||||
- 配置本地大模型(如 `Gemma-4-E2B`)用于实体关系抽取
|
||||
- 构建包含实体和关系的图结构,存储到图数据库
|
||||
- 实现混合检索逻辑,结合向量相似度和图路径分析
|
||||
|
||||
---
|
||||
|
||||
## 所需依赖与安装
|
||||
|
||||
为了支持完整的文档解析和 Qdrant 写入,需要安装以下 Python 包:
|
||||
|
||||
```bash
|
||||
# 基础核心库
|
||||
pip install langchain langchain-core langchain-openai langchain-qdrant
|
||||
|
||||
# 用于复杂文档解析 (PDF, Word, Excel 等)
|
||||
pip install unstructured pdf2image pdfminer.six
|
||||
|
||||
# 用于语义分块 (可选)
|
||||
pip install langchain-experimental
|
||||
|
||||
# 用于 PostgreSQL 存储 (可选,用于 Parent-Child 策略)
|
||||
pip install psycopg2-binary
|
||||
|
||||
# 用于 RAG-Fusion (可选,需要语言模型)
|
||||
pip install langchain-openai
|
||||
```
|
||||
### Level 5: 多模态 RAG (Multi-modal RAG)
|
||||
- **核心算法**: 跨模态嵌入和多模态融合。
|
||||
- **核心思路**: 突破纯文本限制,支持图像、表格、音频等多种数据类型的理解和检索。
|
||||
- **实现原理**:
|
||||
1. **多模态嵌入**: 使用 CLIP 等模型将不同模态数据映射到统一向量空间
|
||||
2. **多模态索引**: 为不同类型的内容创建专用索引
|
||||
3. **跨模态检索**: 支持以文搜图、以图搜文等跨模态查询
|
||||
- **技术栈**:
|
||||
- **多模态模型**: CLIP、BLIP 等
|
||||
- **存储**: 向量数据库 + 对象存储
|
||||
- **检索**: 混合向量检索
|
||||
|
||||
---
|
||||
|
||||
## 📂 架构与文件结构设计
|
||||
|
||||
在 `rag_indexer/` 目录下,需创建以下核心文件:
|
||||
|
||||
```text
|
||||
```
|
||||
rag_indexer/
|
||||
├── __init__.py
|
||||
├── loaders.py # 负责调用 unstructured 解析不同类型文件
|
||||
├── splitters.py # 负责实现 Recursive、Semantic、Parent-Child 切分逻辑
|
||||
├── splitters.py # 负责实现 Recursive、Semantic 切分逻辑及适配器
|
||||
├── embedders.py # 封装本地 llama.cpp 交互的 Embedding 接口
|
||||
├── vector_store.py # 封装 Qdrant 写入、Upsert、Collection 初始化操作
|
||||
├── docstore_manager.py # 文档存储管理器,支持 LocalFileStore 和 PostgreSQL
|
||||
└── builder.py # 核心编排文件,将上述模块串联成 Pipeline
|
||||
├── builder.py # 核心编排文件,将上述模块串联成 Pipeline
|
||||
├── cli.py # 命令行入口
|
||||
└── store/
|
||||
├── __init__.py
|
||||
├── factory.py # docstore 工厂函数
|
||||
└── postgres.py # PostgreSQL DocStore 实现
|
||||
```
|
||||
|
||||
---
|
||||
@@ -211,36 +141,36 @@ rag_indexer/
|
||||
```
|
||||
┌─────────────────────────────────────────┐
|
||||
│ builder.py │
|
||||
│ IndexBuilder 入口 │
|
||||
│ IndexBuilder 入口 │
|
||||
└─────────────────┬───────────────────────┘
|
||||
│
|
||||
┌─────────────────▼───────────────────────┐
|
||||
│ loaders.py │
|
||||
│ DocumentLoader.load_file() │
|
||||
│ → 返回 List[Document] │
|
||||
│ loaders.py │
|
||||
│ DocumentLoader.load_file() │
|
||||
│ → 返回 List[Document] │
|
||||
└─────────────────┬───────────────────────┘
|
||||
│
|
||||
┌─────────────────▼───────────────────────┐
|
||||
│ ParentDocumentRetriever.add_documents()│
|
||||
│ ┌─────────────────────────────────┐ │
|
||||
│ │ parent_splitter (粗切) │ │
|
||||
│ │ 父块 ~1000 词 │ │
|
||||
│ └────────────┬────────────────────┘ │
|
||||
│ │ │
|
||||
│ ┌────────────▼────────────────────┐ │
|
||||
│ │ child_splitter (细切) │ │
|
||||
│ │ 子块 ~200 词 │ │
|
||||
│ └────────────┬────────────────────┘ │
|
||||
│ │ │
|
||||
│ ┌──────────┴──────────┐ │
|
||||
│ ▼ ▼ │
|
||||
│ 子块向量 父块原始内容 │
|
||||
│ │ │ │
|
||||
│ ▼ ▼ │
|
||||
│ ┌────────────┐ ┌─────────────────┐ │
|
||||
│ │vector_store│ │ docstore_manager│ │
|
||||
│ │ (Qdrant) │ │ (PostgreSQL) │ │
|
||||
│ └────────────┘ └─────────────────┘ │
|
||||
│ ParentDocumentRetriever │
|
||||
│ ┌─────────────────────────────────┐ │
|
||||
│ │ parent_splitter (粗切) │ │
|
||||
│ │ 父块 ~1000 字符 │ │
|
||||
│ └────────────┬────────────────────┘ │
|
||||
│ │ │
|
||||
│ ┌────────────▼────────────────────┐ │
|
||||
│ │ child_splitter (细切) │ │
|
||||
│ │ 子块 ~200 字符 │ │
|
||||
│ └────────────┬────────────────────┘ │
|
||||
│ │ │
|
||||
│ ┌──────────┴──────────┐ │
|
||||
│ ▼ ▼ │
|
||||
│ 子块向量 父块原始内容 │
|
||||
│ │ │ │
|
||||
│ ▼ ▼ │
|
||||
│ ┌────────────┐ ┌─────────────────┐ │
|
||||
│ │vector_store│ │ store/ │ │
|
||||
│ │ (Qdrant) │ │ (PostgreSQL) │ │
|
||||
│ └────────────┘ └─────────────────┘ │
|
||||
└─────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
@@ -250,10 +180,31 @@ rag_indexer/
|
||||
|------|------|------------|
|
||||
| **builder.py** | 核心编排,负责串联整个流程 | `IndexBuilder` |
|
||||
| **loaders.py** | 解析各种文档格式(PDF、Word、TXT等) | `DocumentLoader` |
|
||||
| **splitters.py** | 文本切分策略(Recursive/Semantic/Parent-Child) | `SplitterType`, `get_splitter()` |
|
||||
| **splitters.py** | 文本切分策略(Recursive/Semantic)及适配器 | `SplitterType`, `get_splitter()`, `SemanticChunkerAdapter` |
|
||||
| **embedders.py** | 向量化(封装 llama.cpp embedding 接口) | `LlamaCppEmbedder` |
|
||||
| **vector_store.py** | Qdrant 向量数据库操作 | `QdrantVectorStore` |
|
||||
| **docstore_manager.py** | 父文档存储(PostgreSQL/本地文件) | `PostgresDocStore`, `get_docstore()` |
|
||||
| **store/postgres.py** | PostgreSQL DocStore 实现 | `PostgresDocStore` |
|
||||
| **store/factory.py** | docstore 工厂函数 | `create_docstore()` |
|
||||
|
||||
### 核心实现细节
|
||||
|
||||
#### 1. 文本切分
|
||||
- **递归切分**: 使用 `langchain_text_splitters.RecursiveCharacterTextSplitter`,支持中文分隔符
|
||||
- **语义切分**: 使用 `langchain_experimental.text_splitter.SemanticChunker`,通过 `SemanticChunkerAdapter` 适配 `TextSplitter` 接口
|
||||
- **父子块策略**: 父块使用递归切分(1000字符),子块可选择递归或语义切分(200字符)
|
||||
|
||||
#### 2. 向量化
|
||||
- **Embedding API**: 使用 `LlamaCppEmbedder` 封装本地 llama.cpp 服务,支持 `embed_documents` 和 `embed_query` 方法
|
||||
- **向量维度**: 自动检测模型维度(默认 2560),创建对应大小的 Qdrant 集合
|
||||
|
||||
#### 3. 向量存储
|
||||
- **Qdrant 集成**: 使用 `langchain_qdrant.QdrantVectorStore` 作为底层存储
|
||||
- **集合管理**: 自动创建/复用集合,支持 `force_recreate` 参数
|
||||
- **批量写入**: 支持 `batch_size` 参数,避免单次请求过大
|
||||
|
||||
#### 4. 文档存储
|
||||
- **PostgreSQL**: 使用 `PostgresDocStore` 持久化存储父块,支持异步连接池
|
||||
- **数据映射**: 通过 UUID 将子块与父块关联,检索时返回完整父块
|
||||
|
||||
### 调用顺序
|
||||
|
||||
@@ -265,27 +216,42 @@ from rag_indexer.builder import IndexBuilder, SplitterType
|
||||
builder = IndexBuilder(
|
||||
collection_name="my_docs",
|
||||
splitter_type=SplitterType.PARENT_CHILD,
|
||||
qdrant_url="http://localhost:6333",
|
||||
parent_chunk_size=1000,
|
||||
child_chunk_size=200,
|
||||
docstore_conn_string="postgresql://user:pass@host:5432/db",
|
||||
)
|
||||
```
|
||||
|
||||
#### 2. 构建索引
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
|
||||
# 方式A:从单个文件构建
|
||||
builder.build_from_file("/path/to/document.pdf")
|
||||
async def main():
|
||||
count = await builder.build_from_file("/path/to/document.pdf")
|
||||
print(f"已索引 {count} 个块")
|
||||
|
||||
# 方式B:从目录批量构建
|
||||
builder.build_from_directory("/path/to/docs/")
|
||||
async def main():
|
||||
count = await builder.build_from_directory("/path/to/docs/")
|
||||
print(f"已索引 {count} 个块")
|
||||
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
#### 3. 检索(获取完整父块上下文)
|
||||
|
||||
```python
|
||||
# 检索时返回完整父块
|
||||
results = builder.search_with_parent_context("查询内容")
|
||||
import asyncio
|
||||
|
||||
async def main():
|
||||
# 检索时返回完整父块
|
||||
results = await builder.search_with_parent_context("查询内容", k=5)
|
||||
for doc in results:
|
||||
print(doc.page_content)
|
||||
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
### 检索流程
|
||||
@@ -299,11 +265,16 @@ results = builder.search_with_parent_context("查询内容")
|
||||
---
|
||||
|
||||
### 串联与触发方式
|
||||
在你的 LangGraph 系统外,创建一个执行脚本 `scripts/run_indexer.py`:
|
||||
使用 `cli.py` 入口脚本:
|
||||
|
||||
```bash
|
||||
# 终端执行,将本地的 PDF 手册刷入向量数据库
|
||||
# 设置环境变量
|
||||
export QDRANT_URL="http://115.190.121.151:6333"
|
||||
python scripts/run_indexer.py --file data/user_docs/tech_manual.pdf
|
||||
export QDRANT_API_KEY="your-api-key"
|
||||
export DB_URI="postgresql://postgres:password@host:5432/langgraph_db?sslmode=disable"
|
||||
|
||||
# 执行索引构建
|
||||
python -m rag_indexer.cli --path data/user_docs/tech_manual.pdf
|
||||
```
|
||||
这相当于系统后台的**“离线学习阶段”**,你可以随时挂载定时任务去扫描文件夹,增量更新知识库。
|
||||
|
||||
这相当于系统后台的**"离线学习阶段"**,你可以随时挂载定时任务去扫描文件夹,增量更新知识库。
|
||||
|
||||
@@ -9,52 +9,52 @@ Offline RAG Indexer module.
|
||||
- 父文档存储(PostgreSQL)
|
||||
|
||||
示例用法:
|
||||
>>> from rag_indexer import IndexBuilder, SplitterType
|
||||
>>> from rag_indexer import IndexBuilder, IndexBuilderConfig, SplitterType
|
||||
>>>
|
||||
>>> builder = IndexBuilder(
|
||||
>>> config = IndexBuilderConfig(
|
||||
... collection_name="my_docs",
|
||||
... splitter_type=SplitterType.PARENT_CHILD,
|
||||
... qdrant_url="http://localhost:6333"
|
||||
... )
|
||||
>>> builder = IndexBuilder(config)
|
||||
>>>
|
||||
>>> builder.build_from_file("document.pdf")
|
||||
>>> # 或直接传参(向后兼容)
|
||||
>>> builder = IndexBuilder(collection_name="my_docs")
|
||||
>>>
|
||||
>>> await builder.build_from_file("document.pdf")
|
||||
"""
|
||||
|
||||
from .IndexBuilder import IndexBuilder, IndexBuilderConfig, DocstoreConfig
|
||||
from .loaders import DocumentLoader
|
||||
from .splitters import (
|
||||
SplitterType,
|
||||
get_splitter,
|
||||
ParentChildSplitter,
|
||||
)
|
||||
from .embedders import LlamaCppEmbedder
|
||||
from .vector_store import QdrantVectorStore
|
||||
from .builder import IndexBuilder
|
||||
from .splitters import SplitterType, get_splitter
|
||||
|
||||
# 导出存储相关类(从新的 store 包)
|
||||
from .store import (
|
||||
# 从 rag_core 重新导出常用组件
|
||||
from rag_core import (
|
||||
LlamaCppEmbedder,
|
||||
QdrantVectorStore,
|
||||
PostgresDocStore,
|
||||
create_docstore,
|
||||
)
|
||||
|
||||
|
||||
|
||||
__version__ = "2.0.0"
|
||||
|
||||
__all__ = [
|
||||
# 核心类
|
||||
"DocumentLoader",
|
||||
# 核心构建器与配置
|
||||
"IndexBuilder",
|
||||
"IndexBuilderConfig",
|
||||
"DocstoreConfig",
|
||||
|
||||
# 加载器
|
||||
"DocumentLoader",
|
||||
|
||||
# 切分相关
|
||||
"SplitterType",
|
||||
"get_splitter",
|
||||
"ParentChildSplitter",
|
||||
|
||||
# 嵌入和向量存储
|
||||
# 嵌入与向量存储
|
||||
"LlamaCppEmbedder",
|
||||
"QdrantVectorStore",
|
||||
|
||||
# 存储(新的 store 包)
|
||||
# 文档存储
|
||||
"PostgresDocStore",
|
||||
"create_docstore",
|
||||
]
|
||||
]
|
||||
@@ -1,392 +0,0 @@
|
||||
"""
|
||||
离线 RAG 索引构建核心流水线。
|
||||
|
||||
支持 LangChain 的 ParentDocumentRetriever 用于父子块切分。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Union, Optional, Tuple, Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
from httpx import RemoteProtocolError
|
||||
from langchain_core.documents import Document
|
||||
from langchain_classic.retrievers import ParentDocumentRetriever
|
||||
from langchain_core.stores import BaseStore
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
from langchain_experimental.text_splitter import SemanticChunker
|
||||
|
||||
from .loaders import DocumentLoader
|
||||
from .splitters import SplitterType, get_splitter, ParentChildSplitter, SemanticChunkerAdapter
|
||||
from .embedders import LlamaCppEmbedder
|
||||
from .vector_store import QdrantVectorStore, QDRANT_URL, QDRANT_API_KEY
|
||||
from .store import create_docstore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParentChildConfig:
|
||||
"""父子块切分配置。"""
|
||||
parent_chunk_size: int = 1000
|
||||
child_chunk_size: int = 200
|
||||
parent_chunk_overlap: int = 100
|
||||
child_chunk_overlap: int = 20
|
||||
search_k: int = 5
|
||||
docstore_path: Optional[str] = None
|
||||
docstore_type: str = "local"
|
||||
docstore_conn_string: Optional[str] = None
|
||||
|
||||
|
||||
class IndexBuilder:
|
||||
"""RAG 索引构建主流水线。"""
|
||||
|
||||
# 类型注解
|
||||
parent_splitter: "RecursiveCharacterTextSplitter"
|
||||
child_splitter: Union["RecursiveCharacterTextSplitter", "SemanticChunker", "SemanticChunkerAdapter"]
|
||||
docstore: Optional["BaseStore"]
|
||||
_docstore_conn: Optional[str]
|
||||
retriever: Optional["ParentDocumentRetriever"]
|
||||
vector_store_obj: Any
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str = "rag_documents",
|
||||
splitter_type: SplitterType = SplitterType.PARENT_CHILD,
|
||||
docstore=None,
|
||||
**splitter_kwargs,
|
||||
):
|
||||
self.collection_name = collection_name
|
||||
self.splitter_type = splitter_type
|
||||
self.splitter_kwargs = splitter_kwargs
|
||||
self.docstore = docstore # 从外部注入
|
||||
|
||||
# 组件
|
||||
self.loader = DocumentLoader()
|
||||
self.embedder = LlamaCppEmbedder()
|
||||
self.embeddings = self.embedder.as_langchain_embeddings()
|
||||
|
||||
self.vector_store = QdrantVectorStore(
|
||||
collection_name=collection_name,
|
||||
embeddings=self.embeddings,
|
||||
)
|
||||
|
||||
# 切分器(父子块单独处理)
|
||||
if splitter_type != SplitterType.PARENT_CHILD:
|
||||
if splitter_type == SplitterType.SEMANTIC:
|
||||
splitter_kwargs["embeddings"] = self.embeddings
|
||||
self.splitter = get_splitter(splitter_type, **splitter_kwargs)
|
||||
else:
|
||||
self.splitter = None
|
||||
# 为父子块切分初始化 ParentDocumentRetriever
|
||||
self._init_parent_child_retriever()
|
||||
|
||||
def _init_parent_child_retriever(self, **kwargs):
|
||||
"""
|
||||
初始化 ParentDocumentRetriever 用于父子块切分。
|
||||
|
||||
支持动态语义切分与父子块策略结合:
|
||||
- 父块使用递归切分(大块,提供上下文)
|
||||
- 子块可以使用递归切分或语义切分(根据语义动态切分,提高检索精度)
|
||||
|
||||
替代自定义的 ParentChildSplitter 逻辑。
|
||||
"""
|
||||
# 解析父子块配置参数
|
||||
parent_size = kwargs.get("parent_chunk_size", 1000)
|
||||
child_size = kwargs.get("child_chunk_size", 200)
|
||||
parent_overlap = kwargs.get("parent_chunk_overlap", kwargs.get("chunk_overlap", 100))
|
||||
child_overlap = kwargs.get("child_chunk_overlap", kwargs.get("chunk_overlap", 20))
|
||||
|
||||
# 子块切分器类型,默认为语义切分
|
||||
child_splitter_type = kwargs.get("child_splitter_type", SplitterType.SEMANTIC)
|
||||
|
||||
# 定义父块切分器(始终使用递归切分)
|
||||
self.parent_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=parent_size,
|
||||
chunk_overlap=parent_overlap,
|
||||
)
|
||||
|
||||
# 定义子块切分器(根据类型选择)
|
||||
if child_splitter_type == SplitterType.SEMANTIC:
|
||||
self.child_splitter = get_splitter(
|
||||
SplitterType.SEMANTIC,
|
||||
embeddings=self.embeddings,
|
||||
)
|
||||
logger.info(f"子块使用语义切分器")
|
||||
else:
|
||||
# 默认使用递归切分
|
||||
self.child_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=child_size,
|
||||
chunk_overlap=child_overlap,
|
||||
)
|
||||
logger.info(f"子块使用递归切分器,块大小: {child_size},重叠: {child_overlap}")
|
||||
|
||||
# 向量存储(用于子块)
|
||||
self.vector_store_obj = self.vector_store.get_langchain_vectorstore()
|
||||
|
||||
# 文档存储(用于父块)
|
||||
if self.docstore is None:
|
||||
# 如果没有外部注入 docstore,则使用 PostgreSQL 创建
|
||||
docstore_conn = kwargs.get("docstore_conn_string")
|
||||
pool_config = kwargs.get("pool_config")
|
||||
max_concurrency = kwargs.get("max_concurrency")
|
||||
|
||||
# 使用 create_docstore 创建 PostgreSQL 存储
|
||||
self.docstore, self._docstore_conn = create_docstore(
|
||||
connection_string=docstore_conn,
|
||||
pool_config=pool_config,
|
||||
max_concurrency=max_concurrency
|
||||
)
|
||||
else:
|
||||
# 使用外部注入的 docstore
|
||||
self._docstore_conn = None
|
||||
|
||||
# 创建检索器
|
||||
self.retriever = ParentDocumentRetriever(
|
||||
vectorstore=self.vector_store_obj,
|
||||
docstore=self.docstore,
|
||||
child_splitter=self.child_splitter, # type: ignore
|
||||
parent_splitter=self.parent_splitter,
|
||||
search_kwargs={"k": kwargs.get("search_k", 5)},
|
||||
)
|
||||
logger.info(f"ParentDocumentRetriever 已初始化,父块大小: {parent_size},子块类型: {child_splitter_type}")
|
||||
|
||||
async def build_from_file(self, file_path: Union[str, Path]) -> int:
|
||||
logger.info("加载文件: %s", file_path)
|
||||
documents = self.loader.load_file(file_path)
|
||||
logger.info("已加载 %d 个文档", len(documents))
|
||||
return await self._process_documents(documents)
|
||||
|
||||
async def build_from_directory(self, directory_path: Union[str, Path], recursive: bool = True) -> int:
|
||||
logger.info("加载目录: %s (递归=%s)", directory_path, recursive)
|
||||
documents = self.loader.load_directory(directory_path, recursive=recursive)
|
||||
logger.info("已从目录加载 %d 个文档", len(documents))
|
||||
return await self._process_documents(documents)
|
||||
|
||||
async def _process_documents(self, documents: List[Document]) -> int:
|
||||
if not documents:
|
||||
logger.warning("没有文档需要处理")
|
||||
return 0
|
||||
|
||||
if self.splitter_type == SplitterType.PARENT_CHILD:
|
||||
logger.info("使用 LangChain ParentDocumentRetriever")
|
||||
|
||||
# 确保集合存在(用于子块)
|
||||
self.vector_store.create_collection()
|
||||
|
||||
# 分批处理,避免单次请求过大
|
||||
assert self.retriever is not None, "retriever 未初始化"
|
||||
batch_size = 10 # 每次处理10个文档
|
||||
total = len(documents)
|
||||
processed = 0
|
||||
|
||||
for i in range(0, total, batch_size):
|
||||
batch = documents[i:i + batch_size]
|
||||
max_retries = 3
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
await self.retriever.aadd_documents(batch)
|
||||
processed += len(batch)
|
||||
logger.info(f"批次 {i//batch_size + 1}: 已处理 {processed}/{total}")
|
||||
break
|
||||
except (RemoteProtocolError, ConnectionError, OSError) as e:
|
||||
if attempt == max_retries - 1:
|
||||
raise
|
||||
logger.warning(f"连接断开,重试 ({attempt+1}/{max_retries}): {e}")
|
||||
self.vector_store.refresh_client()
|
||||
await asyncio.sleep(1)
|
||||
|
||||
logger.info(
|
||||
"已使用 ParentDocumentRetriever 索引: "
|
||||
f"共 {processed} 个父块"
|
||||
)
|
||||
return processed
|
||||
|
||||
else:
|
||||
logger.info("使用 %s 切分文档", self.splitter_type)
|
||||
# 当 splitter_type 不是 PARENT_CHILD 时,splitter 一定不为 None
|
||||
assert self.splitter is not None, "splitter 未初始化"
|
||||
chunks = self.splitter.split_documents(documents)
|
||||
logger.info("已切分为 %d 个块", len(chunks))
|
||||
|
||||
self.vector_store.create_collection()
|
||||
self.vector_store.add_documents(chunks)
|
||||
return len(chunks)
|
||||
|
||||
def get_collection_info(self):
|
||||
return self.vector_store.get_collection_info()
|
||||
|
||||
def search(self, query: str, k: int = 5) -> List[Document]:
|
||||
"""标准搜索 - 返回子块。"""
|
||||
return self.vector_store.similarity_search(query, k=k)
|
||||
|
||||
async def search_with_parent_context(self, query: str, k: int = 5) -> List[Document]:
|
||||
"""
|
||||
带父块上下文的搜索 - 返回完整父块。
|
||||
|
||||
这是使用父子块切分时的主要检索方法。
|
||||
"""
|
||||
if self.splitter_type != SplitterType.PARENT_CHILD:
|
||||
raise RuntimeError(
|
||||
"search_with_parent_context 仅在 PARENT_CHILD 切分器下可用。"
|
||||
"请使用 search() 进行标准检索。"
|
||||
)
|
||||
assert self.retriever is not None, "retriever 未初始化"
|
||||
return await self.retriever.ainvoke(query, config={"k": k}) # type: ignore
|
||||
|
||||
async def retrieve(self, query: str, return_parent: bool = True) -> List[Document]:
|
||||
"""
|
||||
统一检索接口。
|
||||
|
||||
Args:
|
||||
query: 搜索查询
|
||||
return_parent: 如果为 True 且使用父子块切分,返回父块
|
||||
如果为 False,始终返回子块
|
||||
|
||||
Returns:
|
||||
相关文档列表
|
||||
"""
|
||||
if self.splitter_type == SplitterType.PARENT_CHILD and return_parent:
|
||||
return await self.search_with_parent_context(query)
|
||||
else:
|
||||
return self.search(query)
|
||||
|
||||
async def retrieve_with_fusion(self, query: str, llm: Any, num_queries: int = 3, k: int = 5, return_parent: bool = True) -> List[Document]:
|
||||
"""
|
||||
使用 RAG-Fusion 进行检索(多路查询改写 + 倒数排名融合)。
|
||||
|
||||
核心原理:
|
||||
1. 多路查询改写: 利用 LLM 将原始查询改写成多个不同表述
|
||||
2. 倒数排名融合: 对每个改写查询的结果进行 RRF 融合,避免单一检索结果主导
|
||||
|
||||
Args:
|
||||
query: 原始搜索查询
|
||||
llm: 语言模型实例(用于查询改写)
|
||||
num_queries: 生成的查询数量
|
||||
k: 返回的文档数量
|
||||
return_parent: 如果为 True 且使用父子块切分,返回父块
|
||||
如果为 False,始终返回子块
|
||||
|
||||
Returns:
|
||||
经过融合后的相关文档列表
|
||||
"""
|
||||
from langchain.retrievers.multi_query import MultiQueryRetriever
|
||||
from langchain.retrievers import EnsembleRetriever
|
||||
|
||||
if self.splitter_type == SplitterType.PARENT_CHILD and return_parent:
|
||||
# 使用 ParentDocumentRetriever 作为基础检索器
|
||||
assert self.retriever is not None, "retriever 未初始化"
|
||||
base_retriever = self.retriever
|
||||
else:
|
||||
# 使用向量存储作为基础检索器
|
||||
base_retriever = self.vector_store.as_langchain_vectorstore().as_retriever(search_kwargs={"k": k * 2})
|
||||
|
||||
# 创建多路查询检索器
|
||||
multi_query_retriever = MultiQueryRetriever.from_llm(
|
||||
retriever=base_retriever,
|
||||
llm=llm,
|
||||
include_original=True
|
||||
)
|
||||
|
||||
# 设置自定义提示词以生成指定数量的查询
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
multi_query_retriever.llm_chain.prompt = PromptTemplate.from_template(
|
||||
"你是一个专业的查询改写助手。你的任务是将用户的问题改写成 {num_queries} 个不同的版本。\n"
|
||||
"这些版本应该从不同的角度、使用不同的关键词来表达相同或相关的意图。\n\n"
|
||||
"原始问题: {question}\n\n"
|
||||
"请生成 {num_queries} 个不同版本的查询,每个版本一行。\n"
|
||||
"确保每个版本都是独立、完整的查询语句。\n\n"
|
||||
"生成 {num_queries} 个查询:"
|
||||
)
|
||||
|
||||
# 修改调用参数以包含 num_queries
|
||||
original_ainvoke = multi_query_retriever.llm_chain.ainvoke
|
||||
async def new_ainvoke(input_dict):
|
||||
input_dict["num_queries"] = num_queries
|
||||
return await original_ainvoke(input_dict)
|
||||
multi_query_retriever.llm_chain.ainvoke = new_ainvoke
|
||||
|
||||
# 执行检索
|
||||
documents = await multi_query_retriever.ainvoke(query)
|
||||
|
||||
# 去重并限制数量
|
||||
seen_content = set()
|
||||
unique_documents = []
|
||||
for doc in documents:
|
||||
content = doc.page_content
|
||||
if content not in seen_content:
|
||||
seen_content.add(content)
|
||||
unique_documents.append(doc)
|
||||
if len(unique_documents) >= k:
|
||||
break
|
||||
|
||||
logger.info(f"RAG-Fusion 检索完成: 原始 {len(documents)} 个结果,去重后 {len(unique_documents)} 个结果")
|
||||
return unique_documents
|
||||
|
||||
def get_retriever(self) -> ParentDocumentRetriever:
|
||||
"""
|
||||
直接获取 ParentDocumentRetriever 实例。
|
||||
|
||||
适用于需要在 IndexBuilder 外部访问检索器的高级用例。
|
||||
"""
|
||||
if self.splitter_type != SplitterType.PARENT_CHILD:
|
||||
raise RuntimeError(
|
||||
"get_retriever() 仅在 PARENT_CHILD 切分器下可用。"
|
||||
"请使用 search() 或 search_with_parent_context() 进行标准检索。"
|
||||
)
|
||||
assert self.retriever is not None, "retriever 未初始化"
|
||||
return self.retriever
|
||||
|
||||
def get_child_splitter(self) -> Union["RecursiveCharacterTextSplitter", "SemanticChunker", "SemanticChunkerAdapter"]:
|
||||
"""获取子块切分器以便重新配置。"""
|
||||
if self.splitter_type != SplitterType.PARENT_CHILD:
|
||||
return self.splitter # type: ignore
|
||||
return self.child_splitter
|
||||
|
||||
def get_parent_splitter(self) -> "RecursiveCharacterTextSplitter":
|
||||
"""获取父块切分器以便重新配置。"""
|
||||
if self.splitter_type != SplitterType.PARENT_CHILD:
|
||||
raise RuntimeError(
|
||||
"父块切分器仅在 PARENT_CHILD 切分器下可用。"
|
||||
)
|
||||
return self.parent_splitter
|
||||
|
||||
def get_docstore(self) -> BaseStore:
|
||||
"""获取父块的文档存储。"""
|
||||
if self.splitter_type != SplitterType.PARENT_CHILD:
|
||||
raise RuntimeError(
|
||||
"文档存储仅在 PARENT_CHILD 切分器下可用。"
|
||||
)
|
||||
assert self.docstore is not None, "docstore 未初始化"
|
||||
return self.docstore
|
||||
|
||||
def get_docstore_path(self) -> Optional[str]:
|
||||
"""获取文档存储路径(已弃用,仅用于兼容性)。"""
|
||||
if self.splitter_type != SplitterType.PARENT_CHILD:
|
||||
raise RuntimeError(
|
||||
"文档存储路径仅在 PARENT_CHILD 切分器下可用。"
|
||||
)
|
||||
# PostgreSQL 存储没有 persist_path,返回 None
|
||||
return None
|
||||
|
||||
def close(self):
|
||||
"""关闭资源。"""
|
||||
if self.docstore is not None and hasattr(self.docstore, "aclose"):
|
||||
import asyncio
|
||||
asyncio.get_event_loop().run_until_complete(self.docstore.aclose()) # type: ignore
|
||||
logger.info("PostgreSQL 异步连接池已关闭")
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
return False
|
||||
|
||||
|
||||
# 需要导入 RecursiveCharacterTextSplitter
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
|
||||
|
||||
# 示例用法已移除,请参考文档
|
||||
@@ -1,85 +1,77 @@
|
||||
"""
|
||||
Command-line interface for the RAG index builder.
|
||||
简易命令行入口,使用默认配置构建 RAG 索引。
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from rag_indexer.builder import IndexBuilder
|
||||
from rag_indexer.IndexBuilder import IndexBuilder, IndexBuilderConfig
|
||||
from rag_indexer.splitters import SplitterType
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 基础配置
|
||||
# 默认配置(所有连接参数从环境变量读取)
|
||||
COLLECTION_NAME = "rag_documents"
|
||||
DB_URI = "postgresql://postgres:huang1998@115.190.121.151:5432/langgraph_db?sslmode=disable"
|
||||
SPLITTER_TYPE = SplitterType.PARENT_CHILD
|
||||
CHILD_SPLITTER_TYPE = SplitterType.SEMANTIC
|
||||
|
||||
# 基础切分参数
|
||||
CHUNK_SIZE = 500
|
||||
CHUNK_OVERLAP = 50
|
||||
|
||||
# 父子块切分参数
|
||||
# 父子块大小参数(可根据需要调整)
|
||||
PARENT_CHUNK_SIZE = 1000
|
||||
CHILD_CHUNK_SIZE = 200
|
||||
PARENT_CHUNK_OVERLAP = 100
|
||||
CHILD_CHUNK_SIZE = 200
|
||||
CHILD_CHUNK_OVERLAP = 20
|
||||
SEARCH_K = 5
|
||||
|
||||
# 切分策略:basic(基础)、semantic(语义)、parent-child(父子块)
|
||||
STRATEGY = "parent-child"
|
||||
|
||||
# 存储类型:postgres(PostgreSQL)、local(本地文件)
|
||||
STORAGE_TYPE = "postgres"
|
||||
def get_input_path() -> Path:
|
||||
"""从命令行参数获取输入路径,若未提供则使用默认示例路径。"""
|
||||
if len(sys.argv) > 1:
|
||||
return Path(sys.argv[1])
|
||||
# 默认测试路径(可按需修改)
|
||||
return Path("data/user_docs/a.txt")
|
||||
|
||||
|
||||
async def main():
|
||||
# 使用固定策略
|
||||
splitter_type = SplitterType.PARENT_CHILD
|
||||
child_splitter_type = SplitterType.SEMANTIC
|
||||
input_path = get_input_path()
|
||||
if not input_path.exists():
|
||||
logger.error("路径不存在: %s", input_path)
|
||||
sys.exit(1)
|
||||
|
||||
splitter_kwargs = {}
|
||||
|
||||
if splitter_type == SplitterType.RECURSIVE:
|
||||
splitter_kwargs["chunk_size"] = CHUNK_SIZE
|
||||
splitter_kwargs["chunk_overlap"] = CHUNK_OVERLAP
|
||||
elif splitter_type == SplitterType.PARENT_CHILD:
|
||||
splitter_kwargs["parent_chunk_size"] = PARENT_CHUNK_SIZE
|
||||
splitter_kwargs["child_chunk_size"] = CHILD_CHUNK_SIZE
|
||||
splitter_kwargs["parent_chunk_overlap"] = PARENT_CHUNK_OVERLAP
|
||||
splitter_kwargs["child_chunk_overlap"] = CHILD_CHUNK_OVERLAP
|
||||
splitter_kwargs["child_splitter_type"] = child_splitter_type
|
||||
if STORAGE_TYPE == "postgres":
|
||||
splitter_kwargs["docstore_conn_string"] = DB_URI
|
||||
elif STORAGE_TYPE == "local":
|
||||
splitter_kwargs["docstore_path"] = "./parent_docs"
|
||||
else:
|
||||
splitter_kwargs["docstore_conn_string"] = DB_URI
|
||||
|
||||
builder = IndexBuilder(
|
||||
# 构建配置(使用全部默认值)
|
||||
config = IndexBuilderConfig(
|
||||
collection_name=COLLECTION_NAME,
|
||||
splitter_type=splitter_type,
|
||||
**splitter_kwargs
|
||||
splitter_type=SPLITTER_TYPE,
|
||||
parent_chunk_size=PARENT_CHUNK_SIZE,
|
||||
parent_chunk_overlap=PARENT_CHUNK_OVERLAP,
|
||||
child_chunk_size=CHILD_CHUNK_SIZE,
|
||||
child_chunk_overlap=CHILD_CHUNK_OVERLAP,
|
||||
child_splitter_type=CHILD_SPLITTER_TYPE,
|
||||
search_k=SEARCH_K,
|
||||
# docstore 默认使用 create_docstore 从环境变量读取 PostgreSQL 连接
|
||||
)
|
||||
|
||||
is_file=False
|
||||
path="data/corpus/"
|
||||
builder = IndexBuilder(config)
|
||||
is_directory = input_path.is_dir()
|
||||
|
||||
try:
|
||||
if is_file:
|
||||
chunk_count = await builder.build_from_file(path)
|
||||
else:
|
||||
chunk_count = await builder.build_from_directory(path, recursive=True)
|
||||
async with builder:
|
||||
if is_directory:
|
||||
chunk_count = await builder.build_from_directory(input_path, recursive=True)
|
||||
else:
|
||||
chunk_count = await builder.build_from_file(input_path)
|
||||
|
||||
print(f"索引构建完成。共索引 {chunk_count} 个块")
|
||||
print(f"\n索引构建完成。共索引 {chunk_count} 个块")
|
||||
info = builder.get_collection_info()
|
||||
print(f"集合 '{info['name']}' 包含 {info['vectors_count']} 个向量(维度:{info['vector_size']})")
|
||||
|
||||
except Exception as e:
|
||||
logging.exception(f"索引构建失败:{e}")
|
||||
logger.exception("索引构建失败: %s", e)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
|
||||
@@ -1,88 +0,0 @@
|
||||
"""
|
||||
嵌入模型包装器,用于 llama.cpp 服务。
|
||||
"""
|
||||
|
||||
import os
|
||||
import httpx
|
||||
from typing import List, Optional
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
|
||||
class LlamaCppEmbedder:
|
||||
"""通过 OpenAI 兼容 API 封装 llama.cpp 嵌入服务。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
model: str = "embeddinggemma-300M-Q8_0",
|
||||
):
|
||||
self.base_url = base_url or os.getenv("LLAMACPP_EMBEDDING_URL", "http://127.0.0.1:8082")
|
||||
self.api_key = api_key or os.getenv("LLAMACPP_API_KEY", "")
|
||||
self.model = model
|
||||
|
||||
def as_langchain_embeddings(self) -> Embeddings:
|
||||
"""创建 LangChain 兼容的嵌入实例。"""
|
||||
return _LlamaCppLangchainAdapter(self)
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""嵌入一批文档。"""
|
||||
return self._call_embedding_api(texts)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""嵌入单个查询。"""
|
||||
return self._call_embedding_api([text])[0]
|
||||
|
||||
def get_embedding_dimension(self) -> int:
|
||||
"""通过嵌入测试字符串获取嵌入维度。"""
|
||||
test_embedding = self.embed_query("test")
|
||||
return len(test_embedding)
|
||||
|
||||
def _call_embedding_api(self, texts: List[str]) -> List[List[float]]:
|
||||
"""直接调用 llama.cpp 嵌入 API。"""
|
||||
base = self.base_url.rstrip("/")
|
||||
if not base.endswith("/v1"):
|
||||
base = base + "/v1"
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
payload = {
|
||||
"input": texts,
|
||||
"model": self.model,
|
||||
}
|
||||
|
||||
with httpx.Client(timeout=120) as client:
|
||||
response = client.post(
|
||||
f"{base}/embeddings",
|
||||
headers=headers,
|
||||
json=payload,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
# 处理不同响应格式
|
||||
if isinstance(data, list):
|
||||
# llama.cpp 直接返回列表
|
||||
return [item["embedding"] for item in data]
|
||||
elif isinstance(data, dict) and "data" in data:
|
||||
# OpenAI 标准格式
|
||||
return [item["embedding"] for item in sorted(data["data"], key=lambda x: x["index"])]
|
||||
else:
|
||||
raise ValueError(f"未知的嵌入 API 响应格式: {data}")
|
||||
|
||||
|
||||
class _LlamaCppLangchainAdapter(Embeddings):
|
||||
"""将 LlamaCppEmbedder 适配为 LangChain Embeddings 接口。"""
|
||||
|
||||
def __init__(self, embedder: LlamaCppEmbedder):
|
||||
self._embedder = embedder
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
return self._embedder.embed_documents(texts)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
return self._embedder.embed_query(text)
|
||||
@@ -3,19 +3,27 @@
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Mapping, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from unstructured.documents.elements import Element
|
||||
from unstructured.partition.auto import partition
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 模块加载时设置一次环境变量,避免重复设置
|
||||
os.environ.setdefault("UNSTRUCTURED_LANGUAGE_CHECKS", "false")
|
||||
|
||||
|
||||
class DocumentLoader:
|
||||
"""从各种文件格式加载文档。"""
|
||||
|
||||
SUPPORTED_EXTENSIONS = {".pdf", ".docx", ".doc", ".txt", ".md", ".html", ".pptx", ".xlsx", ".json"}
|
||||
SUPPORTED_EXTENSIONS = {
|
||||
".pdf", ".docx", ".doc", ".txt", ".md",
|
||||
".html", ".pptx", ".xlsx", ".json"
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -32,13 +40,11 @@ class DocumentLoader:
|
||||
extract_images: 是否提取 PDF 中的图片
|
||||
strategy: 解析策略 (auto, fast, hi_res, ocr_only)
|
||||
ocr_languages: OCR 语言列表,如 ['chi_sim', 'eng']
|
||||
languages: 文档主语言,如 ['zh']
|
||||
languages: 文档主语言,如 ['zh'](主要用于非 OCR 场景)
|
||||
include_page_breaks: 是否包含分页符
|
||||
pdf_infer_table_structure: 是否识别表格结构 (需 hi_res 策略)
|
||||
pdf_infer_table_structure: 是否识别表格结构(需 hi_res 策略)
|
||||
partition_kwargs: 额外的 partition 参数字典(高级定制)
|
||||
"""
|
||||
import os
|
||||
os.environ["UNSTRUCTURED_LANGUAGE_CHECKS"] = "false"
|
||||
self.extract_images = extract_images
|
||||
self.strategy = strategy
|
||||
self.ocr_languages = ocr_languages or ["chi_sim", "eng"]
|
||||
@@ -47,6 +53,52 @@ class DocumentLoader:
|
||||
self.pdf_infer_table_structure = pdf_infer_table_structure
|
||||
self.partition_kwargs = partition_kwargs or {}
|
||||
|
||||
def _build_partition_kwargs(self, file_path: Path) -> Dict[str, Any]:
|
||||
"""根据文件类型构建 partition 的参数。"""
|
||||
kwargs: Dict[str, Any] = {
|
||||
"include_page_breaks": self.include_page_breaks,
|
||||
}
|
||||
|
||||
suffix = file_path.suffix.lower()
|
||||
|
||||
# PDF 专用参数
|
||||
if suffix == ".pdf":
|
||||
kwargs.update({
|
||||
"strategy": self.strategy,
|
||||
"ocr_languages": self.ocr_languages,
|
||||
"extract_images_in_pdf": self.extract_images,
|
||||
"pdf_infer_table_structure": self.pdf_infer_table_structure,
|
||||
})
|
||||
|
||||
# 所有文件适用的语言参数
|
||||
if self.languages:
|
||||
kwargs["languages"] = self.languages
|
||||
|
||||
# 用户自定义参数覆盖默认值
|
||||
kwargs.update(self.partition_kwargs)
|
||||
|
||||
return kwargs
|
||||
|
||||
def _element_to_document(self, element: Element, file_path: Path) -> Optional[Document]:
|
||||
"""将单个 Element 转换为 Document,同时保留关键元数据。"""
|
||||
text = getattr(element, "text", "")
|
||||
if not text or not text.strip():
|
||||
return None
|
||||
|
||||
# 提取 unstructured 提供的元数据(根据实际需要选择)
|
||||
metadata = {
|
||||
"source": str(file_path),
|
||||
"file_name": file_path.name,
|
||||
"file_type": file_path.suffix.lower(),
|
||||
# 以下元数据来自 Element 对象,可能为 None
|
||||
"page_number": getattr(getattr(element, "metadata", None), "page_number", None),
|
||||
"category": getattr(getattr(element, "metadata", None), "category", None),
|
||||
}
|
||||
# 过滤掉值为 None 的元数据
|
||||
metadata = {k: v for k, v in metadata.items() if v is not None}
|
||||
|
||||
return Document(page_content=text, metadata=metadata)
|
||||
|
||||
def load_file(self, file_path: Union[str, Path]) -> List[Document]:
|
||||
"""将单个文件加载为 LangChain Document 对象。"""
|
||||
file_path = Path(file_path).resolve()
|
||||
@@ -59,68 +111,58 @@ class DocumentLoader:
|
||||
f"不支持的文件扩展名: {suffix}。支持的格式: {self.SUPPORTED_EXTENSIONS}"
|
||||
)
|
||||
|
||||
# 根据文件类型动态调整参数
|
||||
extra_kwargs = {}
|
||||
if suffix == ".pdf":
|
||||
extra_kwargs["strategy"] = self.strategy
|
||||
extra_kwargs["ocr_languages"] = self.ocr_languages
|
||||
extra_kwargs["extract_images_in_pdf"] = self.extract_images
|
||||
extra_kwargs["pdf_infer_table_structure"] = self.pdf_infer_table_structure
|
||||
|
||||
# languages 参数适用于所有文件类型
|
||||
if self.languages:
|
||||
extra_kwargs["languages"] = self.languages
|
||||
|
||||
extra_kwargs["include_page_breaks"] = self.include_page_breaks
|
||||
kwargs = self._build_partition_kwargs(file_path)
|
||||
|
||||
# 合并用户自定义的额外参数(优先级最高)
|
||||
extra_kwargs.update(self.partition_kwargs)
|
||||
|
||||
# 使用 unstructured 解析
|
||||
elements = partition(
|
||||
filename=str(file_path),
|
||||
|
||||
**extra_kwargs
|
||||
)
|
||||
try:
|
||||
elements = partition(filename=str(file_path), **kwargs)
|
||||
except Exception as e:
|
||||
logger.exception("解析文件 %s 失败", file_path)
|
||||
raise RuntimeError(f"文件解析失败: {file_path}") from e
|
||||
|
||||
documents = []
|
||||
for elem in elements:
|
||||
text = getattr(elem, "text", "")
|
||||
if not text or not text.strip():
|
||||
continue
|
||||
|
||||
# 基础元数据
|
||||
metadata = {
|
||||
"source": str(file_path),
|
||||
"file_name": file_path.name,
|
||||
"file_type": suffix,
|
||||
}
|
||||
|
||||
documents.append(Document(page_content=text, metadata=metadata))
|
||||
doc = self._element_to_document(elem, file_path)
|
||||
if doc:
|
||||
documents.append(doc)
|
||||
|
||||
if not documents:
|
||||
logger.warning("未从 %s 提取到文本内容", file_path)
|
||||
return []
|
||||
|
||||
return documents
|
||||
|
||||
def load_directory(
|
||||
self, directory_path: Union[str, Path], recursive: bool = True
|
||||
self,
|
||||
directory_path: Union[str, Path],
|
||||
recursive: bool = True,
|
||||
fail_fast: bool = False
|
||||
) -> List[Document]:
|
||||
"""从目录加载所有支持的文件。"""
|
||||
"""
|
||||
从目录加载所有支持的文件。
|
||||
|
||||
Args:
|
||||
directory_path: 目录路径
|
||||
recursive: 是否递归子目录
|
||||
fail_fast: 遇到第一个失败时是否立即抛出异常
|
||||
"""
|
||||
directory_path = Path(directory_path).resolve()
|
||||
if not directory_path.is_dir():
|
||||
raise NotADirectoryError(f"不是目录: {directory_path}")
|
||||
|
||||
all_documents = []
|
||||
all_documents: List[Document] = []
|
||||
pattern = "**/*" if recursive else "*"
|
||||
|
||||
for file_path in directory_path.glob(pattern):
|
||||
if file_path.is_file() and file_path.suffix.lower() in self.SUPPORTED_EXTENSIONS:
|
||||
try:
|
||||
docs = self.load_file(file_path)
|
||||
all_documents.extend(docs)
|
||||
except Exception as e:
|
||||
logger.error("加载 %s 失败: %s", file_path, e)
|
||||
if not file_path.is_file():
|
||||
continue
|
||||
if file_path.suffix.lower() not in self.SUPPORTED_EXTENSIONS:
|
||||
continue
|
||||
|
||||
try:
|
||||
docs = self.load_file(file_path)
|
||||
all_documents.extend(docs)
|
||||
except Exception as e:
|
||||
logger.error("加载 %s 失败: %s", file_path, e)
|
||||
if fail_fast:
|
||||
raise
|
||||
|
||||
return all_documents
|
||||
@@ -3,7 +3,8 @@
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Tuple, Dict, Any
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
|
||||
@@ -16,68 +17,195 @@ class SplitterType(str, Enum):
|
||||
PARENT_CHILD = "parent_child"
|
||||
|
||||
|
||||
def get_splitter(splitter_type: SplitterType, **kwargs):
|
||||
"""工厂函数,创建文本切分器。"""
|
||||
if splitter_type == SplitterType.RECURSIVE:
|
||||
chunk_size = kwargs.get("chunk_size", 500)
|
||||
chunk_overlap = kwargs.get("chunk_overlap", 50)
|
||||
return RecursiveCharacterTextSplitter(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
separators=["\n\n", "\n", "。", "!", "?", " ", ""],
|
||||
)
|
||||
elif splitter_type == SplitterType.SEMANTIC:
|
||||
embeddings = kwargs.pop("embeddings", None)
|
||||
if embeddings is None:
|
||||
raise ValueError("语义切分器需要提供 'embeddings' 参数")
|
||||
return SemanticChunkerAdapter(embeddings=embeddings, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"不支持的切分器类型: {splitter_type}")
|
||||
# ---------- 配置数据类,统一参数 ----------
|
||||
@dataclass
|
||||
class RecursiveSplitterConfig:
|
||||
"""递归字符切分器配置"""
|
||||
chunk_size: int = 500
|
||||
chunk_overlap: int = 50
|
||||
separators: List[str] = field(default_factory=lambda: ["\n\n", "\n", "。", "!", "?", " ", ""])
|
||||
keep_separator: bool = True
|
||||
strip_whitespace: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class SemanticSplitterConfig:
|
||||
"""语义切分器配置,仅包含 SemanticChunker 支持的参数。"""
|
||||
embeddings: Any
|
||||
buffer_size: int = 1
|
||||
add_start_index: bool = False
|
||||
breakpoint_threshold_type: str = "percentile"
|
||||
breakpoint_threshold_amount: Optional[float] = None
|
||||
number_of_chunks: Optional[int] = None
|
||||
sentence_split_regex: str = r"(?<=[.?!。?!])\s+"
|
||||
min_chunk_size: int = 100
|
||||
|
||||
@dataclass
|
||||
class ParentChildSplitterConfig:
|
||||
"""父子切分器配置"""
|
||||
embeddings: Any # 子块语义切分所需
|
||||
parent_chunk_size: int = 1000
|
||||
parent_chunk_overlap: int = 100
|
||||
child_buffer_size: int = 1
|
||||
child_breakpoint_threshold_type: str = "percentile"
|
||||
child_breakpoint_threshold_amount: Optional[float] = None
|
||||
child_min_chunk_size: int = 100
|
||||
child_max_chunk_size: Optional[int] = 200
|
||||
|
||||
|
||||
# ---------- 适配器:让 SemanticChunker 实现 TextSplitter 接口 ----------
|
||||
class SemanticChunkerAdapter(TextSplitter):
|
||||
"""将 SemanticChunker 适配为 TextSplitter 接口。"""
|
||||
"""将 SemanticChunker 适配为 LangChain TextSplitter 接口。"""
|
||||
|
||||
def __init__(self, embeddings, **kwargs):
|
||||
def __init__(self, config: SemanticSplitterConfig, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
chunk_size = kwargs.pop("chunk_size", None)
|
||||
chunk_overlap = kwargs.pop("chunk_overlap", None)
|
||||
self._chunker = SemanticChunker(embeddings=embeddings, **kwargs)
|
||||
self._config = config
|
||||
self._chunker = SemanticChunker(
|
||||
embeddings=config.embeddings,
|
||||
buffer_size=config.buffer_size,
|
||||
add_start_index=config.add_start_index,
|
||||
breakpoint_threshold_type=config.breakpoint_threshold_type,
|
||||
breakpoint_threshold_amount=config.breakpoint_threshold_amount,
|
||||
number_of_chunks=config.number_of_chunks,
|
||||
sentence_split_regex=config.sentence_split_regex,
|
||||
min_chunk_size=config.min_chunk_size,
|
||||
)
|
||||
|
||||
def split_text(self, text: str) -> List[str]:
|
||||
return self._chunker.split_text(text)
|
||||
|
||||
def split_documents(self, documents: List[Document]) -> List[Document]:
|
||||
result = []
|
||||
for doc in documents:
|
||||
chunks = self.split_text(doc.page_content)
|
||||
for i, chunk in enumerate(chunks):
|
||||
result.append(Document(
|
||||
page_content=chunk,
|
||||
metadata={**doc.metadata, "chunk_index": i}
|
||||
))
|
||||
return result
|
||||
|
||||
|
||||
# ---------- 工厂函数,统一创建切分器 ----------
|
||||
def get_splitter(splitter_type: SplitterType, **kwargs) -> TextSplitter:
|
||||
"""
|
||||
根据类型创建切分器。
|
||||
支持传入配置对象或直接参数。
|
||||
"""
|
||||
if splitter_type == SplitterType.RECURSIVE:
|
||||
config = RecursiveSplitterConfig(
|
||||
chunk_size=kwargs.get("chunk_size", 500),
|
||||
chunk_overlap=kwargs.get("chunk_overlap", 50),
|
||||
separators=kwargs.get("separators", ["\n\n", "\n", "。", "!", "?", " ", ""]),
|
||||
)
|
||||
return RecursiveCharacterTextSplitter(
|
||||
chunk_size=config.chunk_size,
|
||||
chunk_overlap=config.chunk_overlap,
|
||||
separators=config.separators,
|
||||
keep_separator=config.keep_separator,
|
||||
strip_whitespace=config.strip_whitespace,
|
||||
)
|
||||
|
||||
elif splitter_type == SplitterType.SEMANTIC:
|
||||
embeddings = kwargs.get("embeddings")
|
||||
if embeddings is None:
|
||||
raise ValueError("语义切分器需要提供 'embeddings' 参数")
|
||||
|
||||
if "config" in kwargs and isinstance(kwargs["config"], SemanticSplitterConfig):
|
||||
config = kwargs["config"]
|
||||
else:
|
||||
# 过滤出 SemanticSplitterConfig 支持的字段
|
||||
config_kwargs = {
|
||||
"embeddings": embeddings,
|
||||
"buffer_size": kwargs.get("buffer_size", 1),
|
||||
"breakpoint_threshold_type": kwargs.get("breakpoint_threshold_type", "percentile"),
|
||||
"breakpoint_threshold_amount": kwargs.get("breakpoint_threshold_amount"),
|
||||
"number_of_chunks": kwargs.get("number_of_chunks"),
|
||||
"min_chunk_size": kwargs.get("min_chunk_size", 100),
|
||||
}
|
||||
config = SemanticSplitterConfig(**config_kwargs)
|
||||
return SemanticChunkerAdapter(config)
|
||||
|
||||
elif splitter_type == SplitterType.PARENT_CHILD:
|
||||
# 父子切分器在 builder 中单独处理,不通过本函数创建
|
||||
raise ValueError("父子切分器应通过 IndexBuilder 创建,不支持 get_splitter 直接构建")
|
||||
|
||||
else:
|
||||
raise ValueError(f"不支持的切分器类型: {splitter_type}")
|
||||
|
||||
# ---------- 父子切分器实现 ----------
|
||||
class ParentChildSplitter:
|
||||
"""
|
||||
将文档切分为父块(大块)和子块(小块)。
|
||||
子块用于索引检索,父块用于存储上下文。
|
||||
将文档切分为父块(大块,用于上下文)和子块(小块,用于索引检索)。
|
||||
内部维护父子块之间的映射关系。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
parent_chunk_size: int = 1000,
|
||||
child_chunk_size: int = 200,
|
||||
parent_chunk_overlap: int = 100,
|
||||
child_chunk_overlap: int = 20,
|
||||
):
|
||||
def __init__(self, config: ParentChildSplitterConfig):
|
||||
self.config = config
|
||||
# 父块使用递归字符切分
|
||||
self.parent_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=parent_chunk_size,
|
||||
chunk_overlap=parent_chunk_overlap,
|
||||
chunk_size=config.parent_chunk_size,
|
||||
chunk_overlap=config.parent_chunk_overlap,
|
||||
)
|
||||
self.child_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=child_chunk_size,
|
||||
chunk_overlap=child_chunk_overlap,
|
||||
# 子块使用语义切分
|
||||
semantic_config = SemanticSplitterConfig(
|
||||
embeddings=config.embeddings,
|
||||
buffer_size=config.child_buffer_size,
|
||||
breakpoint_threshold_type=config.child_breakpoint_threshold_type,
|
||||
breakpoint_threshold_amount=config.child_breakpoint_threshold_amount,
|
||||
min_chunk_size=config.child_min_chunk_size,
|
||||
)
|
||||
self.child_splitter = SemanticChunkerAdapter(semantic_config)
|
||||
|
||||
def split_documents(self, documents: List[Document]) -> tuple[List[Document], List[Document]]:
|
||||
# 存储父子块映射关系(可选)
|
||||
self.parent_to_children: Dict[str, List[str]] = {}
|
||||
self.child_to_parent: Dict[str, str] = {}
|
||||
|
||||
def split_documents(self, documents: List[Document]) -> Tuple[List[Document], List[Document]]:
|
||||
"""
|
||||
返回:
|
||||
(父块列表, 子块列表)
|
||||
同时填充内部映射字典。
|
||||
"""
|
||||
parent_chunks = self.parent_splitter.split_documents(documents)
|
||||
child_chunks = self.child_splitter.split_documents(documents)
|
||||
|
||||
# 将子块与父块 ID 关联(可选元数据)
|
||||
# 在实际实现中,需要将每个子块映射到对应的父块 ID。
|
||||
return parent_chunks, child_chunks
|
||||
# 建立映射关系(简化示例:根据文本包含关系粗略匹配,实际需更精确的算法)
|
||||
# 这里仅作示意,生产环境建议使用 embedding 相似度或精确子串定位
|
||||
self._build_mappings(parent_chunks, child_chunks)
|
||||
|
||||
return parent_chunks, child_chunks
|
||||
|
||||
def _build_mappings(self, parents: List[Document], children: List[Document]) -> None:
|
||||
"""
|
||||
根据文本内容建立父子映射。
|
||||
本方法为简化实现,实际使用时请替换为更可靠的匹配逻辑。
|
||||
"""
|
||||
self.parent_to_children.clear()
|
||||
self.child_to_parent.clear()
|
||||
|
||||
# 为每个父块生成唯一 ID(若无则使用索引)
|
||||
for p_idx, parent in enumerate(parents):
|
||||
parent_id = parent.metadata.get("id", f"parent_{p_idx}")
|
||||
parent.metadata["id"] = parent_id
|
||||
self.parent_to_children[parent_id] = []
|
||||
|
||||
# 将每个子块分配给包含其文本的第一个父块
|
||||
for c_idx, child in enumerate(children):
|
||||
child_id = child.metadata.get("id", f"child_{c_idx}")
|
||||
child.metadata["id"] = child_id
|
||||
for parent in parents:
|
||||
if child.page_content in parent.page_content:
|
||||
parent_id = parent.metadata["id"]
|
||||
self.parent_to_children[parent_id].append(child_id)
|
||||
self.child_to_parent[child_id] = parent_id
|
||||
child.metadata["parent_id"] = parent_id
|
||||
break
|
||||
|
||||
def get_parent_for_child(self, child_id: str) -> Optional[str]:
|
||||
"""根据子块 ID 获取父块 ID"""
|
||||
return self.child_to_parent.get(child_id)
|
||||
|
||||
def get_children_for_parent(self, parent_id: str) -> List[str]:
|
||||
"""根据父块 ID 获取所有子块 ID"""
|
||||
return self.parent_to_children.get(parent_id, [])
|
||||
@@ -1,31 +0,0 @@
|
||||
"""
|
||||
文档存储模块 - 用于 ParentDocumentRetriever 的父文档存储。
|
||||
|
||||
提供 PostgreSQL 存储后端:
|
||||
- PostgresDocStore: PostgreSQL 数据库存储(生产环境)
|
||||
|
||||
示例用法:
|
||||
>>> from rag_indexer.store import create_docstore
|
||||
|
||||
>>> # 创建 PostgreSQL 存储
|
||||
>>> store, conn = create_docstore(
|
||||
... connection_string="postgresql://user:pass@host:5432/db",
|
||||
... table_name="parent_docs"
|
||||
... )
|
||||
"""
|
||||
|
||||
|
||||
from .postgres import PostgresDocStore
|
||||
from .factory import create_docstore, get_docstore_uri, DEFAULT_DB_URI
|
||||
|
||||
__version__ = "2.0.0"
|
||||
|
||||
__all__ = [
|
||||
# 具体实现
|
||||
"PostgresDocStore",
|
||||
|
||||
# 工厂函数
|
||||
"create_docstore",
|
||||
"get_docstore_uri",
|
||||
"DEFAULT_DB_URI",
|
||||
]
|
||||
@@ -1,73 +0,0 @@
|
||||
"""
|
||||
文档存储工厂 - 创建不同类型的存储实例。
|
||||
|
||||
提供统一的接口来创建本地文件存储或 PostgreSQL 存储。
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from langchain_core.stores import BaseStore
|
||||
from .postgres import PostgresDocStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 默认连接字符串(从环境变量读取)
|
||||
DEFAULT_DB_URI = os.getenv(
|
||||
"DB_URI",
|
||||
"postgresql://postgres:huang1998@ai-postgres:5432/langgraph_db?sslmode=disable"
|
||||
)
|
||||
|
||||
|
||||
def get_docstore_uri() -> str:
|
||||
"""获取 docstore 专用的数据库连接字符串(可与主库相同)"""
|
||||
return os.getenv("DOCSTORE_URI", DEFAULT_DB_URI)
|
||||
|
||||
|
||||
def create_docstore(
|
||||
store_type: str = "postgres",
|
||||
connection_string: Optional[str] = None,
|
||||
table_name: str = "parent_documents",
|
||||
pool_config: Optional[dict] = None,
|
||||
max_concurrency: Optional[int] = None
|
||||
) -> Tuple[BaseStore, Optional[str]]:
|
||||
"""
|
||||
工厂函数,创建 PostgreSQL 文档存储。
|
||||
|
||||
Args:
|
||||
store_type: 存储类型,目前仅支持 "postgres"(默认)
|
||||
connection_string: PostgreSQL 连接字符串
|
||||
table_name: PostgreSQL 表名(默认:parent_documents)
|
||||
pool_config: 连接池配置
|
||||
max_concurrency: 最大并发操作数,如果为 None 则不限制
|
||||
|
||||
Returns:
|
||||
元组 (存储实例, 连接字符串)
|
||||
|
||||
Raises:
|
||||
ValueError: 不支持的存储类型
|
||||
ImportError: 缺少必要的依赖
|
||||
|
||||
Example:
|
||||
>>> # 创建 PostgreSQL 存储
|
||||
>>> store, conn = create_docstore(
|
||||
... connection_string="postgresql://user:pass@host:5432/db",
|
||||
... table_name="parent_docs",
|
||||
... max_concurrency=10
|
||||
... )
|
||||
"""
|
||||
store_type = store_type.lower()
|
||||
|
||||
if store_type == "postgres":
|
||||
conn_str = connection_string or get_docstore_uri()
|
||||
store = PostgresDocStore(
|
||||
connection_string=conn_str,
|
||||
table_name=table_name,
|
||||
pool_config=pool_config,
|
||||
max_concurrency=max_concurrency
|
||||
)
|
||||
return store, conn_str
|
||||
|
||||
else:
|
||||
raise ValueError(f"不支持的存储类型: {store_type}。目前仅支持: postgres")
|
||||
@@ -1,249 +0,0 @@
|
||||
"""
|
||||
异步 PostgreSQL 存储实现 - 用于生产环境。
|
||||
|
||||
使用 asyncpg 实现真正的异步 PostgreSQL 文档存储,支持高并发访问。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional, Iterator, Tuple, Sequence, cast
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.stores import BaseStore
|
||||
|
||||
import asyncpg
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PostgresDocStore(BaseStore[str, Any]):
|
||||
"""
|
||||
异步 PostgreSQL 文档存储实现。
|
||||
|
||||
使用 asyncpg 作为异步 PostgreSQL 客户端,支持:
|
||||
- 真正的异步操作
|
||||
- 连接池管理
|
||||
- 自动表创建
|
||||
- 批量操作(amget/amset/amdelete)
|
||||
- JSONB 数据存储
|
||||
- 并发控制
|
||||
|
||||
适用于生产环境,提供高性能的异步数据持久化。
|
||||
|
||||
Attributes:
|
||||
dsn: PostgreSQL 连接字符串
|
||||
table_name: 存储表名,默认为 "parent_documents"
|
||||
_pool: asyncpg 连接池实例
|
||||
_semaphore: 控制并发数的信号量(可选)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection_string: str,
|
||||
table_name: str = "parent_documents",
|
||||
pool_config: Optional[Dict[str, Any]] = None,
|
||||
max_concurrency: Optional[int] = None
|
||||
):
|
||||
"""
|
||||
初始化异步 PostgreSQL 文档存储。
|
||||
|
||||
Args:
|
||||
connection_string: PostgreSQL 连接 URL,格式:
|
||||
"postgresql://user:password@host:port/database?sslmode=disable"
|
||||
table_name: 存储表名,默认为 "parent_documents"
|
||||
pool_config: 连接池配置字典,包含:
|
||||
- min_size: 最小连接数(默认 2)
|
||||
- max_size: 最大连接数(默认 10)
|
||||
max_concurrency: 最大并发操作数,如果为 None 则不限制
|
||||
|
||||
Raises:
|
||||
ImportError: 未安装 asyncpg 时抛出
|
||||
|
||||
Example:
|
||||
>>> store = PostgresDocStore(
|
||||
... "postgresql://user:pass@localhost:5432/mydb",
|
||||
... table_name="parent_docs",
|
||||
... pool_config={"min_size": 5, "max_size": 20},
|
||||
... max_concurrency=10
|
||||
... )
|
||||
"""
|
||||
|
||||
|
||||
self.dsn = connection_string
|
||||
self.table_name = table_name
|
||||
self._pool: Optional["asyncpg.Pool"] = None
|
||||
self._pool_config = pool_config or {}
|
||||
|
||||
# 并发控制信号量
|
||||
self._semaphore = None
|
||||
if max_concurrency is not None and max_concurrency > 0:
|
||||
self._semaphore = asyncio.Semaphore(max_concurrency)
|
||||
|
||||
# 注意:连接池的异步初始化延迟到第一次使用时
|
||||
# 表结构创建也延迟到第一次操作时
|
||||
|
||||
async def _get_pool(self):
|
||||
"""获取或创建 asyncpg 连接池。"""
|
||||
if self._pool is None:
|
||||
import asyncpg
|
||||
min_size = self._pool_config.get("min_size", 2)
|
||||
max_size = self._pool_config.get("max_size", 10)
|
||||
|
||||
try:
|
||||
self._pool = await asyncpg.create_pool(
|
||||
dsn=self.dsn,
|
||||
min_size=min_size,
|
||||
max_size=max_size
|
||||
)
|
||||
logger.info(f"PostgreSQL 异步连接池已创建: {self.table_name}")
|
||||
|
||||
# 初始化表结构
|
||||
await self._create_table()
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"PostgreSQL 异步连接池创建失败: {e}")
|
||||
|
||||
return self._pool
|
||||
|
||||
async def _create_table(self):
|
||||
"""创建存储表(如果不存在)。"""
|
||||
pool = await self._get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
async with conn.transaction():
|
||||
await conn.execute(f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.table_name} (
|
||||
key TEXT PRIMARY KEY,
|
||||
value JSONB NOT NULL,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW()
|
||||
)
|
||||
""")
|
||||
logger.info(f"表 {self.table_name} 已就绪")
|
||||
|
||||
async def _with_concurrency_control(self, coro):
|
||||
"""使用信号量控制并发执行。"""
|
||||
if self._semaphore is None:
|
||||
return await coro
|
||||
async with self._semaphore:
|
||||
return await coro
|
||||
|
||||
# --- 同步方法(保持兼容性,但功能有限)---
|
||||
|
||||
def mget(self, keys: Sequence[str]) -> List[Optional[Any]]:
|
||||
"""不支持同步操作,请使用异步 amget 方法。"""
|
||||
raise NotImplementedError("不支持同步操作,请使用异步 amget 方法。")
|
||||
|
||||
def mset(self, key_value_pairs: Sequence[Tuple[str, Any]]) -> None:
|
||||
"""不支持同步操作,请使用异步 amset 方法。"""
|
||||
raise NotImplementedError("不支持同步操作,请使用异步 amset 方法。")
|
||||
|
||||
def mdelete(self, keys: Sequence[str]) -> None:
|
||||
"""不支持同步操作,请使用异步 amdelete 方法。"""
|
||||
raise NotImplementedError("不支持同步操作,请使用异步 amdelete 方法。")
|
||||
|
||||
def yield_keys(self, *, prefix: str | None = None) -> Iterator[str]:
|
||||
"""不支持同步操作,请使用异步 ayield_keys 方法。"""
|
||||
raise NotImplementedError("不支持同步操作,请使用异步 ayield_keys 方法。")
|
||||
|
||||
# --- 异步方法(真正的实现)---
|
||||
|
||||
async def amget(self, keys: Sequence[str]) -> List[Optional[Any]]:
|
||||
"""异步批量获取文档。"""
|
||||
if not keys:
|
||||
return []
|
||||
|
||||
async def _amget():
|
||||
pool = await self._get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
rows = await conn.fetch(
|
||||
f"SELECT key, value FROM {self.table_name} WHERE key = ANY($1)",
|
||||
keys
|
||||
)
|
||||
result_map = {}
|
||||
for row in rows:
|
||||
val = row['value']
|
||||
if isinstance(val, str):
|
||||
val = json.loads(val)
|
||||
if isinstance(val, dict) and 'page_content' in val:
|
||||
result_map[row['key']] = Document(**val)
|
||||
else:
|
||||
result_map[row['key']] = val
|
||||
return [result_map.get(key) for key in keys]
|
||||
|
||||
return await self._with_concurrency_control(_amget())
|
||||
|
||||
async def amset(self, key_value_pairs: Sequence[Tuple[str, Any]]) -> None:
|
||||
"""异步批量设置文档。"""
|
||||
if not key_value_pairs:
|
||||
return
|
||||
|
||||
async def _amset():
|
||||
pool = await self._get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
async with conn.transaction():
|
||||
await conn.executemany(
|
||||
f"""
|
||||
INSERT INTO {self.table_name} (key, value)
|
||||
VALUES ($1, $2)
|
||||
ON CONFLICT (key) DO UPDATE SET value = EXCLUDED.value
|
||||
""",
|
||||
[
|
||||
(k, json.dumps(v.dict() if isinstance(v, Document) else v, ensure_ascii=False))
|
||||
for k, v in key_value_pairs
|
||||
]
|
||||
)
|
||||
logger.debug(f"已异步批量设置 {len(key_value_pairs)} 个文档")
|
||||
|
||||
await self._with_concurrency_control(_amset())
|
||||
|
||||
async def amdelete(self, keys: Sequence[str]) -> None:
|
||||
"""异步批量删除文档。"""
|
||||
if not keys:
|
||||
return
|
||||
|
||||
async def _amdelete():
|
||||
pool = await self._get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
async with conn.transaction():
|
||||
await conn.execute(
|
||||
f"DELETE FROM {self.table_name} WHERE key = ANY($1)",
|
||||
keys
|
||||
)
|
||||
logger.debug(f"已异步批量删除 {len(keys)} 个文档")
|
||||
|
||||
await self._with_concurrency_control(_amdelete())
|
||||
|
||||
async def ayield_keys(self, *, prefix: str | None = None) -> Iterator[str]:
|
||||
"""异步迭代所有键。
|
||||
|
||||
注意:这是一个异步生成器,需要使用 async for 迭代。
|
||||
"""
|
||||
pool = await self._get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
if prefix:
|
||||
rows = await conn.fetch(
|
||||
f"SELECT key FROM {self.table_name} WHERE key LIKE $1 ORDER BY key",
|
||||
f"{prefix}%"
|
||||
)
|
||||
else:
|
||||
rows = await conn.fetch(
|
||||
f"SELECT key FROM {self.table_name} ORDER BY key"
|
||||
)
|
||||
|
||||
for row in rows:
|
||||
yield row['key']
|
||||
|
||||
async def aclose(self) -> None:
|
||||
"""异步关闭连接池,释放资源。"""
|
||||
if self._pool:
|
||||
await self._pool.close()
|
||||
self._pool = None
|
||||
logger.info("PostgreSQL 异步连接池已关闭")
|
||||
|
||||
def close(self) -> None:
|
||||
"""同步关闭连接池(功能有限)。
|
||||
|
||||
注意:在异步环境中,请使用 aclose 方法。
|
||||
"""
|
||||
pass
|
||||
80
rag_indexer/test/reset_index.py
Normal file
80
rag_indexer/test/reset_index.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""清理 RAG 索引数据。
|
||||
|
||||
用法:
|
||||
python reset_index.py # 清理全部
|
||||
python reset_index.py --qdrant # 仅清理 Qdrant
|
||||
python reset_index.py --postgres # 仅清理 PostgreSQL
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import argparse
|
||||
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
|
||||
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
|
||||
DB_URI = os.getenv("DB_URI", "postgresql://postgres:huang1998@115.190.121.151:5432/langgraph_db?sslmode=disable")
|
||||
COLLECTION_NAME = "rag_documents"
|
||||
TABLE_NAME = "parent_documents"
|
||||
|
||||
|
||||
def clear_qdrant():
|
||||
"""删除 Qdrant 集合。"""
|
||||
from qdrant_client import QdrantClient
|
||||
|
||||
print("清理 Qdrant...")
|
||||
client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
|
||||
|
||||
collections = client.get_collections().collections
|
||||
if any(c.name == COLLECTION_NAME for c in collections):
|
||||
client.delete_collection(COLLECTION_NAME)
|
||||
print(f" 集合 '{COLLECTION_NAME}' 已删除")
|
||||
else:
|
||||
print(f" 集合 '{COLLECTION_NAME}' 不存在")
|
||||
|
||||
|
||||
async def clear_postgres():
|
||||
"""清空 PostgreSQL 表数据。"""
|
||||
import asyncpg
|
||||
|
||||
print("清理 PostgreSQL...")
|
||||
conn = await asyncpg.connect(dsn=DB_URI)
|
||||
|
||||
try:
|
||||
exists = await conn.fetchval(
|
||||
"SELECT EXISTS(SELECT 1 FROM information_schema.tables WHERE table_schema = 'public' AND table_name = $1)",
|
||||
TABLE_NAME
|
||||
)
|
||||
if exists:
|
||||
count = await conn.fetchval(f"SELECT COUNT(*) FROM {TABLE_NAME}")
|
||||
await conn.execute(f"DELETE FROM {TABLE_NAME}")
|
||||
print(f" 表 '{TABLE_NAME}' 已清空,删除 {count} 条记录")
|
||||
else:
|
||||
print(f" 表 '{TABLE_NAME}' 不存在")
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
|
||||
async def main():
|
||||
parser = argparse.ArgumentParser(description="清理 RAG 索引数据")
|
||||
parser.add_argument("--qdrant", action="store_true", help="仅清理 Qdrant")
|
||||
parser.add_argument("--postgres", action="store_true", help="仅清理 PostgreSQL")
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.qdrant and not args.postgres:
|
||||
args.qdrant = True
|
||||
args.postgres = True
|
||||
|
||||
if args.qdrant:
|
||||
clear_qdrant()
|
||||
|
||||
if args.postgres:
|
||||
await clear_postgres()
|
||||
|
||||
print("\n完成。运行 `python -m rag_indexer.cli` 重建索引")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
63
rag_indexer/test/test_inspect_vectors.py
Normal file
63
rag_indexer/test/test_inspect_vectors.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""检查 Qdrant 中存储的向量质量。"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
from dotenv import load_dotenv
|
||||
from qdrant_client import QdrantClient
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../.."))
|
||||
from rag_core import LlamaCppEmbedder
|
||||
|
||||
load_dotenv()
|
||||
|
||||
QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
|
||||
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
|
||||
COLLECTION_NAME = "rag_documents"
|
||||
|
||||
client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
|
||||
embedder = LlamaCppEmbedder()
|
||||
|
||||
# 获取样本
|
||||
points, _ = client.scroll(
|
||||
collection_name=COLLECTION_NAME,
|
||||
limit=1,
|
||||
with_vectors=True,
|
||||
with_payload=True,
|
||||
)
|
||||
|
||||
if not points:
|
||||
print(f"集合 '{COLLECTION_NAME}' 为空")
|
||||
exit()
|
||||
|
||||
sample = points[0]
|
||||
raw_vec = sample.vector
|
||||
if isinstance(raw_vec, dict):
|
||||
stored_vec = list(raw_vec.values())[0]
|
||||
elif isinstance(raw_vec, list):
|
||||
stored_vec = raw_vec
|
||||
else:
|
||||
stored_vec = []
|
||||
|
||||
stored_payload = sample.payload or {}
|
||||
stored_text = str(stored_payload.get("page_content", ""))[:200]
|
||||
|
||||
print(f"内容预览:\n{stored_text}...\n")
|
||||
print(f"向量维度: {len(stored_vec)}") # type: ignore
|
||||
print(f"前5个值: {stored_vec[:5]}") # type: ignore
|
||||
print(f"是否全零: {all(v == 0.0 for v in stored_vec)}") # type: ignore
|
||||
|
||||
# 重新编码对比
|
||||
if stored_text:
|
||||
new_vec = embedder.embed_query(stored_text)
|
||||
similarity = np.dot(stored_vec, new_vec) / (np.linalg.norm(stored_vec) * np.linalg.norm(new_vec)) # type: ignore
|
||||
print(f"\n重新编码前5个值: {new_vec[:5]}")
|
||||
print(f"余弦相似度: {similarity:.4f}")
|
||||
|
||||
if similarity < 0.8:
|
||||
print("\n⚠️ 相似度过低,建议删除集合并重建索引")
|
||||
else:
|
||||
print("\n✅ 向量一致")
|
||||
else:
|
||||
print("\n⚠️ 样本无文本内容")
|
||||
83
rag_indexer/test/test_refactored.py
Normal file
83
rag_indexer/test/test_refactored.py
Normal file
@@ -0,0 +1,83 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
测试重构后的 IndexBuilder 和 RAGRetriever
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
# 添加项目根目录到 Python 路径
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
|
||||
from rag_indexer.IndexBuilder import IndexBuilder
|
||||
from rag_indexer.splitters import SplitterType
|
||||
|
||||
async def test_index_builder():
|
||||
"""测试索引构建功能"""
|
||||
print("测试索引构建功能...")
|
||||
|
||||
# 创建 IndexBuilder 实例
|
||||
builder = IndexBuilder(
|
||||
collection_name="test_collection",
|
||||
splitter_type=SplitterType.PARENT_CHILD,
|
||||
parent_chunk_size=1000,
|
||||
child_chunk_size=200
|
||||
)
|
||||
|
||||
# 测试文档路径
|
||||
test_file = os.path.join(os.path.dirname(__file__), "..", "data", "corpus", "三国演义.txt")
|
||||
|
||||
if os.path.exists(test_file):
|
||||
# 构建索引
|
||||
print(f"正在为文件 {test_file} 构建索引...")
|
||||
processed = await builder.build_from_file(test_file)
|
||||
print(f"索引构建完成,处理了 {processed} 个文档")
|
||||
|
||||
# 获取集合信息
|
||||
info = builder.get_collection_info()
|
||||
print(f"集合信息: {info}")
|
||||
else:
|
||||
print(f"测试文件不存在: {test_file}")
|
||||
|
||||
# 测试搜索功能
|
||||
print("\n测试搜索功能...")
|
||||
try:
|
||||
results = builder.search("吕布", k=3)
|
||||
print(f"搜索结果数量: {len(results)}")
|
||||
for i, result in enumerate(results):
|
||||
print(f"\n结果 {i+1}:")
|
||||
print(f"内容: {result.page_content[:100]}...")
|
||||
except Exception as e:
|
||||
print(f"搜索测试失败: {e}")
|
||||
|
||||
# 测试带父块上下文的搜索
|
||||
print("\n测试带父块上下文的搜索...")
|
||||
try:
|
||||
results = await builder.search_with_parent_context("吕布", k=3)
|
||||
print(f"搜索结果数量: {len(results)}")
|
||||
for i, result in enumerate(results):
|
||||
print(f"\n结果 {i+1}:")
|
||||
print(f"内容: {result.page_content[:100]}...")
|
||||
except Exception as e:
|
||||
print(f"带父块上下文的搜索测试失败: {e}")
|
||||
|
||||
# 测试统一检索接口
|
||||
print("\n测试统一检索接口...")
|
||||
try:
|
||||
# 返回父块
|
||||
results_parent = await builder.retrieve("吕布", return_parent=True)
|
||||
print(f"返回父块的结果数量: {len(results_parent)}")
|
||||
|
||||
# 返回子块
|
||||
results_child = await builder.retrieve("吕布", return_parent=False)
|
||||
print(f"返回子块的结果数量: {len(results_child)}")
|
||||
except Exception as e:
|
||||
print(f"统一检索接口测试失败: {e}")
|
||||
|
||||
# 关闭资源
|
||||
builder.close()
|
||||
print("\n测试完成")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_index_builder())
|
||||
188
rag_indexer/test/test_validate_index.py
Normal file
188
rag_indexer/test/test_validate_index.py
Normal file
@@ -0,0 +1,188 @@
|
||||
"""
|
||||
验证 RAG 索引完整性。
|
||||
|
||||
检查 Qdrant 向量库、PostgreSQL 文档存储及检索功能。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../.."))
|
||||
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
|
||||
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
|
||||
DB_URI = os.getenv("DB_URI", "postgresql://postgres:huang1998@115.190.121.151:5432/langgraph_db?sslmode=disable")
|
||||
COLLECTION_NAME = "rag_documents"
|
||||
TABLE_NAME = "parent_documents"
|
||||
|
||||
|
||||
def check_qdrant():
|
||||
"""检查 Qdrant 向量库。"""
|
||||
from qdrant_client import QdrantClient
|
||||
|
||||
print("=" * 60)
|
||||
print("Qdrant 向量库")
|
||||
print("=" * 60)
|
||||
|
||||
client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
|
||||
|
||||
# 集合列表
|
||||
collections = client.get_collections().collections
|
||||
print(f"\n集合数: {len(collections)}")
|
||||
for c in collections:
|
||||
print(f" - {c.name}")
|
||||
|
||||
# 目标集合信息
|
||||
if not any(c.name == COLLECTION_NAME for c in collections):
|
||||
print(f"\n集合 '{COLLECTION_NAME}' 不存在")
|
||||
return
|
||||
|
||||
info = client.get_collection(COLLECTION_NAME)
|
||||
print(f"\n集合 '{COLLECTION_NAME}':")
|
||||
print(f" 状态: {info.status}")
|
||||
print(f" 向量数: {info.points_count}")
|
||||
|
||||
vectors_config = info.config.params.vectors
|
||||
if isinstance(vectors_config, dict):
|
||||
for name, vc in vectors_config.items():
|
||||
print(f" 向量 '{name}': 维度={vc.size}, 距离={vc.distance}")
|
||||
else:
|
||||
print(f" 向量维度: {vectors_config.size}")
|
||||
|
||||
# 抽样查看
|
||||
print(f"\n前 3 个向量:")
|
||||
points = client.scroll(
|
||||
collection_name=COLLECTION_NAME,
|
||||
limit=3,
|
||||
with_payload=True,
|
||||
with_vectors=False
|
||||
)
|
||||
for i, point in enumerate(points[0]):
|
||||
print(f"\n {i+1}. ID: {point.id}")
|
||||
payload = point.payload or {}
|
||||
print(f" 内容: {payload.get('page_content', '')[:100]}...")
|
||||
|
||||
|
||||
async def check_postgres():
|
||||
"""检查 PostgreSQL 文档存储。"""
|
||||
import asyncpg
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("PostgreSQL 文档存储")
|
||||
print("=" * 60)
|
||||
|
||||
conn = await asyncpg.connect(dsn=DB_URI)
|
||||
|
||||
try:
|
||||
# 表是否存在
|
||||
tables = await conn.fetch(
|
||||
"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'"
|
||||
)
|
||||
table_names = [t['table_name'] for t in tables]
|
||||
|
||||
if TABLE_NAME not in table_names:
|
||||
print(f"\n表 '{TABLE_NAME}' 不存在")
|
||||
return
|
||||
|
||||
# 统计
|
||||
count = await conn.fetchval(f"SELECT COUNT(*) FROM {TABLE_NAME}")
|
||||
print(f"\n表 '{TABLE_NAME}': {count} 条记录")
|
||||
|
||||
# 抽样
|
||||
print(f"\n前 3 个文档:")
|
||||
rows = await conn.fetch(
|
||||
f"SELECT key, value FROM {TABLE_NAME} ORDER BY key LIMIT 3"
|
||||
)
|
||||
for i, row in enumerate(rows):
|
||||
print(f"\n {i+1}. Key: {row['key']}")
|
||||
val = row['value']
|
||||
if isinstance(val, dict) and 'page_content' in val:
|
||||
print(f" 内容: {val['page_content'][:100]}...")
|
||||
|
||||
# Key 前缀分布
|
||||
key_prefixes = await conn.fetch(
|
||||
f"""
|
||||
SELECT
|
||||
CASE
|
||||
WHEN key LIKE '%:%' THEN split_part(key, ':', 1)
|
||||
ELSE 'no_prefix'
|
||||
END AS prefix,
|
||||
COUNT(*) AS cnt
|
||||
FROM {TABLE_NAME}
|
||||
GROUP BY prefix
|
||||
ORDER BY cnt DESC
|
||||
LIMIT 10
|
||||
"""
|
||||
)
|
||||
print(f"\nKey 前缀分布:")
|
||||
for row in key_prefixes:
|
||||
print(f" {row['prefix']}: {row['cnt']}")
|
||||
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
|
||||
async def test_search():
|
||||
"""测试检索功能。"""
|
||||
from rag_indexer.IndexBuilder import IndexBuilder, IndexBuilderConfig
|
||||
from rag_indexer.splitters import SplitterType
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("检索测试")
|
||||
print("=" * 60)
|
||||
|
||||
# 使用配置对象初始化(与默认构建方式一致)
|
||||
config = IndexBuilderConfig(
|
||||
collection_name=COLLECTION_NAME,
|
||||
splitter_type=SplitterType.PARENT_CHILD,
|
||||
)
|
||||
builder = IndexBuilder(config)
|
||||
|
||||
# 确保检索器已初始化
|
||||
if builder.retriever is None:
|
||||
print("错误: 检索器未初始化,请检查切分策略")
|
||||
return
|
||||
|
||||
query = input("\n查询 (回车使用默认): ").strip() or "你好"
|
||||
print(f"\n查询: {query}")
|
||||
|
||||
# 标准检索(返回父块,因为 ParentDocumentRetriever 默认返回父块)
|
||||
print("\n--- 标准检索 (返回父块) ---")
|
||||
results = await builder.retriever.ainvoke(query)
|
||||
for i, doc in enumerate(results):
|
||||
content = doc.page_content[:200] if hasattr(doc, 'page_content') else str(doc)[:200]
|
||||
print(f"\n {i+1}. {content}...")
|
||||
if hasattr(doc, 'metadata'):
|
||||
source = doc.metadata.get('source', '')
|
||||
if source:
|
||||
print(f" 来源: {source}")
|
||||
|
||||
# 若需要仅返回子块,可以临时修改检索器的 search_type
|
||||
# (注意:ParentDocumentRetriever 的 search_type 默认为 "similarity")
|
||||
print("\n--- 检索子块 (通过修改检索器参数) ---")
|
||||
# 创建一个新的检索器副本,设置为返回子块
|
||||
# 简单起见,直接调用 vectorstore 进行相似度搜索获取子块
|
||||
vectorstore = builder.vector_store.get_langchain_vectorstore()
|
||||
sub_results = await vectorstore.asimilarity_search(query, k=3)
|
||||
for i, doc in enumerate(sub_results):
|
||||
content = doc.page_content[:200] if hasattr(doc, 'page_content') else str(doc)[:200]
|
||||
print(f"\n {i+1}. {content}...")
|
||||
if hasattr(doc, 'metadata'):
|
||||
parent_id = doc.metadata.get('parent_id', '')
|
||||
if parent_id:
|
||||
print(f" 父块 ID: {parent_id}")
|
||||
|
||||
|
||||
async def main():
|
||||
check_qdrant()
|
||||
await check_postgres()
|
||||
await test_search()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,133 +0,0 @@
|
||||
"""
|
||||
Qdrant 向量数据库包装器。
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Optional, Dict, Any
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_qdrant import QdrantVectorStore as LangchainQdrantVS
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.http import models
|
||||
from qdrant_client.http.models import Distance, VectorParams
|
||||
|
||||
from .embedders import LlamaCppEmbedder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
|
||||
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
|
||||
|
||||
|
||||
class QdrantVectorStore:
|
||||
"""Qdrant 向量数据库操作包装器。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str,
|
||||
embeddings: Optional[Any] = None,
|
||||
):
|
||||
self.collection_name = collection_name
|
||||
self._client: Optional[QdrantClient] = None
|
||||
|
||||
# 嵌入模型
|
||||
if embeddings is None:
|
||||
embedder = LlamaCppEmbedder()
|
||||
self.embeddings = embedder.as_langchain_embeddings()
|
||||
else:
|
||||
self.embeddings = embeddings
|
||||
|
||||
# 先创建集合
|
||||
self.create_collection()
|
||||
|
||||
# LangChain 向量存储
|
||||
self.vector_store = LangchainQdrantVS(
|
||||
client=self.get_client(),
|
||||
collection_name=self.collection_name,
|
||||
embedding=self.embeddings,
|
||||
)
|
||||
|
||||
def get_client(self) -> QdrantClient:
|
||||
"""懒加载客户端,每次获取时确保连接可用。"""
|
||||
if self._client is None:
|
||||
self._client = QdrantClient(
|
||||
url=QDRANT_URL,
|
||||
api_key=QDRANT_API_KEY,
|
||||
timeout=120,
|
||||
http2=False,
|
||||
)
|
||||
return self._client
|
||||
|
||||
def refresh_client(self):
|
||||
"""关闭旧连接,创建新连接。"""
|
||||
if self._client is not None:
|
||||
self._client.close()
|
||||
self._client = None
|
||||
|
||||
def create_collection(self, vector_size: Optional[int] = None, force_recreate: bool = False):
|
||||
"""创建集合,设置合适的向量维度。"""
|
||||
if vector_size is None:
|
||||
embedder = LlamaCppEmbedder()
|
||||
vector_size = embedder.get_embedding_dimension()
|
||||
|
||||
client = self.get_client()
|
||||
collections = client.get_collections().collections
|
||||
exists = any(c.name == self.collection_name for c in collections)
|
||||
|
||||
if exists and force_recreate:
|
||||
client.delete_collection(self.collection_name)
|
||||
exists = False
|
||||
|
||||
if not exists:
|
||||
client.create_collection(
|
||||
collection_name=self.collection_name,
|
||||
vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE),
|
||||
)
|
||||
logger.info("集合 '%s' 已创建(维度=%d)", self.collection_name, vector_size)
|
||||
else:
|
||||
logger.info("集合 '%s' 已存在", self.collection_name)
|
||||
|
||||
def add_documents(self, documents: List[Document], batch_size: int = 100):
|
||||
"""将文档添加到向量数据库。"""
|
||||
if not documents:
|
||||
return []
|
||||
self.create_collection()
|
||||
ids = self.vector_store.add_documents(documents, batch_size=batch_size)
|
||||
logger.info("已向 '%s' 添加 %d 个文档", self.collection_name, len(ids))
|
||||
return ids
|
||||
|
||||
def similarity_search(self, query: str, k: int = 5) -> List[Document]:
|
||||
return self.vector_store.similarity_search(query, k=k)
|
||||
|
||||
def similarity_search_with_score(self, query: str, k: int = 5) -> List[tuple[Document, float]]:
|
||||
return self.vector_store.similarity_search_with_score(query, k=k)
|
||||
|
||||
def delete_collection(self):
|
||||
self.get_client().delete_collection(self.collection_name)
|
||||
logger.info("集合 '%s' 已删除", self.collection_name)
|
||||
|
||||
def get_collection_info(self) -> Dict[str, Any]:
|
||||
info = self.get_client().get_collection(self.collection_name)
|
||||
vectors_config = info.config.params.vectors
|
||||
if isinstance(vectors_config, dict):
|
||||
vector_size = next(iter(vectors_config.values())).size
|
||||
else:
|
||||
vector_size = vectors_config.size
|
||||
return {
|
||||
"name": self.collection_name,
|
||||
"vectors_count": info.points_count or 0,
|
||||
"status": info.status,
|
||||
"vector_size": vector_size,
|
||||
}
|
||||
|
||||
def as_langchain_vectorstore(self):
|
||||
return self.vector_store
|
||||
|
||||
def get_langchain_vectorstore(self):
|
||||
"""返回 LangChain Qdrant 向量存储对象(别名)"""
|
||||
return self.vector_store
|
||||
|
||||
def get_qdrant_client(self):
|
||||
"""返回原生 Qdrant 客户端(如需手动管理 collection)"""
|
||||
return self.get_client()
|
||||
Reference in New Issue
Block a user