检索器重构
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 17m12s

This commit is contained in:
2026-04-19 22:01:55 +08:00
parent cc8ef41ef9
commit 933d418d77
26 changed files with 1694 additions and 1717 deletions

299
rag_indexer/IndexBuilder.py Normal file
View File

@@ -0,0 +1,299 @@
"""
离线 RAG 索引构建核心流水线。
使用 LangChain 的 ParentDocumentRetriever 实现父子块策略。
"""
import asyncio
import logging
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Union, Optional, Any, Dict, Tuple
from httpx import RemoteProtocolError
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.stores import BaseStore
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
from langchain_classic.retrievers import ParentDocumentRetriever
from .loaders import DocumentLoader
from .splitters import SplitterType, get_splitter, SemanticChunkerAdapter
from rag_core import LlamaCppEmbedder, QdrantVectorStore, create_docstore
logger = logging.getLogger(__name__)
# ---------- 配置数据类 ----------
@dataclass
class DocstoreConfig:
"""文档存储配置(用于父块存储)。"""
connection_string: Optional[str] = None
pool_config: Optional[Dict[str, Any]] = None
max_concurrency: Optional[int] = None
# 若要从外部注入已创建好的 docstore可直接设置此字段
instance: Optional[BaseStore] = None
@dataclass
class IndexBuilderConfig:
"""索引构建器配置。"""
collection_name: str = "rag_documents"
splitter_type: SplitterType = SplitterType.PARENT_CHILD
# 父块切分参数(仅当 splitter_type 为 PARENT_CHILD 时生效)
parent_chunk_size: int = 1000
parent_chunk_overlap: int = 100
# 子块切分参数
child_chunk_size: int = 200
child_chunk_overlap: int = 20
child_splitter_type: SplitterType = SplitterType.SEMANTIC # 子块默认语义切分
# 检索参数
search_k: int = 5
# 文档存储配置(仅父子块模式需要)
docstore: DocstoreConfig = field(default_factory=DocstoreConfig)
# 其他切分器参数(当 splitter_type 非父子块时使用)
extra_splitter_kwargs: Dict[str, Any] = field(default_factory=dict)
# ---------- 索引构建器 ----------
class IndexBuilder:
"""RAG 索引构建主流水线,支持单块切分与父子块切分。"""
def __init__(self, config: Optional[IndexBuilderConfig] = None, **kwargs):
"""
Args:
config: 索引构建器配置对象,优先级高于 kwargs
**kwargs: 可直接传入配置参数,会合并到 config 中(为方便使用保留)
"""
if config is None:
config = IndexBuilderConfig(**kwargs)
elif kwargs:
# 合并 kwargs 到 config 的字段(仅更新已有字段)
for key, value in kwargs.items():
if hasattr(config, key):
setattr(config, key, value)
self.config = config
self._docstore_conn: Optional[str] = None # 用于记录由 create_docstore 创建的连接信息
# 初始化基础组件
self.loader = DocumentLoader()
self.embedder = LlamaCppEmbedder()
self.embeddings: Embeddings = self.embedder.as_langchain_embeddings()
# 初始化向量存储
self.vector_store = QdrantVectorStore(
collection_name=config.collection_name,
embeddings=self.embeddings,
)
# 根据切分类型初始化相关组件
self._init_splitters_and_retriever()
# ---------- 私有初始化方法 ----------
def _init_splitters_and_retriever(self) -> None:
"""根据配置初始化切分器和检索器。"""
if self.config.splitter_type == SplitterType.PARENT_CHILD:
self._init_parent_child_mode()
else:
self._init_single_splitter_mode()
def _init_single_splitter_mode(self) -> None:
"""单一切分模式(递归或语义)。"""
splitter_kwargs = self.config.extra_splitter_kwargs.copy()
if self.config.splitter_type == SplitterType.SEMANTIC:
splitter_kwargs["embeddings"] = self.embeddings
self.splitter = get_splitter(self.config.splitter_type, **splitter_kwargs)
self.retriever = None
self.docstore = None
logger.info("使用单一 %s 切分器", self.config.splitter_type.value)
def _init_parent_child_mode(self) -> None:
"""父子块切分模式,初始化父块/子块切分器、文档存储和检索器。"""
cfg = self.config
# 父块切分器(始终使用递归切分)
self.parent_splitter = RecursiveCharacterTextSplitter(
chunk_size=cfg.parent_chunk_size,
chunk_overlap=cfg.parent_chunk_overlap,
)
# 子块切分器
if cfg.child_splitter_type == SplitterType.SEMANTIC:
self.child_splitter = get_splitter(
SplitterType.SEMANTIC,
embeddings=self.embeddings,
**cfg.extra_splitter_kwargs
)
logger.info("子块使用语义切分器")
else:
self.child_splitter = RecursiveCharacterTextSplitter(
chunk_size=cfg.child_chunk_size,
chunk_overlap=cfg.child_chunk_overlap,
)
logger.info("子块使用递归切分器,块大小=%d,重叠=%d",
cfg.child_chunk_size, cfg.child_chunk_overlap)
# 初始化文档存储(用于父块)
self.docstore = self._create_or_use_docstore()
# 创建检索器
self.retriever = ParentDocumentRetriever(
vectorstore=self.vector_store.get_langchain_vectorstore(),
docstore=self.docstore,
child_splitter=self.child_splitter, # type: ignore[arg-type]
parent_splitter=self.parent_splitter,
search_kwargs={"k": cfg.search_k},
)
logger.info("ParentDocumentRetriever 初始化完成,父块大小=%d", cfg.parent_chunk_size)
def _create_or_use_docstore(self) -> BaseStore:
"""创建或获取文档存储实例。"""
cfg = self.config.docstore
if cfg.instance is not None:
logger.debug("使用外部注入的文档存储")
return cfg.instance
# 使用 create_docstore 创建 PostgreSQL 存储
docstore, conn_info = create_docstore(
connection_string=cfg.connection_string,
pool_config=cfg.pool_config,
max_concurrency=cfg.max_concurrency,
)
self._docstore_conn = conn_info
logger.info("文档存储已创建PostgreSQL")
return docstore
# ---------- 公共构建方法 ----------
async def build_from_file(self, file_path: Union[str, Path]) -> int:
"""从单个文件构建索引。"""
logger.info("加载文件: %s", file_path)
documents = self.loader.load_file(file_path)
logger.info("已加载 %d 个文档", len(documents))
return await self._process_documents(documents)
async def build_from_directory(
self, directory_path: Union[str, Path], recursive: bool = True
) -> int:
"""从目录递归构建索引。"""
logger.info("加载目录: %s (递归=%s)", directory_path, recursive)
documents = self.loader.load_directory(directory_path, recursive=recursive)
logger.info("已从目录加载 %d 个文档", len(documents))
return await self._process_documents(documents)
async def _process_documents(self, documents: List[Document]) -> int:
"""处理文档列表,分发给相应的索引逻辑。"""
if not documents:
logger.warning("没有文档需要处理")
return 0
if self.config.splitter_type == SplitterType.PARENT_CHILD:
return await self._index_with_parent_child(documents)
else:
return await self._index_with_single_splitter(documents)
async def _index_with_single_splitter(self, documents: List[Document]) -> int:
"""单一模式:切分后直接写入向量库。"""
chunks = self.splitter.split_documents(documents) # type: ignore[union-attr]
logger.info("已切分为 %d 个块", len(chunks))
self.vector_store.create_collection()
self.vector_store.add_documents(chunks)
return len(chunks)
async def _index_with_parent_child(self, documents: List[Document]) -> int:
"""父子模式:使用 ParentDocumentRetriever 批量添加。"""
self.vector_store.create_collection()
assert self.retriever is not None
batch_size = 10
total = len(documents)
processed = 0
for i in range(0, total, batch_size):
batch = documents[i:i + batch_size]
await self._add_batch_with_retry(batch, i // batch_size + 1)
processed += len(batch)
logger.info("批次 %d: 已处理 %d/%d", i // batch_size + 1, processed, total)
logger.info("ParentDocumentRetriever 索引完成,共处理 %d 个文档", processed)
return processed
async def _add_batch_with_retry(self, batch: List[Document], batch_no: int) -> None:
"""添加批次,失败时自动重试(处理网络波动)。"""
max_retries = 3
for attempt in range(max_retries):
try:
await self.retriever.aadd_documents(batch) # type: ignore[union-attr]
return
except (RemoteProtocolError, ConnectionError, OSError) as e:
if attempt == max_retries - 1:
raise
logger.warning("批次 %d 连接断开,重试 (%d/%d): %s",
batch_no, attempt + 1, max_retries, e)
self.vector_store.refresh_client()
await asyncio.sleep(1)
# ---------- 信息获取方法 ----------
def get_collection_info(self) -> Any:
"""获取向量库集合信息。"""
return self.vector_store.get_collection_info()
def get_child_splitter(self) -> TextSplitter:
"""获取当前使用的子块切分器。"""
if self.config.splitter_type == SplitterType.PARENT_CHILD:
return self.child_splitter # type: ignore[return-value]
return self.splitter # type: ignore[return-value]
def get_parent_splitter(self) -> RecursiveCharacterTextSplitter:
"""获取父块切分器(仅父子模式可用)。"""
if self.config.splitter_type != SplitterType.PARENT_CHILD:
raise RuntimeError("父块切分器仅在父子块模式下可用")
return self.parent_splitter # type: ignore[return-value]
def get_docstore(self) -> BaseStore:
"""获取文档存储实例(仅父子模式可用)。"""
if self.config.splitter_type != SplitterType.PARENT_CHILD:
raise RuntimeError("文档存储仅在父子块模式下可用")
assert self.docstore is not None
return self.docstore
# ---------- 资源管理 ----------
def close(self) -> None:
"""关闭资源(同步版本,供上下文管理器使用)。"""
if self.docstore is not None and hasattr(self.docstore, "aclose"):
try:
loop = asyncio.get_running_loop()
except RuntimeError:
# 无运行中的事件循环,创建临时循环
loop = asyncio.new_event_loop()
loop.run_until_complete(self.docstore.aclose()) # type: ignore[attr-defined]
loop.close()
else:
# 已有运行中的循环,创建任务(用户自行等待)
loop.create_task(self.docstore.aclose()) # type: ignore[attr-defined]
logger.info("IndexBuilder 资源已关闭")
async def aclose(self) -> None:
"""异步关闭资源。"""
if self.docstore is not None and hasattr(self.docstore, "aclose"):
await self.docstore.aclose() # type: ignore[attr-defined]
logger.info("IndexBuilder 资源已异步关闭")
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
return False
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.aclose()
return False

View File

@@ -2,35 +2,13 @@
该模块负责 RAG 系统的阶段一:**离线索引构建**。它将外部的非结构化数据如文档、PDF、网页等清洗、切分并转化为向量最终存入向量数据库中。
## 📊 系统工作流示意图
```mermaid
graph TD
A[原始文档集合 <br> PDF / Word / Markdown] --> B(文档加载器 DocumentLoader)
B --> C{文本切分策略 Splitter}
C -->|基础策略| D1[固定字符长度切分 <br> Recursive Split]
C -->|进阶策略| D2[语义边界切分 <br> Semantic Chunking]
C -->|高级策略| D3[父子文档切分 <br> Parent-Child / Auto-merging]
D1 & D2 & D3 --> E[向量化 Embedder <br> llama.cpp: embeddinggemma]
E --> F[(Qdrant 向量数据库)]
subgraph "元数据管理"
G[提取作者、日期、页码等元数据 Metadata] -.附加.-> E
end
```
---
## 🎯 演进路线与核心算法 (Roadmap)
### Level 1: 基础暴力切分 (Basic Recursive Splitting)
- **核心算法**: 递归字符切分。它按照预定义的分隔符列表(如 `["\n\n", "\n", " ", ""]`)从大到小尝试切分文本,直到每块的大小满足最大长度限制。
- **核心算法**: 递归字符切分。它按照预定义的分隔符列表(如 `["\n\n", "\n", "。", "", "", " ", ""]`)从大到小尝试切分文本,直到每块的大小满足最大长度限制。
- **优缺点**: 实现极简单,速度快。但非常容易将一句话拦腰截断,导致上下文语义丢失。
- **实现指南**:
-`langchain.text_splitter` 导入 `RecursiveCharacterTextSplitter`
- **实现指南**:
-`langchain_text_splitters` 导入 `RecursiveCharacterTextSplitter`
- 实例化时设置 `chunk_size`(如 500`chunk_overlap`(如 50直接调用 `.split_documents(raw_docs)` 方法。
### Level 2: 语义动态切分 (Semantic Chunking)
@@ -38,58 +16,52 @@ graph TD
1. 将文章按标点符号按句子拆分。
2. 使用轻量级 Embedding 模型将每一句向量化。
3. 计算相邻两句之间的余弦相似度 (Cosine Similarity)。
4. 当相似度低于设定阈值时(说明两句话讲的不是同一件事,语义发生了转折),在此处切断形成一个新的块。
4. 当相似度低于设定阈值时(说明两句话讲的不是同一件事,语义发生了转折),在此处"切断"形成一个新的块。
- **优缺点**: 极大程度保留了段落内语义的连贯性,对 LLM 回答非常友好。但由于在切分阶段就需要调用向量模型,耗时略长。
- **实现指南**:
- **实现指南**:
-`langchain_text_splitters` 导入 `TextSplitter` 作为基类。
-`langchain_experimental.text_splitter` 导入 `SemanticChunker`
-例化时需要传入你已经配置好的 Embedding 模型实例(如基于 `OpenAIEmbeddings` 封装的 llama.cpp 本地模型),并设置 `breakpoint_threshold_type="percentile"` 等阈值参数
-`SemanticChunkerAdapter` 继承 `TextSplitter`,解决类型不兼容问题
- 实例化时需要传入你已经配置好的 Embedding 模型实例(如基于 `LlamaCppEmbedder` 封装的本地模型)。
### Level 3: 高级父子块策略 (Parent-Child / Auto-merging)
- **核心算法**: 层次化双重存储与映射。
- **切分机制**: 首先将文档粗切为较大的父块 (Parent Chunk, 约 1000 词)”,随后将父块细切为较小的子块 (Child Chunk, 约 200 词)”
- **存储机制**: 仅仅将**子块**的向量存入 Qdrant 用于精准计算距离;将**父块**的原始内容存在内存或 Document Store (如 KV 数据库) 中,通过 UUID 相互映射。
- **切分机制**: 首先将文档粗切为较大的"父块 (Parent Chunk, 约 1000 字符)",随后将父块细切为较小的"子块 (Child Chunk, 约 200 字符)"
- **存储机制**: 仅仅将**子块**的向量存入 Qdrant 用于精准计算距离;将**父块**的原始内容存在 PostgreSQL DocStore 中,通过 UUID 相互映射。
- **核心思路**: 解决 RAG 领域经典的矛盾——检索时块越小越容易精确命中(去除噪声);但生成回答时,块越大越能给大模型提供充足的上下文背景。
- **实现指南**:
- 使用 `langchain.retrievers` 中的 `ParentDocumentRetriever` 模块。
- **实现指南**:
- 使用 `langchain_classic.retrievers` 中的 `ParentDocumentRetriever` 模块。
- 在写入时,你需要同时准备一个底层的 `VectorStore` (即 Qdrant) 和一个 `BaseStore`
- **推荐方案**: 使用 `LocalFileStore` (默认) 或 `PostgresDocStore` 作为 docstore。
- **推荐方案**: 使用 `PostgresDocStore` 作为 docstore,支持持久化存储
- 将两种不同的 `TextSplitter` 分别赋值给检索器的 `child_splitter``parent_splitter`,然后调用 `.add_documents()` 即可让系统自动完成映射。
### Level 3.1: PostgreSQL DocStore 集成
- **核心优势**: 利用 PostgreSQL 作为持久化存储,适合生产环境。使用步连接池,避免异步复杂度
- **核心优势**: 利用 PostgreSQL 作为持久化存储,适合生产环境。使用步连接池,支持高并发
- **实现步骤**:
1. **安装依赖**: `pip install psycopg2-binary`
2. **配置连接**: 设置 `DB_URI` 环境变量或直接在代码中指定 PostgreSQL 连接字符串
3. **创建 docstore**: 使用 `PostgresDocStore` 类直接创建
4. **注入到 IndexBuilder**: 在创建 `IndexBuilder` 时通过 `docstore` 参数注入
1. **配置连接**: 设置 `DB_URI` 环境变量或通过 `docstore_conn_string` 参数指定
2. **创建 docstore**: 使用 `rag_indexer.store.create_docstore()` 工厂函数
3. **注入到 IndexBuilder**: 通过构造函数参数注入
- **使用示例**:
```python
from rag_indexer.docstore_manager import PostgresDocStore
from rag_indexer.builder import IndexBuilder, SplitterType
# 创建 PostgreSQL docstore
docstore = PostgresDocStore(
connection_string="postgresql://user:pass@host:5432/db",
table_name="parent_documents"
)
# 创建 IndexBuilder 并注入 docstore
# 创建 IndexBuilder
builder = IndexBuilder(
collection_name="rag_documents",
splitter_type=SplitterType.PARENT_CHILD,
docstore=docstore,
parent_chunk_size=1000,
child_chunk_size=200,
docstore_conn_string="postgresql://user:pass@host:5432/db",
)
```
### Level 3.2: 语义切分与父子块策略结合
- **核心优势**: 结合语义切分的连贯性和父子块策略的层次化存储优势,实现更精准的检索和更丰富的上下文。
- **实现原理**:
- **父块切分**: 使用递归字符切分创建大块约1000),提供完整的上下文背景
- **子块切分**: 使用语义动态切分创建小块约200词,根据语义连贯性动态切分,提高检索精度
- **存储机制**: 子块向量存入Qdrant用于精准检索父块内容存入PostgreSQL提供完整上下文
- **父块切分**: 使用 `RecursiveCharacterTextSplitter` 创建大块约1000字符),提供完整的上下文背景
- **子块切分**: 使用 `SemanticChunkerAdapter` 创建小块,根据语义连贯性动态切分,提高检索精度
- **存储机制**: 子块向量存入 Qdrant 用于精准检索,父块内容存入 PostgreSQL 提供完整上下文
- **使用示例**:
```python
from rag_indexer.builder import IndexBuilder, SplitterType
@@ -109,97 +81,55 @@ graph TD
```
- **配置参数**:
- `child_splitter_type`: 子块切分器类型,可选 `SplitterType.RECURSIVE`(默认)或 `SplitterType.SEMANTIC`
- 当使用语义切分时系统会自动使用已配置的Embedding模型进行句子级相似度计算
- 当使用语义切分时,系统会自动使用已配置的 Embedding 模型进行句子级相似度计算
### Level 4: RAG-Fusion (多路改写与倒数排名融合)
- **核心优势**: 通过大模型发散思维,将单一问题改写为多个相似问题,扩大搜索面,再利用数学统计算法合并结果,提高检索的全面性和准确性。
- **实现原理**:
1. **多路查询改写**: 利用LLM将原始查询改写成3-5个不同表述的查询从不同角度表达相同意图
2. **倒数排名融合 (RRF)**: 对每个改写查询的结果进行RRF融合公式为 $RRF\_score(d) = \sum_{q \in Q} \frac{1}{k + rank_q(d)}$,避免单一检索结果主导
3. **结果去重**: 对融合后的结果进行去重,确保返回的文档唯一
- **使用示例**:
```python
from rag_indexer.builder import IndexBuilder, SplitterType
from langchain_openai import OpenAI
# 创建 IndexBuilder
builder = IndexBuilder(
collection_name="rag_documents",
splitter_type=SplitterType.PARENT_CHILD,
parent_chunk_size=1000,
child_chunk_size=200,
docstore_conn_string="postgresql://user:pass@host:5432/db",
)
# 创建语言模型用于查询改写
llm = OpenAI(
openai_api_base="http://localhost:8000/v1",
openai_api_key="no-key-needed",
model_name="Qwen2.5-7B-Instruct",
temperature=0.3,
)
# 使用 RAG-Fusion 检索
query = "如何申请项目资金?"
results = builder.retrieve_with_fusion(
query=query,
llm=llm,
num_queries=3,
k=5,
return_parent=True
)
```
- **配置参数**:
- `llm`: 语言模型实例,用于查询改写
- `num_queries`: 生成的查询数量建议3-5个
- `k`: 返回的文档数量
- `return_parent`: 是否返回父块上下文
### Level 5: GraphRAG 与 多模态 (Graph & Multi-modal)
### Level 4: GraphRAG基于图和关系的 RAG
- **核心算法**: LLM 实体关系抽取 (NER & Relation Extraction)。
- **核心思路**: 解决传统纯向量检索难以处理跨文档复杂关系推理的痛点A公司的CEO是谁他名下的B公司主要业务是什么这种需要横跨多页 PDF 的跳跃性问题)。
- **实现指南**:
- 使用本地的大模型(如 `Gemma-4-E2B`)配合 `langchain_community.graphs` 模块。
- 利用 `LLMGraphTransformer` 组件,在读取文档时,通过预设的 Prompt 强制大模型提取出实体Node和关系Edge直接写入诸如 Neo4j 这样的图数据库中,而非传统的 Qdrant 向量库。
- **核心思路**: 解决传统纯向量检索难以处理"跨文档复杂关系推理"的痛点A公司的CEO是谁他名下的B公司主要业务是什么这种需要横跨多页 PDF 的跳跃性问题)。
- **实现原理**:
1. **实体提取**: 利用 LLM 从文档中提取实体(如人物、组织、地点、事件等)
2. **关系抽取**: 识别实体之间的关系(如"CEO of"、"founded by"、"located in"等)
3. **图构建**: 将实体作为节点,关系作为边,构建知识图谱
4. **混合检索**: 结合向量检索和图查询,同时利用语义相似性和结构关系
- **技术栈**:
- **图数据库**: Neo4j 或 RedisGraph
- **LLM 工具**: `LLMGraphTransformer` 或自定义 Prompt
- **集成方式**: 与向量存储并行,形成混合检索系统
- **实现指南**:
- 使用 `langchain_community.graphs` 模块
- 配置本地大模型(如 `Gemma-4-E2B`)用于实体关系抽取
- 构建包含实体和关系的图结构,存储到图数据库
- 实现混合检索逻辑,结合向量相似度和图路径分析
---
## 所需依赖与安装
为了支持完整的文档解析和 Qdrant 写入,需要安装以下 Python 包:
```bash
# 基础核心库
pip install langchain langchain-core langchain-openai langchain-qdrant
# 用于复杂文档解析 (PDF, Word, Excel 等)
pip install unstructured pdf2image pdfminer.six
# 用于语义分块 (可选)
pip install langchain-experimental
# 用于 PostgreSQL 存储 (可选,用于 Parent-Child 策略)
pip install psycopg2-binary
# 用于 RAG-Fusion (可选,需要语言模型)
pip install langchain-openai
```
### Level 5: 多模态 RAG (Multi-modal RAG)
- **核心算法**: 跨模态嵌入和多模态融合。
- **核心思路**: 突破纯文本限制,支持图像、表格、音频等多种数据类型的理解和检索。
- **实现原理**:
1. **多模态嵌入**: 使用 CLIP 等模型将不同模态数据映射到统一向量空间
2. **多模态索引**: 为不同类型的内容创建专用索引
3. **跨模态检索**: 支持以文搜图、以图搜文等跨模态查询
- **技术栈**:
- **多模态模型**: CLIP、BLIP 等
- **存储**: 向量数据库 + 对象存储
- **检索**: 混合向量检索
---
## 📂 架构与文件结构设计
在 `rag_indexer/` 目录下,需创建以下核心文件:
```text
```
rag_indexer/
├── __init__.py
├── loaders.py # 负责调用 unstructured 解析不同类型文件
├── splitters.py # 负责实现 Recursive、Semantic、Parent-Child 切分逻辑
├── splitters.py # 负责实现 Recursive、Semantic 切分逻辑及适配器
├── embedders.py # 封装本地 llama.cpp 交互的 Embedding 接口
├── vector_store.py # 封装 Qdrant 写入、Upsert、Collection 初始化操作
├── docstore_manager.py # 文档存储管理器,支持 LocalFileStore 和 PostgreSQL
── builder.py # 核心编排文件,将上述模块串联成 Pipeline
├── builder.py # 核心编排文件,将上述模块串联成 Pipeline
── cli.py # 命令行入口
└── store/
├── __init__.py
├── factory.py # docstore 工厂函数
└── postgres.py # PostgreSQL DocStore 实现
```
---
@@ -211,36 +141,36 @@ rag_indexer/
```
┌─────────────────────────────────────────┐
│ builder.py │
│ IndexBuilder 入口 │
│ IndexBuilder 入口
└─────────────────┬───────────────────────┘
┌─────────────────▼───────────────────────┐
│ loaders.py │
│ DocumentLoader.load_file() │
│ → 返回 List[Document] │
│ loaders.py
│ DocumentLoader.load_file()
│ → 返回 List[Document]
└─────────────────┬───────────────────────┘
┌─────────────────▼───────────────────────┐
│ ParentDocumentRetriever.add_documents()
│ ┌─────────────────────────────────┐ │
│ │ parent_splitter (粗切) │ │
│ │ 父块 ~1000 │ │
│ └────────────┬────────────────────┘ │
│ │ │
│ ┌────────────▼────────────────────┐ │
│ │ child_splitter (细切) │ │
│ │ 子块 ~200 │ │
│ └────────────┬────────────────────┘ │
│ │ │
│ ┌──────────┴──────────┐ │
│ ▼ ▼ │
│ 子块向量 父块原始内容 │
│ │ │ │
│ ▼ ▼ │
│ ┌────────────┐ ┌─────────────────┐ │
│ │vector_store│ │ docstore_manager│
│ │ (Qdrant) │ │ (PostgreSQL) │ │
│ └────────────┘ └─────────────────┘ │
│ ParentDocumentRetriever
│ ┌─────────────────────────────────┐
│ │ parent_splitter (粗切) │
│ │ 父块 ~1000 字符
│ └────────────┬────────────────────┘
│ │
│ ┌────────────▼────────────────────┐
│ │ child_splitter (细切) │
│ │ 子块 ~200 字符
│ └────────────┬────────────────────┘
│ │
│ ┌──────────┴──────────┐
│ ▼ ▼
│ 子块向量 父块原始内容
│ │ │
│ ▼ ▼
│ ┌────────────┐ ┌─────────────────┐
│ │vector_store│ │ store/ │
│ │ (Qdrant) │ │ (PostgreSQL) │
│ └────────────┘ └─────────────────┘
└─────────────────────────────────────────┘
```
@@ -250,10 +180,31 @@ rag_indexer/
|------|------|------------|
| **builder.py** | 核心编排,负责串联整个流程 | `IndexBuilder` |
| **loaders.py** | 解析各种文档格式PDF、Word、TXT等 | `DocumentLoader` |
| **splitters.py** | 文本切分策略Recursive/Semantic/Parent-Child | `SplitterType`, `get_splitter()` |
| **splitters.py** | 文本切分策略Recursive/Semantic及适配器 | `SplitterType`, `get_splitter()`, `SemanticChunkerAdapter` |
| **embedders.py** | 向量化(封装 llama.cpp embedding 接口) | `LlamaCppEmbedder` |
| **vector_store.py** | Qdrant 向量数据库操作 | `QdrantVectorStore` |
| **docstore_manager.py** | 父文档存储(PostgreSQL/本地文件) | `PostgresDocStore`, `get_docstore()` |
| **store/postgres.py** | PostgreSQL DocStore 实现 | `PostgresDocStore` |
| **store/factory.py** | docstore 工厂函数 | `create_docstore()` |
### 核心实现细节
#### 1. 文本切分
- **递归切分**: 使用 `langchain_text_splitters.RecursiveCharacterTextSplitter`,支持中文分隔符
- **语义切分**: 使用 `langchain_experimental.text_splitter.SemanticChunker`,通过 `SemanticChunkerAdapter` 适配 `TextSplitter` 接口
- **父子块策略**: 父块使用递归切分1000字符子块可选择递归或语义切分200字符
#### 2. 向量化
- **Embedding API**: 使用 `LlamaCppEmbedder` 封装本地 llama.cpp 服务,支持 `embed_documents` 和 `embed_query` 方法
- **向量维度**: 自动检测模型维度(默认 2560创建对应大小的 Qdrant 集合
#### 3. 向量存储
- **Qdrant 集成**: 使用 `langchain_qdrant.QdrantVectorStore` 作为底层存储
- **集合管理**: 自动创建/复用集合,支持 `force_recreate` 参数
- **批量写入**: 支持 `batch_size` 参数,避免单次请求过大
#### 4. 文档存储
- **PostgreSQL**: 使用 `PostgresDocStore` 持久化存储父块,支持异步连接池
- **数据映射**: 通过 UUID 将子块与父块关联,检索时返回完整父块
### 调用顺序
@@ -265,27 +216,42 @@ from rag_indexer.builder import IndexBuilder, SplitterType
builder = IndexBuilder(
collection_name="my_docs",
splitter_type=SplitterType.PARENT_CHILD,
qdrant_url="http://localhost:6333",
parent_chunk_size=1000,
child_chunk_size=200,
docstore_conn_string="postgresql://user:pass@host:5432/db",
)
```
#### 2. 构建索引
```python
import asyncio
# 方式A从单个文件构建
builder.build_from_file("/path/to/document.pdf")
async def main():
count = await builder.build_from_file("/path/to/document.pdf")
print(f"已索引 {count} 个块")
# 方式B从目录批量构建
builder.build_from_directory("/path/to/docs/")
async def main():
count = await builder.build_from_directory("/path/to/docs/")
print(f"已索引 {count} 个块")
asyncio.run(main())
```
#### 3. 检索(获取完整父块上下文)
```python
# 检索时返回完整父块
results = builder.search_with_parent_context("查询内容")
import asyncio
async def main():
# 检索时返回完整父块
results = await builder.search_with_parent_context("查询内容", k=5)
for doc in results:
print(doc.page_content)
asyncio.run(main())
```
### 检索流程
@@ -299,11 +265,16 @@ results = builder.search_with_parent_context("查询内容")
---
### 串联与触发方式
在你的 LangGraph 系统外,创建一个执行脚本 `scripts/run_indexer.py`
使用 `cli.py` 入口脚本
```bash
# 终端执行,将本地的 PDF 手册刷入向量数据库
# 设置环境变量
export QDRANT_URL="http://115.190.121.151:6333"
python scripts/run_indexer.py --file data/user_docs/tech_manual.pdf
export QDRANT_API_KEY="your-api-key"
export DB_URI="postgresql://postgres:password@host:5432/langgraph_db?sslmode=disable"
# 执行索引构建
python -m rag_indexer.cli --path data/user_docs/tech_manual.pdf
```
这相当于系统后台的**“离线学习阶段”**,你可以随时挂载定时任务去扫描文件夹,增量更新知识库。
这相当于系统后台的**"离线学习阶段"**,你可以随时挂载定时任务去扫描文件夹,增量更新知识库。

View File

@@ -9,52 +9,52 @@ Offline RAG Indexer module.
- 父文档存储PostgreSQL
示例用法:
>>> from rag_indexer import IndexBuilder, SplitterType
>>> from rag_indexer import IndexBuilder, IndexBuilderConfig, SplitterType
>>>
>>> builder = IndexBuilder(
>>> config = IndexBuilderConfig(
... collection_name="my_docs",
... splitter_type=SplitterType.PARENT_CHILD,
... qdrant_url="http://localhost:6333"
... )
>>> builder = IndexBuilder(config)
>>>
>>> builder.build_from_file("document.pdf")
>>> # 或直接传参(向后兼容)
>>> builder = IndexBuilder(collection_name="my_docs")
>>>
>>> await builder.build_from_file("document.pdf")
"""
from .IndexBuilder import IndexBuilder, IndexBuilderConfig, DocstoreConfig
from .loaders import DocumentLoader
from .splitters import (
SplitterType,
get_splitter,
ParentChildSplitter,
)
from .embedders import LlamaCppEmbedder
from .vector_store import QdrantVectorStore
from .builder import IndexBuilder
from .splitters import SplitterType, get_splitter
# 导出存储相关类(从新的 store 包)
from .store import (
# 从 rag_core 重新导出常用组件
from rag_core import (
LlamaCppEmbedder,
QdrantVectorStore,
PostgresDocStore,
create_docstore,
)
__version__ = "2.0.0"
__all__ = [
# 核心
"DocumentLoader",
# 核心构建器与配置
"IndexBuilder",
"IndexBuilderConfig",
"DocstoreConfig",
# 加载器
"DocumentLoader",
# 切分相关
"SplitterType",
"get_splitter",
"ParentChildSplitter",
# 嵌入向量存储
# 嵌入向量存储
"LlamaCppEmbedder",
"QdrantVectorStore",
# 存储(新的 store 包)
# 文档存储
"PostgresDocStore",
"create_docstore",
]
]

View File

@@ -1,392 +0,0 @@
"""
离线 RAG 索引构建核心流水线。
支持 LangChain 的 ParentDocumentRetriever 用于父子块切分。
"""
import asyncio
import logging
from pathlib import Path
from typing import List, Union, Optional, Tuple, Any
from dataclasses import dataclass
from httpx import RemoteProtocolError
from langchain_core.documents import Document
from langchain_classic.retrievers import ParentDocumentRetriever
from langchain_core.stores import BaseStore
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_experimental.text_splitter import SemanticChunker
from .loaders import DocumentLoader
from .splitters import SplitterType, get_splitter, ParentChildSplitter, SemanticChunkerAdapter
from .embedders import LlamaCppEmbedder
from .vector_store import QdrantVectorStore, QDRANT_URL, QDRANT_API_KEY
from .store import create_docstore
logger = logging.getLogger(__name__)
@dataclass
class ParentChildConfig:
"""父子块切分配置。"""
parent_chunk_size: int = 1000
child_chunk_size: int = 200
parent_chunk_overlap: int = 100
child_chunk_overlap: int = 20
search_k: int = 5
docstore_path: Optional[str] = None
docstore_type: str = "local"
docstore_conn_string: Optional[str] = None
class IndexBuilder:
"""RAG 索引构建主流水线。"""
# 类型注解
parent_splitter: "RecursiveCharacterTextSplitter"
child_splitter: Union["RecursiveCharacterTextSplitter", "SemanticChunker", "SemanticChunkerAdapter"]
docstore: Optional["BaseStore"]
_docstore_conn: Optional[str]
retriever: Optional["ParentDocumentRetriever"]
vector_store_obj: Any
def __init__(
self,
collection_name: str = "rag_documents",
splitter_type: SplitterType = SplitterType.PARENT_CHILD,
docstore=None,
**splitter_kwargs,
):
self.collection_name = collection_name
self.splitter_type = splitter_type
self.splitter_kwargs = splitter_kwargs
self.docstore = docstore # 从外部注入
# 组件
self.loader = DocumentLoader()
self.embedder = LlamaCppEmbedder()
self.embeddings = self.embedder.as_langchain_embeddings()
self.vector_store = QdrantVectorStore(
collection_name=collection_name,
embeddings=self.embeddings,
)
# 切分器(父子块单独处理)
if splitter_type != SplitterType.PARENT_CHILD:
if splitter_type == SplitterType.SEMANTIC:
splitter_kwargs["embeddings"] = self.embeddings
self.splitter = get_splitter(splitter_type, **splitter_kwargs)
else:
self.splitter = None
# 为父子块切分初始化 ParentDocumentRetriever
self._init_parent_child_retriever()
def _init_parent_child_retriever(self, **kwargs):
"""
初始化 ParentDocumentRetriever 用于父子块切分。
支持动态语义切分与父子块策略结合:
- 父块使用递归切分(大块,提供上下文)
- 子块可以使用递归切分或语义切分(根据语义动态切分,提高检索精度)
替代自定义的 ParentChildSplitter 逻辑。
"""
# 解析父子块配置参数
parent_size = kwargs.get("parent_chunk_size", 1000)
child_size = kwargs.get("child_chunk_size", 200)
parent_overlap = kwargs.get("parent_chunk_overlap", kwargs.get("chunk_overlap", 100))
child_overlap = kwargs.get("child_chunk_overlap", kwargs.get("chunk_overlap", 20))
# 子块切分器类型,默认为语义切分
child_splitter_type = kwargs.get("child_splitter_type", SplitterType.SEMANTIC)
# 定义父块切分器(始终使用递归切分)
self.parent_splitter = RecursiveCharacterTextSplitter(
chunk_size=parent_size,
chunk_overlap=parent_overlap,
)
# 定义子块切分器(根据类型选择)
if child_splitter_type == SplitterType.SEMANTIC:
self.child_splitter = get_splitter(
SplitterType.SEMANTIC,
embeddings=self.embeddings,
)
logger.info(f"子块使用语义切分器")
else:
# 默认使用递归切分
self.child_splitter = RecursiveCharacterTextSplitter(
chunk_size=child_size,
chunk_overlap=child_overlap,
)
logger.info(f"子块使用递归切分器,块大小: {child_size},重叠: {child_overlap}")
# 向量存储(用于子块)
self.vector_store_obj = self.vector_store.get_langchain_vectorstore()
# 文档存储(用于父块)
if self.docstore is None:
# 如果没有外部注入 docstore则使用 PostgreSQL 创建
docstore_conn = kwargs.get("docstore_conn_string")
pool_config = kwargs.get("pool_config")
max_concurrency = kwargs.get("max_concurrency")
# 使用 create_docstore 创建 PostgreSQL 存储
self.docstore, self._docstore_conn = create_docstore(
connection_string=docstore_conn,
pool_config=pool_config,
max_concurrency=max_concurrency
)
else:
# 使用外部注入的 docstore
self._docstore_conn = None
# 创建检索器
self.retriever = ParentDocumentRetriever(
vectorstore=self.vector_store_obj,
docstore=self.docstore,
child_splitter=self.child_splitter, # type: ignore
parent_splitter=self.parent_splitter,
search_kwargs={"k": kwargs.get("search_k", 5)},
)
logger.info(f"ParentDocumentRetriever 已初始化,父块大小: {parent_size},子块类型: {child_splitter_type}")
async def build_from_file(self, file_path: Union[str, Path]) -> int:
logger.info("加载文件: %s", file_path)
documents = self.loader.load_file(file_path)
logger.info("已加载 %d 个文档", len(documents))
return await self._process_documents(documents)
async def build_from_directory(self, directory_path: Union[str, Path], recursive: bool = True) -> int:
logger.info("加载目录: %s (递归=%s)", directory_path, recursive)
documents = self.loader.load_directory(directory_path, recursive=recursive)
logger.info("已从目录加载 %d 个文档", len(documents))
return await self._process_documents(documents)
async def _process_documents(self, documents: List[Document]) -> int:
if not documents:
logger.warning("没有文档需要处理")
return 0
if self.splitter_type == SplitterType.PARENT_CHILD:
logger.info("使用 LangChain ParentDocumentRetriever")
# 确保集合存在(用于子块)
self.vector_store.create_collection()
# 分批处理,避免单次请求过大
assert self.retriever is not None, "retriever 未初始化"
batch_size = 10 # 每次处理10个文档
total = len(documents)
processed = 0
for i in range(0, total, batch_size):
batch = documents[i:i + batch_size]
max_retries = 3
for attempt in range(max_retries):
try:
await self.retriever.aadd_documents(batch)
processed += len(batch)
logger.info(f"批次 {i//batch_size + 1}: 已处理 {processed}/{total}")
break
except (RemoteProtocolError, ConnectionError, OSError) as e:
if attempt == max_retries - 1:
raise
logger.warning(f"连接断开,重试 ({attempt+1}/{max_retries}): {e}")
self.vector_store.refresh_client()
await asyncio.sleep(1)
logger.info(
"已使用 ParentDocumentRetriever 索引: "
f"{processed} 个父块"
)
return processed
else:
logger.info("使用 %s 切分文档", self.splitter_type)
# 当 splitter_type 不是 PARENT_CHILD 时splitter 一定不为 None
assert self.splitter is not None, "splitter 未初始化"
chunks = self.splitter.split_documents(documents)
logger.info("已切分为 %d 个块", len(chunks))
self.vector_store.create_collection()
self.vector_store.add_documents(chunks)
return len(chunks)
def get_collection_info(self):
return self.vector_store.get_collection_info()
def search(self, query: str, k: int = 5) -> List[Document]:
"""标准搜索 - 返回子块。"""
return self.vector_store.similarity_search(query, k=k)
async def search_with_parent_context(self, query: str, k: int = 5) -> List[Document]:
"""
带父块上下文的搜索 - 返回完整父块。
这是使用父子块切分时的主要检索方法。
"""
if self.splitter_type != SplitterType.PARENT_CHILD:
raise RuntimeError(
"search_with_parent_context 仅在 PARENT_CHILD 切分器下可用。"
"请使用 search() 进行标准检索。"
)
assert self.retriever is not None, "retriever 未初始化"
return await self.retriever.ainvoke(query, config={"k": k}) # type: ignore
async def retrieve(self, query: str, return_parent: bool = True) -> List[Document]:
"""
统一检索接口。
Args:
query: 搜索查询
return_parent: 如果为 True 且使用父子块切分,返回父块
如果为 False始终返回子块
Returns:
相关文档列表
"""
if self.splitter_type == SplitterType.PARENT_CHILD and return_parent:
return await self.search_with_parent_context(query)
else:
return self.search(query)
async def retrieve_with_fusion(self, query: str, llm: Any, num_queries: int = 3, k: int = 5, return_parent: bool = True) -> List[Document]:
"""
使用 RAG-Fusion 进行检索(多路查询改写 + 倒数排名融合)。
核心原理:
1. 多路查询改写: 利用 LLM 将原始查询改写成多个不同表述
2. 倒数排名融合: 对每个改写查询的结果进行 RRF 融合,避免单一检索结果主导
Args:
query: 原始搜索查询
llm: 语言模型实例(用于查询改写)
num_queries: 生成的查询数量
k: 返回的文档数量
return_parent: 如果为 True 且使用父子块切分,返回父块
如果为 False始终返回子块
Returns:
经过融合后的相关文档列表
"""
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain.retrievers import EnsembleRetriever
if self.splitter_type == SplitterType.PARENT_CHILD and return_parent:
# 使用 ParentDocumentRetriever 作为基础检索器
assert self.retriever is not None, "retriever 未初始化"
base_retriever = self.retriever
else:
# 使用向量存储作为基础检索器
base_retriever = self.vector_store.as_langchain_vectorstore().as_retriever(search_kwargs={"k": k * 2})
# 创建多路查询检索器
multi_query_retriever = MultiQueryRetriever.from_llm(
retriever=base_retriever,
llm=llm,
include_original=True
)
# 设置自定义提示词以生成指定数量的查询
from langchain_core.prompts import PromptTemplate
multi_query_retriever.llm_chain.prompt = PromptTemplate.from_template(
"你是一个专业的查询改写助手。你的任务是将用户的问题改写成 {num_queries} 个不同的版本。\n"
"这些版本应该从不同的角度、使用不同的关键词来表达相同或相关的意图。\n\n"
"原始问题: {question}\n\n"
"请生成 {num_queries} 个不同版本的查询,每个版本一行。\n"
"确保每个版本都是独立、完整的查询语句。\n\n"
"生成 {num_queries} 个查询:"
)
# 修改调用参数以包含 num_queries
original_ainvoke = multi_query_retriever.llm_chain.ainvoke
async def new_ainvoke(input_dict):
input_dict["num_queries"] = num_queries
return await original_ainvoke(input_dict)
multi_query_retriever.llm_chain.ainvoke = new_ainvoke
# 执行检索
documents = await multi_query_retriever.ainvoke(query)
# 去重并限制数量
seen_content = set()
unique_documents = []
for doc in documents:
content = doc.page_content
if content not in seen_content:
seen_content.add(content)
unique_documents.append(doc)
if len(unique_documents) >= k:
break
logger.info(f"RAG-Fusion 检索完成: 原始 {len(documents)} 个结果,去重后 {len(unique_documents)} 个结果")
return unique_documents
def get_retriever(self) -> ParentDocumentRetriever:
"""
直接获取 ParentDocumentRetriever 实例。
适用于需要在 IndexBuilder 外部访问检索器的高级用例。
"""
if self.splitter_type != SplitterType.PARENT_CHILD:
raise RuntimeError(
"get_retriever() 仅在 PARENT_CHILD 切分器下可用。"
"请使用 search() 或 search_with_parent_context() 进行标准检索。"
)
assert self.retriever is not None, "retriever 未初始化"
return self.retriever
def get_child_splitter(self) -> Union["RecursiveCharacterTextSplitter", "SemanticChunker", "SemanticChunkerAdapter"]:
"""获取子块切分器以便重新配置。"""
if self.splitter_type != SplitterType.PARENT_CHILD:
return self.splitter # type: ignore
return self.child_splitter
def get_parent_splitter(self) -> "RecursiveCharacterTextSplitter":
"""获取父块切分器以便重新配置。"""
if self.splitter_type != SplitterType.PARENT_CHILD:
raise RuntimeError(
"父块切分器仅在 PARENT_CHILD 切分器下可用。"
)
return self.parent_splitter
def get_docstore(self) -> BaseStore:
"""获取父块的文档存储。"""
if self.splitter_type != SplitterType.PARENT_CHILD:
raise RuntimeError(
"文档存储仅在 PARENT_CHILD 切分器下可用。"
)
assert self.docstore is not None, "docstore 未初始化"
return self.docstore
def get_docstore_path(self) -> Optional[str]:
"""获取文档存储路径(已弃用,仅用于兼容性)。"""
if self.splitter_type != SplitterType.PARENT_CHILD:
raise RuntimeError(
"文档存储路径仅在 PARENT_CHILD 切分器下可用。"
)
# PostgreSQL 存储没有 persist_path返回 None
return None
def close(self):
"""关闭资源。"""
if self.docstore is not None and hasattr(self.docstore, "aclose"):
import asyncio
asyncio.get_event_loop().run_until_complete(self.docstore.aclose()) # type: ignore
logger.info("PostgreSQL 异步连接池已关闭")
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
return False
# 需要导入 RecursiveCharacterTextSplitter
from langchain_text_splitters import RecursiveCharacterTextSplitter
# 示例用法已移除,请参考文档

View File

@@ -1,85 +1,77 @@
"""
Command-line interface for the RAG index builder.
简易命令行入口,使用默认配置构建 RAG 索引。
"""
import argparse
import asyncio
import logging
import sys
from pathlib import Path
from rag_indexer.builder import IndexBuilder
from rag_indexer.IndexBuilder import IndexBuilder, IndexBuilderConfig
from rag_indexer.splitters import SplitterType
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
# 基础配置
# 默认配置(所有连接参数从环境变量读取)
COLLECTION_NAME = "rag_documents"
DB_URI = "postgresql://postgres:huang1998@115.190.121.151:5432/langgraph_db?sslmode=disable"
SPLITTER_TYPE = SplitterType.PARENT_CHILD
CHILD_SPLITTER_TYPE = SplitterType.SEMANTIC
# 基础切分参数
CHUNK_SIZE = 500
CHUNK_OVERLAP = 50
# 父子块切分参数
# 父子块大小参数(可根据需要调整)
PARENT_CHUNK_SIZE = 1000
CHILD_CHUNK_SIZE = 200
PARENT_CHUNK_OVERLAP = 100
CHILD_CHUNK_SIZE = 200
CHILD_CHUNK_OVERLAP = 20
SEARCH_K = 5
# 切分策略basic基础、semantic语义、parent-child父子块
STRATEGY = "parent-child"
# 存储类型postgresPostgreSQL、local本地文件
STORAGE_TYPE = "postgres"
def get_input_path() -> Path:
"""从命令行参数获取输入路径,若未提供则使用默认示例路径。"""
if len(sys.argv) > 1:
return Path(sys.argv[1])
# 默认测试路径(可按需修改)
return Path("data/user_docs/a.txt")
async def main():
# 使用固定策略
splitter_type = SplitterType.PARENT_CHILD
child_splitter_type = SplitterType.SEMANTIC
input_path = get_input_path()
if not input_path.exists():
logger.error("路径不存在: %s", input_path)
sys.exit(1)
splitter_kwargs = {}
if splitter_type == SplitterType.RECURSIVE:
splitter_kwargs["chunk_size"] = CHUNK_SIZE
splitter_kwargs["chunk_overlap"] = CHUNK_OVERLAP
elif splitter_type == SplitterType.PARENT_CHILD:
splitter_kwargs["parent_chunk_size"] = PARENT_CHUNK_SIZE
splitter_kwargs["child_chunk_size"] = CHILD_CHUNK_SIZE
splitter_kwargs["parent_chunk_overlap"] = PARENT_CHUNK_OVERLAP
splitter_kwargs["child_chunk_overlap"] = CHILD_CHUNK_OVERLAP
splitter_kwargs["child_splitter_type"] = child_splitter_type
if STORAGE_TYPE == "postgres":
splitter_kwargs["docstore_conn_string"] = DB_URI
elif STORAGE_TYPE == "local":
splitter_kwargs["docstore_path"] = "./parent_docs"
else:
splitter_kwargs["docstore_conn_string"] = DB_URI
builder = IndexBuilder(
# 构建配置(使用全部默认值)
config = IndexBuilderConfig(
collection_name=COLLECTION_NAME,
splitter_type=splitter_type,
**splitter_kwargs
splitter_type=SPLITTER_TYPE,
parent_chunk_size=PARENT_CHUNK_SIZE,
parent_chunk_overlap=PARENT_CHUNK_OVERLAP,
child_chunk_size=CHILD_CHUNK_SIZE,
child_chunk_overlap=CHILD_CHUNK_OVERLAP,
child_splitter_type=CHILD_SPLITTER_TYPE,
search_k=SEARCH_K,
# docstore 默认使用 create_docstore 从环境变量读取 PostgreSQL 连接
)
is_file=False
path="data/corpus/"
builder = IndexBuilder(config)
is_directory = input_path.is_dir()
try:
if is_file:
chunk_count = await builder.build_from_file(path)
else:
chunk_count = await builder.build_from_directory(path, recursive=True)
async with builder:
if is_directory:
chunk_count = await builder.build_from_directory(input_path, recursive=True)
else:
chunk_count = await builder.build_from_file(input_path)
print(f"索引构建完成。共索引 {chunk_count} 个块")
print(f"\n索引构建完成。共索引 {chunk_count} 个块")
info = builder.get_collection_info()
print(f"集合 '{info['name']}' 包含 {info['vectors_count']} 个向量(维度:{info['vector_size']}")
except Exception as e:
logging.exception(f"索引构建失败{e}")
logger.exception("索引构建失败: %s", e)
sys.exit(1)

View File

@@ -1,88 +0,0 @@
"""
嵌入模型包装器,用于 llama.cpp 服务。
"""
import os
import httpx
from typing import List, Optional
from urllib.parse import urljoin
from langchain_core.embeddings import Embeddings
class LlamaCppEmbedder:
"""通过 OpenAI 兼容 API 封装 llama.cpp 嵌入服务。"""
def __init__(
self,
base_url: Optional[str] = None,
api_key: Optional[str] = None,
model: str = "embeddinggemma-300M-Q8_0",
):
self.base_url = base_url or os.getenv("LLAMACPP_EMBEDDING_URL", "http://127.0.0.1:8082")
self.api_key = api_key or os.getenv("LLAMACPP_API_KEY", "")
self.model = model
def as_langchain_embeddings(self) -> Embeddings:
"""创建 LangChain 兼容的嵌入实例。"""
return _LlamaCppLangchainAdapter(self)
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""嵌入一批文档。"""
return self._call_embedding_api(texts)
def embed_query(self, text: str) -> List[float]:
"""嵌入单个查询。"""
return self._call_embedding_api([text])[0]
def get_embedding_dimension(self) -> int:
"""通过嵌入测试字符串获取嵌入维度。"""
test_embedding = self.embed_query("test")
return len(test_embedding)
def _call_embedding_api(self, texts: List[str]) -> List[List[float]]:
"""直接调用 llama.cpp 嵌入 API。"""
base = self.base_url.rstrip("/")
if not base.endswith("/v1"):
base = base + "/v1"
headers = {"Content-Type": "application/json"}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
payload = {
"input": texts,
"model": self.model,
}
with httpx.Client(timeout=120) as client:
response = client.post(
f"{base}/embeddings",
headers=headers,
json=payload,
)
response.raise_for_status()
data = response.json()
# 处理不同响应格式
if isinstance(data, list):
# llama.cpp 直接返回列表
return [item["embedding"] for item in data]
elif isinstance(data, dict) and "data" in data:
# OpenAI 标准格式
return [item["embedding"] for item in sorted(data["data"], key=lambda x: x["index"])]
else:
raise ValueError(f"未知的嵌入 API 响应格式: {data}")
class _LlamaCppLangchainAdapter(Embeddings):
"""将 LlamaCppEmbedder 适配为 LangChain Embeddings 接口。"""
def __init__(self, embedder: LlamaCppEmbedder):
self._embedder = embedder
def embed_documents(self, texts: List[str]) -> List[List[float]]:
return self._embedder.embed_documents(texts)
def embed_query(self, text: str) -> List[float]:
return self._embedder.embed_query(text)

View File

@@ -3,19 +3,27 @@
"""
import logging
import os
from pathlib import Path
from typing import Any, Dict, List, Mapping, Optional, Union
from typing import Any, Dict, List, Optional, Union
from langchain_core.documents import Document
from unstructured.documents.elements import Element
from unstructured.partition.auto import partition
logger = logging.getLogger(__name__)
# 模块加载时设置一次环境变量,避免重复设置
os.environ.setdefault("UNSTRUCTURED_LANGUAGE_CHECKS", "false")
class DocumentLoader:
"""从各种文件格式加载文档。"""
SUPPORTED_EXTENSIONS = {".pdf", ".docx", ".doc", ".txt", ".md", ".html", ".pptx", ".xlsx", ".json"}
SUPPORTED_EXTENSIONS = {
".pdf", ".docx", ".doc", ".txt", ".md",
".html", ".pptx", ".xlsx", ".json"
}
def __init__(
self,
@@ -32,13 +40,11 @@ class DocumentLoader:
extract_images: 是否提取 PDF 中的图片
strategy: 解析策略 (auto, fast, hi_res, ocr_only)
ocr_languages: OCR 语言列表,如 ['chi_sim', 'eng']
languages: 文档主语言,如 ['zh']
languages: 文档主语言,如 ['zh'](主要用于非 OCR 场景)
include_page_breaks: 是否包含分页符
pdf_infer_table_structure: 是否识别表格结构 (需 hi_res 策略)
pdf_infer_table_structure: 是否识别表格结构需 hi_res 策略
partition_kwargs: 额外的 partition 参数字典(高级定制)
"""
import os
os.environ["UNSTRUCTURED_LANGUAGE_CHECKS"] = "false"
self.extract_images = extract_images
self.strategy = strategy
self.ocr_languages = ocr_languages or ["chi_sim", "eng"]
@@ -47,6 +53,52 @@ class DocumentLoader:
self.pdf_infer_table_structure = pdf_infer_table_structure
self.partition_kwargs = partition_kwargs or {}
def _build_partition_kwargs(self, file_path: Path) -> Dict[str, Any]:
"""根据文件类型构建 partition 的参数。"""
kwargs: Dict[str, Any] = {
"include_page_breaks": self.include_page_breaks,
}
suffix = file_path.suffix.lower()
# PDF 专用参数
if suffix == ".pdf":
kwargs.update({
"strategy": self.strategy,
"ocr_languages": self.ocr_languages,
"extract_images_in_pdf": self.extract_images,
"pdf_infer_table_structure": self.pdf_infer_table_structure,
})
# 所有文件适用的语言参数
if self.languages:
kwargs["languages"] = self.languages
# 用户自定义参数覆盖默认值
kwargs.update(self.partition_kwargs)
return kwargs
def _element_to_document(self, element: Element, file_path: Path) -> Optional[Document]:
"""将单个 Element 转换为 Document同时保留关键元数据。"""
text = getattr(element, "text", "")
if not text or not text.strip():
return None
# 提取 unstructured 提供的元数据(根据实际需要选择)
metadata = {
"source": str(file_path),
"file_name": file_path.name,
"file_type": file_path.suffix.lower(),
# 以下元数据来自 Element 对象,可能为 None
"page_number": getattr(getattr(element, "metadata", None), "page_number", None),
"category": getattr(getattr(element, "metadata", None), "category", None),
}
# 过滤掉值为 None 的元数据
metadata = {k: v for k, v in metadata.items() if v is not None}
return Document(page_content=text, metadata=metadata)
def load_file(self, file_path: Union[str, Path]) -> List[Document]:
"""将单个文件加载为 LangChain Document 对象。"""
file_path = Path(file_path).resolve()
@@ -59,68 +111,58 @@ class DocumentLoader:
f"不支持的文件扩展名: {suffix}。支持的格式: {self.SUPPORTED_EXTENSIONS}"
)
# 根据文件类型动态调整参数
extra_kwargs = {}
if suffix == ".pdf":
extra_kwargs["strategy"] = self.strategy
extra_kwargs["ocr_languages"] = self.ocr_languages
extra_kwargs["extract_images_in_pdf"] = self.extract_images
extra_kwargs["pdf_infer_table_structure"] = self.pdf_infer_table_structure
# languages 参数适用于所有文件类型
if self.languages:
extra_kwargs["languages"] = self.languages
extra_kwargs["include_page_breaks"] = self.include_page_breaks
kwargs = self._build_partition_kwargs(file_path)
# 合并用户自定义的额外参数(优先级最高)
extra_kwargs.update(self.partition_kwargs)
# 使用 unstructured 解析
elements = partition(
filename=str(file_path),
**extra_kwargs
)
try:
elements = partition(filename=str(file_path), **kwargs)
except Exception as e:
logger.exception("解析文件 %s 失败", file_path)
raise RuntimeError(f"文件解析失败: {file_path}") from e
documents = []
for elem in elements:
text = getattr(elem, "text", "")
if not text or not text.strip():
continue
# 基础元数据
metadata = {
"source": str(file_path),
"file_name": file_path.name,
"file_type": suffix,
}
documents.append(Document(page_content=text, metadata=metadata))
doc = self._element_to_document(elem, file_path)
if doc:
documents.append(doc)
if not documents:
logger.warning("未从 %s 提取到文本内容", file_path)
return []
return documents
def load_directory(
self, directory_path: Union[str, Path], recursive: bool = True
self,
directory_path: Union[str, Path],
recursive: bool = True,
fail_fast: bool = False
) -> List[Document]:
"""从目录加载所有支持的文件。"""
"""
从目录加载所有支持的文件。
Args:
directory_path: 目录路径
recursive: 是否递归子目录
fail_fast: 遇到第一个失败时是否立即抛出异常
"""
directory_path = Path(directory_path).resolve()
if not directory_path.is_dir():
raise NotADirectoryError(f"不是目录: {directory_path}")
all_documents = []
all_documents: List[Document] = []
pattern = "**/*" if recursive else "*"
for file_path in directory_path.glob(pattern):
if file_path.is_file() and file_path.suffix.lower() in self.SUPPORTED_EXTENSIONS:
try:
docs = self.load_file(file_path)
all_documents.extend(docs)
except Exception as e:
logger.error("加载 %s 失败: %s", file_path, e)
if not file_path.is_file():
continue
if file_path.suffix.lower() not in self.SUPPORTED_EXTENSIONS:
continue
try:
docs = self.load_file(file_path)
all_documents.extend(docs)
except Exception as e:
logger.error("加载 %s 失败: %s", file_path, e)
if fail_fast:
raise
return all_documents

View File

@@ -3,7 +3,8 @@
"""
from enum import Enum
from typing import List, Optional
from typing import List, Optional, Tuple, Dict, Any
from dataclasses import dataclass, field
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
@@ -16,68 +17,195 @@ class SplitterType(str, Enum):
PARENT_CHILD = "parent_child"
def get_splitter(splitter_type: SplitterType, **kwargs):
"""工厂函数,创建文本切分器。"""
if splitter_type == SplitterType.RECURSIVE:
chunk_size = kwargs.get("chunk_size", 500)
chunk_overlap = kwargs.get("chunk_overlap", 50)
return RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
separators=["\n\n", "\n", "", "", "", " ", ""],
)
elif splitter_type == SplitterType.SEMANTIC:
embeddings = kwargs.pop("embeddings", None)
if embeddings is None:
raise ValueError("语义切分器需要提供 'embeddings' 参数")
return SemanticChunkerAdapter(embeddings=embeddings, **kwargs)
else:
raise ValueError(f"不支持的切分器类型: {splitter_type}")
# ---------- 配置数据类,统一参数 ----------
@dataclass
class RecursiveSplitterConfig:
"""递归字符切分器配置"""
chunk_size: int = 500
chunk_overlap: int = 50
separators: List[str] = field(default_factory=lambda: ["\n\n", "\n", "", "", "", " ", ""])
keep_separator: bool = True
strip_whitespace: bool = True
@dataclass
class SemanticSplitterConfig:
"""语义切分器配置,仅包含 SemanticChunker 支持的参数。"""
embeddings: Any
buffer_size: int = 1
add_start_index: bool = False
breakpoint_threshold_type: str = "percentile"
breakpoint_threshold_amount: Optional[float] = None
number_of_chunks: Optional[int] = None
sentence_split_regex: str = r"(?<=[.?!。?!])\s+"
min_chunk_size: int = 100
@dataclass
class ParentChildSplitterConfig:
"""父子切分器配置"""
embeddings: Any # 子块语义切分所需
parent_chunk_size: int = 1000
parent_chunk_overlap: int = 100
child_buffer_size: int = 1
child_breakpoint_threshold_type: str = "percentile"
child_breakpoint_threshold_amount: Optional[float] = None
child_min_chunk_size: int = 100
child_max_chunk_size: Optional[int] = 200
# ---------- 适配器:让 SemanticChunker 实现 TextSplitter 接口 ----------
class SemanticChunkerAdapter(TextSplitter):
"""将 SemanticChunker 适配为 TextSplitter 接口。"""
"""将 SemanticChunker 适配为 LangChain TextSplitter 接口。"""
def __init__(self, embeddings, **kwargs):
def __init__(self, config: SemanticSplitterConfig, **kwargs):
super().__init__(**kwargs)
chunk_size = kwargs.pop("chunk_size", None)
chunk_overlap = kwargs.pop("chunk_overlap", None)
self._chunker = SemanticChunker(embeddings=embeddings, **kwargs)
self._config = config
self._chunker = SemanticChunker(
embeddings=config.embeddings,
buffer_size=config.buffer_size,
add_start_index=config.add_start_index,
breakpoint_threshold_type=config.breakpoint_threshold_type,
breakpoint_threshold_amount=config.breakpoint_threshold_amount,
number_of_chunks=config.number_of_chunks,
sentence_split_regex=config.sentence_split_regex,
min_chunk_size=config.min_chunk_size,
)
def split_text(self, text: str) -> List[str]:
return self._chunker.split_text(text)
def split_documents(self, documents: List[Document]) -> List[Document]:
result = []
for doc in documents:
chunks = self.split_text(doc.page_content)
for i, chunk in enumerate(chunks):
result.append(Document(
page_content=chunk,
metadata={**doc.metadata, "chunk_index": i}
))
return result
# ---------- 工厂函数,统一创建切分器 ----------
def get_splitter(splitter_type: SplitterType, **kwargs) -> TextSplitter:
"""
根据类型创建切分器。
支持传入配置对象或直接参数。
"""
if splitter_type == SplitterType.RECURSIVE:
config = RecursiveSplitterConfig(
chunk_size=kwargs.get("chunk_size", 500),
chunk_overlap=kwargs.get("chunk_overlap", 50),
separators=kwargs.get("separators", ["\n\n", "\n", "", "", "", " ", ""]),
)
return RecursiveCharacterTextSplitter(
chunk_size=config.chunk_size,
chunk_overlap=config.chunk_overlap,
separators=config.separators,
keep_separator=config.keep_separator,
strip_whitespace=config.strip_whitespace,
)
elif splitter_type == SplitterType.SEMANTIC:
embeddings = kwargs.get("embeddings")
if embeddings is None:
raise ValueError("语义切分器需要提供 'embeddings' 参数")
if "config" in kwargs and isinstance(kwargs["config"], SemanticSplitterConfig):
config = kwargs["config"]
else:
# 过滤出 SemanticSplitterConfig 支持的字段
config_kwargs = {
"embeddings": embeddings,
"buffer_size": kwargs.get("buffer_size", 1),
"breakpoint_threshold_type": kwargs.get("breakpoint_threshold_type", "percentile"),
"breakpoint_threshold_amount": kwargs.get("breakpoint_threshold_amount"),
"number_of_chunks": kwargs.get("number_of_chunks"),
"min_chunk_size": kwargs.get("min_chunk_size", 100),
}
config = SemanticSplitterConfig(**config_kwargs)
return SemanticChunkerAdapter(config)
elif splitter_type == SplitterType.PARENT_CHILD:
# 父子切分器在 builder 中单独处理,不通过本函数创建
raise ValueError("父子切分器应通过 IndexBuilder 创建,不支持 get_splitter 直接构建")
else:
raise ValueError(f"不支持的切分器类型: {splitter_type}")
# ---------- 父子切分器实现 ----------
class ParentChildSplitter:
"""
将文档切分为父块(大块)和子块(小块)。
子块用于索引检索,父块用于存储上下文
将文档切分为父块(大块,用于上下文)和子块(小块,用于索引检索)。
内部维护父子块之间的映射关系
"""
def __init__(
self,
parent_chunk_size: int = 1000,
child_chunk_size: int = 200,
parent_chunk_overlap: int = 100,
child_chunk_overlap: int = 20,
):
def __init__(self, config: ParentChildSplitterConfig):
self.config = config
# 父块使用递归字符切分
self.parent_splitter = RecursiveCharacterTextSplitter(
chunk_size=parent_chunk_size,
chunk_overlap=parent_chunk_overlap,
chunk_size=config.parent_chunk_size,
chunk_overlap=config.parent_chunk_overlap,
)
self.child_splitter = RecursiveCharacterTextSplitter(
chunk_size=child_chunk_size,
chunk_overlap=child_chunk_overlap,
# 子块使用语义切分
semantic_config = SemanticSplitterConfig(
embeddings=config.embeddings,
buffer_size=config.child_buffer_size,
breakpoint_threshold_type=config.child_breakpoint_threshold_type,
breakpoint_threshold_amount=config.child_breakpoint_threshold_amount,
min_chunk_size=config.child_min_chunk_size,
)
self.child_splitter = SemanticChunkerAdapter(semantic_config)
def split_documents(self, documents: List[Document]) -> tuple[List[Document], List[Document]]:
# 存储父子块映射关系(可选)
self.parent_to_children: Dict[str, List[str]] = {}
self.child_to_parent: Dict[str, str] = {}
def split_documents(self, documents: List[Document]) -> Tuple[List[Document], List[Document]]:
"""
返回:
(父块列表, 子块列表)
同时填充内部映射字典。
"""
parent_chunks = self.parent_splitter.split_documents(documents)
child_chunks = self.child_splitter.split_documents(documents)
# 将子块与父块 ID 关联(可选元数据
# 在实际实现中,需要将每个子块映射到对应的父块 ID。
return parent_chunks, child_chunks
# 建立映射关系(简化示例:根据文本包含关系粗略匹配,实际需更精确的算法
# 这里仅作示意,生产环境建议使用 embedding 相似度或精确子串定位
self._build_mappings(parent_chunks, child_chunks)
return parent_chunks, child_chunks
def _build_mappings(self, parents: List[Document], children: List[Document]) -> None:
"""
根据文本内容建立父子映射。
本方法为简化实现,实际使用时请替换为更可靠的匹配逻辑。
"""
self.parent_to_children.clear()
self.child_to_parent.clear()
# 为每个父块生成唯一 ID若无则使用索引
for p_idx, parent in enumerate(parents):
parent_id = parent.metadata.get("id", f"parent_{p_idx}")
parent.metadata["id"] = parent_id
self.parent_to_children[parent_id] = []
# 将每个子块分配给包含其文本的第一个父块
for c_idx, child in enumerate(children):
child_id = child.metadata.get("id", f"child_{c_idx}")
child.metadata["id"] = child_id
for parent in parents:
if child.page_content in parent.page_content:
parent_id = parent.metadata["id"]
self.parent_to_children[parent_id].append(child_id)
self.child_to_parent[child_id] = parent_id
child.metadata["parent_id"] = parent_id
break
def get_parent_for_child(self, child_id: str) -> Optional[str]:
"""根据子块 ID 获取父块 ID"""
return self.child_to_parent.get(child_id)
def get_children_for_parent(self, parent_id: str) -> List[str]:
"""根据父块 ID 获取所有子块 ID"""
return self.parent_to_children.get(parent_id, [])

View File

@@ -1,31 +0,0 @@
"""
文档存储模块 - 用于 ParentDocumentRetriever 的父文档存储。
提供 PostgreSQL 存储后端:
- PostgresDocStore: PostgreSQL 数据库存储(生产环境)
示例用法:
>>> from rag_indexer.store import create_docstore
>>> # 创建 PostgreSQL 存储
>>> store, conn = create_docstore(
... connection_string="postgresql://user:pass@host:5432/db",
... table_name="parent_docs"
... )
"""
from .postgres import PostgresDocStore
from .factory import create_docstore, get_docstore_uri, DEFAULT_DB_URI
__version__ = "2.0.0"
__all__ = [
# 具体实现
"PostgresDocStore",
# 工厂函数
"create_docstore",
"get_docstore_uri",
"DEFAULT_DB_URI",
]

View File

@@ -1,73 +0,0 @@
"""
文档存储工厂 - 创建不同类型的存储实例。
提供统一的接口来创建本地文件存储或 PostgreSQL 存储。
"""
import os
import logging
from typing import Optional, Tuple
from langchain_core.stores import BaseStore
from .postgres import PostgresDocStore
logger = logging.getLogger(__name__)
# 默认连接字符串(从环境变量读取)
DEFAULT_DB_URI = os.getenv(
"DB_URI",
"postgresql://postgres:huang1998@ai-postgres:5432/langgraph_db?sslmode=disable"
)
def get_docstore_uri() -> str:
"""获取 docstore 专用的数据库连接字符串(可与主库相同)"""
return os.getenv("DOCSTORE_URI", DEFAULT_DB_URI)
def create_docstore(
store_type: str = "postgres",
connection_string: Optional[str] = None,
table_name: str = "parent_documents",
pool_config: Optional[dict] = None,
max_concurrency: Optional[int] = None
) -> Tuple[BaseStore, Optional[str]]:
"""
工厂函数,创建 PostgreSQL 文档存储。
Args:
store_type: 存储类型,目前仅支持 "postgres"(默认)
connection_string: PostgreSQL 连接字符串
table_name: PostgreSQL 表名默认parent_documents
pool_config: 连接池配置
max_concurrency: 最大并发操作数,如果为 None 则不限制
Returns:
元组 (存储实例, 连接字符串)
Raises:
ValueError: 不支持的存储类型
ImportError: 缺少必要的依赖
Example:
>>> # 创建 PostgreSQL 存储
>>> store, conn = create_docstore(
... connection_string="postgresql://user:pass@host:5432/db",
... table_name="parent_docs",
... max_concurrency=10
... )
"""
store_type = store_type.lower()
if store_type == "postgres":
conn_str = connection_string or get_docstore_uri()
store = PostgresDocStore(
connection_string=conn_str,
table_name=table_name,
pool_config=pool_config,
max_concurrency=max_concurrency
)
return store, conn_str
else:
raise ValueError(f"不支持的存储类型: {store_type}。目前仅支持: postgres")

View File

@@ -1,249 +0,0 @@
"""
异步 PostgreSQL 存储实现 - 用于生产环境。
使用 asyncpg 实现真正的异步 PostgreSQL 文档存储,支持高并发访问。
"""
from __future__ import annotations
import asyncio
import json
import logging
from typing import List, Dict, Any, Optional, Iterator, Tuple, Sequence, cast
from langchain_core.documents import Document
from langchain_core.stores import BaseStore
import asyncpg
logger = logging.getLogger(__name__)
class PostgresDocStore(BaseStore[str, Any]):
"""
异步 PostgreSQL 文档存储实现。
使用 asyncpg 作为异步 PostgreSQL 客户端,支持:
- 真正的异步操作
- 连接池管理
- 自动表创建
- 批量操作amget/amset/amdelete
- JSONB 数据存储
- 并发控制
适用于生产环境,提供高性能的异步数据持久化。
Attributes:
dsn: PostgreSQL 连接字符串
table_name: 存储表名,默认为 "parent_documents"
_pool: asyncpg 连接池实例
_semaphore: 控制并发数的信号量(可选)
"""
def __init__(
self,
connection_string: str,
table_name: str = "parent_documents",
pool_config: Optional[Dict[str, Any]] = None,
max_concurrency: Optional[int] = None
):
"""
初始化异步 PostgreSQL 文档存储。
Args:
connection_string: PostgreSQL 连接 URL格式
"postgresql://user:password@host:port/database?sslmode=disable"
table_name: 存储表名,默认为 "parent_documents"
pool_config: 连接池配置字典,包含:
- min_size: 最小连接数(默认 2
- max_size: 最大连接数(默认 10
max_concurrency: 最大并发操作数,如果为 None 则不限制
Raises:
ImportError: 未安装 asyncpg 时抛出
Example:
>>> store = PostgresDocStore(
... "postgresql://user:pass@localhost:5432/mydb",
... table_name="parent_docs",
... pool_config={"min_size": 5, "max_size": 20},
... max_concurrency=10
... )
"""
self.dsn = connection_string
self.table_name = table_name
self._pool: Optional["asyncpg.Pool"] = None
self._pool_config = pool_config or {}
# 并发控制信号量
self._semaphore = None
if max_concurrency is not None and max_concurrency > 0:
self._semaphore = asyncio.Semaphore(max_concurrency)
# 注意:连接池的异步初始化延迟到第一次使用时
# 表结构创建也延迟到第一次操作时
async def _get_pool(self):
"""获取或创建 asyncpg 连接池。"""
if self._pool is None:
import asyncpg
min_size = self._pool_config.get("min_size", 2)
max_size = self._pool_config.get("max_size", 10)
try:
self._pool = await asyncpg.create_pool(
dsn=self.dsn,
min_size=min_size,
max_size=max_size
)
logger.info(f"PostgreSQL 异步连接池已创建: {self.table_name}")
# 初始化表结构
await self._create_table()
except Exception as e:
raise RuntimeError(f"PostgreSQL 异步连接池创建失败: {e}")
return self._pool
async def _create_table(self):
"""创建存储表(如果不存在)。"""
pool = await self._get_pool()
async with pool.acquire() as conn:
async with conn.transaction():
await conn.execute(f"""
CREATE TABLE IF NOT EXISTS {self.table_name} (
key TEXT PRIMARY KEY,
value JSONB NOT NULL,
created_at TIMESTAMPTZ DEFAULT NOW()
)
""")
logger.info(f"{self.table_name} 已就绪")
async def _with_concurrency_control(self, coro):
"""使用信号量控制并发执行。"""
if self._semaphore is None:
return await coro
async with self._semaphore:
return await coro
# --- 同步方法(保持兼容性,但功能有限)---
def mget(self, keys: Sequence[str]) -> List[Optional[Any]]:
"""不支持同步操作,请使用异步 amget 方法。"""
raise NotImplementedError("不支持同步操作,请使用异步 amget 方法。")
def mset(self, key_value_pairs: Sequence[Tuple[str, Any]]) -> None:
"""不支持同步操作,请使用异步 amset 方法。"""
raise NotImplementedError("不支持同步操作,请使用异步 amset 方法。")
def mdelete(self, keys: Sequence[str]) -> None:
"""不支持同步操作,请使用异步 amdelete 方法。"""
raise NotImplementedError("不支持同步操作,请使用异步 amdelete 方法。")
def yield_keys(self, *, prefix: str | None = None) -> Iterator[str]:
"""不支持同步操作,请使用异步 ayield_keys 方法。"""
raise NotImplementedError("不支持同步操作,请使用异步 ayield_keys 方法。")
# --- 异步方法(真正的实现)---
async def amget(self, keys: Sequence[str]) -> List[Optional[Any]]:
"""异步批量获取文档。"""
if not keys:
return []
async def _amget():
pool = await self._get_pool()
async with pool.acquire() as conn:
rows = await conn.fetch(
f"SELECT key, value FROM {self.table_name} WHERE key = ANY($1)",
keys
)
result_map = {}
for row in rows:
val = row['value']
if isinstance(val, str):
val = json.loads(val)
if isinstance(val, dict) and 'page_content' in val:
result_map[row['key']] = Document(**val)
else:
result_map[row['key']] = val
return [result_map.get(key) for key in keys]
return await self._with_concurrency_control(_amget())
async def amset(self, key_value_pairs: Sequence[Tuple[str, Any]]) -> None:
"""异步批量设置文档。"""
if not key_value_pairs:
return
async def _amset():
pool = await self._get_pool()
async with pool.acquire() as conn:
async with conn.transaction():
await conn.executemany(
f"""
INSERT INTO {self.table_name} (key, value)
VALUES ($1, $2)
ON CONFLICT (key) DO UPDATE SET value = EXCLUDED.value
""",
[
(k, json.dumps(v.dict() if isinstance(v, Document) else v, ensure_ascii=False))
for k, v in key_value_pairs
]
)
logger.debug(f"已异步批量设置 {len(key_value_pairs)} 个文档")
await self._with_concurrency_control(_amset())
async def amdelete(self, keys: Sequence[str]) -> None:
"""异步批量删除文档。"""
if not keys:
return
async def _amdelete():
pool = await self._get_pool()
async with pool.acquire() as conn:
async with conn.transaction():
await conn.execute(
f"DELETE FROM {self.table_name} WHERE key = ANY($1)",
keys
)
logger.debug(f"已异步批量删除 {len(keys)} 个文档")
await self._with_concurrency_control(_amdelete())
async def ayield_keys(self, *, prefix: str | None = None) -> Iterator[str]:
"""异步迭代所有键。
注意:这是一个异步生成器,需要使用 async for 迭代。
"""
pool = await self._get_pool()
async with pool.acquire() as conn:
if prefix:
rows = await conn.fetch(
f"SELECT key FROM {self.table_name} WHERE key LIKE $1 ORDER BY key",
f"{prefix}%"
)
else:
rows = await conn.fetch(
f"SELECT key FROM {self.table_name} ORDER BY key"
)
for row in rows:
yield row['key']
async def aclose(self) -> None:
"""异步关闭连接池,释放资源。"""
if self._pool:
await self._pool.close()
self._pool = None
logger.info("PostgreSQL 异步连接池已关闭")
def close(self) -> None:
"""同步关闭连接池(功能有限)。
注意:在异步环境中,请使用 aclose 方法。
"""
pass

View File

@@ -0,0 +1,80 @@
"""清理 RAG 索引数据。
用法:
python reset_index.py # 清理全部
python reset_index.py --qdrant # 仅清理 Qdrant
python reset_index.py --postgres # 仅清理 PostgreSQL
"""
import asyncio
import os
import argparse
from dotenv import load_dotenv
load_dotenv()
QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
DB_URI = os.getenv("DB_URI", "postgresql://postgres:huang1998@115.190.121.151:5432/langgraph_db?sslmode=disable")
COLLECTION_NAME = "rag_documents"
TABLE_NAME = "parent_documents"
def clear_qdrant():
"""删除 Qdrant 集合。"""
from qdrant_client import QdrantClient
print("清理 Qdrant...")
client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
collections = client.get_collections().collections
if any(c.name == COLLECTION_NAME for c in collections):
client.delete_collection(COLLECTION_NAME)
print(f" 集合 '{COLLECTION_NAME}' 已删除")
else:
print(f" 集合 '{COLLECTION_NAME}' 不存在")
async def clear_postgres():
"""清空 PostgreSQL 表数据。"""
import asyncpg
print("清理 PostgreSQL...")
conn = await asyncpg.connect(dsn=DB_URI)
try:
exists = await conn.fetchval(
"SELECT EXISTS(SELECT 1 FROM information_schema.tables WHERE table_schema = 'public' AND table_name = $1)",
TABLE_NAME
)
if exists:
count = await conn.fetchval(f"SELECT COUNT(*) FROM {TABLE_NAME}")
await conn.execute(f"DELETE FROM {TABLE_NAME}")
print(f"'{TABLE_NAME}' 已清空,删除 {count} 条记录")
else:
print(f"'{TABLE_NAME}' 不存在")
finally:
await conn.close()
async def main():
parser = argparse.ArgumentParser(description="清理 RAG 索引数据")
parser.add_argument("--qdrant", action="store_true", help="仅清理 Qdrant")
parser.add_argument("--postgres", action="store_true", help="仅清理 PostgreSQL")
args = parser.parse_args()
if not args.qdrant and not args.postgres:
args.qdrant = True
args.postgres = True
if args.qdrant:
clear_qdrant()
if args.postgres:
await clear_postgres()
print("\n完成。运行 `python -m rag_indexer.cli` 重建索引")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,63 @@
"""检查 Qdrant 中存储的向量质量。"""
import os
import sys
import numpy as np
from dotenv import load_dotenv
from qdrant_client import QdrantClient
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../.."))
from rag_core import LlamaCppEmbedder
load_dotenv()
QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
COLLECTION_NAME = "rag_documents"
client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
embedder = LlamaCppEmbedder()
# 获取样本
points, _ = client.scroll(
collection_name=COLLECTION_NAME,
limit=1,
with_vectors=True,
with_payload=True,
)
if not points:
print(f"集合 '{COLLECTION_NAME}' 为空")
exit()
sample = points[0]
raw_vec = sample.vector
if isinstance(raw_vec, dict):
stored_vec = list(raw_vec.values())[0]
elif isinstance(raw_vec, list):
stored_vec = raw_vec
else:
stored_vec = []
stored_payload = sample.payload or {}
stored_text = str(stored_payload.get("page_content", ""))[:200]
print(f"内容预览:\n{stored_text}...\n")
print(f"向量维度: {len(stored_vec)}") # type: ignore
print(f"前5个值: {stored_vec[:5]}") # type: ignore
print(f"是否全零: {all(v == 0.0 for v in stored_vec)}") # type: ignore
# 重新编码对比
if stored_text:
new_vec = embedder.embed_query(stored_text)
similarity = np.dot(stored_vec, new_vec) / (np.linalg.norm(stored_vec) * np.linalg.norm(new_vec)) # type: ignore
print(f"\n重新编码前5个值: {new_vec[:5]}")
print(f"余弦相似度: {similarity:.4f}")
if similarity < 0.8:
print("\n⚠️ 相似度过低,建议删除集合并重建索引")
else:
print("\n✅ 向量一致")
else:
print("\n⚠️ 样本无文本内容")

View File

@@ -0,0 +1,83 @@
#!/usr/bin/env python3
"""
测试重构后的 IndexBuilder 和 RAGRetriever
"""
import asyncio
import os
import sys
# 添加项目根目录到 Python 路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
from rag_indexer.IndexBuilder import IndexBuilder
from rag_indexer.splitters import SplitterType
async def test_index_builder():
"""测试索引构建功能"""
print("测试索引构建功能...")
# 创建 IndexBuilder 实例
builder = IndexBuilder(
collection_name="test_collection",
splitter_type=SplitterType.PARENT_CHILD,
parent_chunk_size=1000,
child_chunk_size=200
)
# 测试文档路径
test_file = os.path.join(os.path.dirname(__file__), "..", "data", "corpus", "三国演义.txt")
if os.path.exists(test_file):
# 构建索引
print(f"正在为文件 {test_file} 构建索引...")
processed = await builder.build_from_file(test_file)
print(f"索引构建完成,处理了 {processed} 个文档")
# 获取集合信息
info = builder.get_collection_info()
print(f"集合信息: {info}")
else:
print(f"测试文件不存在: {test_file}")
# 测试搜索功能
print("\n测试搜索功能...")
try:
results = builder.search("吕布", k=3)
print(f"搜索结果数量: {len(results)}")
for i, result in enumerate(results):
print(f"\n结果 {i+1}:")
print(f"内容: {result.page_content[:100]}...")
except Exception as e:
print(f"搜索测试失败: {e}")
# 测试带父块上下文的搜索
print("\n测试带父块上下文的搜索...")
try:
results = await builder.search_with_parent_context("吕布", k=3)
print(f"搜索结果数量: {len(results)}")
for i, result in enumerate(results):
print(f"\n结果 {i+1}:")
print(f"内容: {result.page_content[:100]}...")
except Exception as e:
print(f"带父块上下文的搜索测试失败: {e}")
# 测试统一检索接口
print("\n测试统一检索接口...")
try:
# 返回父块
results_parent = await builder.retrieve("吕布", return_parent=True)
print(f"返回父块的结果数量: {len(results_parent)}")
# 返回子块
results_child = await builder.retrieve("吕布", return_parent=False)
print(f"返回子块的结果数量: {len(results_child)}")
except Exception as e:
print(f"统一检索接口测试失败: {e}")
# 关闭资源
builder.close()
print("\n测试完成")
if __name__ == "__main__":
asyncio.run(test_index_builder())

View File

@@ -0,0 +1,188 @@
"""
验证 RAG 索引完整性。
检查 Qdrant 向量库、PostgreSQL 文档存储及检索功能。
"""
import asyncio
import os
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../.."))
from dotenv import load_dotenv
load_dotenv()
QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
DB_URI = os.getenv("DB_URI", "postgresql://postgres:huang1998@115.190.121.151:5432/langgraph_db?sslmode=disable")
COLLECTION_NAME = "rag_documents"
TABLE_NAME = "parent_documents"
def check_qdrant():
"""检查 Qdrant 向量库。"""
from qdrant_client import QdrantClient
print("=" * 60)
print("Qdrant 向量库")
print("=" * 60)
client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
# 集合列表
collections = client.get_collections().collections
print(f"\n集合数: {len(collections)}")
for c in collections:
print(f" - {c.name}")
# 目标集合信息
if not any(c.name == COLLECTION_NAME for c in collections):
print(f"\n集合 '{COLLECTION_NAME}' 不存在")
return
info = client.get_collection(COLLECTION_NAME)
print(f"\n集合 '{COLLECTION_NAME}':")
print(f" 状态: {info.status}")
print(f" 向量数: {info.points_count}")
vectors_config = info.config.params.vectors
if isinstance(vectors_config, dict):
for name, vc in vectors_config.items():
print(f" 向量 '{name}': 维度={vc.size}, 距离={vc.distance}")
else:
print(f" 向量维度: {vectors_config.size}")
# 抽样查看
print(f"\n前 3 个向量:")
points = client.scroll(
collection_name=COLLECTION_NAME,
limit=3,
with_payload=True,
with_vectors=False
)
for i, point in enumerate(points[0]):
print(f"\n {i+1}. ID: {point.id}")
payload = point.payload or {}
print(f" 内容: {payload.get('page_content', '')[:100]}...")
async def check_postgres():
"""检查 PostgreSQL 文档存储。"""
import asyncpg
print("\n" + "=" * 60)
print("PostgreSQL 文档存储")
print("=" * 60)
conn = await asyncpg.connect(dsn=DB_URI)
try:
# 表是否存在
tables = await conn.fetch(
"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'"
)
table_names = [t['table_name'] for t in tables]
if TABLE_NAME not in table_names:
print(f"\n'{TABLE_NAME}' 不存在")
return
# 统计
count = await conn.fetchval(f"SELECT COUNT(*) FROM {TABLE_NAME}")
print(f"\n'{TABLE_NAME}': {count} 条记录")
# 抽样
print(f"\n前 3 个文档:")
rows = await conn.fetch(
f"SELECT key, value FROM {TABLE_NAME} ORDER BY key LIMIT 3"
)
for i, row in enumerate(rows):
print(f"\n {i+1}. Key: {row['key']}")
val = row['value']
if isinstance(val, dict) and 'page_content' in val:
print(f" 内容: {val['page_content'][:100]}...")
# Key 前缀分布
key_prefixes = await conn.fetch(
f"""
SELECT
CASE
WHEN key LIKE '%:%' THEN split_part(key, ':', 1)
ELSE 'no_prefix'
END AS prefix,
COUNT(*) AS cnt
FROM {TABLE_NAME}
GROUP BY prefix
ORDER BY cnt DESC
LIMIT 10
"""
)
print(f"\nKey 前缀分布:")
for row in key_prefixes:
print(f" {row['prefix']}: {row['cnt']}")
finally:
await conn.close()
async def test_search():
"""测试检索功能。"""
from rag_indexer.IndexBuilder import IndexBuilder, IndexBuilderConfig
from rag_indexer.splitters import SplitterType
print("\n" + "=" * 60)
print("检索测试")
print("=" * 60)
# 使用配置对象初始化(与默认构建方式一致)
config = IndexBuilderConfig(
collection_name=COLLECTION_NAME,
splitter_type=SplitterType.PARENT_CHILD,
)
builder = IndexBuilder(config)
# 确保检索器已初始化
if builder.retriever is None:
print("错误: 检索器未初始化,请检查切分策略")
return
query = input("\n查询 (回车使用默认): ").strip() or "你好"
print(f"\n查询: {query}")
# 标准检索(返回父块,因为 ParentDocumentRetriever 默认返回父块)
print("\n--- 标准检索 (返回父块) ---")
results = await builder.retriever.ainvoke(query)
for i, doc in enumerate(results):
content = doc.page_content[:200] if hasattr(doc, 'page_content') else str(doc)[:200]
print(f"\n {i+1}. {content}...")
if hasattr(doc, 'metadata'):
source = doc.metadata.get('source', '')
if source:
print(f" 来源: {source}")
# 若需要仅返回子块,可以临时修改检索器的 search_type
# 注意ParentDocumentRetriever 的 search_type 默认为 "similarity"
print("\n--- 检索子块 (通过修改检索器参数) ---")
# 创建一个新的检索器副本,设置为返回子块
# 简单起见,直接调用 vectorstore 进行相似度搜索获取子块
vectorstore = builder.vector_store.get_langchain_vectorstore()
sub_results = await vectorstore.asimilarity_search(query, k=3)
for i, doc in enumerate(sub_results):
content = doc.page_content[:200] if hasattr(doc, 'page_content') else str(doc)[:200]
print(f"\n {i+1}. {content}...")
if hasattr(doc, 'metadata'):
parent_id = doc.metadata.get('parent_id', '')
if parent_id:
print(f" 父块 ID: {parent_id}")
async def main():
check_qdrant()
await check_postgres()
await test_search()
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,133 +0,0 @@
"""
Qdrant 向量数据库包装器。
"""
import logging
import os
from typing import List, Optional, Dict, Any
from langchain_core.documents import Document
from langchain_qdrant import QdrantVectorStore as LangchainQdrantVS
from qdrant_client import QdrantClient
from qdrant_client.http import models
from qdrant_client.http.models import Distance, VectorParams
from .embedders import LlamaCppEmbedder
logger = logging.getLogger(__name__)
QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
class QdrantVectorStore:
"""Qdrant 向量数据库操作包装器。"""
def __init__(
self,
collection_name: str,
embeddings: Optional[Any] = None,
):
self.collection_name = collection_name
self._client: Optional[QdrantClient] = None
# 嵌入模型
if embeddings is None:
embedder = LlamaCppEmbedder()
self.embeddings = embedder.as_langchain_embeddings()
else:
self.embeddings = embeddings
# 先创建集合
self.create_collection()
# LangChain 向量存储
self.vector_store = LangchainQdrantVS(
client=self.get_client(),
collection_name=self.collection_name,
embedding=self.embeddings,
)
def get_client(self) -> QdrantClient:
"""懒加载客户端,每次获取时确保连接可用。"""
if self._client is None:
self._client = QdrantClient(
url=QDRANT_URL,
api_key=QDRANT_API_KEY,
timeout=120,
http2=False,
)
return self._client
def refresh_client(self):
"""关闭旧连接,创建新连接。"""
if self._client is not None:
self._client.close()
self._client = None
def create_collection(self, vector_size: Optional[int] = None, force_recreate: bool = False):
"""创建集合,设置合适的向量维度。"""
if vector_size is None:
embedder = LlamaCppEmbedder()
vector_size = embedder.get_embedding_dimension()
client = self.get_client()
collections = client.get_collections().collections
exists = any(c.name == self.collection_name for c in collections)
if exists and force_recreate:
client.delete_collection(self.collection_name)
exists = False
if not exists:
client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE),
)
logger.info("集合 '%s' 已创建(维度=%d", self.collection_name, vector_size)
else:
logger.info("集合 '%s' 已存在", self.collection_name)
def add_documents(self, documents: List[Document], batch_size: int = 100):
"""将文档添加到向量数据库。"""
if not documents:
return []
self.create_collection()
ids = self.vector_store.add_documents(documents, batch_size=batch_size)
logger.info("已向 '%s' 添加 %d 个文档", self.collection_name, len(ids))
return ids
def similarity_search(self, query: str, k: int = 5) -> List[Document]:
return self.vector_store.similarity_search(query, k=k)
def similarity_search_with_score(self, query: str, k: int = 5) -> List[tuple[Document, float]]:
return self.vector_store.similarity_search_with_score(query, k=k)
def delete_collection(self):
self.get_client().delete_collection(self.collection_name)
logger.info("集合 '%s' 已删除", self.collection_name)
def get_collection_info(self) -> Dict[str, Any]:
info = self.get_client().get_collection(self.collection_name)
vectors_config = info.config.params.vectors
if isinstance(vectors_config, dict):
vector_size = next(iter(vectors_config.values())).size
else:
vector_size = vectors_config.size
return {
"name": self.collection_name,
"vectors_count": info.points_count or 0,
"status": info.status,
"vector_size": vector_size,
}
def as_langchain_vectorstore(self):
return self.vector_store
def get_langchain_vectorstore(self):
"""返回 LangChain Qdrant 向量存储对象(别名)"""
return self.vector_store
def get_qdrant_client(self):
"""返回原生 Qdrant 客户端(如需手动管理 collection"""
return self.get_client()