RAG数据库生成
This commit is contained in:
11
.env.docker
11
.env.docker
@@ -64,3 +64,14 @@ API_URL=http://backend:8083/chat
|
||||
# 应用行为配置
|
||||
# -----------------------------------------------------------------------------
|
||||
MEMORY_SUMMARIZE_INTERVAL=10
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# unstructured 库 spaCy 模型配置
|
||||
# -----------------------------------------------------------------------------
|
||||
# 指定文档解析使用的语言: eng (英语) 或 zho (中文)
|
||||
UNSTRUCTURED_LANGUAGE=zho
|
||||
|
||||
# 指定 spaCy 模型名称(需与 UNSTRUCTURED_LANGUAGE 对应)
|
||||
# eng -> en_core_web_sm
|
||||
# zho -> zh_core_web_sm
|
||||
SPACY_MODEL=zh_core_web_sm
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
# RAG 系统依赖
|
||||
# 基础框架
|
||||
langchain>=0.1.0
|
||||
langchain-core>=0.1.0
|
||||
langchain-openai>=0.0.1
|
||||
langchain-qdrant>=0.1.0
|
||||
|
||||
# 用于 Cross-Encoder 重排序模型
|
||||
sentence-transformers>=2.2.0
|
||||
|
||||
# 用于 BM25 关键词混合检索
|
||||
rank-bm25>=0.2.2
|
||||
|
||||
# Qdrant 客户端
|
||||
qdrant-client>=1.6.0
|
||||
|
||||
# 可选的本地模型支持
|
||||
# vllm>=0.5.0 # 如果需要本地模型推理
|
||||
# transformers>=4.35.0 # 如果需要其他模型支持
|
||||
|
||||
# 开发依赖(测试用)
|
||||
pytest>=7.0.0
|
||||
pytest-asyncio>=0.21.0
|
||||
@@ -18,6 +18,10 @@ ENV QDRANT_COLLECTION_NAME=mem0_user_memories
|
||||
ENV MEMORY_SUMMARIZE_INTERVAL=10
|
||||
ENV ENABLE_GRAPH_TRACE=false
|
||||
|
||||
# unstructured 库 spaCy 模型配置
|
||||
ENV UNSTRUCTURED_LANGUAGE=eng
|
||||
ENV SPACY_MODEL=en_core_web_sm
|
||||
|
||||
# 日志配置
|
||||
ENV LOG_LEVEL=WARNING
|
||||
ENV DEBUG=false
|
||||
@@ -28,6 +32,13 @@ ENV DEBUG=false
|
||||
COPY requirement.txt .
|
||||
RUN pip install --no-cache-dir -r requirement.txt
|
||||
|
||||
# =============================================================================
|
||||
# 预下载 spaCy 语言模型(避免容器启动时重复下载)
|
||||
# =============================================================================
|
||||
RUN pip install --no-cache-dir spacy && \
|
||||
python -m spacy download en_core_web_sm && \
|
||||
python -m spacy download zh_core_web_sm
|
||||
|
||||
# =============================================================================
|
||||
# 复制项目代码 (只复制必需的文件夹,避免依赖被忽略的目录)
|
||||
# =============================================================================
|
||||
|
||||
@@ -51,10 +51,111 @@ graph TD
|
||||
- **核心思路**: 解决 RAG 领域经典的矛盾——检索时块越小越容易精确命中(去除噪声);但生成回答时,块越大越能给大模型提供充足的上下文背景。
|
||||
- **实现指南**:
|
||||
- 使用 `langchain.retrievers` 中的 `ParentDocumentRetriever` 模块。
|
||||
- 在写入时,你需要同时准备一个底层的 `VectorStore` (即 Qdrant) 和一个 `BaseStore` (比如原生的 `InMemoryStore` 或 `Redis`)。
|
||||
- 在写入时,你需要同时准备一个底层的 `VectorStore` (即 Qdrant) 和一个 `BaseStore`。
|
||||
- **推荐方案**: 使用 `LocalFileStore` (默认) 或 `PostgresDocStore` 作为 docstore。
|
||||
- 将两种不同的 `TextSplitter` 分别赋值给检索器的 `child_splitter` 和 `parent_splitter`,然后调用 `.add_documents()` 即可让系统自动完成映射。
|
||||
|
||||
### Level 4: GraphRAG 与 多模态 (Graph & Multi-modal)
|
||||
### Level 3.1: PostgreSQL DocStore 集成
|
||||
- **核心优势**: 利用 PostgreSQL 作为持久化存储,适合生产环境。使用同步连接池,避免异步复杂度。
|
||||
- **实现步骤**:
|
||||
1. **安装依赖**: `pip install psycopg2-binary`
|
||||
2. **配置连接**: 设置 `DB_URI` 环境变量或直接在代码中指定 PostgreSQL 连接字符串
|
||||
3. **创建 docstore**: 使用 `PostgresDocStore` 类直接创建
|
||||
4. **注入到 IndexBuilder**: 在创建 `IndexBuilder` 时通过 `docstore` 参数注入
|
||||
|
||||
- **使用示例**:
|
||||
```python
|
||||
from rag_indexer.docstore_manager import PostgresDocStore
|
||||
from rag_indexer.builder import IndexBuilder, SplitterType
|
||||
|
||||
# 创建 PostgreSQL docstore
|
||||
docstore = PostgresDocStore(
|
||||
connection_string="postgresql://user:pass@host:5432/db",
|
||||
table_name="parent_documents"
|
||||
)
|
||||
|
||||
# 创建 IndexBuilder 并注入 docstore
|
||||
builder = IndexBuilder(
|
||||
collection_name="rag_documents",
|
||||
splitter_type=SplitterType.PARENT_CHILD,
|
||||
docstore=docstore,
|
||||
parent_chunk_size=1000,
|
||||
child_chunk_size=200,
|
||||
)
|
||||
```
|
||||
|
||||
### Level 3.2: 语义切分与父子块策略结合
|
||||
- **核心优势**: 结合语义切分的连贯性和父子块策略的层次化存储优势,实现更精准的检索和更丰富的上下文。
|
||||
- **实现原理**:
|
||||
- **父块切分**: 使用递归字符切分创建大块(约1000词),提供完整的上下文背景
|
||||
- **子块切分**: 使用语义动态切分创建小块(约200词),根据语义连贯性动态切分,提高检索精度
|
||||
- **存储机制**: 子块向量存入Qdrant用于精准检索,父块内容存入PostgreSQL提供完整上下文
|
||||
- **使用示例**:
|
||||
```python
|
||||
from rag_indexer.builder import IndexBuilder, SplitterType
|
||||
|
||||
# 创建 IndexBuilder,结合语义切分与父子块策略
|
||||
builder = IndexBuilder(
|
||||
collection_name="rag_documents",
|
||||
splitter_type=SplitterType.PARENT_CHILD,
|
||||
# 父子块配置
|
||||
parent_chunk_size=1000,
|
||||
child_chunk_size=200,
|
||||
# 子块使用语义切分
|
||||
child_splitter_type=SplitterType.SEMANTIC,
|
||||
# PostgreSQL 存储配置
|
||||
docstore_conn_string="postgresql://user:pass@host:5432/db",
|
||||
)
|
||||
```
|
||||
- **配置参数**:
|
||||
- `child_splitter_type`: 子块切分器类型,可选 `SplitterType.RECURSIVE`(默认)或 `SplitterType.SEMANTIC`
|
||||
- 当使用语义切分时,系统会自动使用已配置的Embedding模型进行句子级相似度计算
|
||||
|
||||
### Level 4: RAG-Fusion (多路改写与倒数排名融合)
|
||||
- **核心优势**: 通过大模型发散思维,将单一问题改写为多个相似问题,扩大搜索面,再利用数学统计算法合并结果,提高检索的全面性和准确性。
|
||||
- **实现原理**:
|
||||
1. **多路查询改写**: 利用LLM将原始查询改写成3-5个不同表述的查询,从不同角度表达相同意图
|
||||
2. **倒数排名融合 (RRF)**: 对每个改写查询的结果进行RRF融合,公式为 $RRF\_score(d) = \sum_{q \in Q} \frac{1}{k + rank_q(d)}$,避免单一检索结果主导
|
||||
3. **结果去重**: 对融合后的结果进行去重,确保返回的文档唯一
|
||||
- **使用示例**:
|
||||
```python
|
||||
from rag_indexer.builder import IndexBuilder, SplitterType
|
||||
from langchain_openai import OpenAI
|
||||
|
||||
# 创建 IndexBuilder
|
||||
builder = IndexBuilder(
|
||||
collection_name="rag_documents",
|
||||
splitter_type=SplitterType.PARENT_CHILD,
|
||||
parent_chunk_size=1000,
|
||||
child_chunk_size=200,
|
||||
docstore_conn_string="postgresql://user:pass@host:5432/db",
|
||||
)
|
||||
|
||||
# 创建语言模型用于查询改写
|
||||
llm = OpenAI(
|
||||
openai_api_base="http://localhost:8000/v1",
|
||||
openai_api_key="no-key-needed",
|
||||
model_name="Qwen2.5-7B-Instruct",
|
||||
temperature=0.3,
|
||||
)
|
||||
|
||||
# 使用 RAG-Fusion 检索
|
||||
query = "如何申请项目资金?"
|
||||
results = builder.retrieve_with_fusion(
|
||||
query=query,
|
||||
llm=llm,
|
||||
num_queries=3,
|
||||
k=5,
|
||||
return_parent=True
|
||||
)
|
||||
```
|
||||
- **配置参数**:
|
||||
- `llm`: 语言模型实例,用于查询改写
|
||||
- `num_queries`: 生成的查询数量,建议3-5个
|
||||
- `k`: 返回的文档数量
|
||||
- `return_parent`: 是否返回父块上下文
|
||||
|
||||
### Level 5: GraphRAG 与 多模态 (Graph & Multi-modal)
|
||||
- **核心算法**: LLM 实体关系抽取 (NER & Relation Extraction)。
|
||||
- **核心思路**: 解决传统纯向量检索难以处理“跨文档复杂关系推理”的痛点(如:A公司的CEO是谁?他名下的B公司主要业务是什么?这种需要横跨多页 PDF 的跳跃性问题)。
|
||||
- **实现指南**:
|
||||
@@ -63,7 +164,7 @@ graph TD
|
||||
|
||||
---
|
||||
|
||||
## <20> 所需依赖与安装
|
||||
## 所需依赖与安装
|
||||
|
||||
为了支持完整的文档解析和 Qdrant 写入,需要安装以下 Python 包:
|
||||
|
||||
@@ -76,6 +177,12 @@ pip install unstructured pdf2image pdfminer.six
|
||||
|
||||
# 用于语义分块 (可选)
|
||||
pip install langchain-experimental
|
||||
|
||||
# 用于 PostgreSQL 存储 (可选,用于 Parent-Child 策略)
|
||||
pip install psycopg2-binary
|
||||
|
||||
# 用于 RAG-Fusion (可选,需要语言模型)
|
||||
pip install langchain-openai
|
||||
```
|
||||
|
||||
---
|
||||
@@ -91,12 +198,105 @@ rag_indexer/
|
||||
├── splitters.py # 负责实现 Recursive、Semantic、Parent-Child 切分逻辑
|
||||
├── embedders.py # 封装本地 llama.cpp 交互的 Embedding 接口
|
||||
├── vector_store.py # 封装 Qdrant 写入、Upsert、Collection 初始化操作
|
||||
├── docstore_manager.py # 文档存储管理器,支持 LocalFileStore 和 PostgreSQL
|
||||
└── builder.py # 核心编排文件,将上述模块串联成 Pipeline
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🔄 工作流程详解
|
||||
|
||||
### 数据流向总览
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────┐
|
||||
│ builder.py │
|
||||
│ IndexBuilder 入口 │
|
||||
└─────────────────┬───────────────────────┘
|
||||
│
|
||||
┌─────────────────▼───────────────────────┐
|
||||
│ loaders.py │
|
||||
│ DocumentLoader.load_file() │
|
||||
│ → 返回 List[Document] │
|
||||
└─────────────────┬───────────────────────┘
|
||||
│
|
||||
┌─────────────────▼───────────────────────┐
|
||||
│ ParentDocumentRetriever.add_documents()│
|
||||
│ ┌─────────────────────────────────┐ │
|
||||
│ │ parent_splitter (粗切) │ │
|
||||
│ │ 父块 ~1000 词 │ │
|
||||
│ └────────────┬────────────────────┘ │
|
||||
│ │ │
|
||||
│ ┌────────────▼────────────────────┐ │
|
||||
│ │ child_splitter (细切) │ │
|
||||
│ │ 子块 ~200 词 │ │
|
||||
│ └────────────┬────────────────────┘ │
|
||||
│ │ │
|
||||
│ ┌──────────┴──────────┐ │
|
||||
│ ▼ ▼ │
|
||||
│ 子块向量 父块原始内容 │
|
||||
│ │ │ │
|
||||
│ ▼ ▼ │
|
||||
│ ┌────────────┐ ┌─────────────────┐ │
|
||||
│ │vector_store│ │ docstore_manager│ │
|
||||
│ │ (Qdrant) │ │ (PostgreSQL) │ │
|
||||
│ └────────────┘ └─────────────────┘ │
|
||||
└─────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### 文件职责详解
|
||||
|
||||
| 文件 | 职责 | 关键类/函数 |
|
||||
|------|------|------------|
|
||||
| **builder.py** | 核心编排,负责串联整个流程 | `IndexBuilder` |
|
||||
| **loaders.py** | 解析各种文档格式(PDF、Word、TXT等) | `DocumentLoader` |
|
||||
| **splitters.py** | 文本切分策略(Recursive/Semantic/Parent-Child) | `SplitterType`, `get_splitter()` |
|
||||
| **embedders.py** | 向量化(封装 llama.cpp embedding 接口) | `LlamaCppEmbedder` |
|
||||
| **vector_store.py** | Qdrant 向量数据库操作 | `QdrantVectorStore` |
|
||||
| **docstore_manager.py** | 父文档存储(PostgreSQL/本地文件) | `PostgresDocStore`, `get_docstore()` |
|
||||
|
||||
### 调用顺序
|
||||
|
||||
#### 1. 创建 IndexBuilder(入口)
|
||||
|
||||
```python
|
||||
from rag_indexer.builder import IndexBuilder, SplitterType
|
||||
|
||||
builder = IndexBuilder(
|
||||
collection_name="my_docs",
|
||||
splitter_type=SplitterType.PARENT_CHILD,
|
||||
qdrant_url="http://localhost:6333",
|
||||
parent_chunk_size=1000,
|
||||
child_chunk_size=200,
|
||||
)
|
||||
```
|
||||
|
||||
#### 2. 构建索引
|
||||
|
||||
```python
|
||||
# 方式A:从单个文件构建
|
||||
builder.build_from_file("/path/to/document.pdf")
|
||||
|
||||
# 方式B:从目录批量构建
|
||||
builder.build_from_directory("/path/to/docs/")
|
||||
```
|
||||
|
||||
#### 3. 检索(获取完整父块上下文)
|
||||
|
||||
```python
|
||||
# 检索时返回完整父块
|
||||
results = builder.search_with_parent_context("查询内容")
|
||||
```
|
||||
|
||||
### 检索流程
|
||||
|
||||
```
|
||||
1. vector_store.similarity_search() → 从 Qdrant 找到相关子块
|
||||
2. retriever.get_relevant_documents() → 根据子块 ID 获取对应父块
|
||||
3. 返回完整父块给用户
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 串联与触发方式
|
||||
在你的 LangGraph 系统外,创建一个执行脚本 `scripts/run_indexer.py`:
|
||||
|
||||
@@ -1,25 +1,60 @@
|
||||
"""
|
||||
Offline RAG Indexer module.
|
||||
|
||||
提供完整的离线索引构建功能,包括:
|
||||
- 文档加载(PDF、Word、TXT 等)
|
||||
- 文本切分(递归、语义、父子块)
|
||||
- 向量嵌入(支持 llama.cpp)
|
||||
- 向量存储(Qdrant)
|
||||
- 父文档存储(PostgreSQL)
|
||||
|
||||
示例用法:
|
||||
>>> from rag_indexer import IndexBuilder, SplitterType
|
||||
>>>
|
||||
>>> builder = IndexBuilder(
|
||||
... collection_name="my_docs",
|
||||
... splitter_type=SplitterType.PARENT_CHILD,
|
||||
... qdrant_url="http://localhost:6333"
|
||||
... )
|
||||
>>>
|
||||
>>> builder.build_from_file("document.pdf")
|
||||
"""
|
||||
|
||||
from .loaders import DocumentLoader
|
||||
from .splitters import (
|
||||
RecursiveSplitter,
|
||||
SemanticSplitter,
|
||||
ParentChildSplitter,
|
||||
SplitterType,
|
||||
get_splitter,
|
||||
ParentChildSplitter,
|
||||
)
|
||||
from .embedders import LlamaCppEmbedder
|
||||
from .vector_store import QdrantVectorStore
|
||||
from .builder import IndexBuilder
|
||||
|
||||
# 导出存储相关类(从新的 store 包)
|
||||
from .store import (
|
||||
PostgresDocStore,
|
||||
create_docstore,
|
||||
)
|
||||
|
||||
|
||||
|
||||
__version__ = "2.0.0"
|
||||
|
||||
__all__ = [
|
||||
# 核心类
|
||||
"DocumentLoader",
|
||||
"RecursiveSplitter",
|
||||
"SemanticSplitter",
|
||||
"ParentChildSplitter",
|
||||
"IndexBuilder",
|
||||
|
||||
# 切分相关
|
||||
"SplitterType",
|
||||
"get_splitter",
|
||||
"ParentChildSplitter",
|
||||
|
||||
# 嵌入和向量存储
|
||||
"LlamaCppEmbedder",
|
||||
"QdrantVectorStore",
|
||||
"IndexBuilder",
|
||||
|
||||
# 存储(新的 store 包)
|
||||
"PostgresDocStore",
|
||||
"create_docstore",
|
||||
]
|
||||
@@ -1,56 +1,68 @@
|
||||
"""
|
||||
Core pipeline builder for offline RAG index construction.
|
||||
离线 RAG 索引构建核心流水线。
|
||||
|
||||
Now supports LangChain's ParentDocumentRetriever for parent-child chunking.
|
||||
支持 LangChain 的 ParentDocumentRetriever 用于父子块切分。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Union, Optional, Tuple
|
||||
from typing import List, Union, Optional, Tuple, Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
from httpx import RemoteProtocolError
|
||||
from langchain_core.documents import Document
|
||||
from langchain.retrievers import ParentDocumentRetriever
|
||||
from langchain.storage import LocalFileStore, BaseStore
|
||||
from langchain_classic.retrievers import ParentDocumentRetriever
|
||||
from langchain_core.stores import BaseStore
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
from langchain_experimental.text_splitter import SemanticChunker
|
||||
|
||||
from .loaders import DocumentLoader
|
||||
from .splitters import SplitterType, get_splitter, ParentChildSplitter
|
||||
from .splitters import SplitterType, get_splitter, ParentChildSplitter, SemanticChunkerAdapter
|
||||
from .embedders import LlamaCppEmbedder
|
||||
from .vector_store import QdrantVectorStore
|
||||
from .docstore_manager import get_docstore, PostgresDocStore, create_docstore
|
||||
from .vector_store import QdrantVectorStore, QDRANT_URL, QDRANT_API_KEY
|
||||
from .store import create_docstore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParentChildConfig:
|
||||
"""Configuration for parent-child splitting."""
|
||||
"""父子块切分配置。"""
|
||||
parent_chunk_size: int = 1000
|
||||
child_chunk_size: int = 200
|
||||
parent_chunk_overlap: int = 100
|
||||
child_chunk_overlap: int = 20
|
||||
search_k: int = 5
|
||||
docstore_path: str = None
|
||||
docstore_path: Optional[str] = None
|
||||
docstore_type: str = "local"
|
||||
docstore_conn_string: str = None
|
||||
docstore_conn_string: Optional[str] = None
|
||||
|
||||
|
||||
class IndexBuilder:
|
||||
"""Main pipeline for RAG index construction."""
|
||||
"""RAG 索引构建主流水线。"""
|
||||
|
||||
# 类型注解
|
||||
parent_splitter: "RecursiveCharacterTextSplitter"
|
||||
child_splitter: Union["RecursiveCharacterTextSplitter", "SemanticChunker", "SemanticChunkerAdapter"]
|
||||
docstore: Optional["BaseStore"]
|
||||
_docstore_conn: Optional[str]
|
||||
retriever: Optional["ParentDocumentRetriever"]
|
||||
vector_store_obj: Any
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str = "rag_documents",
|
||||
qdrant_url: str = None,
|
||||
splitter_type: SplitterType = SplitterType.RECURSIVE,
|
||||
splitter_type: SplitterType = SplitterType.PARENT_CHILD,
|
||||
docstore=None,
|
||||
**splitter_kwargs,
|
||||
):
|
||||
self.collection_name = collection_name
|
||||
self.qdrant_url = qdrant_url
|
||||
self.splitter_type = splitter_type
|
||||
self.splitter_kwargs = splitter_kwargs
|
||||
self.docstore = docstore # 从外部注入
|
||||
|
||||
# Components
|
||||
# 组件
|
||||
self.loader = DocumentLoader()
|
||||
self.embedder = LlamaCppEmbedder()
|
||||
self.embeddings = self.embedder.as_langchain_embeddings()
|
||||
@@ -58,104 +70,145 @@ class IndexBuilder:
|
||||
self.vector_store = QdrantVectorStore(
|
||||
collection_name=collection_name,
|
||||
embeddings=self.embeddings,
|
||||
qdrant_url=qdrant_url,
|
||||
)
|
||||
|
||||
# Splitter (except parent-child which is handled separately)
|
||||
# 切分器(父子块单独处理)
|
||||
if splitter_type != SplitterType.PARENT_CHILD:
|
||||
if splitter_type == SplitterType.SEMANTIC:
|
||||
splitter_kwargs["embeddings"] = self.embeddings
|
||||
self.splitter = get_splitter(splitter_type, **splitter_kwargs)
|
||||
else:
|
||||
self.splitter = None
|
||||
# Initialize ParentDocumentRetriever for parent-child splitting
|
||||
# 为父子块切分初始化 ParentDocumentRetriever
|
||||
self._init_parent_child_retriever()
|
||||
|
||||
def _init_parent_child_retriever(self, **kwargs):
|
||||
"""
|
||||
Initialize ParentDocumentRetriever for parent-child chunking.
|
||||
初始化 ParentDocumentRetriever 用于父子块切分。
|
||||
|
||||
This replaces the custom ParentChildSplitter logic.
|
||||
支持动态语义切分与父子块策略结合:
|
||||
- 父块使用递归切分(大块,提供上下文)
|
||||
- 子块可以使用递归切分或语义切分(根据语义动态切分,提高检索精度)
|
||||
|
||||
替代自定义的 ParentChildSplitter 逻辑。
|
||||
"""
|
||||
# Parse kwargs for parent-child config
|
||||
# 解析父子块配置参数
|
||||
parent_size = kwargs.get("parent_chunk_size", 1000)
|
||||
child_size = kwargs.get("child_chunk_size", 200)
|
||||
parent_overlap = kwargs.get("parent_chunk_overlap", kwargs.get("chunk_overlap", 100))
|
||||
child_overlap = kwargs.get("child_chunk_overlap", kwargs.get("chunk_overlap", 20))
|
||||
|
||||
# Define splitters
|
||||
# 子块切分器类型,默认为语义切分
|
||||
child_splitter_type = kwargs.get("child_splitter_type", SplitterType.SEMANTIC)
|
||||
|
||||
# 定义父块切分器(始终使用递归切分)
|
||||
self.parent_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=parent_size,
|
||||
chunk_overlap=parent_overlap,
|
||||
)
|
||||
|
||||
# 定义子块切分器(根据类型选择)
|
||||
if child_splitter_type == SplitterType.SEMANTIC:
|
||||
self.child_splitter = get_splitter(
|
||||
SplitterType.SEMANTIC,
|
||||
embeddings=self.embeddings,
|
||||
)
|
||||
logger.info(f"子块使用语义切分器")
|
||||
else:
|
||||
# 默认使用递归切分
|
||||
self.child_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=child_size,
|
||||
chunk_overlap=child_overlap,
|
||||
)
|
||||
logger.info(f"子块使用递归切分器,块大小: {child_size},重叠: {child_overlap}")
|
||||
|
||||
# Vector store (for child chunks)
|
||||
# 向量存储(用于子块)
|
||||
self.vector_store_obj = self.vector_store.get_langchain_vectorstore()
|
||||
|
||||
# Document store (for parent chunks)
|
||||
docstore_path = kwargs.get("docstore_path")
|
||||
docstore_type = kwargs.get("docstore_type", "local")
|
||||
# 文档存储(用于父块)
|
||||
if self.docstore is None:
|
||||
# 如果没有外部注入 docstore,则使用 PostgreSQL 创建
|
||||
docstore_conn = kwargs.get("docstore_conn_string")
|
||||
pool_config = kwargs.get("pool_config")
|
||||
max_concurrency = kwargs.get("max_concurrency")
|
||||
|
||||
if docstore_type == "postgres" and docstore_conn:
|
||||
self.docstore = PostgresDocStore(docstore_conn)
|
||||
self._docstore_conn = docstore_conn
|
||||
# 使用 create_docstore 创建 PostgreSQL 存储
|
||||
self.docstore, self._docstore_conn = create_docstore(
|
||||
connection_string=docstore_conn,
|
||||
pool_config=pool_config,
|
||||
max_concurrency=max_concurrency
|
||||
)
|
||||
else:
|
||||
self.docstore = get_docstore(docstore_path)
|
||||
# 使用外部注入的 docstore
|
||||
self._docstore_conn = None
|
||||
|
||||
# Create retriever
|
||||
# 创建检索器
|
||||
self.retriever = ParentDocumentRetriever(
|
||||
vectorstore=self.vector_store_obj,
|
||||
docstore=self.docstore,
|
||||
child_splitter=self.child_splitter,
|
||||
child_splitter=self.child_splitter, # type: ignore
|
||||
parent_splitter=self.parent_splitter,
|
||||
search_kwargs={"k": kwargs.get("search_k", 5)},
|
||||
)
|
||||
logger.info(f"ParentDocumentRetriever 已初始化,父块大小: {parent_size},子块类型: {child_splitter_type}")
|
||||
|
||||
def build_from_file(self, file_path: Union[str, Path]) -> int:
|
||||
logger.info("Loading file: %s", file_path)
|
||||
async def build_from_file(self, file_path: Union[str, Path]) -> int:
|
||||
logger.info("加载文件: %s", file_path)
|
||||
documents = self.loader.load_file(file_path)
|
||||
logger.info("Loaded %d documents", len(documents))
|
||||
return self._process_documents(documents)
|
||||
logger.info("已加载 %d 个文档", len(documents))
|
||||
return await self._process_documents(documents)
|
||||
|
||||
def build_from_directory(self, directory_path: Union[str, Path], recursive: bool = True) -> int:
|
||||
logger.info("Loading directory: %s (recursive=%s)", directory_path, recursive)
|
||||
async def build_from_directory(self, directory_path: Union[str, Path], recursive: bool = True) -> int:
|
||||
logger.info("加载目录: %s (递归=%s)", directory_path, recursive)
|
||||
documents = self.loader.load_directory(directory_path, recursive=recursive)
|
||||
logger.info("Loaded %d documents from directory", len(documents))
|
||||
return self._process_documents(documents)
|
||||
logger.info("已从目录加载 %d 个文档", len(documents))
|
||||
return await self._process_documents(documents)
|
||||
|
||||
def _process_documents(self, documents: List[Document]) -> int:
|
||||
async def _process_documents(self, documents: List[Document]) -> int:
|
||||
if not documents:
|
||||
logger.warning("No documents to process")
|
||||
logger.warning("没有文档需要处理")
|
||||
return 0
|
||||
|
||||
if self.splitter_type == SplitterType.PARENT_CHILD:
|
||||
logger.info("Using LangChain ParentDocumentRetriever")
|
||||
logger.info("使用 LangChain ParentDocumentRetriever")
|
||||
|
||||
# Ensure collection exists for child chunks
|
||||
# 确保集合存在(用于子块)
|
||||
self.vector_store.create_collection()
|
||||
|
||||
# Use ParentDocumentRetriever to add documents
|
||||
# This automatically handles parent-child splitting, mapping, and retrieval
|
||||
self.retriever.add_documents(documents)
|
||||
# 分批处理,避免单次请求过大
|
||||
assert self.retriever is not None, "retriever 未初始化"
|
||||
batch_size = 10 # 每次处理10个文档
|
||||
total = len(documents)
|
||||
processed = 0
|
||||
|
||||
for i in range(0, total, batch_size):
|
||||
batch = documents[i:i + batch_size]
|
||||
max_retries = 3
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
await self.retriever.aadd_documents(batch)
|
||||
processed += len(batch)
|
||||
logger.info(f"批次 {i//batch_size + 1}: 已处理 {processed}/{total}")
|
||||
break
|
||||
except (RemoteProtocolError, ConnectionError, OSError) as e:
|
||||
if attempt == max_retries - 1:
|
||||
raise
|
||||
logger.warning(f"连接断开,重试 ({attempt+1}/{max_retries}): {e}")
|
||||
self.vector_store.refresh_client()
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Log estimated chunk counts
|
||||
estimated_parent_chunks = len(documents) * (self.parent_splitter._chunk_size // self.child_splitter._chunk_size)
|
||||
logger.info(
|
||||
"Indexed with ParentDocumentRetriever: "
|
||||
f"~{len(documents)} parent chunks, ~{estimated_parent_chunks} child chunks"
|
||||
"已使用 ParentDocumentRetriever 索引: "
|
||||
f"共 {processed} 个父块"
|
||||
)
|
||||
return len(documents)
|
||||
return processed
|
||||
|
||||
else:
|
||||
logger.info("Splitting documents using %s", self.splitter_type)
|
||||
logger.info("使用 %s 切分文档", self.splitter_type)
|
||||
# 当 splitter_type 不是 PARENT_CHILD 时,splitter 一定不为 None
|
||||
assert self.splitter is not None, "splitter 未初始化"
|
||||
chunks = self.splitter.split_documents(documents)
|
||||
logger.info("Split into %d chunks", len(chunks))
|
||||
logger.info("已切分为 %d 个块", len(chunks))
|
||||
|
||||
self.vector_store.create_collection()
|
||||
self.vector_store.add_documents(chunks)
|
||||
@@ -165,90 +218,164 @@ class IndexBuilder:
|
||||
return self.vector_store.get_collection_info()
|
||||
|
||||
def search(self, query: str, k: int = 5) -> List[Document]:
|
||||
"""Standard search - returns child chunks."""
|
||||
"""标准搜索 - 返回子块。"""
|
||||
return self.vector_store.similarity_search(query, k=k)
|
||||
|
||||
def search_with_parent_context(self, query: str, k: int = 5) -> List[Document]:
|
||||
async def search_with_parent_context(self, query: str, k: int = 5) -> List[Document]:
|
||||
"""
|
||||
Search with parent context - returns full parent chunks.
|
||||
带父块上下文的搜索 - 返回完整父块。
|
||||
|
||||
This is the main retrieval method when using parent-child splitting.
|
||||
这是使用父子块切分时的主要检索方法。
|
||||
"""
|
||||
if self.splitter_type != SplitterType.PARENT_CHILD:
|
||||
raise RuntimeError(
|
||||
"search_with_parent_context only available with PARENT_CHILD splitter. "
|
||||
"Use search() for standard retrieval."
|
||||
"search_with_parent_context 仅在 PARENT_CHILD 切分器下可用。"
|
||||
"请使用 search() 进行标准检索。"
|
||||
)
|
||||
return self.retriever.get_relevant_documents(query, k=k)
|
||||
assert self.retriever is not None, "retriever 未初始化"
|
||||
return await self.retriever.ainvoke(query, config={"k": k}) # type: ignore
|
||||
|
||||
def retrieve(self, query: str, return_parent: bool = True) -> List[Document]:
|
||||
async def retrieve(self, query: str, return_parent: bool = True) -> List[Document]:
|
||||
"""
|
||||
Unified retrieval interface.
|
||||
统一检索接口。
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
return_parent: If True and using parent-child splitter, return parent chunks
|
||||
If False, always return child chunks
|
||||
query: 搜索查询
|
||||
return_parent: 如果为 True 且使用父子块切分,返回父块
|
||||
如果为 False,始终返回子块
|
||||
|
||||
Returns:
|
||||
List of relevant documents
|
||||
相关文档列表
|
||||
"""
|
||||
if self.splitter_type == SplitterType.PARENT_CHILD and return_parent:
|
||||
return self.search_with_parent_context(query)
|
||||
return await self.search_with_parent_context(query)
|
||||
else:
|
||||
return self.search(query)
|
||||
|
||||
async def retrieve_with_fusion(self, query: str, llm: Any, num_queries: int = 3, k: int = 5, return_parent: bool = True) -> List[Document]:
|
||||
"""
|
||||
使用 RAG-Fusion 进行检索(多路查询改写 + 倒数排名融合)。
|
||||
|
||||
核心原理:
|
||||
1. 多路查询改写: 利用 LLM 将原始查询改写成多个不同表述
|
||||
2. 倒数排名融合: 对每个改写查询的结果进行 RRF 融合,避免单一检索结果主导
|
||||
|
||||
Args:
|
||||
query: 原始搜索查询
|
||||
llm: 语言模型实例(用于查询改写)
|
||||
num_queries: 生成的查询数量
|
||||
k: 返回的文档数量
|
||||
return_parent: 如果为 True 且使用父子块切分,返回父块
|
||||
如果为 False,始终返回子块
|
||||
|
||||
Returns:
|
||||
经过融合后的相关文档列表
|
||||
"""
|
||||
from langchain.retrievers.multi_query import MultiQueryRetriever
|
||||
from langchain.retrievers import EnsembleRetriever
|
||||
|
||||
if self.splitter_type == SplitterType.PARENT_CHILD and return_parent:
|
||||
# 使用 ParentDocumentRetriever 作为基础检索器
|
||||
assert self.retriever is not None, "retriever 未初始化"
|
||||
base_retriever = self.retriever
|
||||
else:
|
||||
# 使用向量存储作为基础检索器
|
||||
base_retriever = self.vector_store.as_langchain_vectorstore().as_retriever(search_kwargs={"k": k * 2})
|
||||
|
||||
# 创建多路查询检索器
|
||||
multi_query_retriever = MultiQueryRetriever.from_llm(
|
||||
retriever=base_retriever,
|
||||
llm=llm,
|
||||
include_original=True
|
||||
)
|
||||
|
||||
# 设置自定义提示词以生成指定数量的查询
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
multi_query_retriever.llm_chain.prompt = PromptTemplate.from_template(
|
||||
"你是一个专业的查询改写助手。你的任务是将用户的问题改写成 {num_queries} 个不同的版本。\n"
|
||||
"这些版本应该从不同的角度、使用不同的关键词来表达相同或相关的意图。\n\n"
|
||||
"原始问题: {question}\n\n"
|
||||
"请生成 {num_queries} 个不同版本的查询,每个版本一行。\n"
|
||||
"确保每个版本都是独立、完整的查询语句。\n\n"
|
||||
"生成 {num_queries} 个查询:"
|
||||
)
|
||||
|
||||
# 修改调用参数以包含 num_queries
|
||||
original_ainvoke = multi_query_retriever.llm_chain.ainvoke
|
||||
async def new_ainvoke(input_dict):
|
||||
input_dict["num_queries"] = num_queries
|
||||
return await original_ainvoke(input_dict)
|
||||
multi_query_retriever.llm_chain.ainvoke = new_ainvoke
|
||||
|
||||
# 执行检索
|
||||
documents = await multi_query_retriever.ainvoke(query)
|
||||
|
||||
# 去重并限制数量
|
||||
seen_content = set()
|
||||
unique_documents = []
|
||||
for doc in documents:
|
||||
content = doc.page_content
|
||||
if content not in seen_content:
|
||||
seen_content.add(content)
|
||||
unique_documents.append(doc)
|
||||
if len(unique_documents) >= k:
|
||||
break
|
||||
|
||||
logger.info(f"RAG-Fusion 检索完成: 原始 {len(documents)} 个结果,去重后 {len(unique_documents)} 个结果")
|
||||
return unique_documents
|
||||
|
||||
def get_retriever(self) -> ParentDocumentRetriever:
|
||||
"""
|
||||
Get the ParentDocumentRetriever instance directly.
|
||||
直接获取 ParentDocumentRetriever 实例。
|
||||
|
||||
Useful for advanced use cases where you want to access the retriever
|
||||
outside of IndexBuilder.
|
||||
适用于需要在 IndexBuilder 外部访问检索器的高级用例。
|
||||
"""
|
||||
if self.splitter_type != SplitterType.PARENT_CHILD:
|
||||
raise RuntimeError(
|
||||
"get_retriever() only available with PARENT_CHILD splitter. "
|
||||
"Use search() or search_with_parent_context() for standard retrieval."
|
||||
"get_retriever() 仅在 PARENT_CHILD 切分器下可用。"
|
||||
"请使用 search() 或 search_with_parent_context() 进行标准检索。"
|
||||
)
|
||||
assert self.retriever is not None, "retriever 未初始化"
|
||||
return self.retriever
|
||||
|
||||
def get_child_splitter(self) -> "RecursiveCharacterTextSplitter":
|
||||
"""Get the child splitter for reconfiguration."""
|
||||
def get_child_splitter(self) -> Union["RecursiveCharacterTextSplitter", "SemanticChunker", "SemanticChunkerAdapter"]:
|
||||
"""获取子块切分器以便重新配置。"""
|
||||
if self.splitter_type != SplitterType.PARENT_CHILD:
|
||||
return self.splitter
|
||||
return self.splitter # type: ignore
|
||||
return self.child_splitter
|
||||
|
||||
def get_parent_splitter(self) -> "RecursiveCharacterTextSplitter":
|
||||
"""Get the parent splitter for reconfiguration."""
|
||||
"""获取父块切分器以便重新配置。"""
|
||||
if self.splitter_type != SplitterType.PARENT_CHILD:
|
||||
raise RuntimeError(
|
||||
"Parent splitter only available with PARENT_CHILD splitter."
|
||||
"父块切分器仅在 PARENT_CHILD 切分器下可用。"
|
||||
)
|
||||
return self.parent_splitter
|
||||
|
||||
def get_docstore(self) -> BaseStore:
|
||||
"""Get the document store for parent chunks."""
|
||||
"""获取父块的文档存储。"""
|
||||
if self.splitter_type != SplitterType.PARENT_CHILD:
|
||||
raise RuntimeError(
|
||||
"Docstore only available with PARENT_CHILD splitter."
|
||||
"文档存储仅在 PARENT_CHILD 切分器下可用。"
|
||||
)
|
||||
assert self.docstore is not None, "docstore 未初始化"
|
||||
return self.docstore
|
||||
|
||||
def get_docstore_path(self) -> str:
|
||||
"""Get the document store path."""
|
||||
def get_docstore_path(self) -> Optional[str]:
|
||||
"""获取文档存储路径(已弃用,仅用于兼容性)。"""
|
||||
if self.splitter_type != SplitterType.PARENT_CHILD:
|
||||
raise RuntimeError(
|
||||
"Docstore path only available with PARENT_CHILD splitter."
|
||||
"文档存储路径仅在 PARENT_CHILD 切分器下可用。"
|
||||
)
|
||||
return self.docstore.persist_path
|
||||
# PostgreSQL 存储没有 persist_path,返回 None
|
||||
return None
|
||||
|
||||
def close(self):
|
||||
"""Close resources."""
|
||||
if hasattr(self, "_docstore_conn") and self._docstore_conn:
|
||||
import psycopg2
|
||||
conn = psycopg2.connect(self._docstore_conn)
|
||||
conn.close()
|
||||
logger.info("Closed PostgreSQL connection")
|
||||
"""关闭资源。"""
|
||||
if self.docstore is not None and hasattr(self.docstore, "aclose"):
|
||||
import asyncio
|
||||
asyncio.get_event_loop().run_until_complete(self.docstore.aclose()) # type: ignore
|
||||
logger.info("PostgreSQL 异步连接池已关闭")
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
@@ -258,20 +385,8 @@ class IndexBuilder:
|
||||
return False
|
||||
|
||||
|
||||
# RecursiveCharacterTextSplitter needs to be imported
|
||||
# 需要导入 RecursiveCharacterTextSplitter
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example usage
|
||||
builder = IndexBuilder(
|
||||
splitter_type=SplitterType.PARENT_CHILD,
|
||||
parent_chunk_size=1000,
|
||||
child_chunk_size=200,
|
||||
docstore_path="./my_parent_docs",
|
||||
)
|
||||
|
||||
print("Parent splitter:", builder.get_parent_splitter().chunk_size)
|
||||
print("Child splitter:", builder.get_child_splitter().chunk_size)
|
||||
print("Docstore path:", builder.get_docstore_path())
|
||||
print("Retriever:", builder.get_retriever())
|
||||
# 示例用法已移除,请参考文档
|
||||
|
||||
@@ -3,100 +3,85 @@ Command-line interface for the RAG index builder.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from builder import IndexBuilder
|
||||
from splitters import SplitterType
|
||||
from rag_indexer.builder import IndexBuilder
|
||||
from rag_indexer.splitters import SplitterType
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
|
||||
# 基础配置
|
||||
COLLECTION_NAME = "rag_documents"
|
||||
DB_URI = "postgresql://postgres:huang1998@115.190.121.151:5432/langgraph_db?sslmode=disable"
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Offline RAG Index Builder")
|
||||
parser.add_argument("--file", type=str, help="Path to file to index")
|
||||
parser.add_argument("--dir", type=str, help="Path to directory to index")
|
||||
parser.add_argument("--recursive", action="store_true", default=True,
|
||||
help="Recursively process directories (default: True)")
|
||||
parser.add_argument("--collection", type=str, default="rag_documents",
|
||||
help="Qdrant collection name (default: rag_documents)")
|
||||
parser.add_argument("--qdrant-url", type=str,
|
||||
help="Qdrant server URL (default: http://127.0.0.1:6333)")
|
||||
parser.add_argument("--splitter", type=str,
|
||||
choices=["recursive", "semantic", "parent_child"],
|
||||
default="recursive",
|
||||
help="Text splitting strategy (default: recursive)")
|
||||
parser.add_argument("--chunk-size", type=int, default=500,
|
||||
help="Chunk size for recursive/parent splitter (default: 500)")
|
||||
parser.add_argument("--chunk-overlap", type=int, default=50,
|
||||
parser.add_argument("--docstore-path", type=str,
|
||||
default=None,
|
||||
help="Path to store parent documents for parent-child splitter (default: ./parent_docs or HERMES_HOME/parent_docs)")
|
||||
parser.add_argument("--docstore-type", type=str,
|
||||
choices=["local", "postgres"],
|
||||
default="local",
|
||||
help="Type of docstore: 'local' (default) or 'postgres' for PostgreSQL-backed storage")
|
||||
parser.add_argument("--docstore-conn", type=str,
|
||||
default=None,
|
||||
help="PostgreSQL connection string for postgres docstore")
|
||||
# 基础切分参数
|
||||
CHUNK_SIZE = 500
|
||||
CHUNK_OVERLAP = 50
|
||||
|
||||
help="Chunk overlap (default: 50)")
|
||||
parser.add_argument("--parent-size", type=int, default=1000,
|
||||
help="Parent chunk size for parent-child splitter (default: 1000)")
|
||||
parser.add_argument("--child-size", type=int, default=200,
|
||||
help="Child chunk size for parent-child splitter (default: 200)")
|
||||
# 父子块切分参数
|
||||
PARENT_CHUNK_SIZE = 1000
|
||||
CHILD_CHUNK_SIZE = 200
|
||||
PARENT_CHUNK_OVERLAP = 100
|
||||
CHILD_CHUNK_OVERLAP = 20
|
||||
|
||||
args = parser.parse_args()
|
||||
# 切分策略:basic(基础)、semantic(语义)、parent-child(父子块)
|
||||
STRATEGY = "parent-child"
|
||||
|
||||
if not args.file and not args.dir:
|
||||
print("Error: Either --file or --dir must be specified", file=sys.stderr)
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
# 存储类型:postgres(PostgreSQL)、local(本地文件)
|
||||
STORAGE_TYPE = "postgres"
|
||||
|
||||
splitter_map = {
|
||||
"recursive": SplitterType.RECURSIVE,
|
||||
"semantic": SplitterType.SEMANTIC,
|
||||
"parent_child": SplitterType.PARENT_CHILD,
|
||||
}
|
||||
splitter_type = splitter_map[args.splitter]
|
||||
|
||||
async def main():
|
||||
# 使用固定策略
|
||||
splitter_type = SplitterType.PARENT_CHILD
|
||||
child_splitter_type = SplitterType.SEMANTIC
|
||||
|
||||
splitter_kwargs = {}
|
||||
|
||||
if splitter_type == SplitterType.RECURSIVE:
|
||||
splitter_kwargs["chunk_size"] = args.chunk_size
|
||||
splitter_kwargs["chunk_overlap"] = args.chunk_overlap
|
||||
splitter_kwargs["chunk_size"] = CHUNK_SIZE
|
||||
splitter_kwargs["chunk_overlap"] = CHUNK_OVERLAP
|
||||
elif splitter_type == SplitterType.PARENT_CHILD:
|
||||
splitter_kwargs["parent_chunk_size"] = args.parent_size
|
||||
splitter_kwargs["child_chunk_size"] = args.child_size
|
||||
splitter_kwargs["parent_chunk_overlap"] = args.chunk_overlap
|
||||
splitter_kwargs["child_chunk_overlap"] = args.chunk_overlap // 2
|
||||
splitter_kwargs["docstore_path"] = args.docstore_path
|
||||
splitter_kwargs["docstore_type"] = args.docstore_type
|
||||
splitter_kwargs["docstore_conn_string"] = args.docstore_conn
|
||||
splitter_kwargs["parent_chunk_size"] = PARENT_CHUNK_SIZE
|
||||
splitter_kwargs["child_chunk_size"] = CHILD_CHUNK_SIZE
|
||||
splitter_kwargs["parent_chunk_overlap"] = PARENT_CHUNK_OVERLAP
|
||||
splitter_kwargs["child_chunk_overlap"] = CHILD_CHUNK_OVERLAP
|
||||
splitter_kwargs["child_splitter_type"] = child_splitter_type
|
||||
if STORAGE_TYPE == "postgres":
|
||||
splitter_kwargs["docstore_conn_string"] = DB_URI
|
||||
elif STORAGE_TYPE == "local":
|
||||
splitter_kwargs["docstore_path"] = "./parent_docs"
|
||||
else:
|
||||
splitter_kwargs["docstore_conn_string"] = DB_URI
|
||||
|
||||
builder = IndexBuilder(
|
||||
collection_name=args.collection,
|
||||
qdrant_url=args.qdrant_url,
|
||||
collection_name=COLLECTION_NAME,
|
||||
splitter_type=splitter_type,
|
||||
**splitter_kwargs
|
||||
)
|
||||
|
||||
try:
|
||||
if args.file:
|
||||
chunk_count = builder.build_from_file(args.file)
|
||||
else:
|
||||
chunk_count = builder.build_from_directory(args.dir, args.recursive)
|
||||
is_file=False
|
||||
path="data/corpus/"
|
||||
|
||||
print(f"Indexing completed. Total chunks indexed: {chunk_count}")
|
||||
try:
|
||||
if is_file:
|
||||
chunk_count = await builder.build_from_file(path)
|
||||
else:
|
||||
chunk_count = await builder.build_from_directory(path, recursive=True)
|
||||
|
||||
print(f"索引构建完成。共索引 {chunk_count} 个块")
|
||||
info = builder.get_collection_info()
|
||||
print(f"Collection '{info['name']}' has {info['vectors_count']} vectors (dim={info['vector_size']})")
|
||||
print(f"集合 '{info['name']}' 包含 {info['vectors_count']} 个向量(维度:{info['vector_size']})")
|
||||
|
||||
except Exception as e:
|
||||
logging.exception("Indexing failed")
|
||||
logging.exception(f"索引构建失败:{e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
asyncio.run(main())
|
||||
@@ -1,142 +0,0 @@
|
||||
"""
|
||||
Document store manager for ParentDocumentRetriever.
|
||||
|
||||
Supports both LocalFileStore (default) and custom PostgreSQL-backed stores.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
from langchain.storage import BaseStore, LocalFileStore
|
||||
|
||||
|
||||
def get_docstore(persist_path: str = None) -> LocalFileStore:
|
||||
"""
|
||||
Create and return a document store for parent chunks.
|
||||
|
||||
Args:
|
||||
persist_path: Path to store parent documents. Defaults to ./parent_docs
|
||||
or HERMES_HOME/parent_docs if set.
|
||||
"""
|
||||
if persist_path is None:
|
||||
# Use HERMES_HOME if available, otherwise default to current directory
|
||||
persist_path = os.getenv("HERMES_HOME")
|
||||
if persist_path:
|
||||
persist_path = os.path.join(persist_path, "parent_docs")
|
||||
else:
|
||||
persist_path = "./parent_docs"
|
||||
|
||||
os.makedirs(persist_path, exist_ok=True)
|
||||
return LocalFileStore(persist_path)
|
||||
|
||||
|
||||
class PostgresDocStore(BaseStore):
|
||||
"""
|
||||
PostgreSQL-backed document store for parent chunks.
|
||||
|
||||
This is an optional advanced feature. For most use cases,
|
||||
LocalFileStore is sufficient and simpler.
|
||||
"""
|
||||
|
||||
def __init__(self, connection_string: str):
|
||||
"""
|
||||
Initialize PostgreSQL document store.
|
||||
|
||||
Args:
|
||||
connection_string: PostgreSQL connection URL
|
||||
"""
|
||||
import psycopg2
|
||||
from psycopg2 import sql
|
||||
|
||||
self.conn_string = connection_string
|
||||
self._conn = None
|
||||
|
||||
# Create table if not exists
|
||||
self._create_table()
|
||||
|
||||
def _create_table(self):
|
||||
"""Create the parent documents table if not exists."""
|
||||
try:
|
||||
self._conn = psycopg2.connect(self.conn_string)
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS parent_documents (
|
||||
key TEXT PRIMARY KEY,
|
||||
value JSONB NOT NULL,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW()
|
||||
)
|
||||
""")
|
||||
self._conn.commit()
|
||||
cursor.close()
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to create PostgreSQL table: {e}")
|
||||
|
||||
def get(self, key: str) -> Optional[dict]:
|
||||
"""Retrieve a document by key."""
|
||||
try:
|
||||
self._ensure_connection()
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute("SELECT value FROM parent_documents WHERE key = %s", (key,))
|
||||
row = cursor.fetchone()
|
||||
cursor.close()
|
||||
if row:
|
||||
import json
|
||||
return json.loads(row[0])
|
||||
return None
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to retrieve document: {e}")
|
||||
|
||||
def set(self, key: str, value: dict) -> None:
|
||||
"""Store a document."""
|
||||
try:
|
||||
self._ensure_connection()
|
||||
cursor = self._conn.cursor()
|
||||
# Upsert
|
||||
insert_query = sql.SQL(
|
||||
"INSERT INTO parent_documents (key, value) VALUES (%s, %s)"
|
||||
)
|
||||
update_query = sql.SQL(
|
||||
"UPDATE parent_documents SET value = %s WHERE key = %s"
|
||||
)
|
||||
cursor.execute(insert_query, (key, json.dumps(value)))
|
||||
try:
|
||||
cursor.execute(update_query, (key, json.dumps(value)))
|
||||
except psycopg2.IntegrityError:
|
||||
pass # Key exists, ignore
|
||||
self._conn.commit()
|
||||
cursor.close()
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to store document: {e}")
|
||||
|
||||
def _ensure_connection(self):
|
||||
"""Ensure we have an open connection."""
|
||||
if self._conn is None or self._conn.closed:
|
||||
self._conn = psycopg2.connect(self.conn_string)
|
||||
|
||||
def close(self):
|
||||
"""Close the connection."""
|
||||
if self._conn and not self._conn.closed:
|
||||
self._conn.close()
|
||||
|
||||
|
||||
# Factory function for creating custom docstores
|
||||
# Returns a tuple: (BaseStore instance, connection_string or None)
|
||||
def create_docstore(
|
||||
store_type: str = "local",
|
||||
persist_path: str = None,
|
||||
connection_string: str = None
|
||||
) -> tuple:
|
||||
"""
|
||||
Factory function to create different types of document stores.
|
||||
|
||||
Args:
|
||||
store_type: "local" (default), "postgres"
|
||||
persist_path: Path for local file store
|
||||
connection_string: PostgreSQL connection string
|
||||
|
||||
Returns:
|
||||
Tuple of (BaseStore instance, connection_string or None)
|
||||
"""
|
||||
if store_type == "postgres" and connection_string:
|
||||
return (PostgresDocStore(connection_string), connection_string)
|
||||
else:
|
||||
return (get_docstore(persist_path), None)
|
||||
@@ -1,16 +1,17 @@
|
||||
"""
|
||||
Embedding model wrapper for llama.cpp service.
|
||||
嵌入模型包装器,用于 llama.cpp 服务。
|
||||
"""
|
||||
|
||||
import os
|
||||
import httpx
|
||||
from typing import List, Optional
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
|
||||
class LlamaCppEmbedder:
|
||||
"""Wrapper for llama.cpp embedding service via OpenAI-compatible API."""
|
||||
"""通过 OpenAI 兼容 API 封装 llama.cpp 嵌入服务。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -22,47 +23,66 @@ class LlamaCppEmbedder:
|
||||
self.api_key = api_key or os.getenv("LLAMACPP_API_KEY", "")
|
||||
self.model = model
|
||||
|
||||
# Ensure URL ends with /v1
|
||||
self.base_url = urljoin(self.base_url.rstrip("/") + "/", "v1")
|
||||
|
||||
def as_langchain_embeddings(self) -> OpenAIEmbeddings:
|
||||
"""Create LangChain OpenAIEmbeddings instance."""
|
||||
return OpenAIEmbeddings(
|
||||
openai_api_base=self.base_url,
|
||||
openai_api_key=self.api_key,
|
||||
model=self.model,
|
||||
)
|
||||
def as_langchain_embeddings(self) -> Embeddings:
|
||||
"""创建 LangChain 兼容的嵌入实例。"""
|
||||
return _LlamaCppLangchainAdapter(self)
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed a list of documents."""
|
||||
emb = self.as_langchain_embeddings()
|
||||
return emb.embed_documents(texts)
|
||||
"""嵌入一批文档。"""
|
||||
return self._call_embedding_api(texts)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Embed a single query."""
|
||||
emb = self.as_langchain_embeddings()
|
||||
return emb.embed_query(text)
|
||||
"""嵌入单个查询。"""
|
||||
return self._call_embedding_api([text])[0]
|
||||
|
||||
def get_embedding_dimension(self) -> int:
|
||||
"""Get embedding dimension by embedding a test string."""
|
||||
"""通过嵌入测试字符串获取嵌入维度。"""
|
||||
test_embedding = self.embed_query("test")
|
||||
return len(test_embedding)
|
||||
|
||||
def _call_embedding_api(self, texts: List[str]) -> List[List[float]]:
|
||||
"""直接调用 llama.cpp 嵌入 API。"""
|
||||
base = self.base_url.rstrip("/")
|
||||
if not base.endswith("/v1"):
|
||||
base = base + "/v1"
|
||||
|
||||
class MockEmbedder:
|
||||
"""Mock embedder for testing without a real service."""
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
def __init__(self, dimension: int = 768):
|
||||
self.dimension = dimension
|
||||
payload = {
|
||||
"input": texts,
|
||||
"model": self.model,
|
||||
}
|
||||
|
||||
def as_langchain_embeddings(self) -> OpenAIEmbeddings:
|
||||
raise NotImplementedError("MockEmbedder cannot be used as LangChain embeddings")
|
||||
with httpx.Client(timeout=120) as client:
|
||||
response = client.post(
|
||||
f"{base}/embeddings",
|
||||
headers=headers,
|
||||
json=payload,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
# 处理不同响应格式
|
||||
if isinstance(data, list):
|
||||
# llama.cpp 直接返回列表
|
||||
return [item["embedding"] for item in data]
|
||||
elif isinstance(data, dict) and "data" in data:
|
||||
# OpenAI 标准格式
|
||||
return [item["embedding"] for item in sorted(data["data"], key=lambda x: x["index"])]
|
||||
else:
|
||||
raise ValueError(f"未知的嵌入 API 响应格式: {data}")
|
||||
|
||||
|
||||
class _LlamaCppLangchainAdapter(Embeddings):
|
||||
"""将 LlamaCppEmbedder 适配为 LangChain Embeddings 接口。"""
|
||||
|
||||
def __init__(self, embedder: LlamaCppEmbedder):
|
||||
self._embedder = embedder
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
return [[0.0] * self.dimension for _ in texts]
|
||||
return self._embedder.embed_documents(texts)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
return [0.0] * self.dimension
|
||||
|
||||
def get_embedding_dimension(self) -> int:
|
||||
return self.dimension
|
||||
return self._embedder.embed_query(text)
|
||||
|
||||
@@ -1,124 +0,0 @@
|
||||
"""
|
||||
Example demonstrating ParentDocumentRetriever usage.
|
||||
|
||||
This script shows how to:
|
||||
1. Build an index with parent-child chunking
|
||||
2. Search with child chunks (fast, precise)
|
||||
3. Search with parent context (large context)
|
||||
4. Access the retriever directly for advanced use cases
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
|
||||
from builder import IndexBuilder
|
||||
from splitters import SplitterType
|
||||
|
||||
|
||||
def main():
|
||||
print("=" * 70)
|
||||
print("ParentDocumentRetriever Example")
|
||||
print("=" * 70)
|
||||
|
||||
# Step 1: Create IndexBuilder with parent-child splitting
|
||||
print("\n1. Creating IndexBuilder with parent-child splitting...")
|
||||
builder = IndexBuilder(
|
||||
collection_name="parent_child_demo",
|
||||
splitter_type=SplitterType.PARENT_CHILD,
|
||||
parent_chunk_size=1000, # Parent chunks: larger context
|
||||
child_chunk_size=200, # Child chunks: smaller for precision
|
||||
docstore_path="./my_parent_docs", # Where to store parent chunks
|
||||
search_k=5, # Number of child chunks to retrieve
|
||||
)
|
||||
|
||||
print(f" Parent splitter: chunk_size={builder.get_parent_splitter().chunk_size}")
|
||||
print(f" Child splitter: chunk_size={builder.get_child_splitter().chunk_size}")
|
||||
print(f" Docstore path: {builder.get_docstore_path()}")
|
||||
print(f" Search k: {builder.retriever.search_kwargs['k']}")
|
||||
|
||||
# Step 2: Build index from a sample file
|
||||
print("\n2. Building index from sample file...")
|
||||
|
||||
# Create a test document
|
||||
test_content = """
|
||||
This is a test document for demonstrating ParentDocumentRetriever.
|
||||
|
||||
Parent chunks contain larger portions of text (1000 characters),
|
||||
while child chunks are smaller (200 characters) for precise retrieval.
|
||||
|
||||
When you search with ParentDocumentRetriever:
|
||||
- It first retrieves relevant child chunks
|
||||
- Then replaces them with their corresponding parent chunks
|
||||
- This gives you large context while maintaining precision
|
||||
|
||||
Example search queries:
|
||||
- "ParentDocumentRetriever"
|
||||
- "child chunks"
|
||||
- "large context"
|
||||
- "precise retrieval"
|
||||
"""
|
||||
|
||||
test_file = Path("./test_document.txt")
|
||||
test_file.write_text(test_content)
|
||||
|
||||
chunk_count = builder.build_from_file(str(test_file))
|
||||
print(f" Indexed {chunk_count} documents")
|
||||
|
||||
# Step 3: Search with child chunks (fast, precise)
|
||||
print("\n3. Searching with child chunks (fast, precise)...")
|
||||
child_results = builder.search("ParentDocumentRetriever", k=3)
|
||||
print(f" Found {len(child_results)} child chunks:")
|
||||
for i, doc in enumerate(child_results, 1):
|
||||
print(f" [{i}] {doc.page_content[:100]}...")
|
||||
|
||||
# Step 4: Search with parent context (large context)
|
||||
print("\n4. Searching with parent context (large context)...")
|
||||
parent_results = builder.search_with_parent_context("ParentDocumentRetriever", k=3)
|
||||
print(f" Found {len(parent_results)} parent chunks:")
|
||||
for i, doc in enumerate(parent_results, 1):
|
||||
print(f" [{i}] {doc.page_content[:150]}...")
|
||||
|
||||
# Step 5: Compare results
|
||||
print("\n5. Comparing child vs parent results...")
|
||||
print(f" Child chunks total length: {sum(len(d.page_content) for d in child_results)}")
|
||||
print(f" Parent chunks total length: {sum(len(d.page_content) for d in parent_results)}")
|
||||
print(f" Ratio: parent/child = {sum(len(d.page_content) for d in parent_results) / max(sum(len(d.page_content) for d in child_results), 1):.2f}x larger")
|
||||
|
||||
# Step 6: Access retriever directly
|
||||
print("\n6. Accessing retriever directly...")
|
||||
retriever = builder.get_retriever()
|
||||
print(f" Retriever type: {type(retriever).__name__}")
|
||||
print(f" Vectorstore: {retriever.vectorstore}")
|
||||
print(f" Docstore: {retriever.docstore}")
|
||||
|
||||
# Step 7: Unified retrieval interface
|
||||
print("\n7. Using unified retrieval interface...")
|
||||
unified_results = builder.retrieve("ParentDocumentRetriever", return_parent=True)
|
||||
print(f" Retrieved {len(unified_results)} documents (with parent context)")
|
||||
|
||||
# Step 8: Collection info
|
||||
print("\n8. Collection info...")
|
||||
info = builder.get_collection_info()
|
||||
print(f" Collection: {info['name']}")
|
||||
print(f" Vectors: {info['vectors_count']}")
|
||||
print(f" Vector size: {info['vector_size']}")
|
||||
|
||||
# Cleanup
|
||||
print("\n9. Cleaning up...")
|
||||
builder.close()
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("Example completed successfully!")
|
||||
print("=" * 70)
|
||||
|
||||
return builder
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
builder = main()
|
||||
@@ -1,10 +1,10 @@
|
||||
"""
|
||||
Document loaders using unstructured library.
|
||||
文档加载器,使用 unstructured 库解析文档。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
from typing import Any, Dict, List, Mapping, Optional, Union
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from unstructured.partition.auto import partition
|
||||
@@ -13,33 +13,74 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DocumentLoader:
|
||||
"""Load documents from various file formats."""
|
||||
"""从各种文件格式加载文档。"""
|
||||
|
||||
SUPPORTED_EXTENSIONS = {".pdf", ".docx", ".doc", ".txt", ".md", ".html", ".pptx", ".xlsx"}
|
||||
SUPPORTED_EXTENSIONS = {".pdf", ".docx", ".doc", ".txt", ".md", ".html", ".pptx", ".xlsx", ".json"}
|
||||
|
||||
def __init__(self, extract_images: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
extract_images: bool = False,
|
||||
strategy: str = "auto",
|
||||
ocr_languages: Optional[List[str]] = None,
|
||||
languages: Optional[List[str]] = None,
|
||||
include_page_breaks: bool = False,
|
||||
pdf_infer_table_structure: bool = True,
|
||||
partition_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
extract_images: Whether to extract images from documents (requires additional dependencies)
|
||||
extract_images: 是否提取 PDF 中的图片
|
||||
strategy: 解析策略 (auto, fast, hi_res, ocr_only)
|
||||
ocr_languages: OCR 语言列表,如 ['chi_sim', 'eng']
|
||||
languages: 文档主语言,如 ['zh']
|
||||
include_page_breaks: 是否包含分页符
|
||||
pdf_infer_table_structure: 是否识别表格结构 (需 hi_res 策略)
|
||||
partition_kwargs: 额外的 partition 参数字典(高级定制)
|
||||
"""
|
||||
import os
|
||||
os.environ["UNSTRUCTURED_LANGUAGE_CHECKS"] = "false"
|
||||
self.extract_images = extract_images
|
||||
self.strategy = strategy
|
||||
self.ocr_languages = ocr_languages or ["chi_sim", "eng"]
|
||||
self.languages = languages or ["zh"]
|
||||
self.include_page_breaks = include_page_breaks
|
||||
self.pdf_infer_table_structure = pdf_infer_table_structure
|
||||
self.partition_kwargs = partition_kwargs or {}
|
||||
|
||||
def load_file(self, file_path: Union[str, Path]) -> List[Document]:
|
||||
"""Load a single file into LangChain Document objects."""
|
||||
"""将单个文件加载为 LangChain Document 对象。"""
|
||||
file_path = Path(file_path).resolve()
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
raise FileNotFoundError(f"文件不存在: {file_path}")
|
||||
|
||||
suffix = file_path.suffix.lower()
|
||||
if suffix not in self.SUPPORTED_EXTENSIONS:
|
||||
raise ValueError(
|
||||
f"Unsupported file extension: {suffix}. Supported: {self.SUPPORTED_EXTENSIONS}"
|
||||
f"不支持的文件扩展名: {suffix}。支持的格式: {self.SUPPORTED_EXTENSIONS}"
|
||||
)
|
||||
|
||||
# Parse with unstructured
|
||||
# 根据文件类型动态调整参数
|
||||
extra_kwargs = {}
|
||||
if suffix == ".pdf":
|
||||
extra_kwargs["strategy"] = self.strategy
|
||||
extra_kwargs["ocr_languages"] = self.ocr_languages
|
||||
extra_kwargs["extract_images_in_pdf"] = self.extract_images
|
||||
extra_kwargs["pdf_infer_table_structure"] = self.pdf_infer_table_structure
|
||||
|
||||
# languages 参数适用于所有文件类型
|
||||
if self.languages:
|
||||
extra_kwargs["languages"] = self.languages
|
||||
|
||||
extra_kwargs["include_page_breaks"] = self.include_page_breaks
|
||||
|
||||
# 合并用户自定义的额外参数(优先级最高)
|
||||
extra_kwargs.update(self.partition_kwargs)
|
||||
|
||||
# 使用 unstructured 解析
|
||||
elements = partition(
|
||||
filename=str(file_path),
|
||||
extract_images_in_pdf=self.extract_images,
|
||||
|
||||
**extra_kwargs
|
||||
)
|
||||
|
||||
documents = []
|
||||
@@ -48,23 +89,17 @@ class DocumentLoader:
|
||||
if not text or not text.strip():
|
||||
continue
|
||||
|
||||
# Base metadata
|
||||
# 基础元数据
|
||||
metadata = {
|
||||
"source": str(file_path),
|
||||
"file_name": file_path.name,
|
||||
"file_type": suffix,
|
||||
}
|
||||
|
||||
# Merge element-specific metadata without overwriting base fields
|
||||
elem_meta = getattr(elem, "metadata", {}) or {}
|
||||
for key, value in elem_meta.items():
|
||||
if value and key not in metadata:
|
||||
metadata[key] = value
|
||||
|
||||
documents.append(Document(page_content=text, metadata=metadata))
|
||||
|
||||
if not documents:
|
||||
logger.warning("No text content extracted from %s", file_path)
|
||||
logger.warning("未从 %s 提取到文本内容", file_path)
|
||||
return []
|
||||
|
||||
return documents
|
||||
@@ -72,10 +107,10 @@ class DocumentLoader:
|
||||
def load_directory(
|
||||
self, directory_path: Union[str, Path], recursive: bool = True
|
||||
) -> List[Document]:
|
||||
"""Load all supported files from a directory."""
|
||||
"""从目录加载所有支持的文件。"""
|
||||
directory_path = Path(directory_path).resolve()
|
||||
if not directory_path.is_dir():
|
||||
raise NotADirectoryError(f"Not a directory: {directory_path}")
|
||||
raise NotADirectoryError(f"不是目录: {directory_path}")
|
||||
|
||||
all_documents = []
|
||||
pattern = "**/*" if recursive else "*"
|
||||
@@ -86,6 +121,6 @@ class DocumentLoader:
|
||||
docs = self.load_file(file_path)
|
||||
all_documents.extend(docs)
|
||||
except Exception as e:
|
||||
logger.error("Failed to load %s: %s", file_path, e)
|
||||
logger.error("加载 %s 失败: %s", file_path, e)
|
||||
|
||||
return all_documents
|
||||
@@ -1,12 +1,12 @@
|
||||
"""
|
||||
Text splitters for chunking documents.
|
||||
文本切分器,用于将文档切分成块。
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
|
||||
from langchain_experimental.text_splitter import SemanticChunker
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ class SplitterType(str, Enum):
|
||||
|
||||
|
||||
def get_splitter(splitter_type: SplitterType, **kwargs):
|
||||
"""Factory function to create a text splitter."""
|
||||
"""工厂函数,创建文本切分器。"""
|
||||
if splitter_type == SplitterType.RECURSIVE:
|
||||
chunk_size = kwargs.get("chunk_size", 500)
|
||||
chunk_overlap = kwargs.get("chunk_overlap", 50)
|
||||
@@ -27,19 +27,31 @@ def get_splitter(splitter_type: SplitterType, **kwargs):
|
||||
separators=["\n\n", "\n", "。", "!", "?", " ", ""],
|
||||
)
|
||||
elif splitter_type == SplitterType.SEMANTIC:
|
||||
# Requires embeddings for semantic splitting
|
||||
embeddings = kwargs.get("embeddings")
|
||||
embeddings = kwargs.pop("embeddings", None)
|
||||
if embeddings is None:
|
||||
raise ValueError("Semantic splitter requires 'embeddings' parameter")
|
||||
return SemanticChunker(embeddings=embeddings)
|
||||
raise ValueError("语义切分器需要提供 'embeddings' 参数")
|
||||
return SemanticChunkerAdapter(embeddings=embeddings, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unsupported splitter type: {splitter_type}")
|
||||
raise ValueError(f"不支持的切分器类型: {splitter_type}")
|
||||
|
||||
|
||||
class SemanticChunkerAdapter(TextSplitter):
|
||||
"""将 SemanticChunker 适配为 TextSplitter 接口。"""
|
||||
|
||||
def __init__(self, embeddings, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
chunk_size = kwargs.pop("chunk_size", None)
|
||||
chunk_overlap = kwargs.pop("chunk_overlap", None)
|
||||
self._chunker = SemanticChunker(embeddings=embeddings, **kwargs)
|
||||
|
||||
def split_text(self, text: str) -> List[str]:
|
||||
return self._chunker.split_text(text)
|
||||
|
||||
|
||||
class ParentChildSplitter:
|
||||
"""
|
||||
Splits documents into parent (large) and child (small) chunks.
|
||||
Child chunks are indexed for retrieval, parent chunks are stored for context.
|
||||
将文档切分为父块(大块)和子块(小块)。
|
||||
子块用于索引检索,父块用于存储上下文。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -60,12 +72,12 @@ class ParentChildSplitter:
|
||||
|
||||
def split_documents(self, documents: List[Document]) -> tuple[List[Document], List[Document]]:
|
||||
"""
|
||||
Returns:
|
||||
(parent_chunks, child_chunks)
|
||||
返回:
|
||||
(父块列表, 子块列表)
|
||||
"""
|
||||
parent_chunks = self.parent_splitter.split_documents(documents)
|
||||
child_chunks = self.child_splitter.split_documents(documents)
|
||||
|
||||
# Link child chunks to parent IDs (optional metadata)
|
||||
# In a real implementation, you'd map each child to a parent chunk ID.
|
||||
# 将子块与父块 ID 关联(可选元数据)
|
||||
# 在实际实现中,需要将每个子块映射到对应的父块 ID。
|
||||
return parent_chunks, child_chunks
|
||||
31
rag_indexer/store/__init__.py
Normal file
31
rag_indexer/store/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""
|
||||
文档存储模块 - 用于 ParentDocumentRetriever 的父文档存储。
|
||||
|
||||
提供 PostgreSQL 存储后端:
|
||||
- PostgresDocStore: PostgreSQL 数据库存储(生产环境)
|
||||
|
||||
示例用法:
|
||||
>>> from rag_indexer.store import create_docstore
|
||||
|
||||
>>> # 创建 PostgreSQL 存储
|
||||
>>> store, conn = create_docstore(
|
||||
... connection_string="postgresql://user:pass@host:5432/db",
|
||||
... table_name="parent_docs"
|
||||
... )
|
||||
"""
|
||||
|
||||
|
||||
from .postgres import PostgresDocStore
|
||||
from .factory import create_docstore, get_docstore_uri, DEFAULT_DB_URI
|
||||
|
||||
__version__ = "2.0.0"
|
||||
|
||||
__all__ = [
|
||||
# 具体实现
|
||||
"PostgresDocStore",
|
||||
|
||||
# 工厂函数
|
||||
"create_docstore",
|
||||
"get_docstore_uri",
|
||||
"DEFAULT_DB_URI",
|
||||
]
|
||||
73
rag_indexer/store/factory.py
Normal file
73
rag_indexer/store/factory.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""
|
||||
文档存储工厂 - 创建不同类型的存储实例。
|
||||
|
||||
提供统一的接口来创建本地文件存储或 PostgreSQL 存储。
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from langchain_core.stores import BaseStore
|
||||
from .postgres import PostgresDocStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 默认连接字符串(从环境变量读取)
|
||||
DEFAULT_DB_URI = os.getenv(
|
||||
"DB_URI",
|
||||
"postgresql://postgres:huang1998@ai-postgres:5432/langgraph_db?sslmode=disable"
|
||||
)
|
||||
|
||||
|
||||
def get_docstore_uri() -> str:
|
||||
"""获取 docstore 专用的数据库连接字符串(可与主库相同)"""
|
||||
return os.getenv("DOCSTORE_URI", DEFAULT_DB_URI)
|
||||
|
||||
|
||||
def create_docstore(
|
||||
store_type: str = "postgres",
|
||||
connection_string: Optional[str] = None,
|
||||
table_name: str = "parent_documents",
|
||||
pool_config: Optional[dict] = None,
|
||||
max_concurrency: Optional[int] = None
|
||||
) -> Tuple[BaseStore, Optional[str]]:
|
||||
"""
|
||||
工厂函数,创建 PostgreSQL 文档存储。
|
||||
|
||||
Args:
|
||||
store_type: 存储类型,目前仅支持 "postgres"(默认)
|
||||
connection_string: PostgreSQL 连接字符串
|
||||
table_name: PostgreSQL 表名(默认:parent_documents)
|
||||
pool_config: 连接池配置
|
||||
max_concurrency: 最大并发操作数,如果为 None 则不限制
|
||||
|
||||
Returns:
|
||||
元组 (存储实例, 连接字符串)
|
||||
|
||||
Raises:
|
||||
ValueError: 不支持的存储类型
|
||||
ImportError: 缺少必要的依赖
|
||||
|
||||
Example:
|
||||
>>> # 创建 PostgreSQL 存储
|
||||
>>> store, conn = create_docstore(
|
||||
... connection_string="postgresql://user:pass@host:5432/db",
|
||||
... table_name="parent_docs",
|
||||
... max_concurrency=10
|
||||
... )
|
||||
"""
|
||||
store_type = store_type.lower()
|
||||
|
||||
if store_type == "postgres":
|
||||
conn_str = connection_string or get_docstore_uri()
|
||||
store = PostgresDocStore(
|
||||
connection_string=conn_str,
|
||||
table_name=table_name,
|
||||
pool_config=pool_config,
|
||||
max_concurrency=max_concurrency
|
||||
)
|
||||
return store, conn_str
|
||||
|
||||
else:
|
||||
raise ValueError(f"不支持的存储类型: {store_type}。目前仅支持: postgres")
|
||||
249
rag_indexer/store/postgres.py
Normal file
249
rag_indexer/store/postgres.py
Normal file
@@ -0,0 +1,249 @@
|
||||
"""
|
||||
异步 PostgreSQL 存储实现 - 用于生产环境。
|
||||
|
||||
使用 asyncpg 实现真正的异步 PostgreSQL 文档存储,支持高并发访问。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional, Iterator, Tuple, Sequence, cast
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.stores import BaseStore
|
||||
|
||||
import asyncpg
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PostgresDocStore(BaseStore[str, Any]):
|
||||
"""
|
||||
异步 PostgreSQL 文档存储实现。
|
||||
|
||||
使用 asyncpg 作为异步 PostgreSQL 客户端,支持:
|
||||
- 真正的异步操作
|
||||
- 连接池管理
|
||||
- 自动表创建
|
||||
- 批量操作(amget/amset/amdelete)
|
||||
- JSONB 数据存储
|
||||
- 并发控制
|
||||
|
||||
适用于生产环境,提供高性能的异步数据持久化。
|
||||
|
||||
Attributes:
|
||||
dsn: PostgreSQL 连接字符串
|
||||
table_name: 存储表名,默认为 "parent_documents"
|
||||
_pool: asyncpg 连接池实例
|
||||
_semaphore: 控制并发数的信号量(可选)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection_string: str,
|
||||
table_name: str = "parent_documents",
|
||||
pool_config: Optional[Dict[str, Any]] = None,
|
||||
max_concurrency: Optional[int] = None
|
||||
):
|
||||
"""
|
||||
初始化异步 PostgreSQL 文档存储。
|
||||
|
||||
Args:
|
||||
connection_string: PostgreSQL 连接 URL,格式:
|
||||
"postgresql://user:password@host:port/database?sslmode=disable"
|
||||
table_name: 存储表名,默认为 "parent_documents"
|
||||
pool_config: 连接池配置字典,包含:
|
||||
- min_size: 最小连接数(默认 2)
|
||||
- max_size: 最大连接数(默认 10)
|
||||
max_concurrency: 最大并发操作数,如果为 None 则不限制
|
||||
|
||||
Raises:
|
||||
ImportError: 未安装 asyncpg 时抛出
|
||||
|
||||
Example:
|
||||
>>> store = PostgresDocStore(
|
||||
... "postgresql://user:pass@localhost:5432/mydb",
|
||||
... table_name="parent_docs",
|
||||
... pool_config={"min_size": 5, "max_size": 20},
|
||||
... max_concurrency=10
|
||||
... )
|
||||
"""
|
||||
|
||||
|
||||
self.dsn = connection_string
|
||||
self.table_name = table_name
|
||||
self._pool: Optional["asyncpg.Pool"] = None
|
||||
self._pool_config = pool_config or {}
|
||||
|
||||
# 并发控制信号量
|
||||
self._semaphore = None
|
||||
if max_concurrency is not None and max_concurrency > 0:
|
||||
self._semaphore = asyncio.Semaphore(max_concurrency)
|
||||
|
||||
# 注意:连接池的异步初始化延迟到第一次使用时
|
||||
# 表结构创建也延迟到第一次操作时
|
||||
|
||||
async def _get_pool(self):
|
||||
"""获取或创建 asyncpg 连接池。"""
|
||||
if self._pool is None:
|
||||
import asyncpg
|
||||
min_size = self._pool_config.get("min_size", 2)
|
||||
max_size = self._pool_config.get("max_size", 10)
|
||||
|
||||
try:
|
||||
self._pool = await asyncpg.create_pool(
|
||||
dsn=self.dsn,
|
||||
min_size=min_size,
|
||||
max_size=max_size
|
||||
)
|
||||
logger.info(f"PostgreSQL 异步连接池已创建: {self.table_name}")
|
||||
|
||||
# 初始化表结构
|
||||
await self._create_table()
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"PostgreSQL 异步连接池创建失败: {e}")
|
||||
|
||||
return self._pool
|
||||
|
||||
async def _create_table(self):
|
||||
"""创建存储表(如果不存在)。"""
|
||||
pool = await self._get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
async with conn.transaction():
|
||||
await conn.execute(f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.table_name} (
|
||||
key TEXT PRIMARY KEY,
|
||||
value JSONB NOT NULL,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW()
|
||||
)
|
||||
""")
|
||||
logger.info(f"表 {self.table_name} 已就绪")
|
||||
|
||||
async def _with_concurrency_control(self, coro):
|
||||
"""使用信号量控制并发执行。"""
|
||||
if self._semaphore is None:
|
||||
return await coro
|
||||
async with self._semaphore:
|
||||
return await coro
|
||||
|
||||
# --- 同步方法(保持兼容性,但功能有限)---
|
||||
|
||||
def mget(self, keys: Sequence[str]) -> List[Optional[Any]]:
|
||||
"""不支持同步操作,请使用异步 amget 方法。"""
|
||||
raise NotImplementedError("不支持同步操作,请使用异步 amget 方法。")
|
||||
|
||||
def mset(self, key_value_pairs: Sequence[Tuple[str, Any]]) -> None:
|
||||
"""不支持同步操作,请使用异步 amset 方法。"""
|
||||
raise NotImplementedError("不支持同步操作,请使用异步 amset 方法。")
|
||||
|
||||
def mdelete(self, keys: Sequence[str]) -> None:
|
||||
"""不支持同步操作,请使用异步 amdelete 方法。"""
|
||||
raise NotImplementedError("不支持同步操作,请使用异步 amdelete 方法。")
|
||||
|
||||
def yield_keys(self, *, prefix: str | None = None) -> Iterator[str]:
|
||||
"""不支持同步操作,请使用异步 ayield_keys 方法。"""
|
||||
raise NotImplementedError("不支持同步操作,请使用异步 ayield_keys 方法。")
|
||||
|
||||
# --- 异步方法(真正的实现)---
|
||||
|
||||
async def amget(self, keys: Sequence[str]) -> List[Optional[Any]]:
|
||||
"""异步批量获取文档。"""
|
||||
if not keys:
|
||||
return []
|
||||
|
||||
async def _amget():
|
||||
pool = await self._get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
rows = await conn.fetch(
|
||||
f"SELECT key, value FROM {self.table_name} WHERE key = ANY($1)",
|
||||
keys
|
||||
)
|
||||
result_map = {}
|
||||
for row in rows:
|
||||
val = row['value']
|
||||
if isinstance(val, str):
|
||||
val = json.loads(val)
|
||||
if isinstance(val, dict) and 'page_content' in val:
|
||||
result_map[row['key']] = Document(**val)
|
||||
else:
|
||||
result_map[row['key']] = val
|
||||
return [result_map.get(key) for key in keys]
|
||||
|
||||
return await self._with_concurrency_control(_amget())
|
||||
|
||||
async def amset(self, key_value_pairs: Sequence[Tuple[str, Any]]) -> None:
|
||||
"""异步批量设置文档。"""
|
||||
if not key_value_pairs:
|
||||
return
|
||||
|
||||
async def _amset():
|
||||
pool = await self._get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
async with conn.transaction():
|
||||
await conn.executemany(
|
||||
f"""
|
||||
INSERT INTO {self.table_name} (key, value)
|
||||
VALUES ($1, $2)
|
||||
ON CONFLICT (key) DO UPDATE SET value = EXCLUDED.value
|
||||
""",
|
||||
[
|
||||
(k, json.dumps(v.dict() if isinstance(v, Document) else v, ensure_ascii=False))
|
||||
for k, v in key_value_pairs
|
||||
]
|
||||
)
|
||||
logger.debug(f"已异步批量设置 {len(key_value_pairs)} 个文档")
|
||||
|
||||
await self._with_concurrency_control(_amset())
|
||||
|
||||
async def amdelete(self, keys: Sequence[str]) -> None:
|
||||
"""异步批量删除文档。"""
|
||||
if not keys:
|
||||
return
|
||||
|
||||
async def _amdelete():
|
||||
pool = await self._get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
async with conn.transaction():
|
||||
await conn.execute(
|
||||
f"DELETE FROM {self.table_name} WHERE key = ANY($1)",
|
||||
keys
|
||||
)
|
||||
logger.debug(f"已异步批量删除 {len(keys)} 个文档")
|
||||
|
||||
await self._with_concurrency_control(_amdelete())
|
||||
|
||||
async def ayield_keys(self, *, prefix: str | None = None) -> Iterator[str]:
|
||||
"""异步迭代所有键。
|
||||
|
||||
注意:这是一个异步生成器,需要使用 async for 迭代。
|
||||
"""
|
||||
pool = await self._get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
if prefix:
|
||||
rows = await conn.fetch(
|
||||
f"SELECT key FROM {self.table_name} WHERE key LIKE $1 ORDER BY key",
|
||||
f"{prefix}%"
|
||||
)
|
||||
else:
|
||||
rows = await conn.fetch(
|
||||
f"SELECT key FROM {self.table_name} ORDER BY key"
|
||||
)
|
||||
|
||||
for row in rows:
|
||||
yield row['key']
|
||||
|
||||
async def aclose(self) -> None:
|
||||
"""异步关闭连接池,释放资源。"""
|
||||
if self._pool:
|
||||
await self._pool.close()
|
||||
self._pool = None
|
||||
logger.info("PostgreSQL 异步连接池已关闭")
|
||||
|
||||
def close(self) -> None:
|
||||
"""同步关闭连接池(功能有限)。
|
||||
|
||||
注意:在异步环境中,请使用 aclose 方法。
|
||||
"""
|
||||
pass
|
||||
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
Qdrant vector store wrapper.
|
||||
Qdrant 向量数据库包装器。
|
||||
"""
|
||||
|
||||
import logging
|
||||
@@ -16,67 +16,85 @@ from .embedders import LlamaCppEmbedder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
|
||||
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
|
||||
|
||||
|
||||
class QdrantVectorStore:
|
||||
"""Wrapper for Qdrant vector database operations."""
|
||||
"""Qdrant 向量数据库操作包装器。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str,
|
||||
embeddings: Optional[Any] = None,
|
||||
qdrant_url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
):
|
||||
self.collection_name = collection_name
|
||||
self.qdrant_url = qdrant_url or os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
|
||||
self.api_key = api_key
|
||||
self._client: Optional[QdrantClient] = None
|
||||
|
||||
# Embeddings
|
||||
# 嵌入模型
|
||||
if embeddings is None:
|
||||
embedder = LlamaCppEmbedder()
|
||||
self.embeddings = embedder.as_langchain_embeddings()
|
||||
else:
|
||||
self.embeddings = embeddings
|
||||
|
||||
# Qdrant client
|
||||
self.client = QdrantClient(url=self.qdrant_url, api_key=self.api_key)
|
||||
# 先创建集合
|
||||
self.create_collection()
|
||||
|
||||
# LangChain vector store
|
||||
# LangChain 向量存储
|
||||
self.vector_store = LangchainQdrantVS(
|
||||
client=self.client,
|
||||
client=self.get_client(),
|
||||
collection_name=self.collection_name,
|
||||
embeddings=self.embeddings,
|
||||
embedding=self.embeddings,
|
||||
)
|
||||
|
||||
def get_client(self) -> QdrantClient:
|
||||
"""懒加载客户端,每次获取时确保连接可用。"""
|
||||
if self._client is None:
|
||||
self._client = QdrantClient(
|
||||
url=QDRANT_URL,
|
||||
api_key=QDRANT_API_KEY,
|
||||
timeout=120,
|
||||
http2=False,
|
||||
)
|
||||
return self._client
|
||||
|
||||
def refresh_client(self):
|
||||
"""关闭旧连接,创建新连接。"""
|
||||
if self._client is not None:
|
||||
self._client.close()
|
||||
self._client = None
|
||||
|
||||
def create_collection(self, vector_size: Optional[int] = None, force_recreate: bool = False):
|
||||
"""Create collection with appropriate vector size."""
|
||||
"""创建集合,设置合适的向量维度。"""
|
||||
if vector_size is None:
|
||||
embedder = LlamaCppEmbedder()
|
||||
vector_size = embedder.get_embedding_dimension()
|
||||
|
||||
collections = self.client.get_collections().collections
|
||||
client = self.get_client()
|
||||
collections = client.get_collections().collections
|
||||
exists = any(c.name == self.collection_name for c in collections)
|
||||
|
||||
if exists and force_recreate:
|
||||
self.client.delete_collection(self.collection_name)
|
||||
client.delete_collection(self.collection_name)
|
||||
exists = False
|
||||
|
||||
if not exists:
|
||||
self.client.create_collection(
|
||||
client.create_collection(
|
||||
collection_name=self.collection_name,
|
||||
vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE),
|
||||
)
|
||||
logger.info("Collection '%s' created (dim=%d)", self.collection_name, vector_size)
|
||||
logger.info("集合 '%s' 已创建(维度=%d)", self.collection_name, vector_size)
|
||||
else:
|
||||
logger.info("Collection '%s' already exists", self.collection_name)
|
||||
logger.info("集合 '%s' 已存在", self.collection_name)
|
||||
|
||||
def add_documents(self, documents: List[Document], batch_size: int = 100):
|
||||
"""Add documents to vector store."""
|
||||
"""将文档添加到向量数据库。"""
|
||||
if not documents:
|
||||
return []
|
||||
self.create_collection()
|
||||
ids = self.vector_store.add_documents(documents, batch_size=batch_size)
|
||||
logger.info("Added %d documents to '%s'", len(ids), self.collection_name)
|
||||
logger.info("已向 '%s' 添加 %d 个文档", self.collection_name, len(ids))
|
||||
return ids
|
||||
|
||||
def similarity_search(self, query: str, k: int = 5) -> List[Document]:
|
||||
@@ -86,16 +104,21 @@ class QdrantVectorStore:
|
||||
return self.vector_store.similarity_search_with_score(query, k=k)
|
||||
|
||||
def delete_collection(self):
|
||||
self.client.delete_collection(self.collection_name)
|
||||
logger.info("Collection '%s' deleted", self.collection_name)
|
||||
self.get_client().delete_collection(self.collection_name)
|
||||
logger.info("集合 '%s' 已删除", self.collection_name)
|
||||
|
||||
def get_collection_info(self) -> Dict[str, Any]:
|
||||
info = self.client.get_collection(self.collection_name)
|
||||
info = self.get_client().get_collection(self.collection_name)
|
||||
vectors_config = info.config.params.vectors
|
||||
if isinstance(vectors_config, dict):
|
||||
vector_size = next(iter(vectors_config.values())).size
|
||||
else:
|
||||
vector_size = vectors_config.size
|
||||
return {
|
||||
"name": info.name,
|
||||
"vectors_count": info.vectors_count,
|
||||
"name": self.collection_name,
|
||||
"vectors_count": info.points_count or 0,
|
||||
"status": info.status,
|
||||
"vector_size": info.config.params.vectors.size,
|
||||
"vector_size": vector_size,
|
||||
}
|
||||
|
||||
def as_langchain_vectorstore(self):
|
||||
@@ -107,4 +130,4 @@ class QdrantVectorStore:
|
||||
|
||||
def get_qdrant_client(self):
|
||||
"""返回原生 Qdrant 客户端(如需手动管理 collection)"""
|
||||
return self.client
|
||||
return self.get_client()
|
||||
@@ -49,7 +49,8 @@ python-dotenv==1.2.2
|
||||
typing-extensions==4.15.0
|
||||
|
||||
unstructured>=0.0.1
|
||||
|
||||
spacy>=3.7.0
|
||||
langchain_experimental>=0.0.1
|
||||
# ============================================================================
|
||||
# 注意:
|
||||
# 1. 此文件包含项目直接依赖的精确版本
|
||||
|
||||
Reference in New Issue
Block a user