RAG数据库生成

This commit is contained in:
2026-04-19 15:01:40 +08:00
parent c18e8a9860
commit cc8ef41ef9
17 changed files with 1089 additions and 577 deletions

View File

@@ -64,3 +64,14 @@ API_URL=http://backend:8083/chat
# 应用行为配置
# -----------------------------------------------------------------------------
MEMORY_SUMMARIZE_INTERVAL=10
# -----------------------------------------------------------------------------
# unstructured 库 spaCy 模型配置
# -----------------------------------------------------------------------------
# 指定文档解析使用的语言: eng (英语) 或 zho (中文)
UNSTRUCTURED_LANGUAGE=zho
# 指定 spaCy 模型名称(需与 UNSTRUCTURED_LANGUAGE 对应)
# eng -> en_core_web_sm
# zho -> zh_core_web_sm
SPACY_MODEL=zh_core_web_sm

View File

@@ -1,23 +0,0 @@
# RAG 系统依赖
# 基础框架
langchain>=0.1.0
langchain-core>=0.1.0
langchain-openai>=0.0.1
langchain-qdrant>=0.1.0
# 用于 Cross-Encoder 重排序模型
sentence-transformers>=2.2.0
# 用于 BM25 关键词混合检索
rank-bm25>=0.2.2
# Qdrant 客户端
qdrant-client>=1.6.0
# 可选的本地模型支持
# vllm>=0.5.0 # 如果需要本地模型推理
# transformers>=4.35.0 # 如果需要其他模型支持
# 开发依赖(测试用)
pytest>=7.0.0
pytest-asyncio>=0.21.0

View File

@@ -18,6 +18,10 @@ ENV QDRANT_COLLECTION_NAME=mem0_user_memories
ENV MEMORY_SUMMARIZE_INTERVAL=10
ENV ENABLE_GRAPH_TRACE=false
# unstructured 库 spaCy 模型配置
ENV UNSTRUCTURED_LANGUAGE=eng
ENV SPACY_MODEL=en_core_web_sm
# 日志配置
ENV LOG_LEVEL=WARNING
ENV DEBUG=false
@@ -28,6 +32,13 @@ ENV DEBUG=false
COPY requirement.txt .
RUN pip install --no-cache-dir -r requirement.txt
# =============================================================================
# 预下载 spaCy 语言模型(避免容器启动时重复下载)
# =============================================================================
RUN pip install --no-cache-dir spacy && \
python -m spacy download en_core_web_sm && \
python -m spacy download zh_core_web_sm
# =============================================================================
# 复制项目代码 (只复制必需的文件夹,避免依赖被忽略的目录)
# =============================================================================

View File

@@ -51,10 +51,111 @@ graph TD
- **核心思路**: 解决 RAG 领域经典的矛盾——检索时块越小越容易精确命中(去除噪声);但生成回答时,块越大越能给大模型提供充足的上下文背景。
- **实现指南**:
- 使用 `langchain.retrievers` 中的 `ParentDocumentRetriever` 模块。
- 在写入时,你需要同时准备一个底层的 `VectorStore` (即 Qdrant) 和一个 `BaseStore` (比如原生的 `InMemoryStore``Redis`)
- 在写入时,你需要同时准备一个底层的 `VectorStore` (即 Qdrant) 和一个 `BaseStore`
- **推荐方案**: 使用 `LocalFileStore` (默认) 或 `PostgresDocStore` 作为 docstore。
- 将两种不同的 `TextSplitter` 分别赋值给检索器的 `child_splitter``parent_splitter`,然后调用 `.add_documents()` 即可让系统自动完成映射。
### Level 4: GraphRAG 与 多模态 (Graph & Multi-modal)
### Level 3.1: PostgreSQL DocStore 集成
- **核心优势**: 利用 PostgreSQL 作为持久化存储,适合生产环境。使用同步连接池,避免异步复杂度。
- **实现步骤**:
1. **安装依赖**: `pip install psycopg2-binary`
2. **配置连接**: 设置 `DB_URI` 环境变量或直接在代码中指定 PostgreSQL 连接字符串
3. **创建 docstore**: 使用 `PostgresDocStore` 类直接创建
4. **注入到 IndexBuilder**: 在创建 `IndexBuilder` 时通过 `docstore` 参数注入
- **使用示例**:
```python
from rag_indexer.docstore_manager import PostgresDocStore
from rag_indexer.builder import IndexBuilder, SplitterType
# 创建 PostgreSQL docstore
docstore = PostgresDocStore(
connection_string="postgresql://user:pass@host:5432/db",
table_name="parent_documents"
)
# 创建 IndexBuilder 并注入 docstore
builder = IndexBuilder(
collection_name="rag_documents",
splitter_type=SplitterType.PARENT_CHILD,
docstore=docstore,
parent_chunk_size=1000,
child_chunk_size=200,
)
```
### Level 3.2: 语义切分与父子块策略结合
- **核心优势**: 结合语义切分的连贯性和父子块策略的层次化存储优势,实现更精准的检索和更丰富的上下文。
- **实现原理**:
- **父块切分**: 使用递归字符切分创建大块约1000词提供完整的上下文背景
- **子块切分**: 使用语义动态切分创建小块约200词根据语义连贯性动态切分提高检索精度
- **存储机制**: 子块向量存入Qdrant用于精准检索父块内容存入PostgreSQL提供完整上下文
- **使用示例**:
```python
from rag_indexer.builder import IndexBuilder, SplitterType
# 创建 IndexBuilder结合语义切分与父子块策略
builder = IndexBuilder(
collection_name="rag_documents",
splitter_type=SplitterType.PARENT_CHILD,
# 父子块配置
parent_chunk_size=1000,
child_chunk_size=200,
# 子块使用语义切分
child_splitter_type=SplitterType.SEMANTIC,
# PostgreSQL 存储配置
docstore_conn_string="postgresql://user:pass@host:5432/db",
)
```
- **配置参数**:
- `child_splitter_type`: 子块切分器类型,可选 `SplitterType.RECURSIVE`(默认)或 `SplitterType.SEMANTIC`
- 当使用语义切分时系统会自动使用已配置的Embedding模型进行句子级相似度计算
### Level 4: RAG-Fusion (多路改写与倒数排名融合)
- **核心优势**: 通过大模型发散思维,将单一问题改写为多个相似问题,扩大搜索面,再利用数学统计算法合并结果,提高检索的全面性和准确性。
- **实现原理**:
1. **多路查询改写**: 利用LLM将原始查询改写成3-5个不同表述的查询从不同角度表达相同意图
2. **倒数排名融合 (RRF)**: 对每个改写查询的结果进行RRF融合公式为 $RRF\_score(d) = \sum_{q \in Q} \frac{1}{k + rank_q(d)}$,避免单一检索结果主导
3. **结果去重**: 对融合后的结果进行去重,确保返回的文档唯一
- **使用示例**:
```python
from rag_indexer.builder import IndexBuilder, SplitterType
from langchain_openai import OpenAI
# 创建 IndexBuilder
builder = IndexBuilder(
collection_name="rag_documents",
splitter_type=SplitterType.PARENT_CHILD,
parent_chunk_size=1000,
child_chunk_size=200,
docstore_conn_string="postgresql://user:pass@host:5432/db",
)
# 创建语言模型用于查询改写
llm = OpenAI(
openai_api_base="http://localhost:8000/v1",
openai_api_key="no-key-needed",
model_name="Qwen2.5-7B-Instruct",
temperature=0.3,
)
# 使用 RAG-Fusion 检索
query = "如何申请项目资金?"
results = builder.retrieve_with_fusion(
query=query,
llm=llm,
num_queries=3,
k=5,
return_parent=True
)
```
- **配置参数**:
- `llm`: 语言模型实例,用于查询改写
- `num_queries`: 生成的查询数量建议3-5个
- `k`: 返回的文档数量
- `return_parent`: 是否返回父块上下文
### Level 5: GraphRAG 与 多模态 (Graph & Multi-modal)
- **核心算法**: LLM 实体关系抽取 (NER & Relation Extraction)。
- **核心思路**: 解决传统纯向量检索难以处理“跨文档复杂关系推理”的痛点A公司的CEO是谁他名下的B公司主要业务是什么这种需要横跨多页 PDF 的跳跃性问题)。
- **实现指南**:
@@ -63,7 +164,7 @@ graph TD
---
## <20> 所需依赖与安装
## 所需依赖与安装
为了支持完整的文档解析和 Qdrant 写入,需要安装以下 Python 包:
@@ -76,6 +177,12 @@ pip install unstructured pdf2image pdfminer.six
# 用于语义分块 (可选)
pip install langchain-experimental
# 用于 PostgreSQL 存储 (可选,用于 Parent-Child 策略)
pip install psycopg2-binary
# 用于 RAG-Fusion (可选,需要语言模型)
pip install langchain-openai
```
---
@@ -91,12 +198,105 @@ rag_indexer/
├── splitters.py # 负责实现 Recursive、Semantic、Parent-Child 切分逻辑
├── embedders.py # 封装本地 llama.cpp 交互的 Embedding 接口
├── vector_store.py # 封装 Qdrant 写入、Upsert、Collection 初始化操作
├── docstore_manager.py # 文档存储管理器,支持 LocalFileStore 和 PostgreSQL
└── builder.py # 核心编排文件,将上述模块串联成 Pipeline
```
---
## 🔄 工作流程详解
### 数据流向总览
```
┌─────────────────────────────────────────┐
│ builder.py │
│ IndexBuilder 入口 │
└─────────────────┬───────────────────────┘
┌─────────────────▼───────────────────────┐
│ loaders.py │
│ DocumentLoader.load_file() │
│ → 返回 List[Document] │
└─────────────────┬───────────────────────┘
┌─────────────────▼───────────────────────┐
│ ParentDocumentRetriever.add_documents()│
│ ┌─────────────────────────────────┐ │
│ │ parent_splitter (粗切) │ │
│ │ 父块 ~1000 词 │ │
│ └────────────┬────────────────────┘ │
│ │ │
│ ┌────────────▼────────────────────┐ │
│ │ child_splitter (细切) │ │
│ │ 子块 ~200 词 │ │
│ └────────────┬────────────────────┘ │
│ │ │
│ ┌──────────┴──────────┐ │
│ ▼ ▼ │
│ 子块向量 父块原始内容 │
│ │ │ │
│ ▼ ▼ │
│ ┌────────────┐ ┌─────────────────┐ │
│ │vector_store│ │ docstore_manager│ │
│ │ (Qdrant) │ │ (PostgreSQL) │ │
│ └────────────┘ └─────────────────┘ │
└─────────────────────────────────────────┘
```
### 文件职责详解
| 文件 | 职责 | 关键类/函数 |
|------|------|------------|
| **builder.py** | 核心编排,负责串联整个流程 | `IndexBuilder` |
| **loaders.py** | 解析各种文档格式PDF、Word、TXT等 | `DocumentLoader` |
| **splitters.py** | 文本切分策略Recursive/Semantic/Parent-Child | `SplitterType`, `get_splitter()` |
| **embedders.py** | 向量化(封装 llama.cpp embedding 接口) | `LlamaCppEmbedder` |
| **vector_store.py** | Qdrant 向量数据库操作 | `QdrantVectorStore` |
| **docstore_manager.py** | 父文档存储PostgreSQL/本地文件) | `PostgresDocStore`, `get_docstore()` |
### 调用顺序
#### 1. 创建 IndexBuilder入口
```python
from rag_indexer.builder import IndexBuilder, SplitterType
builder = IndexBuilder(
collection_name="my_docs",
splitter_type=SplitterType.PARENT_CHILD,
qdrant_url="http://localhost:6333",
parent_chunk_size=1000,
child_chunk_size=200,
)
```
#### 2. 构建索引
```python
# 方式A从单个文件构建
builder.build_from_file("/path/to/document.pdf")
# 方式B从目录批量构建
builder.build_from_directory("/path/to/docs/")
```
#### 3. 检索(获取完整父块上下文)
```python
# 检索时返回完整父块
results = builder.search_with_parent_context("查询内容")
```
### 检索流程
```
1. vector_store.similarity_search() → 从 Qdrant 找到相关子块
2. retriever.get_relevant_documents() → 根据子块 ID 获取对应父块
3. 返回完整父块给用户
```
---
### 串联与触发方式
在你的 LangGraph 系统外,创建一个执行脚本 `scripts/run_indexer.py`

View File

@@ -1,25 +1,60 @@
"""
Offline RAG Indexer module.
提供完整的离线索引构建功能,包括:
- 文档加载PDF、Word、TXT 等)
- 文本切分(递归、语义、父子块)
- 向量嵌入(支持 llama.cpp
- 向量存储Qdrant
- 父文档存储PostgreSQL
示例用法:
>>> from rag_indexer import IndexBuilder, SplitterType
>>>
>>> builder = IndexBuilder(
... collection_name="my_docs",
... splitter_type=SplitterType.PARENT_CHILD,
... qdrant_url="http://localhost:6333"
... )
>>>
>>> builder.build_from_file("document.pdf")
"""
from .loaders import DocumentLoader
from .splitters import (
RecursiveSplitter,
SemanticSplitter,
ParentChildSplitter,
SplitterType,
get_splitter,
ParentChildSplitter,
)
from .embedders import LlamaCppEmbedder
from .vector_store import QdrantVectorStore
from .builder import IndexBuilder
# 导出存储相关类(从新的 store 包)
from .store import (
PostgresDocStore,
create_docstore,
)
__version__ = "2.0.0"
__all__ = [
# 核心类
"DocumentLoader",
"RecursiveSplitter",
"SemanticSplitter",
"ParentChildSplitter",
"IndexBuilder",
# 切分相关
"SplitterType",
"get_splitter",
"ParentChildSplitter",
# 嵌入和向量存储
"LlamaCppEmbedder",
"QdrantVectorStore",
"IndexBuilder",
# 存储(新的 store 包)
"PostgresDocStore",
"create_docstore",
]

View File

@@ -1,56 +1,68 @@
"""
Core pipeline builder for offline RAG index construction.
离线 RAG 索引构建核心流水线。
Now supports LangChain's ParentDocumentRetriever for parent-child chunking.
支持 LangChain ParentDocumentRetriever 用于父子块切分。
"""
import asyncio
import logging
from pathlib import Path
from typing import List, Union, Optional, Tuple
from typing import List, Union, Optional, Tuple, Any
from dataclasses import dataclass
from httpx import RemoteProtocolError
from langchain_core.documents import Document
from langchain.retrievers import ParentDocumentRetriever
from langchain.storage import LocalFileStore, BaseStore
from langchain_classic.retrievers import ParentDocumentRetriever
from langchain_core.stores import BaseStore
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_experimental.text_splitter import SemanticChunker
from .loaders import DocumentLoader
from .splitters import SplitterType, get_splitter, ParentChildSplitter
from .splitters import SplitterType, get_splitter, ParentChildSplitter, SemanticChunkerAdapter
from .embedders import LlamaCppEmbedder
from .vector_store import QdrantVectorStore
from .docstore_manager import get_docstore, PostgresDocStore, create_docstore
from .vector_store import QdrantVectorStore, QDRANT_URL, QDRANT_API_KEY
from .store import create_docstore
logger = logging.getLogger(__name__)
@dataclass
class ParentChildConfig:
"""Configuration for parent-child splitting."""
"""父子块切分配置。"""
parent_chunk_size: int = 1000
child_chunk_size: int = 200
parent_chunk_overlap: int = 100
child_chunk_overlap: int = 20
search_k: int = 5
docstore_path: str = None
docstore_path: Optional[str] = None
docstore_type: str = "local"
docstore_conn_string: str = None
docstore_conn_string: Optional[str] = None
class IndexBuilder:
"""Main pipeline for RAG index construction."""
"""RAG 索引构建主流水线。"""
# 类型注解
parent_splitter: "RecursiveCharacterTextSplitter"
child_splitter: Union["RecursiveCharacterTextSplitter", "SemanticChunker", "SemanticChunkerAdapter"]
docstore: Optional["BaseStore"]
_docstore_conn: Optional[str]
retriever: Optional["ParentDocumentRetriever"]
vector_store_obj: Any
def __init__(
self,
collection_name: str = "rag_documents",
qdrant_url: str = None,
splitter_type: SplitterType = SplitterType.RECURSIVE,
splitter_type: SplitterType = SplitterType.PARENT_CHILD,
docstore=None,
**splitter_kwargs,
):
self.collection_name = collection_name
self.qdrant_url = qdrant_url
self.splitter_type = splitter_type
self.splitter_kwargs = splitter_kwargs
self.docstore = docstore # 从外部注入
# Components
# 组件
self.loader = DocumentLoader()
self.embedder = LlamaCppEmbedder()
self.embeddings = self.embedder.as_langchain_embeddings()
@@ -58,104 +70,145 @@ class IndexBuilder:
self.vector_store = QdrantVectorStore(
collection_name=collection_name,
embeddings=self.embeddings,
qdrant_url=qdrant_url,
)
# Splitter (except parent-child which is handled separately)
# 切分器(父子块单独处理)
if splitter_type != SplitterType.PARENT_CHILD:
if splitter_type == SplitterType.SEMANTIC:
splitter_kwargs["embeddings"] = self.embeddings
self.splitter = get_splitter(splitter_type, **splitter_kwargs)
else:
self.splitter = None
# Initialize ParentDocumentRetriever for parent-child splitting
# 为父子块切分初始化 ParentDocumentRetriever
self._init_parent_child_retriever()
def _init_parent_child_retriever(self, **kwargs):
"""
Initialize ParentDocumentRetriever for parent-child chunking.
初始化 ParentDocumentRetriever 用于父子块切分。
This replaces the custom ParentChildSplitter logic.
支持动态语义切分与父子块策略结合:
- 父块使用递归切分(大块,提供上下文)
- 子块可以使用递归切分或语义切分(根据语义动态切分,提高检索精度)
替代自定义的 ParentChildSplitter 逻辑。
"""
# Parse kwargs for parent-child config
# 解析父子块配置参数
parent_size = kwargs.get("parent_chunk_size", 1000)
child_size = kwargs.get("child_chunk_size", 200)
parent_overlap = kwargs.get("parent_chunk_overlap", kwargs.get("chunk_overlap", 100))
child_overlap = kwargs.get("child_chunk_overlap", kwargs.get("chunk_overlap", 20))
# Define splitters
# 子块切分器类型,默认为语义切分
child_splitter_type = kwargs.get("child_splitter_type", SplitterType.SEMANTIC)
# 定义父块切分器(始终使用递归切分)
self.parent_splitter = RecursiveCharacterTextSplitter(
chunk_size=parent_size,
chunk_overlap=parent_overlap,
)
# 定义子块切分器(根据类型选择)
if child_splitter_type == SplitterType.SEMANTIC:
self.child_splitter = get_splitter(
SplitterType.SEMANTIC,
embeddings=self.embeddings,
)
logger.info(f"子块使用语义切分器")
else:
# 默认使用递归切分
self.child_splitter = RecursiveCharacterTextSplitter(
chunk_size=child_size,
chunk_overlap=child_overlap,
)
logger.info(f"子块使用递归切分器,块大小: {child_size},重叠: {child_overlap}")
# Vector store (for child chunks)
# 向量存储(用于子块)
self.vector_store_obj = self.vector_store.get_langchain_vectorstore()
# Document store (for parent chunks)
docstore_path = kwargs.get("docstore_path")
docstore_type = kwargs.get("docstore_type", "local")
# 文档存储(用于父块)
if self.docstore is None:
# 如果没有外部注入 docstore则使用 PostgreSQL 创建
docstore_conn = kwargs.get("docstore_conn_string")
pool_config = kwargs.get("pool_config")
max_concurrency = kwargs.get("max_concurrency")
if docstore_type == "postgres" and docstore_conn:
self.docstore = PostgresDocStore(docstore_conn)
self._docstore_conn = docstore_conn
# 使用 create_docstore 创建 PostgreSQL 存储
self.docstore, self._docstore_conn = create_docstore(
connection_string=docstore_conn,
pool_config=pool_config,
max_concurrency=max_concurrency
)
else:
self.docstore = get_docstore(docstore_path)
# 使用外部注入的 docstore
self._docstore_conn = None
# Create retriever
# 创建检索器
self.retriever = ParentDocumentRetriever(
vectorstore=self.vector_store_obj,
docstore=self.docstore,
child_splitter=self.child_splitter,
child_splitter=self.child_splitter, # type: ignore
parent_splitter=self.parent_splitter,
search_kwargs={"k": kwargs.get("search_k", 5)},
)
logger.info(f"ParentDocumentRetriever 已初始化,父块大小: {parent_size},子块类型: {child_splitter_type}")
def build_from_file(self, file_path: Union[str, Path]) -> int:
logger.info("Loading file: %s", file_path)
async def build_from_file(self, file_path: Union[str, Path]) -> int:
logger.info("加载文件: %s", file_path)
documents = self.loader.load_file(file_path)
logger.info("Loaded %d documents", len(documents))
return self._process_documents(documents)
logger.info("已加载 %d 个文档", len(documents))
return await self._process_documents(documents)
def build_from_directory(self, directory_path: Union[str, Path], recursive: bool = True) -> int:
logger.info("Loading directory: %s (recursive=%s)", directory_path, recursive)
async def build_from_directory(self, directory_path: Union[str, Path], recursive: bool = True) -> int:
logger.info("加载目录: %s (递归=%s)", directory_path, recursive)
documents = self.loader.load_directory(directory_path, recursive=recursive)
logger.info("Loaded %d documents from directory", len(documents))
return self._process_documents(documents)
logger.info("已从目录加载 %d 个文档", len(documents))
return await self._process_documents(documents)
def _process_documents(self, documents: List[Document]) -> int:
async def _process_documents(self, documents: List[Document]) -> int:
if not documents:
logger.warning("No documents to process")
logger.warning("没有文档需要处理")
return 0
if self.splitter_type == SplitterType.PARENT_CHILD:
logger.info("Using LangChain ParentDocumentRetriever")
logger.info("使用 LangChain ParentDocumentRetriever")
# Ensure collection exists for child chunks
# 确保集合存在(用于子块)
self.vector_store.create_collection()
# Use ParentDocumentRetriever to add documents
# This automatically handles parent-child splitting, mapping, and retrieval
self.retriever.add_documents(documents)
# 分批处理,避免单次请求过大
assert self.retriever is not None, "retriever 未初始化"
batch_size = 10 # 每次处理10个文档
total = len(documents)
processed = 0
for i in range(0, total, batch_size):
batch = documents[i:i + batch_size]
max_retries = 3
for attempt in range(max_retries):
try:
await self.retriever.aadd_documents(batch)
processed += len(batch)
logger.info(f"批次 {i//batch_size + 1}: 已处理 {processed}/{total}")
break
except (RemoteProtocolError, ConnectionError, OSError) as e:
if attempt == max_retries - 1:
raise
logger.warning(f"连接断开,重试 ({attempt+1}/{max_retries}): {e}")
self.vector_store.refresh_client()
await asyncio.sleep(1)
# Log estimated chunk counts
estimated_parent_chunks = len(documents) * (self.parent_splitter._chunk_size // self.child_splitter._chunk_size)
logger.info(
"Indexed with ParentDocumentRetriever: "
f"~{len(documents)} parent chunks, ~{estimated_parent_chunks} child chunks"
"已使用 ParentDocumentRetriever 索引: "
f"{processed} 个父块"
)
return len(documents)
return processed
else:
logger.info("Splitting documents using %s", self.splitter_type)
logger.info("使用 %s 切分文档", self.splitter_type)
# 当 splitter_type 不是 PARENT_CHILD 时splitter 一定不为 None
assert self.splitter is not None, "splitter 未初始化"
chunks = self.splitter.split_documents(documents)
logger.info("Split into %d chunks", len(chunks))
logger.info("已切分为 %d 个块", len(chunks))
self.vector_store.create_collection()
self.vector_store.add_documents(chunks)
@@ -165,90 +218,164 @@ class IndexBuilder:
return self.vector_store.get_collection_info()
def search(self, query: str, k: int = 5) -> List[Document]:
"""Standard search - returns child chunks."""
"""标准搜索 - 返回子块。"""
return self.vector_store.similarity_search(query, k=k)
def search_with_parent_context(self, query: str, k: int = 5) -> List[Document]:
async def search_with_parent_context(self, query: str, k: int = 5) -> List[Document]:
"""
Search with parent context - returns full parent chunks.
带父块上下文的搜索 - 返回完整父块。
This is the main retrieval method when using parent-child splitting.
这是使用父子块切分时的主要检索方法。
"""
if self.splitter_type != SplitterType.PARENT_CHILD:
raise RuntimeError(
"search_with_parent_context only available with PARENT_CHILD splitter. "
"Use search() for standard retrieval."
"search_with_parent_context 仅在 PARENT_CHILD 切分器下可用。"
"请使用 search() 进行标准检索。"
)
return self.retriever.get_relevant_documents(query, k=k)
assert self.retriever is not None, "retriever 未初始化"
return await self.retriever.ainvoke(query, config={"k": k}) # type: ignore
def retrieve(self, query: str, return_parent: bool = True) -> List[Document]:
async def retrieve(self, query: str, return_parent: bool = True) -> List[Document]:
"""
Unified retrieval interface.
统一检索接口。
Args:
query: Search query
return_parent: If True and using parent-child splitter, return parent chunks
If False, always return child chunks
query: 搜索查询
return_parent: 如果为 True 且使用父子块切分,返回父块
如果为 False,始终返回子块
Returns:
List of relevant documents
相关文档列表
"""
if self.splitter_type == SplitterType.PARENT_CHILD and return_parent:
return self.search_with_parent_context(query)
return await self.search_with_parent_context(query)
else:
return self.search(query)
async def retrieve_with_fusion(self, query: str, llm: Any, num_queries: int = 3, k: int = 5, return_parent: bool = True) -> List[Document]:
"""
使用 RAG-Fusion 进行检索(多路查询改写 + 倒数排名融合)。
核心原理:
1. 多路查询改写: 利用 LLM 将原始查询改写成多个不同表述
2. 倒数排名融合: 对每个改写查询的结果进行 RRF 融合,避免单一检索结果主导
Args:
query: 原始搜索查询
llm: 语言模型实例(用于查询改写)
num_queries: 生成的查询数量
k: 返回的文档数量
return_parent: 如果为 True 且使用父子块切分,返回父块
如果为 False始终返回子块
Returns:
经过融合后的相关文档列表
"""
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain.retrievers import EnsembleRetriever
if self.splitter_type == SplitterType.PARENT_CHILD and return_parent:
# 使用 ParentDocumentRetriever 作为基础检索器
assert self.retriever is not None, "retriever 未初始化"
base_retriever = self.retriever
else:
# 使用向量存储作为基础检索器
base_retriever = self.vector_store.as_langchain_vectorstore().as_retriever(search_kwargs={"k": k * 2})
# 创建多路查询检索器
multi_query_retriever = MultiQueryRetriever.from_llm(
retriever=base_retriever,
llm=llm,
include_original=True
)
# 设置自定义提示词以生成指定数量的查询
from langchain_core.prompts import PromptTemplate
multi_query_retriever.llm_chain.prompt = PromptTemplate.from_template(
"你是一个专业的查询改写助手。你的任务是将用户的问题改写成 {num_queries} 个不同的版本。\n"
"这些版本应该从不同的角度、使用不同的关键词来表达相同或相关的意图。\n\n"
"原始问题: {question}\n\n"
"请生成 {num_queries} 个不同版本的查询,每个版本一行。\n"
"确保每个版本都是独立、完整的查询语句。\n\n"
"生成 {num_queries} 个查询:"
)
# 修改调用参数以包含 num_queries
original_ainvoke = multi_query_retriever.llm_chain.ainvoke
async def new_ainvoke(input_dict):
input_dict["num_queries"] = num_queries
return await original_ainvoke(input_dict)
multi_query_retriever.llm_chain.ainvoke = new_ainvoke
# 执行检索
documents = await multi_query_retriever.ainvoke(query)
# 去重并限制数量
seen_content = set()
unique_documents = []
for doc in documents:
content = doc.page_content
if content not in seen_content:
seen_content.add(content)
unique_documents.append(doc)
if len(unique_documents) >= k:
break
logger.info(f"RAG-Fusion 检索完成: 原始 {len(documents)} 个结果,去重后 {len(unique_documents)} 个结果")
return unique_documents
def get_retriever(self) -> ParentDocumentRetriever:
"""
Get the ParentDocumentRetriever instance directly.
直接获取 ParentDocumentRetriever 实例。
Useful for advanced use cases where you want to access the retriever
outside of IndexBuilder.
适用于需要在 IndexBuilder 外部访问检索器的高级用例。
"""
if self.splitter_type != SplitterType.PARENT_CHILD:
raise RuntimeError(
"get_retriever() only available with PARENT_CHILD splitter. "
"Use search() or search_with_parent_context() for standard retrieval."
"get_retriever() 仅在 PARENT_CHILD 切分器下可用。"
"请使用 search() search_with_parent_context() 进行标准检索。"
)
assert self.retriever is not None, "retriever 未初始化"
return self.retriever
def get_child_splitter(self) -> "RecursiveCharacterTextSplitter":
"""Get the child splitter for reconfiguration."""
def get_child_splitter(self) -> Union["RecursiveCharacterTextSplitter", "SemanticChunker", "SemanticChunkerAdapter"]:
"""获取子块切分器以便重新配置。"""
if self.splitter_type != SplitterType.PARENT_CHILD:
return self.splitter
return self.splitter # type: ignore
return self.child_splitter
def get_parent_splitter(self) -> "RecursiveCharacterTextSplitter":
"""Get the parent splitter for reconfiguration."""
"""获取父块切分器以便重新配置。"""
if self.splitter_type != SplitterType.PARENT_CHILD:
raise RuntimeError(
"Parent splitter only available with PARENT_CHILD splitter."
"父块切分器仅在 PARENT_CHILD 切分器下可用。"
)
return self.parent_splitter
def get_docstore(self) -> BaseStore:
"""Get the document store for parent chunks."""
"""获取父块的文档存储。"""
if self.splitter_type != SplitterType.PARENT_CHILD:
raise RuntimeError(
"Docstore only available with PARENT_CHILD splitter."
"文档存储仅在 PARENT_CHILD 切分器下可用。"
)
assert self.docstore is not None, "docstore 未初始化"
return self.docstore
def get_docstore_path(self) -> str:
"""Get the document store path."""
def get_docstore_path(self) -> Optional[str]:
"""获取文档存储路径(已弃用,仅用于兼容性)。"""
if self.splitter_type != SplitterType.PARENT_CHILD:
raise RuntimeError(
"Docstore path only available with PARENT_CHILD splitter."
"文档存储路径仅在 PARENT_CHILD 切分器下可用。"
)
return self.docstore.persist_path
# PostgreSQL 存储没有 persist_path,返回 None
return None
def close(self):
"""Close resources."""
if hasattr(self, "_docstore_conn") and self._docstore_conn:
import psycopg2
conn = psycopg2.connect(self._docstore_conn)
conn.close()
logger.info("Closed PostgreSQL connection")
"""关闭资源。"""
if self.docstore is not None and hasattr(self.docstore, "aclose"):
import asyncio
asyncio.get_event_loop().run_until_complete(self.docstore.aclose()) # type: ignore
logger.info("PostgreSQL 异步连接池已关闭")
def __enter__(self):
return self
@@ -258,20 +385,8 @@ class IndexBuilder:
return False
# RecursiveCharacterTextSplitter needs to be imported
# 需要导入 RecursiveCharacterTextSplitter
from langchain_text_splitters import RecursiveCharacterTextSplitter
if __name__ == "__main__":
# Example usage
builder = IndexBuilder(
splitter_type=SplitterType.PARENT_CHILD,
parent_chunk_size=1000,
child_chunk_size=200,
docstore_path="./my_parent_docs",
)
print("Parent splitter:", builder.get_parent_splitter().chunk_size)
print("Child splitter:", builder.get_child_splitter().chunk_size)
print("Docstore path:", builder.get_docstore_path())
print("Retriever:", builder.get_retriever())
# 示例用法已移除,请参考文档

View File

@@ -3,100 +3,85 @@ Command-line interface for the RAG index builder.
"""
import argparse
import asyncio
import logging
import sys
from builder import IndexBuilder
from splitters import SplitterType
from rag_indexer.builder import IndexBuilder
from rag_indexer.splitters import SplitterType
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
# 基础配置
COLLECTION_NAME = "rag_documents"
DB_URI = "postgresql://postgres:huang1998@115.190.121.151:5432/langgraph_db?sslmode=disable"
def main():
parser = argparse.ArgumentParser(description="Offline RAG Index Builder")
parser.add_argument("--file", type=str, help="Path to file to index")
parser.add_argument("--dir", type=str, help="Path to directory to index")
parser.add_argument("--recursive", action="store_true", default=True,
help="Recursively process directories (default: True)")
parser.add_argument("--collection", type=str, default="rag_documents",
help="Qdrant collection name (default: rag_documents)")
parser.add_argument("--qdrant-url", type=str,
help="Qdrant server URL (default: http://127.0.0.1:6333)")
parser.add_argument("--splitter", type=str,
choices=["recursive", "semantic", "parent_child"],
default="recursive",
help="Text splitting strategy (default: recursive)")
parser.add_argument("--chunk-size", type=int, default=500,
help="Chunk size for recursive/parent splitter (default: 500)")
parser.add_argument("--chunk-overlap", type=int, default=50,
parser.add_argument("--docstore-path", type=str,
default=None,
help="Path to store parent documents for parent-child splitter (default: ./parent_docs or HERMES_HOME/parent_docs)")
parser.add_argument("--docstore-type", type=str,
choices=["local", "postgres"],
default="local",
help="Type of docstore: 'local' (default) or 'postgres' for PostgreSQL-backed storage")
parser.add_argument("--docstore-conn", type=str,
default=None,
help="PostgreSQL connection string for postgres docstore")
# 基础切分参数
CHUNK_SIZE = 500
CHUNK_OVERLAP = 50
help="Chunk overlap (default: 50)")
parser.add_argument("--parent-size", type=int, default=1000,
help="Parent chunk size for parent-child splitter (default: 1000)")
parser.add_argument("--child-size", type=int, default=200,
help="Child chunk size for parent-child splitter (default: 200)")
# 父子块切分参数
PARENT_CHUNK_SIZE = 1000
CHILD_CHUNK_SIZE = 200
PARENT_CHUNK_OVERLAP = 100
CHILD_CHUNK_OVERLAP = 20
args = parser.parse_args()
# 切分策略basic基础、semantic语义、parent-child父子块
STRATEGY = "parent-child"
if not args.file and not args.dir:
print("Error: Either --file or --dir must be specified", file=sys.stderr)
parser.print_help()
sys.exit(1)
# 存储类型postgresPostgreSQL、local本地文件
STORAGE_TYPE = "postgres"
splitter_map = {
"recursive": SplitterType.RECURSIVE,
"semantic": SplitterType.SEMANTIC,
"parent_child": SplitterType.PARENT_CHILD,
}
splitter_type = splitter_map[args.splitter]
async def main():
# 使用固定策略
splitter_type = SplitterType.PARENT_CHILD
child_splitter_type = SplitterType.SEMANTIC
splitter_kwargs = {}
if splitter_type == SplitterType.RECURSIVE:
splitter_kwargs["chunk_size"] = args.chunk_size
splitter_kwargs["chunk_overlap"] = args.chunk_overlap
splitter_kwargs["chunk_size"] = CHUNK_SIZE
splitter_kwargs["chunk_overlap"] = CHUNK_OVERLAP
elif splitter_type == SplitterType.PARENT_CHILD:
splitter_kwargs["parent_chunk_size"] = args.parent_size
splitter_kwargs["child_chunk_size"] = args.child_size
splitter_kwargs["parent_chunk_overlap"] = args.chunk_overlap
splitter_kwargs["child_chunk_overlap"] = args.chunk_overlap // 2
splitter_kwargs["docstore_path"] = args.docstore_path
splitter_kwargs["docstore_type"] = args.docstore_type
splitter_kwargs["docstore_conn_string"] = args.docstore_conn
splitter_kwargs["parent_chunk_size"] = PARENT_CHUNK_SIZE
splitter_kwargs["child_chunk_size"] = CHILD_CHUNK_SIZE
splitter_kwargs["parent_chunk_overlap"] = PARENT_CHUNK_OVERLAP
splitter_kwargs["child_chunk_overlap"] = CHILD_CHUNK_OVERLAP
splitter_kwargs["child_splitter_type"] = child_splitter_type
if STORAGE_TYPE == "postgres":
splitter_kwargs["docstore_conn_string"] = DB_URI
elif STORAGE_TYPE == "local":
splitter_kwargs["docstore_path"] = "./parent_docs"
else:
splitter_kwargs["docstore_conn_string"] = DB_URI
builder = IndexBuilder(
collection_name=args.collection,
qdrant_url=args.qdrant_url,
collection_name=COLLECTION_NAME,
splitter_type=splitter_type,
**splitter_kwargs
)
try:
if args.file:
chunk_count = builder.build_from_file(args.file)
else:
chunk_count = builder.build_from_directory(args.dir, args.recursive)
is_file=False
path="data/corpus/"
print(f"Indexing completed. Total chunks indexed: {chunk_count}")
try:
if is_file:
chunk_count = await builder.build_from_file(path)
else:
chunk_count = await builder.build_from_directory(path, recursive=True)
print(f"索引构建完成。共索引 {chunk_count} 个块")
info = builder.get_collection_info()
print(f"Collection '{info['name']}' has {info['vectors_count']} vectors (dim={info['vector_size']})")
print(f"集合 '{info['name']}' 包含 {info['vectors_count']} 个向量(维度:{info['vector_size']}")
except Exception as e:
logging.exception("Indexing failed")
logging.exception(f"索引构建失败:{e}")
sys.exit(1)
if __name__ == "__main__":
main()
asyncio.run(main())

View File

@@ -1,142 +0,0 @@
"""
Document store manager for ParentDocumentRetriever.
Supports both LocalFileStore (default) and custom PostgreSQL-backed stores.
"""
import os
from typing import Optional
from langchain.storage import BaseStore, LocalFileStore
def get_docstore(persist_path: str = None) -> LocalFileStore:
"""
Create and return a document store for parent chunks.
Args:
persist_path: Path to store parent documents. Defaults to ./parent_docs
or HERMES_HOME/parent_docs if set.
"""
if persist_path is None:
# Use HERMES_HOME if available, otherwise default to current directory
persist_path = os.getenv("HERMES_HOME")
if persist_path:
persist_path = os.path.join(persist_path, "parent_docs")
else:
persist_path = "./parent_docs"
os.makedirs(persist_path, exist_ok=True)
return LocalFileStore(persist_path)
class PostgresDocStore(BaseStore):
"""
PostgreSQL-backed document store for parent chunks.
This is an optional advanced feature. For most use cases,
LocalFileStore is sufficient and simpler.
"""
def __init__(self, connection_string: str):
"""
Initialize PostgreSQL document store.
Args:
connection_string: PostgreSQL connection URL
"""
import psycopg2
from psycopg2 import sql
self.conn_string = connection_string
self._conn = None
# Create table if not exists
self._create_table()
def _create_table(self):
"""Create the parent documents table if not exists."""
try:
self._conn = psycopg2.connect(self.conn_string)
cursor = self._conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS parent_documents (
key TEXT PRIMARY KEY,
value JSONB NOT NULL,
created_at TIMESTAMPTZ DEFAULT NOW()
)
""")
self._conn.commit()
cursor.close()
except Exception as e:
raise RuntimeError(f"Failed to create PostgreSQL table: {e}")
def get(self, key: str) -> Optional[dict]:
"""Retrieve a document by key."""
try:
self._ensure_connection()
cursor = self._conn.cursor()
cursor.execute("SELECT value FROM parent_documents WHERE key = %s", (key,))
row = cursor.fetchone()
cursor.close()
if row:
import json
return json.loads(row[0])
return None
except Exception as e:
raise RuntimeError(f"Failed to retrieve document: {e}")
def set(self, key: str, value: dict) -> None:
"""Store a document."""
try:
self._ensure_connection()
cursor = self._conn.cursor()
# Upsert
insert_query = sql.SQL(
"INSERT INTO parent_documents (key, value) VALUES (%s, %s)"
)
update_query = sql.SQL(
"UPDATE parent_documents SET value = %s WHERE key = %s"
)
cursor.execute(insert_query, (key, json.dumps(value)))
try:
cursor.execute(update_query, (key, json.dumps(value)))
except psycopg2.IntegrityError:
pass # Key exists, ignore
self._conn.commit()
cursor.close()
except Exception as e:
raise RuntimeError(f"Failed to store document: {e}")
def _ensure_connection(self):
"""Ensure we have an open connection."""
if self._conn is None or self._conn.closed:
self._conn = psycopg2.connect(self.conn_string)
def close(self):
"""Close the connection."""
if self._conn and not self._conn.closed:
self._conn.close()
# Factory function for creating custom docstores
# Returns a tuple: (BaseStore instance, connection_string or None)
def create_docstore(
store_type: str = "local",
persist_path: str = None,
connection_string: str = None
) -> tuple:
"""
Factory function to create different types of document stores.
Args:
store_type: "local" (default), "postgres"
persist_path: Path for local file store
connection_string: PostgreSQL connection string
Returns:
Tuple of (BaseStore instance, connection_string or None)
"""
if store_type == "postgres" and connection_string:
return (PostgresDocStore(connection_string), connection_string)
else:
return (get_docstore(persist_path), None)

View File

@@ -1,16 +1,17 @@
"""
Embedding model wrapper for llama.cpp service.
嵌入模型包装器,用于 llama.cpp 服务。
"""
import os
import httpx
from typing import List, Optional
from urllib.parse import urljoin
from langchain_openai import OpenAIEmbeddings
from langchain_core.embeddings import Embeddings
class LlamaCppEmbedder:
"""Wrapper for llama.cpp embedding service via OpenAI-compatible API."""
"""通过 OpenAI 兼容 API 封装 llama.cpp 嵌入服务。"""
def __init__(
self,
@@ -22,47 +23,66 @@ class LlamaCppEmbedder:
self.api_key = api_key or os.getenv("LLAMACPP_API_KEY", "")
self.model = model
# Ensure URL ends with /v1
self.base_url = urljoin(self.base_url.rstrip("/") + "/", "v1")
def as_langchain_embeddings(self) -> OpenAIEmbeddings:
"""Create LangChain OpenAIEmbeddings instance."""
return OpenAIEmbeddings(
openai_api_base=self.base_url,
openai_api_key=self.api_key,
model=self.model,
)
def as_langchain_embeddings(self) -> Embeddings:
"""创建 LangChain 兼容的嵌入实例。"""
return _LlamaCppLangchainAdapter(self)
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed a list of documents."""
emb = self.as_langchain_embeddings()
return emb.embed_documents(texts)
"""嵌入一批文档。"""
return self._call_embedding_api(texts)
def embed_query(self, text: str) -> List[float]:
"""Embed a single query."""
emb = self.as_langchain_embeddings()
return emb.embed_query(text)
"""嵌入单个查询。"""
return self._call_embedding_api([text])[0]
def get_embedding_dimension(self) -> int:
"""Get embedding dimension by embedding a test string."""
"""通过嵌入测试字符串获取嵌入维度。"""
test_embedding = self.embed_query("test")
return len(test_embedding)
def _call_embedding_api(self, texts: List[str]) -> List[List[float]]:
"""直接调用 llama.cpp 嵌入 API。"""
base = self.base_url.rstrip("/")
if not base.endswith("/v1"):
base = base + "/v1"
class MockEmbedder:
"""Mock embedder for testing without a real service."""
headers = {"Content-Type": "application/json"}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
def __init__(self, dimension: int = 768):
self.dimension = dimension
payload = {
"input": texts,
"model": self.model,
}
def as_langchain_embeddings(self) -> OpenAIEmbeddings:
raise NotImplementedError("MockEmbedder cannot be used as LangChain embeddings")
with httpx.Client(timeout=120) as client:
response = client.post(
f"{base}/embeddings",
headers=headers,
json=payload,
)
response.raise_for_status()
data = response.json()
# 处理不同响应格式
if isinstance(data, list):
# llama.cpp 直接返回列表
return [item["embedding"] for item in data]
elif isinstance(data, dict) and "data" in data:
# OpenAI 标准格式
return [item["embedding"] for item in sorted(data["data"], key=lambda x: x["index"])]
else:
raise ValueError(f"未知的嵌入 API 响应格式: {data}")
class _LlamaCppLangchainAdapter(Embeddings):
"""将 LlamaCppEmbedder 适配为 LangChain Embeddings 接口。"""
def __init__(self, embedder: LlamaCppEmbedder):
self._embedder = embedder
def embed_documents(self, texts: List[str]) -> List[List[float]]:
return [[0.0] * self.dimension for _ in texts]
return self._embedder.embed_documents(texts)
def embed_query(self, text: str) -> List[float]:
return [0.0] * self.dimension
def get_embedding_dimension(self) -> int:
return self.dimension
return self._embedder.embed_query(text)

View File

@@ -1,124 +0,0 @@
"""
Example demonstrating ParentDocumentRetriever usage.
This script shows how to:
1. Build an index with parent-child chunking
2. Search with child chunks (fast, precise)
3. Search with parent context (large context)
4. Access the retriever directly for advanced use cases
"""
import logging
from pathlib import Path
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
from builder import IndexBuilder
from splitters import SplitterType
def main():
print("=" * 70)
print("ParentDocumentRetriever Example")
print("=" * 70)
# Step 1: Create IndexBuilder with parent-child splitting
print("\n1. Creating IndexBuilder with parent-child splitting...")
builder = IndexBuilder(
collection_name="parent_child_demo",
splitter_type=SplitterType.PARENT_CHILD,
parent_chunk_size=1000, # Parent chunks: larger context
child_chunk_size=200, # Child chunks: smaller for precision
docstore_path="./my_parent_docs", # Where to store parent chunks
search_k=5, # Number of child chunks to retrieve
)
print(f" Parent splitter: chunk_size={builder.get_parent_splitter().chunk_size}")
print(f" Child splitter: chunk_size={builder.get_child_splitter().chunk_size}")
print(f" Docstore path: {builder.get_docstore_path()}")
print(f" Search k: {builder.retriever.search_kwargs['k']}")
# Step 2: Build index from a sample file
print("\n2. Building index from sample file...")
# Create a test document
test_content = """
This is a test document for demonstrating ParentDocumentRetriever.
Parent chunks contain larger portions of text (1000 characters),
while child chunks are smaller (200 characters) for precise retrieval.
When you search with ParentDocumentRetriever:
- It first retrieves relevant child chunks
- Then replaces them with their corresponding parent chunks
- This gives you large context while maintaining precision
Example search queries:
- "ParentDocumentRetriever"
- "child chunks"
- "large context"
- "precise retrieval"
"""
test_file = Path("./test_document.txt")
test_file.write_text(test_content)
chunk_count = builder.build_from_file(str(test_file))
print(f" Indexed {chunk_count} documents")
# Step 3: Search with child chunks (fast, precise)
print("\n3. Searching with child chunks (fast, precise)...")
child_results = builder.search("ParentDocumentRetriever", k=3)
print(f" Found {len(child_results)} child chunks:")
for i, doc in enumerate(child_results, 1):
print(f" [{i}] {doc.page_content[:100]}...")
# Step 4: Search with parent context (large context)
print("\n4. Searching with parent context (large context)...")
parent_results = builder.search_with_parent_context("ParentDocumentRetriever", k=3)
print(f" Found {len(parent_results)} parent chunks:")
for i, doc in enumerate(parent_results, 1):
print(f" [{i}] {doc.page_content[:150]}...")
# Step 5: Compare results
print("\n5. Comparing child vs parent results...")
print(f" Child chunks total length: {sum(len(d.page_content) for d in child_results)}")
print(f" Parent chunks total length: {sum(len(d.page_content) for d in parent_results)}")
print(f" Ratio: parent/child = {sum(len(d.page_content) for d in parent_results) / max(sum(len(d.page_content) for d in child_results), 1):.2f}x larger")
# Step 6: Access retriever directly
print("\n6. Accessing retriever directly...")
retriever = builder.get_retriever()
print(f" Retriever type: {type(retriever).__name__}")
print(f" Vectorstore: {retriever.vectorstore}")
print(f" Docstore: {retriever.docstore}")
# Step 7: Unified retrieval interface
print("\n7. Using unified retrieval interface...")
unified_results = builder.retrieve("ParentDocumentRetriever", return_parent=True)
print(f" Retrieved {len(unified_results)} documents (with parent context)")
# Step 8: Collection info
print("\n8. Collection info...")
info = builder.get_collection_info()
print(f" Collection: {info['name']}")
print(f" Vectors: {info['vectors_count']}")
print(f" Vector size: {info['vector_size']}")
# Cleanup
print("\n9. Cleaning up...")
builder.close()
print("\n" + "=" * 70)
print("Example completed successfully!")
print("=" * 70)
return builder
if __name__ == "__main__":
builder = main()

View File

@@ -1,10 +1,10 @@
"""
Document loaders using unstructured library.
文档加载器,使用 unstructured 库解析文档。
"""
import logging
from pathlib import Path
from typing import List, Union
from typing import Any, Dict, List, Mapping, Optional, Union
from langchain_core.documents import Document
from unstructured.partition.auto import partition
@@ -13,33 +13,74 @@ logger = logging.getLogger(__name__)
class DocumentLoader:
"""Load documents from various file formats."""
"""从各种文件格式加载文档。"""
SUPPORTED_EXTENSIONS = {".pdf", ".docx", ".doc", ".txt", ".md", ".html", ".pptx", ".xlsx"}
SUPPORTED_EXTENSIONS = {".pdf", ".docx", ".doc", ".txt", ".md", ".html", ".pptx", ".xlsx", ".json"}
def __init__(self, extract_images: bool = False):
def __init__(
self,
extract_images: bool = False,
strategy: str = "auto",
ocr_languages: Optional[List[str]] = None,
languages: Optional[List[str]] = None,
include_page_breaks: bool = False,
pdf_infer_table_structure: bool = True,
partition_kwargs: Optional[Dict[str, Any]] = None,
):
"""
Args:
extract_images: Whether to extract images from documents (requires additional dependencies)
extract_images: 是否提取 PDF 中的图片
strategy: 解析策略 (auto, fast, hi_res, ocr_only)
ocr_languages: OCR 语言列表,如 ['chi_sim', 'eng']
languages: 文档主语言,如 ['zh']
include_page_breaks: 是否包含分页符
pdf_infer_table_structure: 是否识别表格结构 (需 hi_res 策略)
partition_kwargs: 额外的 partition 参数字典(高级定制)
"""
import os
os.environ["UNSTRUCTURED_LANGUAGE_CHECKS"] = "false"
self.extract_images = extract_images
self.strategy = strategy
self.ocr_languages = ocr_languages or ["chi_sim", "eng"]
self.languages = languages or ["zh"]
self.include_page_breaks = include_page_breaks
self.pdf_infer_table_structure = pdf_infer_table_structure
self.partition_kwargs = partition_kwargs or {}
def load_file(self, file_path: Union[str, Path]) -> List[Document]:
"""Load a single file into LangChain Document objects."""
"""将单个文件加载为 LangChain Document 对象。"""
file_path = Path(file_path).resolve()
if not file_path.exists():
raise FileNotFoundError(f"File not found: {file_path}")
raise FileNotFoundError(f"文件不存在: {file_path}")
suffix = file_path.suffix.lower()
if suffix not in self.SUPPORTED_EXTENSIONS:
raise ValueError(
f"Unsupported file extension: {suffix}. Supported: {self.SUPPORTED_EXTENSIONS}"
f"不支持的文件扩展名: {suffix}。支持的格式: {self.SUPPORTED_EXTENSIONS}"
)
# Parse with unstructured
# 根据文件类型动态调整参数
extra_kwargs = {}
if suffix == ".pdf":
extra_kwargs["strategy"] = self.strategy
extra_kwargs["ocr_languages"] = self.ocr_languages
extra_kwargs["extract_images_in_pdf"] = self.extract_images
extra_kwargs["pdf_infer_table_structure"] = self.pdf_infer_table_structure
# languages 参数适用于所有文件类型
if self.languages:
extra_kwargs["languages"] = self.languages
extra_kwargs["include_page_breaks"] = self.include_page_breaks
# 合并用户自定义的额外参数(优先级最高)
extra_kwargs.update(self.partition_kwargs)
# 使用 unstructured 解析
elements = partition(
filename=str(file_path),
extract_images_in_pdf=self.extract_images,
**extra_kwargs
)
documents = []
@@ -48,23 +89,17 @@ class DocumentLoader:
if not text or not text.strip():
continue
# Base metadata
# 基础元数据
metadata = {
"source": str(file_path),
"file_name": file_path.name,
"file_type": suffix,
}
# Merge element-specific metadata without overwriting base fields
elem_meta = getattr(elem, "metadata", {}) or {}
for key, value in elem_meta.items():
if value and key not in metadata:
metadata[key] = value
documents.append(Document(page_content=text, metadata=metadata))
if not documents:
logger.warning("No text content extracted from %s", file_path)
logger.warning("未从 %s 提取到文本内容", file_path)
return []
return documents
@@ -72,10 +107,10 @@ class DocumentLoader:
def load_directory(
self, directory_path: Union[str, Path], recursive: bool = True
) -> List[Document]:
"""Load all supported files from a directory."""
"""从目录加载所有支持的文件。"""
directory_path = Path(directory_path).resolve()
if not directory_path.is_dir():
raise NotADirectoryError(f"Not a directory: {directory_path}")
raise NotADirectoryError(f"不是目录: {directory_path}")
all_documents = []
pattern = "**/*" if recursive else "*"
@@ -86,6 +121,6 @@ class DocumentLoader:
docs = self.load_file(file_path)
all_documents.extend(docs)
except Exception as e:
logger.error("Failed to load %s: %s", file_path, e)
logger.error("加载 %s 失败: %s", file_path, e)
return all_documents

View File

@@ -1,12 +1,12 @@
"""
Text splitters for chunking documents.
文本切分器,用于将文档切分成块。
"""
from enum import Enum
from typing import List, Optional
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
from langchain_experimental.text_splitter import SemanticChunker
@@ -17,7 +17,7 @@ class SplitterType(str, Enum):
def get_splitter(splitter_type: SplitterType, **kwargs):
"""Factory function to create a text splitter."""
"""工厂函数,创建文本切分器。"""
if splitter_type == SplitterType.RECURSIVE:
chunk_size = kwargs.get("chunk_size", 500)
chunk_overlap = kwargs.get("chunk_overlap", 50)
@@ -27,19 +27,31 @@ def get_splitter(splitter_type: SplitterType, **kwargs):
separators=["\n\n", "\n", "", "", "", " ", ""],
)
elif splitter_type == SplitterType.SEMANTIC:
# Requires embeddings for semantic splitting
embeddings = kwargs.get("embeddings")
embeddings = kwargs.pop("embeddings", None)
if embeddings is None:
raise ValueError("Semantic splitter requires 'embeddings' parameter")
return SemanticChunker(embeddings=embeddings)
raise ValueError("语义切分器需要提供 'embeddings' 参数")
return SemanticChunkerAdapter(embeddings=embeddings, **kwargs)
else:
raise ValueError(f"Unsupported splitter type: {splitter_type}")
raise ValueError(f"不支持的切分器类型: {splitter_type}")
class SemanticChunkerAdapter(TextSplitter):
"""将 SemanticChunker 适配为 TextSplitter 接口。"""
def __init__(self, embeddings, **kwargs):
super().__init__(**kwargs)
chunk_size = kwargs.pop("chunk_size", None)
chunk_overlap = kwargs.pop("chunk_overlap", None)
self._chunker = SemanticChunker(embeddings=embeddings, **kwargs)
def split_text(self, text: str) -> List[str]:
return self._chunker.split_text(text)
class ParentChildSplitter:
"""
Splits documents into parent (large) and child (small) chunks.
Child chunks are indexed for retrieval, parent chunks are stored for context.
将文档切分为父块(大块)和子块(小块)。
子块用于索引检索,父块用于存储上下文。
"""
def __init__(
@@ -60,12 +72,12 @@ class ParentChildSplitter:
def split_documents(self, documents: List[Document]) -> tuple[List[Document], List[Document]]:
"""
Returns:
(parent_chunks, child_chunks)
返回:
(父块列表, 子块列表)
"""
parent_chunks = self.parent_splitter.split_documents(documents)
child_chunks = self.child_splitter.split_documents(documents)
# Link child chunks to parent IDs (optional metadata)
# In a real implementation, you'd map each child to a parent chunk ID.
# 将子块与父块 ID 关联(可选元数据)
# 在实际实现中,需要将每个子块映射到对应的父块 ID
return parent_chunks, child_chunks

View File

@@ -0,0 +1,31 @@
"""
文档存储模块 - 用于 ParentDocumentRetriever 的父文档存储。
提供 PostgreSQL 存储后端:
- PostgresDocStore: PostgreSQL 数据库存储(生产环境)
示例用法:
>>> from rag_indexer.store import create_docstore
>>> # 创建 PostgreSQL 存储
>>> store, conn = create_docstore(
... connection_string="postgresql://user:pass@host:5432/db",
... table_name="parent_docs"
... )
"""
from .postgres import PostgresDocStore
from .factory import create_docstore, get_docstore_uri, DEFAULT_DB_URI
__version__ = "2.0.0"
__all__ = [
# 具体实现
"PostgresDocStore",
# 工厂函数
"create_docstore",
"get_docstore_uri",
"DEFAULT_DB_URI",
]

View File

@@ -0,0 +1,73 @@
"""
文档存储工厂 - 创建不同类型的存储实例。
提供统一的接口来创建本地文件存储或 PostgreSQL 存储。
"""
import os
import logging
from typing import Optional, Tuple
from langchain_core.stores import BaseStore
from .postgres import PostgresDocStore
logger = logging.getLogger(__name__)
# 默认连接字符串(从环境变量读取)
DEFAULT_DB_URI = os.getenv(
"DB_URI",
"postgresql://postgres:huang1998@ai-postgres:5432/langgraph_db?sslmode=disable"
)
def get_docstore_uri() -> str:
"""获取 docstore 专用的数据库连接字符串(可与主库相同)"""
return os.getenv("DOCSTORE_URI", DEFAULT_DB_URI)
def create_docstore(
store_type: str = "postgres",
connection_string: Optional[str] = None,
table_name: str = "parent_documents",
pool_config: Optional[dict] = None,
max_concurrency: Optional[int] = None
) -> Tuple[BaseStore, Optional[str]]:
"""
工厂函数,创建 PostgreSQL 文档存储。
Args:
store_type: 存储类型,目前仅支持 "postgres"(默认)
connection_string: PostgreSQL 连接字符串
table_name: PostgreSQL 表名默认parent_documents
pool_config: 连接池配置
max_concurrency: 最大并发操作数,如果为 None 则不限制
Returns:
元组 (存储实例, 连接字符串)
Raises:
ValueError: 不支持的存储类型
ImportError: 缺少必要的依赖
Example:
>>> # 创建 PostgreSQL 存储
>>> store, conn = create_docstore(
... connection_string="postgresql://user:pass@host:5432/db",
... table_name="parent_docs",
... max_concurrency=10
... )
"""
store_type = store_type.lower()
if store_type == "postgres":
conn_str = connection_string or get_docstore_uri()
store = PostgresDocStore(
connection_string=conn_str,
table_name=table_name,
pool_config=pool_config,
max_concurrency=max_concurrency
)
return store, conn_str
else:
raise ValueError(f"不支持的存储类型: {store_type}。目前仅支持: postgres")

View File

@@ -0,0 +1,249 @@
"""
异步 PostgreSQL 存储实现 - 用于生产环境。
使用 asyncpg 实现真正的异步 PostgreSQL 文档存储,支持高并发访问。
"""
from __future__ import annotations
import asyncio
import json
import logging
from typing import List, Dict, Any, Optional, Iterator, Tuple, Sequence, cast
from langchain_core.documents import Document
from langchain_core.stores import BaseStore
import asyncpg
logger = logging.getLogger(__name__)
class PostgresDocStore(BaseStore[str, Any]):
"""
异步 PostgreSQL 文档存储实现。
使用 asyncpg 作为异步 PostgreSQL 客户端,支持:
- 真正的异步操作
- 连接池管理
- 自动表创建
- 批量操作amget/amset/amdelete
- JSONB 数据存储
- 并发控制
适用于生产环境,提供高性能的异步数据持久化。
Attributes:
dsn: PostgreSQL 连接字符串
table_name: 存储表名,默认为 "parent_documents"
_pool: asyncpg 连接池实例
_semaphore: 控制并发数的信号量(可选)
"""
def __init__(
self,
connection_string: str,
table_name: str = "parent_documents",
pool_config: Optional[Dict[str, Any]] = None,
max_concurrency: Optional[int] = None
):
"""
初始化异步 PostgreSQL 文档存储。
Args:
connection_string: PostgreSQL 连接 URL格式
"postgresql://user:password@host:port/database?sslmode=disable"
table_name: 存储表名,默认为 "parent_documents"
pool_config: 连接池配置字典,包含:
- min_size: 最小连接数(默认 2
- max_size: 最大连接数(默认 10
max_concurrency: 最大并发操作数,如果为 None 则不限制
Raises:
ImportError: 未安装 asyncpg 时抛出
Example:
>>> store = PostgresDocStore(
... "postgresql://user:pass@localhost:5432/mydb",
... table_name="parent_docs",
... pool_config={"min_size": 5, "max_size": 20},
... max_concurrency=10
... )
"""
self.dsn = connection_string
self.table_name = table_name
self._pool: Optional["asyncpg.Pool"] = None
self._pool_config = pool_config or {}
# 并发控制信号量
self._semaphore = None
if max_concurrency is not None and max_concurrency > 0:
self._semaphore = asyncio.Semaphore(max_concurrency)
# 注意:连接池的异步初始化延迟到第一次使用时
# 表结构创建也延迟到第一次操作时
async def _get_pool(self):
"""获取或创建 asyncpg 连接池。"""
if self._pool is None:
import asyncpg
min_size = self._pool_config.get("min_size", 2)
max_size = self._pool_config.get("max_size", 10)
try:
self._pool = await asyncpg.create_pool(
dsn=self.dsn,
min_size=min_size,
max_size=max_size
)
logger.info(f"PostgreSQL 异步连接池已创建: {self.table_name}")
# 初始化表结构
await self._create_table()
except Exception as e:
raise RuntimeError(f"PostgreSQL 异步连接池创建失败: {e}")
return self._pool
async def _create_table(self):
"""创建存储表(如果不存在)。"""
pool = await self._get_pool()
async with pool.acquire() as conn:
async with conn.transaction():
await conn.execute(f"""
CREATE TABLE IF NOT EXISTS {self.table_name} (
key TEXT PRIMARY KEY,
value JSONB NOT NULL,
created_at TIMESTAMPTZ DEFAULT NOW()
)
""")
logger.info(f"{self.table_name} 已就绪")
async def _with_concurrency_control(self, coro):
"""使用信号量控制并发执行。"""
if self._semaphore is None:
return await coro
async with self._semaphore:
return await coro
# --- 同步方法(保持兼容性,但功能有限)---
def mget(self, keys: Sequence[str]) -> List[Optional[Any]]:
"""不支持同步操作,请使用异步 amget 方法。"""
raise NotImplementedError("不支持同步操作,请使用异步 amget 方法。")
def mset(self, key_value_pairs: Sequence[Tuple[str, Any]]) -> None:
"""不支持同步操作,请使用异步 amset 方法。"""
raise NotImplementedError("不支持同步操作,请使用异步 amset 方法。")
def mdelete(self, keys: Sequence[str]) -> None:
"""不支持同步操作,请使用异步 amdelete 方法。"""
raise NotImplementedError("不支持同步操作,请使用异步 amdelete 方法。")
def yield_keys(self, *, prefix: str | None = None) -> Iterator[str]:
"""不支持同步操作,请使用异步 ayield_keys 方法。"""
raise NotImplementedError("不支持同步操作,请使用异步 ayield_keys 方法。")
# --- 异步方法(真正的实现)---
async def amget(self, keys: Sequence[str]) -> List[Optional[Any]]:
"""异步批量获取文档。"""
if not keys:
return []
async def _amget():
pool = await self._get_pool()
async with pool.acquire() as conn:
rows = await conn.fetch(
f"SELECT key, value FROM {self.table_name} WHERE key = ANY($1)",
keys
)
result_map = {}
for row in rows:
val = row['value']
if isinstance(val, str):
val = json.loads(val)
if isinstance(val, dict) and 'page_content' in val:
result_map[row['key']] = Document(**val)
else:
result_map[row['key']] = val
return [result_map.get(key) for key in keys]
return await self._with_concurrency_control(_amget())
async def amset(self, key_value_pairs: Sequence[Tuple[str, Any]]) -> None:
"""异步批量设置文档。"""
if not key_value_pairs:
return
async def _amset():
pool = await self._get_pool()
async with pool.acquire() as conn:
async with conn.transaction():
await conn.executemany(
f"""
INSERT INTO {self.table_name} (key, value)
VALUES ($1, $2)
ON CONFLICT (key) DO UPDATE SET value = EXCLUDED.value
""",
[
(k, json.dumps(v.dict() if isinstance(v, Document) else v, ensure_ascii=False))
for k, v in key_value_pairs
]
)
logger.debug(f"已异步批量设置 {len(key_value_pairs)} 个文档")
await self._with_concurrency_control(_amset())
async def amdelete(self, keys: Sequence[str]) -> None:
"""异步批量删除文档。"""
if not keys:
return
async def _amdelete():
pool = await self._get_pool()
async with pool.acquire() as conn:
async with conn.transaction():
await conn.execute(
f"DELETE FROM {self.table_name} WHERE key = ANY($1)",
keys
)
logger.debug(f"已异步批量删除 {len(keys)} 个文档")
await self._with_concurrency_control(_amdelete())
async def ayield_keys(self, *, prefix: str | None = None) -> Iterator[str]:
"""异步迭代所有键。
注意:这是一个异步生成器,需要使用 async for 迭代。
"""
pool = await self._get_pool()
async with pool.acquire() as conn:
if prefix:
rows = await conn.fetch(
f"SELECT key FROM {self.table_name} WHERE key LIKE $1 ORDER BY key",
f"{prefix}%"
)
else:
rows = await conn.fetch(
f"SELECT key FROM {self.table_name} ORDER BY key"
)
for row in rows:
yield row['key']
async def aclose(self) -> None:
"""异步关闭连接池,释放资源。"""
if self._pool:
await self._pool.close()
self._pool = None
logger.info("PostgreSQL 异步连接池已关闭")
def close(self) -> None:
"""同步关闭连接池(功能有限)。
注意:在异步环境中,请使用 aclose 方法。
"""
pass

View File

@@ -1,5 +1,5 @@
"""
Qdrant vector store wrapper.
Qdrant 向量数据库包装器。
"""
import logging
@@ -16,67 +16,85 @@ from .embedders import LlamaCppEmbedder
logger = logging.getLogger(__name__)
QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
class QdrantVectorStore:
"""Wrapper for Qdrant vector database operations."""
"""Qdrant 向量数据库操作包装器。"""
def __init__(
self,
collection_name: str,
embeddings: Optional[Any] = None,
qdrant_url: Optional[str] = None,
api_key: Optional[str] = None,
):
self.collection_name = collection_name
self.qdrant_url = qdrant_url or os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
self.api_key = api_key
self._client: Optional[QdrantClient] = None
# Embeddings
# 嵌入模型
if embeddings is None:
embedder = LlamaCppEmbedder()
self.embeddings = embedder.as_langchain_embeddings()
else:
self.embeddings = embeddings
# Qdrant client
self.client = QdrantClient(url=self.qdrant_url, api_key=self.api_key)
# 先创建集合
self.create_collection()
# LangChain vector store
# LangChain 向量存储
self.vector_store = LangchainQdrantVS(
client=self.client,
client=self.get_client(),
collection_name=self.collection_name,
embeddings=self.embeddings,
embedding=self.embeddings,
)
def get_client(self) -> QdrantClient:
"""懒加载客户端,每次获取时确保连接可用。"""
if self._client is None:
self._client = QdrantClient(
url=QDRANT_URL,
api_key=QDRANT_API_KEY,
timeout=120,
http2=False,
)
return self._client
def refresh_client(self):
"""关闭旧连接,创建新连接。"""
if self._client is not None:
self._client.close()
self._client = None
def create_collection(self, vector_size: Optional[int] = None, force_recreate: bool = False):
"""Create collection with appropriate vector size."""
"""创建集合,设置合适的向量维度。"""
if vector_size is None:
embedder = LlamaCppEmbedder()
vector_size = embedder.get_embedding_dimension()
collections = self.client.get_collections().collections
client = self.get_client()
collections = client.get_collections().collections
exists = any(c.name == self.collection_name for c in collections)
if exists and force_recreate:
self.client.delete_collection(self.collection_name)
client.delete_collection(self.collection_name)
exists = False
if not exists:
self.client.create_collection(
client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE),
)
logger.info("Collection '%s' created (dim=%d)", self.collection_name, vector_size)
logger.info("集合 '%s' 已创建(维度=%d", self.collection_name, vector_size)
else:
logger.info("Collection '%s' already exists", self.collection_name)
logger.info("集合 '%s' 已存在", self.collection_name)
def add_documents(self, documents: List[Document], batch_size: int = 100):
"""Add documents to vector store."""
"""将文档添加到向量数据库。"""
if not documents:
return []
self.create_collection()
ids = self.vector_store.add_documents(documents, batch_size=batch_size)
logger.info("Added %d documents to '%s'", len(ids), self.collection_name)
logger.info("已向 '%s' 添加 %d 个文档", self.collection_name, len(ids))
return ids
def similarity_search(self, query: str, k: int = 5) -> List[Document]:
@@ -86,16 +104,21 @@ class QdrantVectorStore:
return self.vector_store.similarity_search_with_score(query, k=k)
def delete_collection(self):
self.client.delete_collection(self.collection_name)
logger.info("Collection '%s' deleted", self.collection_name)
self.get_client().delete_collection(self.collection_name)
logger.info("集合 '%s' 已删除", self.collection_name)
def get_collection_info(self) -> Dict[str, Any]:
info = self.client.get_collection(self.collection_name)
info = self.get_client().get_collection(self.collection_name)
vectors_config = info.config.params.vectors
if isinstance(vectors_config, dict):
vector_size = next(iter(vectors_config.values())).size
else:
vector_size = vectors_config.size
return {
"name": info.name,
"vectors_count": info.vectors_count,
"name": self.collection_name,
"vectors_count": info.points_count or 0,
"status": info.status,
"vector_size": info.config.params.vectors.size,
"vector_size": vector_size,
}
def as_langchain_vectorstore(self):
@@ -107,4 +130,4 @@ class QdrantVectorStore:
def get_qdrant_client(self):
"""返回原生 Qdrant 客户端(如需手动管理 collection"""
return self.client
return self.get_client()

View File

@@ -49,7 +49,8 @@ python-dotenv==1.2.2
typing-extensions==4.15.0
unstructured>=0.0.1
spacy>=3.7.0
langchain_experimental>=0.0.1
# ============================================================================
# 注意:
# 1. 此文件包含项目直接依赖的精确版本