From 933d418d77844d47ab6bfef12fbf337d7efd8e1f Mon Sep 17 00:00:00 2001
From: root <953994191@qq.com>
Date: Sun, 19 Apr 2026 22:01:55 +0800
Subject: [PATCH] =?UTF-8?q?=E6=A3=80=E7=B4=A2=E5=99=A8=E9=87=8D=E6=9E=84?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
.gitignore | 2 +
app/rag/README.md | 129 ++++---
app/rag/__init__.py | 51 ++-
app/rag/example.py | 159 ++++----
app/rag/pipeline.py | 387 ++++++-------------
app/rag/query_transform.py | 203 ++--------
app/rag/reranker.py | 152 ++------
app/rag/retriever.py | 134 ++++---
app/rag/tools.py | 273 ++++----------
rag_core/__init__.py | 18 +
{rag_indexer => rag_core}/embedders.py | 5 +-
{rag_indexer => rag_core}/store/__init__.py | 2 +-
{rag_indexer => rag_core}/store/factory.py | 2 +-
{rag_indexer => rag_core}/store/postgres.py | 2 +-
{rag_indexer => rag_core}/vector_store.py | 8 +-
rag_indexer/IndexBuilder.py | 299 +++++++++++++++
rag_indexer/README.md | 297 +++++++--------
rag_indexer/__init__.py | 44 +--
rag_indexer/builder.py | 392 --------------------
rag_indexer/cli.py | 86 ++---
rag_indexer/loaders.py | 142 ++++---
rag_indexer/splitters.py | 210 +++++++++--
rag_indexer/test/reset_index.py | 80 ++++
rag_indexer/test/test_inspect_vectors.py | 63 ++++
rag_indexer/test/test_refactored.py | 83 +++++
rag_indexer/test/test_validate_index.py | 188 ++++++++++
26 files changed, 1694 insertions(+), 1717 deletions(-)
create mode 100644 rag_core/__init__.py
rename {rag_indexer => rag_core}/embedders.py (94%)
rename {rag_indexer => rag_core}/store/__init__.py (92%)
rename {rag_indexer => rag_core}/store/factory.py (98%)
rename {rag_indexer => rag_core}/store/postgres.py (99%)
rename {rag_indexer => rag_core}/vector_store.py (96%)
create mode 100644 rag_indexer/IndexBuilder.py
delete mode 100644 rag_indexer/builder.py
create mode 100644 rag_indexer/test/reset_index.py
create mode 100644 rag_indexer/test/test_inspect_vectors.py
create mode 100644 rag_indexer/test/test_refactored.py
create mode 100644 rag_indexer/test/test_validate_index.py
diff --git a/.gitignore b/.gitignore
index ff0b849..9e8ca1d 100644
--- a/.gitignore
+++ b/.gitignore
@@ -15,6 +15,8 @@
!scripts/**
!rag_indexer/
!rag_indexer/**
+!rag_core/
+!rag_core/**
!docker/
!docker/**
!.gitea/
diff --git a/app/rag/README.md b/app/rag/README.md
index 19f27f4..b0e0423 100644
--- a/app/rag/README.md
+++ b/app/rag/README.md
@@ -2,71 +2,44 @@
该模块负责 RAG 系统的阶段二:**在线检索与生成**。它接收用户提问,从知识库中检索出上下文,利用各种高级策略去噪、融合,并作为增强上下文输入给大语言模型 (LLM)。
-## 📊 RAG-Fusion & 混合检索流水线示意图
-
-```mermaid
-graph TD
- User((用户提问)) --> A[LLM 查询改写生成器]
-
- subgraph RAG-Fusion 核心流程
- A -->|改写为问题 1| B1[查询 1]
- A -->|改写为问题 2| B2[查询 2]
- A -->|原问题| B3[原始查询]
-
- B1 & B2 & B3 --> C[混合检索器 Hybrid Retriever
Dense Vector + BM25 Sparse]
-
- C --> D[多路召回结果合集 N=60条]
-
- D --> E{RRF 倒数排名融合去重}
- end
-
- E -->|筛选出前 20 条| F[Cross-Encoder 重排器 Reranker]
-
- F -->|精细打分排序 Top 5| G[最终纯净上下文 Context]
-
- G --> H[将 Context 与原问题拼接输入大模型]
- H --> I((LLM 生成最终回答))
-```
-
----
-
## 🎯 演进路线与算法详解 (Roadmap)
### Level 1: 基础向量搜索 (Basic Similarity Search)
- **核心算法**: 近似最近邻搜索 (ANN, 常用 HNSW 算法)。将用户问题转化为向量后,计算它与库中向量的余弦相似度 (Cosine Similarity),取距离最近的 K 个块。
- **优缺点**: 速度极快。但只能捕捉“语义相似”,如果用户搜索特定专有名词、编号、订单号,纯向量检索往往会失效(产生“幻觉”匹配)。
+- **实现指南**:
+ - 使用 `rag_indexer.embedders.LlamaCppEmbedder` 作为嵌入模型
+ - 使用 `app/rag/retriever.py` 中的 `create_base_retriever` 创建基础检索器
+ - 配置 `search_kwargs={"k": 20}` 进行初步召回
### Level 2: 混合检索与重排序 (Hybrid Search + Reranker)
混合检索旨在结合向量的“语义泛化”与关键词的“精准匹配”,随后利用重排序模型过滤噪声。
**1. 基础召回 (混合检索)**
- **核心原理**: 结合基于 HNSW 的 Dense Vector 相似度搜索与基于 TF-IDF 的 BM25 稀疏检索 (Sparse Vector)。
-- **实现指南**: 使用 `langchain_qdrant` 包中的 `Qdrant` 类连接数据库。通过调用 `Qdrant.from_existing_collection(...)` 实例化向量库,并使用 `.as_retriever(search_kwargs={"k": 20})` 方法生成基础检索器。Qdrant 底层会自动处理双路召回。
+- **实现指南**: 使用 `app/rag/retriever.py` 中的 `create_hybrid_retriever` 函数,配置 `dense_k=10` 和 `sparse_k=10`,总召回 20 条结果。
**2. 二次精排 (Cross-Encoder)**
- **核心原理**: 不同于双塔模型(分别算向量再求距离),交叉编码器将“用户问题 + 检索到的单例文档”拼接后整体输入 Transformer 模型,由模型直接输出 0~1 的相关性得分,精度极高。
-- **实现指南**:
- - 使用 `sentence-transformers` 库加载本地轻量级重排模型(如 `BAAI/bge-reranker-base`)。
- - 引入 `langchain.retrievers.document_compressors` 包中的 `CrossEncoderReranker` 类包装该模型,设置参数 `top_n=5`。
- - 最后,使用 `langchain.retrievers` 包中的 `ContextualCompressionRetriever` 类,将 `base_compressor` (重排器) 和 `base_retriever` (基础检索器) 组合。
- - **如何调用**: 业务逻辑中直接对组合后的检索器调用 `.invoke(query)` 方法,即可一键完成“大范围召回 20 条 -> 逐一打分精排选 5 条”的去噪流水线。
+- **实现指南**:
+ - 使用 `app/rag/reranker.py` 中的 `CrossEncoderReranker` 类,加载 `BAAI/bge-reranker-base` 模型
+ - 设置 `top_n=5` 保留最相关的 5 条结果
+ - 使用 `ContextualCompressionRetriever` 组合基础检索器和重排序器
### Level 3: RAG-Fusion (多路改写与倒数排名融合)
RAG-Fusion 通过大模型发散思维,将单一问题改写为多个相似问题,扩大搜索面,再利用数学统计算法合并结果。
**1. 多路查询改写**
- **核心原理**: 克服用户初始提问词不达意或视角受限的问题。
-- **实现指南**: 导入 `langchain.retrievers.multi_query` 包中的 `MultiQueryRetriever` 类。需向其提供一个已实例化的 LLM 对象(如基于 `ChatOpenAI` 封装的本地 VLLM 模型)。系统在底层会自动 Prompt 模型,将原始 `query` 转化为包含 3-5 个不同表述的查询列表。
+- **实现指南**: 使用 `app/rag/query_transform.py` 中的 `MultiQueryTransformer` 类,配置 `num_queries=3` 生成 3 个不同角度的查询。
**2. 倒数排名融合 (RRF)**
- **核心原理**: RRF (Reciprocal Rank Fusion) 是一种无需评分归一化的融合算法。公式为 $RRF\_score(d) = \sum_{q \in Q} \frac{1}{k + rank_q(d)}$,有效避免某一极端检索结果主导全局。
-- **实现指南**:
- - 针对每个改写后的查询 $q$,分别调用精排检索器的 `.invoke(q)` 获取文档列表。
- - 使用 `langchain.retrievers` 中的 `EnsembleRetriever` 类(原生支持 RRF),或在代码中遍历收集到的 `Document` 对象,基于其排名 `rank` 累加得分,最终通过 Python 的 `set` 去重并提取 `doc.page_content`。
+- **实现指南**: 使用 `app/rag/retriever.py` 中的 `create_ensemble_retriever` 函数,配置 `search_type="rrf"` 实现倒数排名融合。
### Level 4: Agentic RAG / Self-RAG (智能体与自我反思)
- **核心原理**: 基于 LangGraph 的 ReAct (Reasoning and Acting) 状态机路由。大模型并非每次都去死板地执行检索,而是先判断问题:“这是闲聊?还是需要查知识库?”。如果是后者,模型输出一个 `ToolCall` 指令,触发检索。
-- **实现指南**: 请参考下方的**与现有系统整合调用**章节。
+- **实现指南**: 使用 `app/rag/tools.py` 中的 `search_knowledge_base` 工具,将其绑定到 LangGraph 状态机中。
- **示意图**:
```mermaid
@@ -87,6 +60,13 @@ RAG-Fusion 通过大模型发散思维,将单一问题改写为多个相似问
LangGraph Agent-->>User: "根据知识库规定,报销流程分为以下3步..."
```
+### Level 5: GraphRAG 集成 (基于图和关系的 RAG)
+- **核心原理**: 结合知识图谱的结构化关系和向量检索的语义相似度,解决跨文档复杂关系推理问题。
+- **实现指南**:
+ - 使用 `langchain_community.graphs` 模块构建知识图谱
+ - 配置本地大模型(如 `Gemma-4-E2B`)用于实体关系抽取
+ - 实现混合检索逻辑,结合向量相似度和图路径分析
+
---
## 📦 所需依赖与安装
@@ -102,18 +82,19 @@ pip install rank_bm25
# 基础框架
pip install langchain langchain-core langchain-openai langchain-qdrant
+
+# 与 rag_indexer 共享的依赖
+pip install qdrant-client httpx
```
---
## 📂 架构与文件结构设计
-在 `app/rag/` 目录下,需创建以下文件来模块化上述功能:
-
-```text
+```
app/rag/
├── __init__.py
-├── retriever.py # 负责 Qdrant 的基础召回与 ContextualCompressionRetriever
+├── retriever.py # 负责 Qdrant 的基础召回与混合检索
├── reranker.py # 负责加载 sentence-transformers 交叉编码器
├── query_transform.py # 负责基于 MultiQueryRetriever 的改写逻辑
├── pipeline.py # 组合上述组件,暴露出核心的 retrieve() 方法
@@ -122,15 +103,69 @@ app/rag/
---
-## � 与现有系统整合调用 (Agentic RAG 实现)
+## 🔄 与 rag_indexer 集成
+
+### 数据结构兼容性
+- **向量存储**: rag_indexer 使用 Qdrant 存储子块向量,app/rag 直接从相同集合读取
+- **文档存储**: rag_indexer 使用 PostgreSQL 存储父块,app/rag 通过 `ParentDocumentRetriever` 关联
+- **嵌入模型**: 共享 `LlamaCppEmbedder` 确保向量空间一致性
+
+### 配置共享
+- **环境变量**: QDRANT_URL、QDRANT_API_KEY、DB_URI 等配置在两个模块间共享
+- **集合名称**: 默认使用 "rag_documents" 集合,确保数据一致性
+
+---
+
+## 🚀 与现有系统整合调用 (Agentic RAG 实现)
基于目前 LangGraph 系统的架构,我们将摒弃将代码堆砌在一起的旧方式,而是利用 **LangChain Tools** 的特性将 RAG 优雅地注入系统:
-1. **封装检索工具 (Tool)**:
+1. **封装检索工具 (Tool)**:
从 `langchain.tools` 导入 `@tool` 装饰器。定义一个名为 `search_knowledge_base(query: str)` 的函数。在函数内部,实例化并调用我们在 `pipeline.py` 中写好的多路召回与重排逻辑。
-2. **模型绑定 (Bind)**:
+
+2. **模型绑定 (Bind)**:
在 `app/agent.py` 或 `app/nodes/tool_call.py` 中,将这个工具引入,并通过 `llm.bind_tools([search_knowledge_base])` 绑定到现有的本地大模型实例上。
-3. **状态机路由 (Graph Routing)**:
+
+3. **状态机路由 (Graph Routing)**:
你的 LangGraph 状态机会像处理普通对话一样自动接管:当模型判断需要调用查阅规章制度或专业资料时,它会输出 `ToolCall` 消息,流转到 `tool_node` 执行上述的 RAG 检索逻辑并返回上下文。
这让你无需修改任何前端 Streamlit 流式代码,就能平滑升级为具备超级知识库检索能力的智能体 (Agent)!
+
+---
+
+## 🎯 快速开始
+
+```python
+# 1. 初始化嵌入模型
+from rag_indexer.embedders import LlamaCppEmbedder
+embeddings = LlamaCppEmbedder()
+
+# 2. 初始化语言模型(用于 RAG-Fusion)
+from langchain_openai import OpenAI
+llm = OpenAI(
+ openai_api_base="http://localhost:8000/v1",
+ openai_api_key="no-key-needed",
+ model_name="Qwen2.5-7B-Instruct",
+ temperature=0.3,
+)
+
+# 3. 创建 RAG 流水线
+from app.rag.pipeline import RAGPipeline, RAGLevel
+pipeline = RAGPipeline(
+ embeddings=embeddings,
+ llm=llm,
+ config={
+ "collection_name": "rag_documents",
+ "rag_level": RAGLevel.FUSION.value,
+ "num_queries": 3,
+ "rerank_top_n": 5,
+ },
+)
+
+# 4. 执行检索
+result = pipeline.retrieve("如何申请项目资金?")
+
+# 5. 格式化上下文
+context = pipeline.format_context(result.documents)
+print(context)
+```
diff --git a/app/rag/__init__.py b/app/rag/__init__.py
index 0d44285..623bb8f 100644
--- a/app/rag/__init__.py
+++ b/app/rag/__init__.py
@@ -1,22 +1,53 @@
"""
-在线 RAG 检索与生成系统
+RAG 检索与生成模块
-提供高级RAG检索功能,支持混合检索、重排序、RAG-Fusion和多路查询改写。
+提供在线检索与生成功能,包括:
+- 基础向量检索
+- 重排序
+- RAG-Fusion
+- Agentic RAG
+
+示例用法:
+ >>> from app.rag import RAGPipeline, search_knowledge_base
+ >>> from rag_core import LlamaCppEmbedder
+ >>>
+ >>> embeddings = LlamaCppEmbedder()
+ >>> pipeline = RAGPipeline(embeddings=embeddings)
+ >>>
+ >>> documents = pipeline.retrieve("戏耍貂蝉美女")
+ >>> context = pipeline.format_context(documents)
"""
-from .pipeline import RAGPipeline
-from .retriever import create_hybrid_retriever, create_base_retriever
+from .retriever import (
+ create_base_retriever,
+ create_hybrid_retriever,
+ # create_ensemble_retriever,
+ create_qdrant_client,
+)
from .reranker import CrossEncoderReranker
from .query_transform import MultiQueryTransformer
-from .tools import search_knowledge_base_tool
+from .pipeline import RAGPipeline, RAGLevel
+from .tools import search_knowledge_base, search_knowledge_base_sync
+
__all__ = [
- "RAGPipeline",
- "create_hybrid_retriever",
+ # 检索器
"create_base_retriever",
+ "create_hybrid_retriever",
+ # "create_ensemble_retriever",
+ "create_qdrant_client",
+
+ # 重排序器
"CrossEncoderReranker",
+
+ # 查询转换器
"MultiQueryTransformer",
- "search_knowledge_base_tool",
+
+ # 流水线
+ "RAGPipeline",
+ "RAGLevel",
+
+ # 工具
+ "search_knowledge_base",
+ "search_knowledge_base_sync",
]
-
-__version__ = "0.1.0"
\ No newline at end of file
diff --git a/app/rag/example.py b/app/rag/example.py
index 53d82fe..8042b09 100644
--- a/app/rag/example.py
+++ b/app/rag/example.py
@@ -7,6 +7,10 @@ RAG 系统使用示例
import sys
import os
+from dotenv import load_dotenv
+
+# 加载环境变量
+load_dotenv()
# 添加项目根目录到路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
@@ -19,10 +23,13 @@ def setup_environment():
"""设置环境变量"""
# 设置 Qdrant 连接信息(根据实际情况修改)
os.environ.setdefault("QDRANT_URL", "http://115.190.121.151:6333")
+ # 设置 Qdrant API 密钥(根据实际情况修改)
+ os.environ.setdefault("QDRANT_API_KEY", "your-api-key-here")
# 如果需要 API 密钥,请设置 QDRANT_API_KEY
print("环境变量已设置")
print(f"QDRANT_URL: {os.environ.get('QDRANT_URL')}")
+ print(f"QDRANT_API_KEY: {'***' if os.environ.get('QDRANT_API_KEY') else '未设置'}")
def demonstrate_basic_rag():
@@ -31,37 +38,32 @@ def demonstrate_basic_rag():
print("演示: 基础 RAG 检索 (Level 1)")
print("="*60)
- # 创建嵌入模型(使用 OpenAI 兼容的本地模型)
- embeddings = OpenAIEmbeddings(
- openai_api_base="http://localhost:8000/v1", # 本地 VLLM 服务
- openai_api_key="no-key-needed",
- model="text-embedding-ada-002", # 假设的模型名称
- )
+ # 创建嵌入模型(使用本地 LlamaCpp 模型)
+ from rag_core import LlamaCppEmbedder
+ embedder = LlamaCppEmbedder()
+ embeddings = embedder.as_langchain_embeddings()
# 创建 RAG 流水线
- from app.rag import RAGPipeline, RAGConfig, RAGLevel
-
- config = RAGConfig(
- collection_name="documents", # 你的集合名称
- rag_level=RAGLevel.BASIC,
- verbose=True,
- )
+ from app.rag import RAGPipeline, RAGLevel
pipeline = RAGPipeline(
embeddings=embeddings,
- config=config,
+ config={
+ "collection_name": "rag_documents", # 你的集合名称
+ "rag_level": RAGLevel.BASIC.value,
+ }
)
# 示例查询
- query = "公司报销流程是什么?"
+ query = "吕布"
print(f"\n查询: {query}")
try:
- result = pipeline.retrieve(query)
- print(f"找到 {len(result.documents)} 个相关文档")
+ documents = pipeline.retrieve(query)
+ print(f"找到 {len(documents)} 个相关文档")
# 格式化上下文
- context = pipeline.format_context(result.documents)
+ context = pipeline.format_context(documents)
print(f"\n上下文预览:\n{context[:500]}...")
except Exception as e:
@@ -75,34 +77,31 @@ def demonstrate_hybrid_rag():
print("演示: 混合 RAG 检索 (Level 2)")
print("="*60)
- embeddings = OpenAIEmbeddings(
- openai_api_base="http://localhost:8000/v1",
- openai_api_key="no-key-needed",
- model="text-embedding-ada-002",
- )
+ from rag_core import LlamaCppEmbedder
+ embedder = LlamaCppEmbedder()
+ embeddings = embedder.as_langchain_embeddings()
- from app.rag import RAGPipeline, RAGConfig, RAGLevel
-
- config = RAGConfig(
- collection_name="documents",
- rag_level=RAGLevel.HYBRID,
- dense_k=10,
- sparse_k=10,
- rerank_top_n=5,
- verbose=True,
- )
+ from app.rag import RAGPipeline, RAGLevel
pipeline = RAGPipeline(
embeddings=embeddings,
- config=config,
+ config={
+ "collection_name": "rag_documents",
+ "rag_level": RAGLevel.RERANK.value,
+ "rerank_top_n": 5,
+ }
)
- query = "如何申请年假?"
+ query = "吕布"
print(f"\n查询: {query}")
try:
- result = pipeline.retrieve(query)
- print(f"找到 {len(result.documents)} 个重排序后的文档")
+ documents = pipeline.retrieve(query)
+ print(f"找到 {len(documents)} 个重排序后的文档")
+
+ # 格式化上下文
+ context = pipeline.format_context(documents)
+ print(f"\n上下文预览:\n{context[:500]}...")
except Exception as e:
print(f"检索失败: {e}")
@@ -114,42 +113,42 @@ def demonstrate_rag_fusion():
print("演示: RAG-Fusion (Level 3)")
print("="*60)
- embeddings = OpenAIEmbeddings(
- openai_api_base="http://localhost:8000/v1",
- openai_api_key="no-key-needed",
- model="text-embedding-ada-002",
- )
+ from rag_core import LlamaCppEmbedder
+ embedder = LlamaCppEmbedder()
+ embeddings = embedder.as_langchain_embeddings()
- # 创建语言模型用于查询改写
- llm = VLLMOpenAI(
+ # 创建语言模型用于查询改写(使用 OpenAI 兼容的本地模型)
+ from langchain_openai import ChatOpenAI
+ llm = ChatOpenAI(
openai_api_base="http://localhost:8000/v1",
openai_api_key="no-key-needed",
- model_name="Qwen2.5-7B-Instruct", # 你的本地模型
+ model="Qwen2.5-7B-Instruct", # 你的本地模型
temperature=0.3,
max_tokens=512,
)
- from app.rag import RAGPipeline, RAGConfig, RAGLevel
-
- config = RAGConfig(
- collection_name="documents",
- rag_level=RAGLevel.FUSION,
- num_queries=3,
- verbose=True,
- )
+ from app.rag import RAGPipeline, RAGLevel
pipeline = RAGPipeline(
embeddings=embeddings,
llm=llm,
- config=config,
+ config={
+ "collection_name": "rag_documents",
+ "rag_level": RAGLevel.FUSION.value,
+ "num_queries": 3,
+ }
)
- query = "项目上线需要哪些审批?"
+ query = "吕布"
print(f"\n查询: {query}")
try:
- result = pipeline.retrieve(query)
- print(f"找到 {len(result.documents)} 个文档 (经过多路查询改写和重排序)")
+ documents = pipeline.retrieve(query)
+ print(f"找到 {len(documents)} 个文档 (经过多路查询改写和重排序)")
+
+ # 格式化上下文
+ context = pipeline.format_context(documents)
+ print(f"\n上下文预览:\n{context[:500]}...")
except Exception as e:
print(f"检索失败: {e}")
@@ -161,44 +160,16 @@ def demonstrate_agentic_rag():
print("演示: Agentic RAG (Level 4)")
print("="*60)
- embeddings = OpenAIEmbeddings(
- openai_api_base="http://localhost:8000/v1",
- openai_api_key="no-key-needed",
- model="text-embedding-ada-002",
- )
-
- llm = VLLMOpenAI(
- openai_api_base="http://localhost:8000/v1",
- openai_api_key="no-key-needed",
- model_name="Qwen2.5-7B-Instruct",
- temperature=0.3,
- max_tokens=512,
- )
-
- from app.rag import create_agentic_rag_pipeline
+ from app.rag import search_knowledge_base_sync
try:
- # 创建 Agentic RAG 流水线
- agentic_rag = create_agentic_rag_pipeline(
- embeddings=embeddings,
- agent_llm=llm,
- config={
- "collection_name": "documents",
- "verbose": True,
- },
- )
-
- print("Agentic RAG 流水线创建成功!")
- print(f"- 绑定的模型: {agentic_rag['llm']}")
- print(f"- RAG 工具: {agentic_rag['tool'].name}")
-
# 演示工具调用
- print("\n工具调用示例:")
- response = agentic_rag["tool"].invoke({"query": "员工福利有哪些?"})
+ print("工具调用示例:")
+ response = search_knowledge_base_sync("吕布")
print(f"工具响应预览: {response[:200]}...")
except Exception as e:
- print(f"创建 Agentic RAG 失败: {e}")
+ print(f"工具调用失败: {e}")
import traceback
traceback.print_exc()
@@ -211,11 +182,11 @@ def main():
# 设置环境
setup_environment()
- # 演示各级功能
+ # 演示基础功能
demonstrate_basic_rag()
demonstrate_hybrid_rag()
- demonstrate_rag_fusion()
- demonstrate_agentic_rag()
+ # demonstrate_rag_fusion() # 需要本地 LLM 服务
+ # demonstrate_agentic_rag() # 需要本地 LLM 服务
print("\n" + "="*60)
print("演示完成!")
@@ -223,8 +194,8 @@ def main():
print("\n使用说明:")
print("1. 确保 Qdrant 服务运行且集合已创建")
- print("2. 根据需要修改 embeddings 和 llm 配置")
- print("3. 在 Agent 系统中导入并使用 app.rag.tools.search_knowledge_base_tool")
+ print("2. 已使用本地 LlamaCpp 嵌入模型")
+ print("3. 在 Agent 系统中导入并使用 app.rag.tools.search_knowledge_base")
print("4. 将工具绑定到你的 Agent 模型")
diff --git a/app/rag/pipeline.py b/app/rag/pipeline.py
index 0e6d0bc..e5eba7c 100644
--- a/app/rag/pipeline.py
+++ b/app/rag/pipeline.py
@@ -1,341 +1,168 @@
"""
RAG 检索流水线
-组合检索器、重排序器、查询改写器等组件,提供完整的 RAG 检索功能。
+整合基础检索、重排序和 RAG-Fusion 功能。
"""
-import time
-from typing import List, Dict, Any, Optional, Union
-from dataclasses import dataclass, field
from enum import Enum
+from typing import List, Optional, Dict, Any
from langchain_core.documents import Document
-from langchain_core.embeddings import Embeddings
from langchain_core.language_models import BaseLanguageModel
from .retriever import (
create_base_retriever,
create_hybrid_retriever,
- create_ensemble_retriever,
create_qdrant_client,
)
from .reranker import CrossEncoderReranker
-from .query_transform import MultiQueryTransformer, create_rag_fusion_pipeline
+from .query_transform import MultiQueryTransformer
+from rag_core import QDRANT_URL, QDRANT_API_KEY
class RAGLevel(Enum):
- """RAG 功能级别"""
- BASIC = 1 # 基础向量搜索
- HYBRID = 2 # 混合检索 + 重排序
- FUSION = 3 # RAG-Fusion
- AGENTIC = 4 # Agentic RAG
-
-
-@dataclass
-class RAGConfig:
- """RAG 配置"""
- # Qdrant 配置
- collection_name: str = "documents"
- qdrant_url: Optional[str] = None
- qdrant_api_key: Optional[str] = None
-
- # 检索配置
- rag_level: RAGLevel = RAGLevel.FUSION
- dense_k: int = 10 # 向量检索数量
- sparse_k: int = 10 # BM25 检索数量
- total_k: int = 20 # 总检索数量
- rerank_top_n: int = 5 # 重排序返回数量
-
- # 查询改写配置
- num_queries: int = 3 # RAG-Fusion 查询数量
-
- # 模型配置
- reranker_model: str = "BAAI/bge-reranker-base"
- device: Optional[str] = None
-
- # 性能配置
- enable_cache: bool = True
- verbose: bool = True
-
-
-@dataclass
-class RetrievalResult:
- """检索结果"""
- documents: List[Document]
- query_time: float
- level: RAGLevel
- metadata: Dict[str, Any] = field(default_factory=dict)
+ """RAG 级别"""
+ BASIC = "basic" # 基础向量检索
+ RERANK = "rerank" # 基础检索 + 重排序
+ FUSION = "fusion" # RAG-Fusion(多路查询 + RRF)
class RAGPipeline:
- """
- RAG 检索流水线
-
- 支持从 Level 1 到 Level 4 的所有功能。
- """
+ """RAG 检索流水线"""
def __init__(
self,
- embeddings: Embeddings,
+ embeddings,
llm: Optional[BaseLanguageModel] = None,
- config: Optional[RAGConfig] = None,
+ config: Optional[Dict[str, Any]] = None,
):
"""
初始化 RAG 流水线
Args:
embeddings: 嵌入模型
- llm: 语言模型(用于查询改写,Level 3+ 需要)
- config: 配置
+ llm: 语言模型(用于 RAG-Fusion)
+ config: 配置参数
"""
self.embeddings = embeddings
self.llm = llm
- self.config = config or RAGConfig()
+ self.config = config or {}
- # 初始化组件
- self._client = None
- self._reranker = None
- self._query_transformer = None
- self._retriever = None
+ self.collection_name = self.config.get("collection_name", "rag_documents")
+ self.rag_level = self.config.get("rag_level", RAGLevel.RERANK.value)
+ self.num_queries = self.config.get("num_queries", 3)
+ self.rerank_top_n = self.config.get("rerank_top_n", 5)
- # 缓存
- self._cache = {}
-
- def _get_client(self):
- """获取 Qdrant 客户端"""
- if self._client is None:
- self._client = create_qdrant_client(
- url=self.config.qdrant_url,
- api_key=self.config.qdrant_api_key,
- )
- return self._client
-
- def _get_reranker(self):
- """获取重排序器"""
- if self._reranker is None:
- self._reranker = CrossEncoderReranker(
- model_name=self.config.reranker_model,
- top_n=self.config.rerank_top_n,
- device=self.config.device,
- )
- return self._reranker
-
- def _get_query_transformer(self):
- """获取查询改写器"""
- if self._query_transformer is None and self.llm is not None:
- self._query_transformer = MultiQueryTransformer(
- llm=self.llm,
- num_queries=self.config.num_queries,
- )
- return self._query_transformer
-
- def _create_basic_retriever(self):
- """创建基础检索器(Level 1)"""
- return create_base_retriever(
- collection_name=self.config.collection_name,
+ # 初始化基础检索器
+ self.base_retriever = create_base_retriever(
+ collection_name=self.collection_name,
embeddings=self.embeddings,
- search_kwargs={"k": self.config.total_k},
- client=self._get_client(),
- )
-
- def _create_hybrid_retriever(self):
- """创建混合检索器(Level 2)"""
- base_retriever = create_hybrid_retriever(
- collection_name=self.config.collection_name,
- embeddings=self.embeddings,
- dense_k=self.config.dense_k,
- sparse_k=self.config.sparse_k,
- client=self._get_client(),
+ search_kwargs={"k": 20}, # 召回 20 条
)
- # 应用重排序
- reranker = self._get_reranker()
- return reranker.create_contextual_compression_retriever(base_retriever)
-
- def _create_fusion_retriever(self):
- """创建 RAG-Fusion 检索器(Level 3)"""
- if self.llm is None:
- raise ValueError("Level 3 (RAG-Fusion) 需要语言模型进行查询改写")
+ # 初始化重排序器
+ try:
+ self.reranker = CrossEncoderReranker(top_n=self.rerank_top_n)
+ except Exception as e:
+ print(f"警告: 无法创建重排序器,将使用基础检索。错误: {e}")
+ self.reranker = None
- # 创建基础混合检索器
- base_retriever = create_hybrid_retriever(
- collection_name=self.config.collection_name,
- embeddings=self.embeddings,
- dense_k=self.config.dense_k,
- sparse_k=self.config.sparse_k,
- client=self._get_client(),
- )
-
- # 创建 RAG-Fusion 流水线
- reranker = self._get_reranker()
- return create_rag_fusion_pipeline(
- base_retriever=base_retriever,
- llm=self.llm,
- reranker=reranker,
- num_queries=self.config.num_queries,
- )
+ # 根据 RAG 级别创建检索器
+ self.retriever = self._create_retriever()
- def _get_retriever(self):
- """根据配置级别获取检索器"""
- if self._retriever is None:
- if self.config.rag_level == RAGLevel.BASIC:
- self._retriever = self._create_basic_retriever()
- elif self.config.rag_level == RAGLevel.HYBRID:
- self._retriever = self._create_hybrid_retriever()
- elif self.config.rag_level == RAGLevel.FUSION:
- self._retriever = self._create_fusion_retriever()
- elif self.config.rag_level == RAGLevel.AGENTIC:
- # Agentic RAG 使用 Fusion 作为基础,在 tools.py 中包装
- self._retriever = self._create_fusion_retriever()
+ def _create_retriever(self):
+ """根据 RAG 级别创建检索器"""
+ if self.rag_level == RAGLevel.BASIC.value:
+ return self.base_retriever
+
+ # 基础检索 + 重排序
+ def rerank_retriever(query):
+ documents = self.base_retriever.invoke(query)
+ if self.reranker:
+ return self.reranker.compress_documents(documents, query)
else:
- raise ValueError(f"不支持的 RAG 级别: {self.config.rag_level}")
+ return documents[:self.rerank_top_n]
- return self._retriever
+ if self.rag_level == RAGLevel.RERANK.value:
+ return SimpleRetriever(rerank_retriever)
+
+ # RAG-Fusion
+ if self.rag_level == RAGLevel.FUSION.value:
+ if not self.llm:
+ raise ValueError("RAG-Fusion 需要提供 llm 参数")
+
+ # 创建多路查询检索器
+ transformer = MultiQueryTransformer(
+ llm=self.llm,
+ num_queries=self.num_queries
+ )
+ multi_query_retriever = transformer.create_multi_query_retriever(
+ base_retriever=SimpleRetriever(rerank_retriever)
+ )
+
+ return multi_query_retriever
+
+ return SimpleRetriever(rerank_retriever)
- def retrieve(
- self,
- query: str,
- use_cache: Optional[bool] = None,
- **kwargs,
- ) -> RetrievalResult:
+ def retrieve(self, query: str) -> List[Document]:
"""
执行检索
Args:
- query: 查询文本
- use_cache: 是否使用缓存
- **kwargs: 额外参数
-
+ query: 查询字符串
+
Returns:
- 检索结果
+ 相关文档列表
"""
- start_time = time.time()
-
- # 检查缓存
- if use_cache is None:
- use_cache = self.config.enable_cache
-
- cache_key = f"{query}:{self.config.rag_level.value}"
- if use_cache and cache_key in self._cache:
- if self.config.verbose:
- print(f"使用缓存结果: {query}")
- return self._cache[cache_key]
-
- # 获取检索器并执行检索
- retriever = self._get_retriever()
- documents = retriever.invoke(query, **kwargs)
-
- # 计算查询时间
- query_time = time.time() - start_time
-
- # 创建结果
- result = RetrievalResult(
- documents=documents,
- query_time=query_time,
- level=self.config.rag_level,
- metadata={
- "query": query,
- "collection": self.config.collection_name,
- "doc_count": len(documents),
- },
- )
-
- # 缓存结果
- if use_cache:
- self._cache[cache_key] = result
-
- if self.config.verbose:
- print(f"检索完成: {len(documents)} 文档, 耗时: {query_time:.2f}s")
-
- return result
+ return self.retriever.invoke(query)
- def format_context(
- self,
- documents: List[Document],
- max_length: Optional[int] = None,
- ) -> str:
+ async def aretrieve(self, query: str) -> List[Document]:
"""
- 格式化检索到的文档为上下文文本
+ 异步执行检索
+
+ Args:
+ query: 查询字符串
+
+ Returns:
+ 相关文档列表
+ """
+ return await self.retriever.ainvoke(query)
+
+ def format_context(self, documents: List[Document]) -> str:
+ """
+ 格式化上下文
Args:
documents: 文档列表
- max_length: 最大长度(字符数)
-
+
Returns:
- 格式化后的上下文文本
+ 格式化后的上下文字符串
"""
+ if not documents:
+ return ""
+
context_parts = []
- total_length = 0
+ for i, doc in enumerate(documents, 1):
+ content = doc.page_content
+ metadata = doc.metadata or {}
+ source = metadata.get("source", "未知来源")
+
+ part = f"【资料 {i}】\n"
+ part += f"来源: {source}\n"
+ part += f"内容: {content}\n"
+ part += "---\n"
+ context_parts.append(part)
- for i, doc in enumerate(documents):
- # 提取内容和元数据
- content = doc.page_content.strip()
- metadata = doc.metadata
-
- # 格式化文档
- doc_text = f"[文档 {i+1}]\n"
- if metadata.get("source"):
- doc_text += f"来源: {metadata['source']}\n"
- if metadata.get("page"):
- doc_text += f"页码: {metadata['page']}\n"
-
- doc_text += f"内容: {content}\n\n"
-
- # 检查长度限制
- if max_length is not None:
- if total_length + len(doc_text) > max_length:
- # 如果添加这个文档会超限,则截断并添加说明
- remaining = max_length - total_length
- if remaining > 100: # 至少保留100字符
- doc_text = doc_text[:remaining] + "...\n\n[内容已截断]"
- context_parts.append(doc_text)
- break
- else:
- break
-
- context_parts.append(doc_text)
- total_length += len(doc_text)
-
- return "".join(context_parts).strip()
+ return "".join(context_parts)
+
+
+class SimpleRetriever:
+ """简单检索器包装类"""
- def clear_cache(self):
- """清空缓存"""
- self._cache.clear()
+ def __init__(self, retrieve_func):
+ self.retrieve_func = retrieve_func
- @classmethod
- def create_from_config(
- cls,
- embeddings: Embeddings,
- llm: Optional[BaseLanguageModel] = None,
- config_dict: Optional[Dict[str, Any]] = None,
- ) -> "RAGPipeline":
- """
- 从配置字典创建流水线
-
- Args:
- embeddings: 嵌入模型
- llm: 语言模型
- config_dict: 配置字典
-
- Returns:
- RAGPipeline 实例
- """
- config_dict = config_dict or {}
-
- # 创建配置对象
- config = RAGConfig(
- collection_name=config_dict.get("collection_name", "documents"),
- qdrant_url=config_dict.get("qdrant_url"),
- qdrant_api_key=config_dict.get("qdrant_api_key"),
- rag_level=RAGLevel(config_dict.get("rag_level", RAGLevel.FUSION.value)),
- dense_k=config_dict.get("dense_k", 10),
- sparse_k=config_dict.get("sparse_k", 10),
- total_k=config_dict.get("total_k", 20),
- rerank_top_n=config_dict.get("rerank_top_n", 5),
- num_queries=config_dict.get("num_queries", 3),
- reranker_model=config_dict.get("reranker_model", "BAAI/bge-reranker-base"),
- device=config_dict.get("device"),
- enable_cache=config_dict.get("enable_cache", True),
- verbose=config_dict.get("verbose", True),
- )
-
- return cls(embeddings=embeddings, llm=llm, config=config)
\ No newline at end of file
+ def invoke(self, query):
+ return self.retrieve_func(query)
+
+ async def ainvoke(self, query):
+ return self.retrieve_func(query)
diff --git a/app/rag/query_transform.py b/app/rag/query_transform.py
index 652e6f4..5183f9e 100644
--- a/app/rag/query_transform.py
+++ b/app/rag/query_transform.py
@@ -1,193 +1,62 @@
"""
-查询改写器
+查询转换器模块
-基于 MultiQueryRetriever 实现多路查询改写,扩大搜索范围。
+实现多路查询改写功能,用于 RAG-Fusion。
"""
-from typing import List, Optional, Any
-from langchain.retrievers.multi_query import MultiQueryRetriever
+from typing import List, Optional
from langchain_core.language_models import BaseLanguageModel
+# from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain_core.prompts import PromptTemplate
-from langchain_core.output_parsers import StrOutputParser
class MultiQueryTransformer:
- """
- 多路查询改写器
+ """多路查询改写器,用于 RAG-Fusion。"""
- 将单个查询改写成多个相关查询,用于 RAG-Fusion。
- """
-
- def __init__(
- self,
- llm: BaseLanguageModel,
- num_queries: int = 3,
- prompt_template: Optional[str] = None,
- ):
+ def __init__(self, llm: BaseLanguageModel, num_queries: int = 3):
"""
- 初始化查询改写器
+ 初始化多路查询改写器。
Args:
llm: 语言模型实例
num_queries: 生成的查询数量
- prompt_template: 提示词模板
"""
self.llm = llm
self.num_queries = num_queries
-
- # 默认提示词模板
- self.prompt_template = prompt_template or """
- 你是一个专业的查询改写助手。你的任务是将用户的问题改写成 {num_queries} 个不同的版本。
- 这些版本应该从不同的角度、使用不同的关键词来表达相同或相关的意图。
-
- 原始问题: {question}
-
- 请生成 {num_queries} 个不同版本的查询,每个版本一行。
- 确保每个版本都是独立、完整的查询语句。
-
- 生成 {num_queries} 个查询:
- """
- def transform_query(self, query: str) -> List[str]:
+ def create_multi_query_retriever(self, base_retriever):
"""
- 将单个查询改写成多个查询
-
- Args:
- query: 原始查询
-
- Returns:
- 改写后的查询列表
- """
- prompt = PromptTemplate.from_template(self.prompt_template)
-
- chain = prompt | self.llm | StrOutputParser()
-
- response = chain.invoke({
- "question": query,
- "num_queries": self.num_queries,
- })
-
- # 解析响应,每行一个查询
- queries = [
- q.strip()
- for q in response.strip().split('\n')
- if q.strip()
- ]
-
- # 确保数量正确,如果不够则添加原始查询
- if len(queries) < self.num_queries:
- queries.extend([query] * (self.num_queries - len(queries)))
- elif len(queries) > self.num_queries:
- queries = queries[:self.num_queries]
-
- # 确保包含原始查询
- if query not in queries:
- queries = [query] + queries[:self.num_queries-1]
-
- return queries
-
- def create_multi_query_retriever(
- self,
- base_retriever: Any,
- include_original: bool = True,
- ) -> MultiQueryRetriever:
- """
- 创建多路查询检索器
+ 创建多路查询检索器。
Args:
base_retriever: 基础检索器
- include_original: 是否包含原始查询
-
+
Returns:
MultiQueryRetriever 实例
"""
- retriever = MultiQueryRetriever.from_llm(
- retriever=base_retriever,
- llm=self.llm,
- include_original=include_original,
- )
-
- # 设置生成的查询数量
- retriever.llm_chain.prompt = PromptTemplate.from_template(
- "你是一个专业的查询改写助手。你的任务是将用户的问题改写成 {num_queries} 个不同的版本。\n"
- "这些版本应该从不同的角度、使用不同的关键词来表达相同或相关的意图。\n\n"
- "原始问题: {question}\n\n"
- "请生成 {num_queries} 个不同版本的查询,每个版本一行。\n"
- "确保每个版本都是独立、完整的查询语句。\n\n"
- "生成 {num_queries} 个查询:"
- )
-
- # 修改调用参数以包含 num_queries
- original_invoke = retriever.llm_chain.invoke
-
- def new_invoke(input_dict):
- input_dict["num_queries"] = self.num_queries
- return original_invoke(input_dict)
-
- retriever.llm_chain.invoke = new_invoke
-
- return retriever
-
- @classmethod
- def create_from_config(
- cls,
- llm: BaseLanguageModel,
- config: Optional[dict] = None,
- ) -> "MultiQueryTransformer":
- """
- 从配置创建查询改写器
-
- Args:
- llm: 语言模型实例
- config: 配置字典
-
- Returns:
- MultiQueryTransformer 实例
- """
- config = config or {}
- return cls(
- llm=llm,
- num_queries=config.get("num_queries", 3),
- prompt_template=config.get("prompt_template", None),
- )
-
-
-def create_rag_fusion_pipeline(
- base_retriever: Any,
- llm: BaseLanguageModel,
- reranker: Optional[Any] = None,
- num_queries: int = 3,
-) -> Any:
- """
- 创建完整的 RAG-Fusion 流水线
-
- Args:
- base_retriever: 基础检索器
- llm: 语言模型(用于查询改写)
- reranker: 重排序器(可选)
- num_queries: 查询改写数量
-
- Returns:
- 检索器实例
- """
- # 创建多路查询改写器
- query_transformer = MultiQueryTransformer(
- llm=llm,
- num_queries=num_queries,
- )
-
- # 创建多路查询检索器
- multi_query_retriever = query_transformer.create_multi_query_retriever(
- base_retriever=base_retriever,
- include_original=True,
- )
-
- # 如果提供了重排序器,则应用重排序
- if reranker is not None:
- from langchain.retrievers import ContextualCompressionRetriever
- return ContextualCompressionRetriever(
- base_compressor=reranker,
- base_retriever=multi_query_retriever,
- )
-
- return multi_query_retriever
\ No newline at end of file
+ # 由于当前 LangChain 版本不支持 MultiQueryRetriever,暂时返回基础检索器
+ # retriever = MultiQueryRetriever.from_llm(
+ # retriever=base_retriever,
+ # llm=self.llm,
+ # include_original=True
+ # )
+ #
+ # # 自定义提示词
+ # retriever.llm_chain.prompt = PromptTemplate.from_template(
+ # "你是一个专业的查询改写助手。你的任务是将用户的问题改写成 {num_queries} 个不同的版本。\n"
+ # "这些版本应该从不同的角度、使用不同的关键词来表达相同或相关的意图。\n\n"
+ # "原始问题: {question}\n\n"
+ # "请生成 {num_queries} 个不同版本的查询,每个版本一行。\n"
+ # "确保每个版本都是独立、完整的查询语句。\n\n"
+ # "生成 {num_queries} 个查询:"
+ # )
+ #
+ # # 修改调用参数以包含 num_queries
+ # original_ainvoke = retriever.llm_chain.ainvoke
+ # async def new_ainvoke(input_dict):
+ # input_dict["num_queries"] = self.num_queries
+ # return await original_ainvoke(input_dict)
+ # retriever.llm_chain.ainvoke = new_ainvoke
+ #
+ # return retriever
+ return base_retriever
diff --git a/app/rag/reranker.py b/app/rag/reranker.py
index 4c7229f..4a414cf 100644
--- a/app/rag/reranker.py
+++ b/app/rag/reranker.py
@@ -1,141 +1,65 @@
"""
-Cross-Encoder 重排序器
+重排序器模块
-使用 sentence-transformers 加载交叉编码器模型,对检索结果进行精排。
+使用 Cross-Encoder 模型对检索结果进行重排序,提高检索精度。
"""
-import os
-from typing import List, Dict, Any, Optional
-from langchain.retrievers.document_compressors import CrossEncoderReranker
+from typing import List
from langchain_core.documents import Document
-from sentence_transformers import CrossEncoder
class CrossEncoderReranker:
- """
- Cross-Encoder 重排序器包装类
+ """使用 Cross-Encoder 对检索结果重排序。"""
- 支持 BAAI/bge-reranker-base 等模型。
- """
-
- def __init__(
- self,
- model_name: str = "BAAI/bge-reranker-base",
- top_n: int = 5,
- device: Optional[str] = None,
- cache_folder: Optional[str] = None,
- ):
+ def __init__(self, model_name: str = "BAAI/bge-reranker-base", top_n: int = 5):
"""
初始化重排序器
Args:
- model_name: 模型名称或路径
- top_n: 返回的顶部文档数量
- device: 设备(cpu/cuda),如果为 None 则自动选择
- cache_folder: 模型缓存目录
+ model_name: 预训练模型名称
+ top_n: 返回前 N 个结果
"""
self.model_name = model_name
self.top_n = top_n
- self.device = device
- self.cache_folder = cache_folder or os.path.join(
- os.path.expanduser("~"), ".cache", "sentence_transformers"
- )
+ self.model = None
- # 延迟加载模型
- self._model = None
- self._langchain_reranker = None
+ # 尝试加载 Cross-Encoder 模型
+ try:
+ from sentence_transformers import CrossEncoder
+ self.model = CrossEncoder(model_name)
+ except Exception as e:
+ print(f"警告: 无法加载 Cross-Encoder 模型 {model_name},将使用简单排序作为回退方案。错误: {e}")
- def _load_model(self):
- """加载交叉编码器模型"""
- if self._model is None:
- try:
- self._model = CrossEncoder(
- self.model_name,
- device=self.device,
- cache_folder=self.cache_folder,
- )
- except Exception as e:
- # 如果指定模型加载失败,尝试备用模型
- print(f"加载模型 {self.model_name} 失败: {e}")
- print("尝试加载备用模型 BAAI/bge-reranker-v2-m3...")
- self._model = CrossEncoder(
- "BAAI/bge-reranker-v2-m3",
- device=self.device,
- cache_folder=self.cache_folder,
- )
-
- def _create_langchain_reranker(self):
- """创建 LangChain 重排序器"""
- if self._langchain_reranker is None:
- self._load_model()
- self._langchain_reranker = CrossEncoderReranker(
- model=self._model,
- top_n=self.top_n,
- )
-
- def rerank(
- self,
- query: str,
- documents: List[Document],
+ def compress_documents(
+ self, documents: List[Document], query: str
) -> List[Document]:
"""
对文档进行重排序
Args:
- query: 查询文本
- documents: 待排序文档列表
-
+ documents: 待排序的文档列表
+ query: 查询字符串
+
Returns:
- 重排序后的文档列表
+ 排序后的文档列表
"""
- self._create_langchain_reranker()
- return self._langchain_reranker.compress_documents(
- documents=documents,
- query=query,
- )
-
- def create_contextual_compression_retriever(
- self,
- base_retriever: Any,
- ) -> Any:
- """
- 创建上下文压缩检索器
+ if not documents:
+ return []
- Args:
- base_retriever: 基础检索器
+ # 如果模型加载失败,返回前 top_n 个文档
+ if self.model is None:
+ return documents[:self.top_n]
+
+ # 使用 Cross-Encoder 进行重排序
+ try:
+ pairs = [[query, doc.page_content] for doc in documents]
+ scores = self.model.predict(pairs)
- Returns:
- 上下文压缩检索器
- """
- from langchain.retrievers import ContextualCompressionRetriever
-
- self._create_langchain_reranker()
-
- compression_retriever = ContextualCompressionRetriever(
- base_compressor=self._langchain_reranker,
- base_retriever=base_retriever,
- )
-
- return compression_retriever
-
- @classmethod
- def create_from_config(
- cls,
- config: Optional[Dict[str, Any]] = None,
- ) -> "CrossEncoderReranker":
- """
- 从配置创建重排序器
-
- Args:
- config: 配置字典,包含 model_name, top_n, device 等
-
- Returns:
- CrossEncoderReranker 实例
- """
- config = config or {}
- return cls(
- model_name=config.get("model_name", "BAAI/bge-reranker-base"),
- top_n=config.get("top_n", 5),
- device=config.get("device", None),
- cache_folder=config.get("cache_folder", None),
- )
\ No newline at end of file
+ # 按分数降序排序
+ scored_docs = sorted(
+ zip(documents, scores), key=lambda x: x[1], reverse=True
+ )
+ return [doc for doc, _ in scored_docs[:self.top_n]]
+ except Exception as e:
+ print(f"警告: 重排序过程出错,将使用原始排序。错误: {e}")
+ return documents[:self.top_n]
diff --git a/app/rag/retriever.py b/app/rag/retriever.py
index 19a7511..80d6284 100644
--- a/app/rag/retriever.py
+++ b/app/rag/retriever.py
@@ -4,15 +4,12 @@ Qdrant 向量检索器
提供基础向量检索、混合检索(Dense + BM25)功能。
"""
-import os
from typing import List, Dict, Any, Optional
-from langchain_qdrant import Qdrant
+from langchain_qdrant import QdrantVectorStore
from langchain.embeddings.base import Embeddings
-from langchain.retrievers import ContextualCompressionRetriever
-from langchain.retrievers.document_compressors import DocumentCompressorPipeline
-from langchain.retrievers import EnsembleRetriever
+# from langchain.retrievers import EnsembleRetriever
from qdrant_client import QdrantClient
-from qdrant_client.http import models
+from rag_core import QDRANT_URL, QDRANT_API_KEY
def create_qdrant_client(
@@ -21,21 +18,21 @@ def create_qdrant_client(
) -> QdrantClient:
"""
创建 Qdrant 客户端
-
+
Args:
url: Qdrant 服务地址,默认从环境变量 QDRANT_URL 读取
api_key: API 密钥,默认从环境变量 QDRANT_API_KEY 读取
-
+
Returns:
QdrantClient 实例
"""
- url = url or os.getenv("QDRANT_URL", "http://localhost:6333")
- api_key = api_key or os.getenv("QDRANT_API_KEY")
-
+ url = url or QDRANT_URL
+ api_key = api_key or QDRANT_API_KEY
+
client_args = {"url": url}
if api_key:
client_args["api_key"] = api_key
-
+
return QdrantClient(**client_args)
@@ -44,34 +41,33 @@ def create_base_retriever(
embeddings: Embeddings,
search_kwargs: Optional[Dict[str, Any]] = None,
client: Optional[QdrantClient] = None,
-) -> Qdrant:
+) -> QdrantVectorStore:
"""
创建基础向量检索器
-
+
Args:
collection_name: Qdrant 集合名称
embeddings: 嵌入模型
search_kwargs: 搜索参数,默认 {"k": 20}
client: Qdrant 客户端,如果为 None 则自动创建
-
+
Returns:
- Qdrant 检索器实例
+ QdrantVectorStore 检索器实例
"""
+ search_kwargs = search_kwargs or {"k": 20}
+
+ # 创建 Qdrant 客户端
if client is None:
client = create_qdrant_client()
-
- search_kwargs = search_kwargs or {"k": 20}
-
- # 创建 Qdrant 检索器
- retriever = Qdrant.from_existing_collection(
- embedding=embeddings,
- collection_name=collection_name,
+
+ # 使用 QdrantVectorStore 创建向量存储
+ vector_store = QdrantVectorStore(
client=client,
- content_payload_key="content", # 假设存储的文本字段名为 "content"
- metadata_payload_key="metadata", # 元数据字段名
+ collection_name=collection_name,
+ embedding=embeddings,
)
-
- return retriever.as_retriever(search_kwargs=search_kwargs)
+
+ return vector_store.as_retriever(search_kwargs=search_kwargs)
def create_hybrid_retriever(
@@ -80,65 +76,63 @@ def create_hybrid_retriever(
dense_k: int = 10,
sparse_k: int = 10,
client: Optional[QdrantClient] = None,
-) -> ContextualCompressionRetriever:
+) -> QdrantVectorStore:
"""
创建混合检索器(Dense Vector + BM25)
-
+
Args:
collection_name: Qdrant 集合名称
embeddings: 嵌入模型
dense_k: 向量检索返回数量
sparse_k: BM25 检索返回数量
client: Qdrant 客户端
-
+
Returns:
混合检索器
"""
+ # 创建 Qdrant 客户端
if client is None:
client = create_qdrant_client()
-
- # 基础检索器(Qdrant 支持混合检索)
- base_retriever = Qdrant.from_existing_collection(
- embedding=embeddings,
- collection_name=collection_name,
+
+ # 使用 QdrantVectorStore 创建向量存储
+ vector_store = QdrantVectorStore(
client=client,
- content_payload_key="content",
- metadata_payload_key="metadata",
+ collection_name=collection_name,
+ embedding=embeddings,
)
-
- # 配置混合检索参数
+
search_kwargs = {
- "k": dense_k + sparse_k, # 总返回数量
- "score_threshold": 0.3, # 相似度阈值
+ "k": dense_k + sparse_k,
+ "score_threshold": 0.3,
}
-
- return base_retriever.as_retriever(search_kwargs=search_kwargs)
+
+ return vector_store.as_retriever(search_kwargs=search_kwargs)
-def create_ensemble_retriever(
- retrievers: List[Any],
- weights: Optional[List[float]] = None,
- c: int = 60,
-) -> EnsembleRetriever:
- """
- 创建集成检索器,支持倒数排名融合 (RRF)
-
- Args:
- retrievers: 检索器列表
- weights: 检索器权重
- c: RRF 常数(通常为60)
-
- Returns:
- 集成检索器
- """
- if weights is None:
- weights = [1.0 / len(retrievers)] * len(retrievers)
-
- ensemble = EnsembleRetriever(
- retrievers=retrievers,
- weights=weights,
- c=c,
- search_type="rrf", # 使用倒数排名融合
- )
-
- return ensemble
\ No newline at end of file
+# def create_ensemble_retriever(
+# retrievers: List[Any],
+# weights: Optional[List[float]] = None,
+# c: int = 60,
+# ) -> EnsembleRetriever:
+# """
+# 创建集成检索器,支持倒数排名融合 (RRF)
+#
+# Args:
+# retrievers: 检索器列表
+# weights: 检索器权重
+# c: RRF 常数(通常为60)
+#
+# Returns:
+# 集成检索器
+# """
+# if weights is None:
+# weights = [1.0 / len(retrievers)] * len(retrievers)
+#
+# ensemble = EnsembleRetriever(
+# retrievers=retrievers,
+# weights=weights,
+# c=c,
+# search_type="rrf",
+# )
+#
+# return ensemble
diff --git a/app/rag/tools.py b/app/rag/tools.py
index 8dcf90f..a284a11 100644
--- a/app/rag/tools.py
+++ b/app/rag/tools.py
@@ -1,230 +1,89 @@
"""
-RAG 工具包装
+RAG 工具模块
-将 RAG 流水线包装成 LangChain Tool,供 Agent 调用。
+将检索功能封装为 LangChain Tool,供 Agent 调用。
"""
-from typing import Optional, Dict, Any
-from langchain.tools import tool
-from langchain_core.tools import Tool
-from langchain_core.embeddings import Embeddings
-from langchain_core.language_models import BaseLanguageModel
-
-from .pipeline import RAGPipeline, RAGConfig, RAGLevel
+from langchain_core.tools import tool
+from rag_core import LlamaCppEmbedder, QDRANT_URL, QDRANT_API_KEY
+from .pipeline import RAGPipeline, RAGLevel
-class RAGTool:
- """
- RAG 工具包装器
+@tool
+async def search_knowledge_base(query: str, rag_level: str = "rerank") -> str:
+ """在知识库中搜索与查询相关的文档片段。
- 将 RAG 流水线包装成 Agent 可调用的工具。
- """
-
- def __init__(
- self,
- pipeline: RAGPipeline,
- tool_name: str = "search_knowledge_base",
- tool_description: str = None,
- ):
- """
- 初始化 RAG 工具
-
- Args:
- pipeline: RAG 流水线实例
- tool_name: 工具名称
- tool_description: 工具描述
- """
- self.pipeline = pipeline
- self.tool_name = tool_name
-
- # 默认工具描述
- self.tool_description = tool_description or (
- "在知识库中搜索相关信息。当用户询问需要查阅文档、规章制度、"
- "专业知识或需要基于已知信息回答的问题时使用此工具。"
- "输入应为要搜索的查询文本。"
- )
-
- # 创建 LangChain 工具
- self._tool = self._create_tool()
-
- def _create_tool(self) -> Tool:
- """创建 LangChain 工具"""
-
- @tool(self.tool_name, args_schema=None)
- def search_knowledge_base(query: str) -> str:
- """
- 在知识库中搜索相关信息
-
- Args:
- query: 搜索查询
-
- Returns:
- 格式化后的搜索结果
- """
- try:
- # 执行检索
- result = self.pipeline.retrieve(query)
-
- if not result.documents:
- return "在知识库中未找到相关信息。"
-
- # 格式化上下文
- context = self.pipeline.format_context(
- result.documents,
- max_length=4000, # 限制上下文长度
- )
-
- # 构建响应
- response = (
- f"🔍 在知识库中找到了 {len(result.documents)} 条相关信息:\n\n"
- f"{context}\n\n"
- f"⏱️ 检索耗时: {result.query_time:.2f}秒"
- )
-
- return response
-
- except Exception as e:
- error_msg = f"检索过程中发生错误: {str(e)}"
- if self.pipeline.config.verbose:
- print(f"RAG 工具错误: {error_msg}")
- return error_msg
-
- # 设置工具描述
- search_knowledge_base.description = self.tool_description
-
- return search_knowledge_base
-
- def get_tool(self) -> Tool:
- """获取 LangChain 工具"""
- return self._tool
-
- def __call__(self, query: str) -> str:
- """直接调用工具"""
- return self._tool.invoke({"query": query})
-
-
-def create_rag_tool(
- embeddings: Embeddings,
- llm: Optional[BaseLanguageModel] = None,
- config: Optional[Dict[str, Any]] = None,
- tool_name: str = "search_knowledge_base",
- tool_description: Optional[str] = None,
-) -> Tool:
- """
- 创建 RAG 工具(便捷函数)
+ 适用于事实性问题、背景知识查询。
Args:
- embeddings: 嵌入模型
- llm: 语言模型(用于高级 RAG 功能)
- config: RAG 配置字典
- tool_name: 工具名称
- tool_description: 工具描述
-
+ query: 查询字符串
+ rag_level: 检索级别,可选值:basic(基础向量检索)、rerank(基础检索+重排序)、fusion(RAG-Fusion)
+
Returns:
- LangChain Tool 实例
+ 检索到的相关文档内容
"""
+ # 初始化嵌入模型
+ embedder = LlamaCppEmbedder()
+ embeddings = embedder.as_langchain_embeddings()
+
# 创建 RAG 流水线
- pipeline = RAGPipeline.create_from_config(
+ pipeline = RAGPipeline(
embeddings=embeddings,
- llm=llm,
- config_dict=config,
+ config={
+ "rag_level": rag_level,
+ "collection_name": "rag_documents",
+ "rerank_top_n": 5,
+ }
)
- # 创建工具包装器
- rag_tool = RAGTool(
- pipeline=pipeline,
- tool_name=tool_name,
- tool_description=tool_description,
- )
+ # 执行检索
+ try:
+ documents = await pipeline.aretrieve(query)
+ if not documents:
+ return "未找到相关信息。"
+
+ # 格式化结果
+ context = pipeline.format_context(documents)
+ return context
+ except Exception as e:
+ return f"检索过程中发生错误: {str(e)}"
+
+
+@tool
+def search_knowledge_base_sync(query: str, rag_level: str = "rerank") -> str:
+ """同步版本的知识库搜索工具。
- return rag_tool.get_tool()
-
-
-# 导出便捷函数
-search_knowledge_base_tool = create_rag_tool
-
-
-def bind_rag_to_agent(
- agent_llm: BaseLanguageModel,
- embeddings: Embeddings,
- rag_llm: Optional[BaseLanguageModel] = None,
- config: Optional[Dict[str, Any]] = None,
- tool_name: str = "search_knowledge_base",
-) -> BaseLanguageModel:
- """
- 将 RAG 工具绑定到 Agent 模型
+ 适用于事实性问题、背景知识查询。
Args:
- agent_llm: Agent 使用的语言模型
- embeddings: 嵌入模型
- rag_llm: RAG 流水线使用的语言模型(如果与 agent_llm 不同)
- config: RAG 配置
- tool_name: 工具名称
-
- Returns:
- 绑定工具后的模型
- """
- # 如果未指定 RAG LLM,使用 Agent LLM
- if rag_llm is None:
- rag_llm = agent_llm
+ query: 查询字符串
+ rag_level: 检索级别,可选值:basic(基础向量检索)、rerank(基础检索+重排序)、fusion(RAG-Fusion)
- # 创建 RAG 工具
- rag_tool = create_rag_tool(
+ Returns:
+ 检索到的相关文档内容
+ """
+ # 初始化嵌入模型
+ embedder = LlamaCppEmbedder()
+ embeddings = embedder.as_langchain_embeddings()
+
+ # 创建 RAG 流水线
+ pipeline = RAGPipeline(
embeddings=embeddings,
- llm=rag_llm,
- config=config,
- tool_name=tool_name,
+ config={
+ "rag_level": rag_level,
+ "collection_name": "rag_documents",
+ "rerank_top_n": 5,
+ }
)
- # 绑定工具到模型
- return agent_llm.bind_tools([rag_tool])
-
-
-def create_agentic_rag_pipeline(
- embeddings: Embeddings,
- agent_llm: BaseLanguageModel,
- rag_llm: Optional[BaseLanguageModel] = None,
- config: Optional[Dict[str, Any]] = None,
-) -> Dict[str, Any]:
- """
- 创建完整的 Agentic RAG 流水线(Level 4)
-
- Args:
- embeddings: 嵌入模型
- agent_llm: Agent 模型
- rag_llm: RAG 专用模型
- config: 配置
+ # 执行检索
+ try:
+ documents = pipeline.retrieve(query)
+ if not documents:
+ return "未找到相关信息。"
- Returns:
- 包含模型和工具的字典
- """
- # 配置 Agentic RAG 级别
- if config is None:
- config = {}
- config["rag_level"] = RAGLevel.AGENTIC.value
-
- # 创建 RAG 工具
- rag_tool = create_rag_tool(
- embeddings=embeddings,
- llm=rag_llm or agent_llm,
- config=config,
- tool_name="search_knowledge_base",
- tool_description=(
- "在知识库中搜索相关信息。当用户询问需要查阅文档、规章制度、"
- "专业知识或需要基于已知信息回答的问题时使用此工具。"
- "Agent 应该先判断是否需要使用此工具,然后调用它获取上下文。"
- ),
- )
-
- # 绑定工具到模型
- bound_llm = agent_llm.bind_tools([rag_tool])
-
- return {
- "llm": bound_llm,
- "tool": rag_tool,
- "pipeline": RAGPipeline.create_from_config(
- embeddings=embeddings,
- llm=rag_llm or agent_llm,
- config_dict=config,
- ),
- }
\ No newline at end of file
+ # 格式化结果
+ context = pipeline.format_context(documents)
+ return context
+ except Exception as e:
+ return f"检索过程中发生错误: {str(e)}"
diff --git a/rag_core/__init__.py b/rag_core/__init__.py
new file mode 100644
index 0000000..c6aa7a6
--- /dev/null
+++ b/rag_core/__init__.py
@@ -0,0 +1,18 @@
+"""
+RAG Core - 公共 RAG 组件包
+
+提供嵌入模型、向量存储和文档存储的公共功能,被 rag_indexer 和 app/rag 共用。
+"""
+
+from .embedders import LlamaCppEmbedder
+from .vector_store import QdrantVectorStore, QDRANT_URL, QDRANT_API_KEY
+from .store import PostgresDocStore, create_docstore
+
+__all__ = [
+ "LlamaCppEmbedder",
+ "QdrantVectorStore",
+ "QDRANT_URL",
+ "QDRANT_API_KEY",
+ "PostgresDocStore",
+ "create_docstore",
+]
diff --git a/rag_indexer/embedders.py b/rag_core/embedders.py
similarity index 94%
rename from rag_indexer/embedders.py
rename to rag_core/embedders.py
index 3f1be8a..e9a87a3 100644
--- a/rag_indexer/embedders.py
+++ b/rag_core/embedders.py
@@ -64,12 +64,9 @@ class LlamaCppEmbedder:
response.raise_for_status()
data = response.json()
- # 处理不同响应格式
if isinstance(data, list):
- # llama.cpp 直接返回列表
return [item["embedding"] for item in data]
elif isinstance(data, dict) and "data" in data:
- # OpenAI 标准格式
return [item["embedding"] for item in sorted(data["data"], key=lambda x: x["index"])]
else:
raise ValueError(f"未知的嵌入 API 响应格式: {data}")
@@ -85,4 +82,4 @@ class _LlamaCppLangchainAdapter(Embeddings):
return self._embedder.embed_documents(texts)
def embed_query(self, text: str) -> List[float]:
- return self._embedder.embed_query(text)
+ return self._embedder.embed_query(text)
\ No newline at end of file
diff --git a/rag_indexer/store/__init__.py b/rag_core/store/__init__.py
similarity index 92%
rename from rag_indexer/store/__init__.py
rename to rag_core/store/__init__.py
index a1e561e..359db76 100644
--- a/rag_indexer/store/__init__.py
+++ b/rag_core/store/__init__.py
@@ -5,7 +5,7 @@
- PostgresDocStore: PostgreSQL 数据库存储(生产环境)
示例用法:
- >>> from rag_indexer.store import create_docstore
+ >>> from rag_core.store import create_docstore
>>> # 创建 PostgreSQL 存储
>>> store, conn = create_docstore(
diff --git a/rag_indexer/store/factory.py b/rag_core/store/factory.py
similarity index 98%
rename from rag_indexer/store/factory.py
rename to rag_core/store/factory.py
index 2388f8f..c32c2c7 100644
--- a/rag_indexer/store/factory.py
+++ b/rag_core/store/factory.py
@@ -70,4 +70,4 @@ def create_docstore(
return store, conn_str
else:
- raise ValueError(f"不支持的存储类型: {store_type}。目前仅支持: postgres")
\ No newline at end of file
+ raise ValueError(f"不支持的存储类型: {store_type}。目前仅支持: postgres")
diff --git a/rag_indexer/store/postgres.py b/rag_core/store/postgres.py
similarity index 99%
rename from rag_indexer/store/postgres.py
rename to rag_core/store/postgres.py
index 69ef4e3..5132355 100644
--- a/rag_indexer/store/postgres.py
+++ b/rag_core/store/postgres.py
@@ -246,4 +246,4 @@ class PostgresDocStore(BaseStore[str, Any]):
注意:在异步环境中,请使用 aclose 方法。
"""
- pass
\ No newline at end of file
+ pass
diff --git a/rag_indexer/vector_store.py b/rag_core/vector_store.py
similarity index 96%
rename from rag_indexer/vector_store.py
rename to rag_core/vector_store.py
index 4f2e5c6..5faa66f 100644
--- a/rag_indexer/vector_store.py
+++ b/rag_core/vector_store.py
@@ -9,11 +9,8 @@ from typing import List, Optional, Dict, Any
from langchain_core.documents import Document
from langchain_qdrant import QdrantVectorStore as LangchainQdrantVS
from qdrant_client import QdrantClient
-from qdrant_client.http import models
from qdrant_client.http.models import Distance, VectorParams
-from .embedders import LlamaCppEmbedder
-
logger = logging.getLogger(__name__)
QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
@@ -31,17 +28,15 @@ class QdrantVectorStore:
self.collection_name = collection_name
self._client: Optional[QdrantClient] = None
- # 嵌入模型
if embeddings is None:
+ from .embedders import LlamaCppEmbedder
embedder = LlamaCppEmbedder()
self.embeddings = embedder.as_langchain_embeddings()
else:
self.embeddings = embeddings
- # 先创建集合
self.create_collection()
- # LangChain 向量存储
self.vector_store = LangchainQdrantVS(
client=self.get_client(),
collection_name=self.collection_name,
@@ -68,6 +63,7 @@ class QdrantVectorStore:
def create_collection(self, vector_size: Optional[int] = None, force_recreate: bool = False):
"""创建集合,设置合适的向量维度。"""
if vector_size is None:
+ from .embedders import LlamaCppEmbedder
embedder = LlamaCppEmbedder()
vector_size = embedder.get_embedding_dimension()
diff --git a/rag_indexer/IndexBuilder.py b/rag_indexer/IndexBuilder.py
new file mode 100644
index 0000000..6f077e9
--- /dev/null
+++ b/rag_indexer/IndexBuilder.py
@@ -0,0 +1,299 @@
+"""
+离线 RAG 索引构建核心流水线。
+
+使用 LangChain 的 ParentDocumentRetriever 实现父子块策略。
+"""
+
+import asyncio
+import logging
+from dataclasses import dataclass, field
+from pathlib import Path
+from typing import List, Union, Optional, Any, Dict, Tuple
+
+from httpx import RemoteProtocolError
+from langchain_core.documents import Document
+from langchain_core.embeddings import Embeddings
+from langchain_core.stores import BaseStore
+from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
+from langchain_classic.retrievers import ParentDocumentRetriever
+
+from .loaders import DocumentLoader
+from .splitters import SplitterType, get_splitter, SemanticChunkerAdapter
+from rag_core import LlamaCppEmbedder, QdrantVectorStore, create_docstore
+
+logger = logging.getLogger(__name__)
+
+
+# ---------- 配置数据类 ----------
+@dataclass
+class DocstoreConfig:
+ """文档存储配置(用于父块存储)。"""
+ connection_string: Optional[str] = None
+ pool_config: Optional[Dict[str, Any]] = None
+ max_concurrency: Optional[int] = None
+ # 若要从外部注入已创建好的 docstore,可直接设置此字段
+ instance: Optional[BaseStore] = None
+
+
+@dataclass
+class IndexBuilderConfig:
+ """索引构建器配置。"""
+ collection_name: str = "rag_documents"
+ splitter_type: SplitterType = SplitterType.PARENT_CHILD
+
+ # 父块切分参数(仅当 splitter_type 为 PARENT_CHILD 时生效)
+ parent_chunk_size: int = 1000
+ parent_chunk_overlap: int = 100
+ # 子块切分参数
+ child_chunk_size: int = 200
+ child_chunk_overlap: int = 20
+ child_splitter_type: SplitterType = SplitterType.SEMANTIC # 子块默认语义切分
+
+ # 检索参数
+ search_k: int = 5
+
+ # 文档存储配置(仅父子块模式需要)
+ docstore: DocstoreConfig = field(default_factory=DocstoreConfig)
+
+ # 其他切分器参数(当 splitter_type 非父子块时使用)
+ extra_splitter_kwargs: Dict[str, Any] = field(default_factory=dict)
+
+
+# ---------- 索引构建器 ----------
+class IndexBuilder:
+ """RAG 索引构建主流水线,支持单块切分与父子块切分。"""
+
+ def __init__(self, config: Optional[IndexBuilderConfig] = None, **kwargs):
+ """
+ Args:
+ config: 索引构建器配置对象,优先级高于 kwargs
+ **kwargs: 可直接传入配置参数,会合并到 config 中(为方便使用保留)
+ """
+ if config is None:
+ config = IndexBuilderConfig(**kwargs)
+ elif kwargs:
+ # 合并 kwargs 到 config 的字段(仅更新已有字段)
+ for key, value in kwargs.items():
+ if hasattr(config, key):
+ setattr(config, key, value)
+
+ self.config = config
+ self._docstore_conn: Optional[str] = None # 用于记录由 create_docstore 创建的连接信息
+
+ # 初始化基础组件
+ self.loader = DocumentLoader()
+ self.embedder = LlamaCppEmbedder()
+ self.embeddings: Embeddings = self.embedder.as_langchain_embeddings()
+
+ # 初始化向量存储
+ self.vector_store = QdrantVectorStore(
+ collection_name=config.collection_name,
+ embeddings=self.embeddings,
+ )
+
+ # 根据切分类型初始化相关组件
+ self._init_splitters_and_retriever()
+
+ # ---------- 私有初始化方法 ----------
+ def _init_splitters_and_retriever(self) -> None:
+ """根据配置初始化切分器和检索器。"""
+ if self.config.splitter_type == SplitterType.PARENT_CHILD:
+ self._init_parent_child_mode()
+ else:
+ self._init_single_splitter_mode()
+
+ def _init_single_splitter_mode(self) -> None:
+ """单一切分模式(递归或语义)。"""
+ splitter_kwargs = self.config.extra_splitter_kwargs.copy()
+ if self.config.splitter_type == SplitterType.SEMANTIC:
+ splitter_kwargs["embeddings"] = self.embeddings
+ self.splitter = get_splitter(self.config.splitter_type, **splitter_kwargs)
+ self.retriever = None
+ self.docstore = None
+ logger.info("使用单一 %s 切分器", self.config.splitter_type.value)
+
+ def _init_parent_child_mode(self) -> None:
+ """父子块切分模式,初始化父块/子块切分器、文档存储和检索器。"""
+ cfg = self.config
+
+ # 父块切分器(始终使用递归切分)
+ self.parent_splitter = RecursiveCharacterTextSplitter(
+ chunk_size=cfg.parent_chunk_size,
+ chunk_overlap=cfg.parent_chunk_overlap,
+ )
+
+ # 子块切分器
+ if cfg.child_splitter_type == SplitterType.SEMANTIC:
+ self.child_splitter = get_splitter(
+ SplitterType.SEMANTIC,
+ embeddings=self.embeddings,
+ **cfg.extra_splitter_kwargs
+ )
+ logger.info("子块使用语义切分器")
+ else:
+ self.child_splitter = RecursiveCharacterTextSplitter(
+ chunk_size=cfg.child_chunk_size,
+ chunk_overlap=cfg.child_chunk_overlap,
+ )
+ logger.info("子块使用递归切分器,块大小=%d,重叠=%d",
+ cfg.child_chunk_size, cfg.child_chunk_overlap)
+
+ # 初始化文档存储(用于父块)
+ self.docstore = self._create_or_use_docstore()
+
+ # 创建检索器
+ self.retriever = ParentDocumentRetriever(
+ vectorstore=self.vector_store.get_langchain_vectorstore(),
+ docstore=self.docstore,
+ child_splitter=self.child_splitter, # type: ignore[arg-type]
+ parent_splitter=self.parent_splitter,
+ search_kwargs={"k": cfg.search_k},
+ )
+ logger.info("ParentDocumentRetriever 初始化完成,父块大小=%d", cfg.parent_chunk_size)
+
+ def _create_or_use_docstore(self) -> BaseStore:
+ """创建或获取文档存储实例。"""
+ cfg = self.config.docstore
+ if cfg.instance is not None:
+ logger.debug("使用外部注入的文档存储")
+ return cfg.instance
+
+ # 使用 create_docstore 创建 PostgreSQL 存储
+ docstore, conn_info = create_docstore(
+ connection_string=cfg.connection_string,
+ pool_config=cfg.pool_config,
+ max_concurrency=cfg.max_concurrency,
+ )
+ self._docstore_conn = conn_info
+ logger.info("文档存储已创建(PostgreSQL)")
+ return docstore
+
+ # ---------- 公共构建方法 ----------
+ async def build_from_file(self, file_path: Union[str, Path]) -> int:
+ """从单个文件构建索引。"""
+ logger.info("加载文件: %s", file_path)
+ documents = self.loader.load_file(file_path)
+ logger.info("已加载 %d 个文档", len(documents))
+ return await self._process_documents(documents)
+
+ async def build_from_directory(
+ self, directory_path: Union[str, Path], recursive: bool = True
+ ) -> int:
+ """从目录递归构建索引。"""
+ logger.info("加载目录: %s (递归=%s)", directory_path, recursive)
+ documents = self.loader.load_directory(directory_path, recursive=recursive)
+ logger.info("已从目录加载 %d 个文档", len(documents))
+ return await self._process_documents(documents)
+
+ async def _process_documents(self, documents: List[Document]) -> int:
+ """处理文档列表,分发给相应的索引逻辑。"""
+ if not documents:
+ logger.warning("没有文档需要处理")
+ return 0
+
+ if self.config.splitter_type == SplitterType.PARENT_CHILD:
+ return await self._index_with_parent_child(documents)
+ else:
+ return await self._index_with_single_splitter(documents)
+
+ async def _index_with_single_splitter(self, documents: List[Document]) -> int:
+ """单一模式:切分后直接写入向量库。"""
+ chunks = self.splitter.split_documents(documents) # type: ignore[union-attr]
+ logger.info("已切分为 %d 个块", len(chunks))
+
+ self.vector_store.create_collection()
+ self.vector_store.add_documents(chunks)
+ return len(chunks)
+
+ async def _index_with_parent_child(self, documents: List[Document]) -> int:
+ """父子模式:使用 ParentDocumentRetriever 批量添加。"""
+ self.vector_store.create_collection()
+ assert self.retriever is not None
+
+ batch_size = 10
+ total = len(documents)
+ processed = 0
+
+ for i in range(0, total, batch_size):
+ batch = documents[i:i + batch_size]
+ await self._add_batch_with_retry(batch, i // batch_size + 1)
+ processed += len(batch)
+ logger.info("批次 %d: 已处理 %d/%d", i // batch_size + 1, processed, total)
+
+ logger.info("ParentDocumentRetriever 索引完成,共处理 %d 个文档", processed)
+ return processed
+
+ async def _add_batch_with_retry(self, batch: List[Document], batch_no: int) -> None:
+ """添加批次,失败时自动重试(处理网络波动)。"""
+ max_retries = 3
+ for attempt in range(max_retries):
+ try:
+ await self.retriever.aadd_documents(batch) # type: ignore[union-attr]
+ return
+ except (RemoteProtocolError, ConnectionError, OSError) as e:
+ if attempt == max_retries - 1:
+ raise
+ logger.warning("批次 %d 连接断开,重试 (%d/%d): %s",
+ batch_no, attempt + 1, max_retries, e)
+ self.vector_store.refresh_client()
+ await asyncio.sleep(1)
+
+ # ---------- 信息获取方法 ----------
+ def get_collection_info(self) -> Any:
+ """获取向量库集合信息。"""
+ return self.vector_store.get_collection_info()
+
+ def get_child_splitter(self) -> TextSplitter:
+ """获取当前使用的子块切分器。"""
+ if self.config.splitter_type == SplitterType.PARENT_CHILD:
+ return self.child_splitter # type: ignore[return-value]
+ return self.splitter # type: ignore[return-value]
+
+ def get_parent_splitter(self) -> RecursiveCharacterTextSplitter:
+ """获取父块切分器(仅父子模式可用)。"""
+ if self.config.splitter_type != SplitterType.PARENT_CHILD:
+ raise RuntimeError("父块切分器仅在父子块模式下可用")
+ return self.parent_splitter # type: ignore[return-value]
+
+ def get_docstore(self) -> BaseStore:
+ """获取文档存储实例(仅父子模式可用)。"""
+ if self.config.splitter_type != SplitterType.PARENT_CHILD:
+ raise RuntimeError("文档存储仅在父子块模式下可用")
+ assert self.docstore is not None
+ return self.docstore
+
+ # ---------- 资源管理 ----------
+ def close(self) -> None:
+ """关闭资源(同步版本,供上下文管理器使用)。"""
+ if self.docstore is not None and hasattr(self.docstore, "aclose"):
+ try:
+ loop = asyncio.get_running_loop()
+ except RuntimeError:
+ # 无运行中的事件循环,创建临时循环
+ loop = asyncio.new_event_loop()
+ loop.run_until_complete(self.docstore.aclose()) # type: ignore[attr-defined]
+ loop.close()
+ else:
+ # 已有运行中的循环,创建任务(用户自行等待)
+ loop.create_task(self.docstore.aclose()) # type: ignore[attr-defined]
+ logger.info("IndexBuilder 资源已关闭")
+
+ async def aclose(self) -> None:
+ """异步关闭资源。"""
+ if self.docstore is not None and hasattr(self.docstore, "aclose"):
+ await self.docstore.aclose() # type: ignore[attr-defined]
+ logger.info("IndexBuilder 资源已异步关闭")
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.close()
+ return False
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
+ await self.aclose()
+ return False
\ No newline at end of file
diff --git a/rag_indexer/README.md b/rag_indexer/README.md
index cb2df4c..c4259c8 100644
--- a/rag_indexer/README.md
+++ b/rag_indexer/README.md
@@ -2,35 +2,13 @@
该模块负责 RAG 系统的阶段一:**离线索引构建**。它将外部的非结构化数据(如文档、PDF、网页等)清洗、切分并转化为向量,最终存入向量数据库中。
-## 📊 系统工作流示意图
-
-```mermaid
-graph TD
- A[原始文档集合
PDF / Word / Markdown] --> B(文档加载器 DocumentLoader)
- B --> C{文本切分策略 Splitter}
-
- C -->|基础策略| D1[固定字符长度切分
Recursive Split]
- C -->|进阶策略| D2[语义边界切分
Semantic Chunking]
- C -->|高级策略| D3[父子文档切分
Parent-Child / Auto-merging]
-
- D1 & D2 & D3 --> E[向量化 Embedder
llama.cpp: embeddinggemma]
-
- E --> F[(Qdrant 向量数据库)]
-
- subgraph "元数据管理"
- G[提取作者、日期、页码等元数据 Metadata] -.附加.-> E
- end
-```
-
----
-
## 🎯 演进路线与核心算法 (Roadmap)
### Level 1: 基础暴力切分 (Basic Recursive Splitting)
-- **核心算法**: 递归字符切分。它按照预定义的分隔符列表(如 `["\n\n", "\n", " ", ""]`)从大到小尝试切分文本,直到每块的大小满足最大长度限制。
+- **核心算法**: 递归字符切分。它按照预定义的分隔符列表(如 `["\n\n", "\n", "。", "!", "?", " ", ""]`)从大到小尝试切分文本,直到每块的大小满足最大长度限制。
- **优缺点**: 实现极简单,速度快。但非常容易将一句话拦腰截断,导致上下文语义丢失。
-- **实现指南**:
- - 从 `langchain.text_splitter` 导入 `RecursiveCharacterTextSplitter`。
+- **实现指南**:
+ - 从 `langchain_text_splitters` 导入 `RecursiveCharacterTextSplitter`。
- 实例化时设置 `chunk_size`(如 500)和 `chunk_overlap`(如 50),直接调用 `.split_documents(raw_docs)` 方法。
### Level 2: 语义动态切分 (Semantic Chunking)
@@ -38,58 +16,52 @@ graph TD
1. 将文章按标点符号按句子拆分。
2. 使用轻量级 Embedding 模型将每一句向量化。
3. 计算相邻两句之间的余弦相似度 (Cosine Similarity)。
- 4. 当相似度低于设定阈值时(说明两句话讲的不是同一件事,语义发生了转折),在此处“切断”形成一个新的块。
+ 4. 当相似度低于设定阈值时(说明两句话讲的不是同一件事,语义发生了转折),在此处"切断"形成一个新的块。
- **优缺点**: 极大程度保留了段落内语义的连贯性,对 LLM 回答非常友好。但由于在切分阶段就需要调用向量模型,耗时略长。
-- **实现指南**:
+- **实现指南**:
+ - 从 `langchain_text_splitters` 导入 `TextSplitter` 作为基类。
- 从 `langchain_experimental.text_splitter` 导入 `SemanticChunker`。
- - 实例化时需要传入你已经配置好的 Embedding 模型实例(如基于 `OpenAIEmbeddings` 封装的 llama.cpp 本地模型),并设置 `breakpoint_threshold_type="percentile"` 等阈值参数。
+ - 实现 `SemanticChunkerAdapter` 继承 `TextSplitter`,解决类型不兼容问题。
+ - 实例化时需要传入你已经配置好的 Embedding 模型实例(如基于 `LlamaCppEmbedder` 封装的本地模型)。
### Level 3: 高级父子块策略 (Parent-Child / Auto-merging)
- **核心算法**: 层次化双重存储与映射。
- - **切分机制**: 首先将文档粗切为较大的“父块 (Parent Chunk, 约 1000 词)”,随后将父块细切为较小的“子块 (Child Chunk, 约 200 词)”。
- - **存储机制**: 仅仅将**子块**的向量存入 Qdrant 用于精准计算距离;将**父块**的原始内容存在内存或 Document Store (如 KV 数据库) 中,通过 UUID 相互映射。
+ - **切分机制**: 首先将文档粗切为较大的"父块 (Parent Chunk, 约 1000 字符)",随后将父块细切为较小的"子块 (Child Chunk, 约 200 字符)"。
+ - **存储机制**: 仅仅将**子块**的向量存入 Qdrant 用于精准计算距离;将**父块**的原始内容存在 PostgreSQL DocStore 中,通过 UUID 相互映射。
- **核心思路**: 解决 RAG 领域经典的矛盾——检索时块越小越容易精确命中(去除噪声);但生成回答时,块越大越能给大模型提供充足的上下文背景。
-- **实现指南**:
- - 使用 `langchain.retrievers` 中的 `ParentDocumentRetriever` 模块。
+- **实现指南**:
+ - 使用 `langchain_classic.retrievers` 中的 `ParentDocumentRetriever` 模块。
- 在写入时,你需要同时准备一个底层的 `VectorStore` (即 Qdrant) 和一个 `BaseStore`。
- - **推荐方案**: 使用 `LocalFileStore` (默认) 或 `PostgresDocStore` 作为 docstore。
+ - **推荐方案**: 使用 `PostgresDocStore` 作为 docstore,支持持久化存储。
- 将两种不同的 `TextSplitter` 分别赋值给检索器的 `child_splitter` 和 `parent_splitter`,然后调用 `.add_documents()` 即可让系统自动完成映射。
### Level 3.1: PostgreSQL DocStore 集成
-- **核心优势**: 利用 PostgreSQL 作为持久化存储,适合生产环境。使用同步连接池,避免异步复杂度。
+- **核心优势**: 利用 PostgreSQL 作为持久化存储,适合生产环境。使用异步连接池,支持高并发。
- **实现步骤**:
- 1. **安装依赖**: `pip install psycopg2-binary`
- 2. **配置连接**: 设置 `DB_URI` 环境变量或直接在代码中指定 PostgreSQL 连接字符串
- 3. **创建 docstore**: 使用 `PostgresDocStore` 类直接创建
- 4. **注入到 IndexBuilder**: 在创建 `IndexBuilder` 时通过 `docstore` 参数注入
+ 1. **配置连接**: 设置 `DB_URI` 环境变量或通过 `docstore_conn_string` 参数指定
+ 2. **创建 docstore**: 使用 `rag_indexer.store.create_docstore()` 工厂函数
+ 3. **注入到 IndexBuilder**: 通过构造函数参数注入
- **使用示例**:
```python
- from rag_indexer.docstore_manager import PostgresDocStore
from rag_indexer.builder import IndexBuilder, SplitterType
- # 创建 PostgreSQL docstore
- docstore = PostgresDocStore(
- connection_string="postgresql://user:pass@host:5432/db",
- table_name="parent_documents"
- )
-
- # 创建 IndexBuilder 并注入 docstore
+ # 创建 IndexBuilder
builder = IndexBuilder(
collection_name="rag_documents",
splitter_type=SplitterType.PARENT_CHILD,
- docstore=docstore,
parent_chunk_size=1000,
child_chunk_size=200,
+ docstore_conn_string="postgresql://user:pass@host:5432/db",
)
```
### Level 3.2: 语义切分与父子块策略结合
- **核心优势**: 结合语义切分的连贯性和父子块策略的层次化存储优势,实现更精准的检索和更丰富的上下文。
- **实现原理**:
- - **父块切分**: 使用递归字符切分创建大块(约1000词),提供完整的上下文背景
- - **子块切分**: 使用语义动态切分创建小块(约200词),根据语义连贯性动态切分,提高检索精度
- - **存储机制**: 子块向量存入Qdrant用于精准检索,父块内容存入PostgreSQL提供完整上下文
+ - **父块切分**: 使用 `RecursiveCharacterTextSplitter` 创建大块(约1000字符),提供完整的上下文背景
+ - **子块切分**: 使用 `SemanticChunkerAdapter` 创建小块,根据语义连贯性动态切分,提高检索精度
+ - **存储机制**: 子块向量存入 Qdrant 用于精准检索,父块内容存入 PostgreSQL 提供完整上下文
- **使用示例**:
```python
from rag_indexer.builder import IndexBuilder, SplitterType
@@ -109,97 +81,55 @@ graph TD
```
- **配置参数**:
- `child_splitter_type`: 子块切分器类型,可选 `SplitterType.RECURSIVE`(默认)或 `SplitterType.SEMANTIC`
- - 当使用语义切分时,系统会自动使用已配置的Embedding模型进行句子级相似度计算
+ - 当使用语义切分时,系统会自动使用已配置的 Embedding 模型进行句子级相似度计算
-### Level 4: RAG-Fusion (多路改写与倒数排名融合)
-- **核心优势**: 通过大模型发散思维,将单一问题改写为多个相似问题,扩大搜索面,再利用数学统计算法合并结果,提高检索的全面性和准确性。
-- **实现原理**:
- 1. **多路查询改写**: 利用LLM将原始查询改写成3-5个不同表述的查询,从不同角度表达相同意图
- 2. **倒数排名融合 (RRF)**: 对每个改写查询的结果进行RRF融合,公式为 $RRF\_score(d) = \sum_{q \in Q} \frac{1}{k + rank_q(d)}$,避免单一检索结果主导
- 3. **结果去重**: 对融合后的结果进行去重,确保返回的文档唯一
-- **使用示例**:
- ```python
- from rag_indexer.builder import IndexBuilder, SplitterType
- from langchain_openai import OpenAI
-
- # 创建 IndexBuilder
- builder = IndexBuilder(
- collection_name="rag_documents",
- splitter_type=SplitterType.PARENT_CHILD,
- parent_chunk_size=1000,
- child_chunk_size=200,
- docstore_conn_string="postgresql://user:pass@host:5432/db",
- )
-
- # 创建语言模型用于查询改写
- llm = OpenAI(
- openai_api_base="http://localhost:8000/v1",
- openai_api_key="no-key-needed",
- model_name="Qwen2.5-7B-Instruct",
- temperature=0.3,
- )
-
- # 使用 RAG-Fusion 检索
- query = "如何申请项目资金?"
- results = builder.retrieve_with_fusion(
- query=query,
- llm=llm,
- num_queries=3,
- k=5,
- return_parent=True
- )
- ```
-- **配置参数**:
- - `llm`: 语言模型实例,用于查询改写
- - `num_queries`: 生成的查询数量,建议3-5个
- - `k`: 返回的文档数量
- - `return_parent`: 是否返回父块上下文
-
-### Level 5: GraphRAG 与 多模态 (Graph & Multi-modal)
+### Level 4: GraphRAG(基于图和关系的 RAG)
- **核心算法**: LLM 实体关系抽取 (NER & Relation Extraction)。
-- **核心思路**: 解决传统纯向量检索难以处理“跨文档复杂关系推理”的痛点(如:A公司的CEO是谁?他名下的B公司主要业务是什么?这种需要横跨多页 PDF 的跳跃性问题)。
-- **实现指南**:
- - 使用本地的大模型(如 `Gemma-4-E2B`)配合 `langchain_community.graphs` 模块。
- - 利用 `LLMGraphTransformer` 组件,在读取文档时,通过预设的 Prompt 强制大模型提取出实体(Node)和关系(Edge),直接写入诸如 Neo4j 这样的图数据库中,而非传统的 Qdrant 向量库。
+- **核心思路**: 解决传统纯向量检索难以处理"跨文档复杂关系推理"的痛点(如:A公司的CEO是谁?他名下的B公司主要业务是什么?这种需要横跨多页 PDF 的跳跃性问题)。
+- **实现原理**:
+ 1. **实体提取**: 利用 LLM 从文档中提取实体(如人物、组织、地点、事件等)
+ 2. **关系抽取**: 识别实体之间的关系(如"CEO of"、"founded by"、"located in"等)
+ 3. **图构建**: 将实体作为节点,关系作为边,构建知识图谱
+ 4. **混合检索**: 结合向量检索和图查询,同时利用语义相似性和结构关系
+- **技术栈**:
+ - **图数据库**: Neo4j 或 RedisGraph
+ - **LLM 工具**: `LLMGraphTransformer` 或自定义 Prompt
+ - **集成方式**: 与向量存储并行,形成混合检索系统
+- **实现指南**:
+ - 使用 `langchain_community.graphs` 模块
+ - 配置本地大模型(如 `Gemma-4-E2B`)用于实体关系抽取
+ - 构建包含实体和关系的图结构,存储到图数据库
+ - 实现混合检索逻辑,结合向量相似度和图路径分析
----
-
-## 所需依赖与安装
-
-为了支持完整的文档解析和 Qdrant 写入,需要安装以下 Python 包:
-
-```bash
-# 基础核心库
-pip install langchain langchain-core langchain-openai langchain-qdrant
-
-# 用于复杂文档解析 (PDF, Word, Excel 等)
-pip install unstructured pdf2image pdfminer.six
-
-# 用于语义分块 (可选)
-pip install langchain-experimental
-
-# 用于 PostgreSQL 存储 (可选,用于 Parent-Child 策略)
-pip install psycopg2-binary
-
-# 用于 RAG-Fusion (可选,需要语言模型)
-pip install langchain-openai
-```
+### Level 5: 多模态 RAG (Multi-modal RAG)
+- **核心算法**: 跨模态嵌入和多模态融合。
+- **核心思路**: 突破纯文本限制,支持图像、表格、音频等多种数据类型的理解和检索。
+- **实现原理**:
+ 1. **多模态嵌入**: 使用 CLIP 等模型将不同模态数据映射到统一向量空间
+ 2. **多模态索引**: 为不同类型的内容创建专用索引
+ 3. **跨模态检索**: 支持以文搜图、以图搜文等跨模态查询
+- **技术栈**:
+ - **多模态模型**: CLIP、BLIP 等
+ - **存储**: 向量数据库 + 对象存储
+ - **检索**: 混合向量检索
---
## 📂 架构与文件结构设计
-在 `rag_indexer/` 目录下,需创建以下核心文件:
-
-```text
+```
rag_indexer/
├── __init__.py
├── loaders.py # 负责调用 unstructured 解析不同类型文件
-├── splitters.py # 负责实现 Recursive、Semantic、Parent-Child 切分逻辑
+├── splitters.py # 负责实现 Recursive、Semantic 切分逻辑及适配器
├── embedders.py # 封装本地 llama.cpp 交互的 Embedding 接口
├── vector_store.py # 封装 Qdrant 写入、Upsert、Collection 初始化操作
-├── docstore_manager.py # 文档存储管理器,支持 LocalFileStore 和 PostgreSQL
-└── builder.py # 核心编排文件,将上述模块串联成 Pipeline
+├── builder.py # 核心编排文件,将上述模块串联成 Pipeline
+├── cli.py # 命令行入口
+└── store/
+ ├── __init__.py
+ ├── factory.py # docstore 工厂函数
+ └── postgres.py # PostgreSQL DocStore 实现
```
---
@@ -211,36 +141,36 @@ rag_indexer/
```
┌─────────────────────────────────────────┐
│ builder.py │
- │ IndexBuilder 入口 │
+ │ IndexBuilder 入口 │
└─────────────────┬───────────────────────┘
│
┌─────────────────▼───────────────────────┐
- │ loaders.py │
- │ DocumentLoader.load_file() │
- │ → 返回 List[Document] │
+ │ loaders.py │
+ │ DocumentLoader.load_file() │
+ │ → 返回 List[Document] │
└─────────────────┬───────────────────────┘
│
┌─────────────────▼───────────────────────┐
- │ ParentDocumentRetriever.add_documents()│
- │ ┌─────────────────────────────────┐ │
- │ │ parent_splitter (粗切) │ │
- │ │ 父块 ~1000 词 │ │
- │ └────────────┬────────────────────┘ │
- │ │ │
- │ ┌────────────▼────────────────────┐ │
- │ │ child_splitter (细切) │ │
- │ │ 子块 ~200 词 │ │
- │ └────────────┬────────────────────┘ │
- │ │ │
- │ ┌──────────┴──────────┐ │
- │ ▼ ▼ │
- │ 子块向量 父块原始内容 │
- │ │ │ │
- │ ▼ ▼ │
- │ ┌────────────┐ ┌─────────────────┐ │
- │ │vector_store│ │ docstore_manager│ │
- │ │ (Qdrant) │ │ (PostgreSQL) │ │
- │ └────────────┘ └─────────────────┘ │
+ │ ParentDocumentRetriever │
+ │ ┌─────────────────────────────────┐ │
+ │ │ parent_splitter (粗切) │ │
+ │ │ 父块 ~1000 字符 │ │
+ │ └────────────┬────────────────────┘ │
+ │ │ │
+ │ ┌────────────▼────────────────────┐ │
+ │ │ child_splitter (细切) │ │
+ │ │ 子块 ~200 字符 │ │
+ │ └────────────┬────────────────────┘ │
+ │ │ │
+ │ ┌──────────┴──────────┐ │
+ │ ▼ ▼ │
+ │ 子块向量 父块原始内容 │
+ │ │ │ │
+ │ ▼ ▼ │
+ │ ┌────────────┐ ┌─────────────────┐ │
+ │ │vector_store│ │ store/ │ │
+ │ │ (Qdrant) │ │ (PostgreSQL) │ │
+ │ └────────────┘ └─────────────────┘ │
└─────────────────────────────────────────┘
```
@@ -250,10 +180,31 @@ rag_indexer/
|------|------|------------|
| **builder.py** | 核心编排,负责串联整个流程 | `IndexBuilder` |
| **loaders.py** | 解析各种文档格式(PDF、Word、TXT等) | `DocumentLoader` |
-| **splitters.py** | 文本切分策略(Recursive/Semantic/Parent-Child) | `SplitterType`, `get_splitter()` |
+| **splitters.py** | 文本切分策略(Recursive/Semantic)及适配器 | `SplitterType`, `get_splitter()`, `SemanticChunkerAdapter` |
| **embedders.py** | 向量化(封装 llama.cpp embedding 接口) | `LlamaCppEmbedder` |
| **vector_store.py** | Qdrant 向量数据库操作 | `QdrantVectorStore` |
-| **docstore_manager.py** | 父文档存储(PostgreSQL/本地文件) | `PostgresDocStore`, `get_docstore()` |
+| **store/postgres.py** | PostgreSQL DocStore 实现 | `PostgresDocStore` |
+| **store/factory.py** | docstore 工厂函数 | `create_docstore()` |
+
+### 核心实现细节
+
+#### 1. 文本切分
+- **递归切分**: 使用 `langchain_text_splitters.RecursiveCharacterTextSplitter`,支持中文分隔符
+- **语义切分**: 使用 `langchain_experimental.text_splitter.SemanticChunker`,通过 `SemanticChunkerAdapter` 适配 `TextSplitter` 接口
+- **父子块策略**: 父块使用递归切分(1000字符),子块可选择递归或语义切分(200字符)
+
+#### 2. 向量化
+- **Embedding API**: 使用 `LlamaCppEmbedder` 封装本地 llama.cpp 服务,支持 `embed_documents` 和 `embed_query` 方法
+- **向量维度**: 自动检测模型维度(默认 2560),创建对应大小的 Qdrant 集合
+
+#### 3. 向量存储
+- **Qdrant 集成**: 使用 `langchain_qdrant.QdrantVectorStore` 作为底层存储
+- **集合管理**: 自动创建/复用集合,支持 `force_recreate` 参数
+- **批量写入**: 支持 `batch_size` 参数,避免单次请求过大
+
+#### 4. 文档存储
+- **PostgreSQL**: 使用 `PostgresDocStore` 持久化存储父块,支持异步连接池
+- **数据映射**: 通过 UUID 将子块与父块关联,检索时返回完整父块
### 调用顺序
@@ -265,27 +216,42 @@ from rag_indexer.builder import IndexBuilder, SplitterType
builder = IndexBuilder(
collection_name="my_docs",
splitter_type=SplitterType.PARENT_CHILD,
- qdrant_url="http://localhost:6333",
parent_chunk_size=1000,
child_chunk_size=200,
+ docstore_conn_string="postgresql://user:pass@host:5432/db",
)
```
#### 2. 构建索引
```python
+import asyncio
+
# 方式A:从单个文件构建
-builder.build_from_file("/path/to/document.pdf")
+async def main():
+ count = await builder.build_from_file("/path/to/document.pdf")
+ print(f"已索引 {count} 个块")
# 方式B:从目录批量构建
-builder.build_from_directory("/path/to/docs/")
+async def main():
+ count = await builder.build_from_directory("/path/to/docs/")
+ print(f"已索引 {count} 个块")
+
+asyncio.run(main())
```
#### 3. 检索(获取完整父块上下文)
```python
-# 检索时返回完整父块
-results = builder.search_with_parent_context("查询内容")
+import asyncio
+
+async def main():
+ # 检索时返回完整父块
+ results = await builder.search_with_parent_context("查询内容", k=5)
+ for doc in results:
+ print(doc.page_content)
+
+asyncio.run(main())
```
### 检索流程
@@ -299,11 +265,16 @@ results = builder.search_with_parent_context("查询内容")
---
### 串联与触发方式
-在你的 LangGraph 系统外,创建一个执行脚本 `scripts/run_indexer.py`:
+使用 `cli.py` 入口脚本:
```bash
-# 终端执行,将本地的 PDF 手册刷入向量数据库
+# 设置环境变量
export QDRANT_URL="http://115.190.121.151:6333"
-python scripts/run_indexer.py --file data/user_docs/tech_manual.pdf
+export QDRANT_API_KEY="your-api-key"
+export DB_URI="postgresql://postgres:password@host:5432/langgraph_db?sslmode=disable"
+
+# 执行索引构建
+python -m rag_indexer.cli --path data/user_docs/tech_manual.pdf
```
-这相当于系统后台的**“离线学习阶段”**,你可以随时挂载定时任务去扫描文件夹,增量更新知识库。
+
+这相当于系统后台的**"离线学习阶段"**,你可以随时挂载定时任务去扫描文件夹,增量更新知识库。
diff --git a/rag_indexer/__init__.py b/rag_indexer/__init__.py
index 78daf84..7d178ac 100644
--- a/rag_indexer/__init__.py
+++ b/rag_indexer/__init__.py
@@ -9,52 +9,52 @@ Offline RAG Indexer module.
- 父文档存储(PostgreSQL)
示例用法:
- >>> from rag_indexer import IndexBuilder, SplitterType
+ >>> from rag_indexer import IndexBuilder, IndexBuilderConfig, SplitterType
>>>
- >>> builder = IndexBuilder(
+ >>> config = IndexBuilderConfig(
... collection_name="my_docs",
... splitter_type=SplitterType.PARENT_CHILD,
- ... qdrant_url="http://localhost:6333"
... )
+ >>> builder = IndexBuilder(config)
>>>
- >>> builder.build_from_file("document.pdf")
+ >>> # 或直接传参(向后兼容)
+ >>> builder = IndexBuilder(collection_name="my_docs")
+ >>>
+ >>> await builder.build_from_file("document.pdf")
"""
+from .IndexBuilder import IndexBuilder, IndexBuilderConfig, DocstoreConfig
from .loaders import DocumentLoader
-from .splitters import (
- SplitterType,
- get_splitter,
- ParentChildSplitter,
-)
-from .embedders import LlamaCppEmbedder
-from .vector_store import QdrantVectorStore
-from .builder import IndexBuilder
+from .splitters import SplitterType, get_splitter
-# 导出存储相关类(从新的 store 包)
-from .store import (
+# 从 rag_core 重新导出常用组件
+from rag_core import (
+ LlamaCppEmbedder,
+ QdrantVectorStore,
PostgresDocStore,
create_docstore,
)
-
-
__version__ = "2.0.0"
__all__ = [
- # 核心类
- "DocumentLoader",
+ # 核心构建器与配置
"IndexBuilder",
+ "IndexBuilderConfig",
+ "DocstoreConfig",
+
+ # 加载器
+ "DocumentLoader",
# 切分相关
"SplitterType",
"get_splitter",
- "ParentChildSplitter",
- # 嵌入和向量存储
+ # 嵌入与向量存储
"LlamaCppEmbedder",
"QdrantVectorStore",
- # 存储(新的 store 包)
+ # 文档存储
"PostgresDocStore",
"create_docstore",
-]
+]
\ No newline at end of file
diff --git a/rag_indexer/builder.py b/rag_indexer/builder.py
deleted file mode 100644
index 2a1c51e..0000000
--- a/rag_indexer/builder.py
+++ /dev/null
@@ -1,392 +0,0 @@
-"""
-离线 RAG 索引构建核心流水线。
-
-支持 LangChain 的 ParentDocumentRetriever 用于父子块切分。
-"""
-
-import asyncio
-import logging
-from pathlib import Path
-from typing import List, Union, Optional, Tuple, Any
-from dataclasses import dataclass
-
-from httpx import RemoteProtocolError
-from langchain_core.documents import Document
-from langchain_classic.retrievers import ParentDocumentRetriever
-from langchain_core.stores import BaseStore
-from langchain_text_splitters import RecursiveCharacterTextSplitter
-from langchain_experimental.text_splitter import SemanticChunker
-
-from .loaders import DocumentLoader
-from .splitters import SplitterType, get_splitter, ParentChildSplitter, SemanticChunkerAdapter
-from .embedders import LlamaCppEmbedder
-from .vector_store import QdrantVectorStore, QDRANT_URL, QDRANT_API_KEY
-from .store import create_docstore
-
-logger = logging.getLogger(__name__)
-
-
-@dataclass
-class ParentChildConfig:
- """父子块切分配置。"""
- parent_chunk_size: int = 1000
- child_chunk_size: int = 200
- parent_chunk_overlap: int = 100
- child_chunk_overlap: int = 20
- search_k: int = 5
- docstore_path: Optional[str] = None
- docstore_type: str = "local"
- docstore_conn_string: Optional[str] = None
-
-
-class IndexBuilder:
- """RAG 索引构建主流水线。"""
-
- # 类型注解
- parent_splitter: "RecursiveCharacterTextSplitter"
- child_splitter: Union["RecursiveCharacterTextSplitter", "SemanticChunker", "SemanticChunkerAdapter"]
- docstore: Optional["BaseStore"]
- _docstore_conn: Optional[str]
- retriever: Optional["ParentDocumentRetriever"]
- vector_store_obj: Any
-
- def __init__(
- self,
- collection_name: str = "rag_documents",
- splitter_type: SplitterType = SplitterType.PARENT_CHILD,
- docstore=None,
- **splitter_kwargs,
- ):
- self.collection_name = collection_name
- self.splitter_type = splitter_type
- self.splitter_kwargs = splitter_kwargs
- self.docstore = docstore # 从外部注入
-
- # 组件
- self.loader = DocumentLoader()
- self.embedder = LlamaCppEmbedder()
- self.embeddings = self.embedder.as_langchain_embeddings()
-
- self.vector_store = QdrantVectorStore(
- collection_name=collection_name,
- embeddings=self.embeddings,
- )
-
- # 切分器(父子块单独处理)
- if splitter_type != SplitterType.PARENT_CHILD:
- if splitter_type == SplitterType.SEMANTIC:
- splitter_kwargs["embeddings"] = self.embeddings
- self.splitter = get_splitter(splitter_type, **splitter_kwargs)
- else:
- self.splitter = None
- # 为父子块切分初始化 ParentDocumentRetriever
- self._init_parent_child_retriever()
-
- def _init_parent_child_retriever(self, **kwargs):
- """
- 初始化 ParentDocumentRetriever 用于父子块切分。
-
- 支持动态语义切分与父子块策略结合:
- - 父块使用递归切分(大块,提供上下文)
- - 子块可以使用递归切分或语义切分(根据语义动态切分,提高检索精度)
-
- 替代自定义的 ParentChildSplitter 逻辑。
- """
- # 解析父子块配置参数
- parent_size = kwargs.get("parent_chunk_size", 1000)
- child_size = kwargs.get("child_chunk_size", 200)
- parent_overlap = kwargs.get("parent_chunk_overlap", kwargs.get("chunk_overlap", 100))
- child_overlap = kwargs.get("child_chunk_overlap", kwargs.get("chunk_overlap", 20))
-
- # 子块切分器类型,默认为语义切分
- child_splitter_type = kwargs.get("child_splitter_type", SplitterType.SEMANTIC)
-
- # 定义父块切分器(始终使用递归切分)
- self.parent_splitter = RecursiveCharacterTextSplitter(
- chunk_size=parent_size,
- chunk_overlap=parent_overlap,
- )
-
- # 定义子块切分器(根据类型选择)
- if child_splitter_type == SplitterType.SEMANTIC:
- self.child_splitter = get_splitter(
- SplitterType.SEMANTIC,
- embeddings=self.embeddings,
- )
- logger.info(f"子块使用语义切分器")
- else:
- # 默认使用递归切分
- self.child_splitter = RecursiveCharacterTextSplitter(
- chunk_size=child_size,
- chunk_overlap=child_overlap,
- )
- logger.info(f"子块使用递归切分器,块大小: {child_size},重叠: {child_overlap}")
-
- # 向量存储(用于子块)
- self.vector_store_obj = self.vector_store.get_langchain_vectorstore()
-
- # 文档存储(用于父块)
- if self.docstore is None:
- # 如果没有外部注入 docstore,则使用 PostgreSQL 创建
- docstore_conn = kwargs.get("docstore_conn_string")
- pool_config = kwargs.get("pool_config")
- max_concurrency = kwargs.get("max_concurrency")
-
- # 使用 create_docstore 创建 PostgreSQL 存储
- self.docstore, self._docstore_conn = create_docstore(
- connection_string=docstore_conn,
- pool_config=pool_config,
- max_concurrency=max_concurrency
- )
- else:
- # 使用外部注入的 docstore
- self._docstore_conn = None
-
- # 创建检索器
- self.retriever = ParentDocumentRetriever(
- vectorstore=self.vector_store_obj,
- docstore=self.docstore,
- child_splitter=self.child_splitter, # type: ignore
- parent_splitter=self.parent_splitter,
- search_kwargs={"k": kwargs.get("search_k", 5)},
- )
- logger.info(f"ParentDocumentRetriever 已初始化,父块大小: {parent_size},子块类型: {child_splitter_type}")
-
- async def build_from_file(self, file_path: Union[str, Path]) -> int:
- logger.info("加载文件: %s", file_path)
- documents = self.loader.load_file(file_path)
- logger.info("已加载 %d 个文档", len(documents))
- return await self._process_documents(documents)
-
- async def build_from_directory(self, directory_path: Union[str, Path], recursive: bool = True) -> int:
- logger.info("加载目录: %s (递归=%s)", directory_path, recursive)
- documents = self.loader.load_directory(directory_path, recursive=recursive)
- logger.info("已从目录加载 %d 个文档", len(documents))
- return await self._process_documents(documents)
-
- async def _process_documents(self, documents: List[Document]) -> int:
- if not documents:
- logger.warning("没有文档需要处理")
- return 0
-
- if self.splitter_type == SplitterType.PARENT_CHILD:
- logger.info("使用 LangChain ParentDocumentRetriever")
-
- # 确保集合存在(用于子块)
- self.vector_store.create_collection()
-
- # 分批处理,避免单次请求过大
- assert self.retriever is not None, "retriever 未初始化"
- batch_size = 10 # 每次处理10个文档
- total = len(documents)
- processed = 0
-
- for i in range(0, total, batch_size):
- batch = documents[i:i + batch_size]
- max_retries = 3
- for attempt in range(max_retries):
- try:
- await self.retriever.aadd_documents(batch)
- processed += len(batch)
- logger.info(f"批次 {i//batch_size + 1}: 已处理 {processed}/{total}")
- break
- except (RemoteProtocolError, ConnectionError, OSError) as e:
- if attempt == max_retries - 1:
- raise
- logger.warning(f"连接断开,重试 ({attempt+1}/{max_retries}): {e}")
- self.vector_store.refresh_client()
- await asyncio.sleep(1)
-
- logger.info(
- "已使用 ParentDocumentRetriever 索引: "
- f"共 {processed} 个父块"
- )
- return processed
-
- else:
- logger.info("使用 %s 切分文档", self.splitter_type)
- # 当 splitter_type 不是 PARENT_CHILD 时,splitter 一定不为 None
- assert self.splitter is not None, "splitter 未初始化"
- chunks = self.splitter.split_documents(documents)
- logger.info("已切分为 %d 个块", len(chunks))
-
- self.vector_store.create_collection()
- self.vector_store.add_documents(chunks)
- return len(chunks)
-
- def get_collection_info(self):
- return self.vector_store.get_collection_info()
-
- def search(self, query: str, k: int = 5) -> List[Document]:
- """标准搜索 - 返回子块。"""
- return self.vector_store.similarity_search(query, k=k)
-
- async def search_with_parent_context(self, query: str, k: int = 5) -> List[Document]:
- """
- 带父块上下文的搜索 - 返回完整父块。
-
- 这是使用父子块切分时的主要检索方法。
- """
- if self.splitter_type != SplitterType.PARENT_CHILD:
- raise RuntimeError(
- "search_with_parent_context 仅在 PARENT_CHILD 切分器下可用。"
- "请使用 search() 进行标准检索。"
- )
- assert self.retriever is not None, "retriever 未初始化"
- return await self.retriever.ainvoke(query, config={"k": k}) # type: ignore
-
- async def retrieve(self, query: str, return_parent: bool = True) -> List[Document]:
- """
- 统一检索接口。
-
- Args:
- query: 搜索查询
- return_parent: 如果为 True 且使用父子块切分,返回父块
- 如果为 False,始终返回子块
-
- Returns:
- 相关文档列表
- """
- if self.splitter_type == SplitterType.PARENT_CHILD and return_parent:
- return await self.search_with_parent_context(query)
- else:
- return self.search(query)
-
- async def retrieve_with_fusion(self, query: str, llm: Any, num_queries: int = 3, k: int = 5, return_parent: bool = True) -> List[Document]:
- """
- 使用 RAG-Fusion 进行检索(多路查询改写 + 倒数排名融合)。
-
- 核心原理:
- 1. 多路查询改写: 利用 LLM 将原始查询改写成多个不同表述
- 2. 倒数排名融合: 对每个改写查询的结果进行 RRF 融合,避免单一检索结果主导
-
- Args:
- query: 原始搜索查询
- llm: 语言模型实例(用于查询改写)
- num_queries: 生成的查询数量
- k: 返回的文档数量
- return_parent: 如果为 True 且使用父子块切分,返回父块
- 如果为 False,始终返回子块
-
- Returns:
- 经过融合后的相关文档列表
- """
- from langchain.retrievers.multi_query import MultiQueryRetriever
- from langchain.retrievers import EnsembleRetriever
-
- if self.splitter_type == SplitterType.PARENT_CHILD and return_parent:
- # 使用 ParentDocumentRetriever 作为基础检索器
- assert self.retriever is not None, "retriever 未初始化"
- base_retriever = self.retriever
- else:
- # 使用向量存储作为基础检索器
- base_retriever = self.vector_store.as_langchain_vectorstore().as_retriever(search_kwargs={"k": k * 2})
-
- # 创建多路查询检索器
- multi_query_retriever = MultiQueryRetriever.from_llm(
- retriever=base_retriever,
- llm=llm,
- include_original=True
- )
-
- # 设置自定义提示词以生成指定数量的查询
- from langchain_core.prompts import PromptTemplate
- multi_query_retriever.llm_chain.prompt = PromptTemplate.from_template(
- "你是一个专业的查询改写助手。你的任务是将用户的问题改写成 {num_queries} 个不同的版本。\n"
- "这些版本应该从不同的角度、使用不同的关键词来表达相同或相关的意图。\n\n"
- "原始问题: {question}\n\n"
- "请生成 {num_queries} 个不同版本的查询,每个版本一行。\n"
- "确保每个版本都是独立、完整的查询语句。\n\n"
- "生成 {num_queries} 个查询:"
- )
-
- # 修改调用参数以包含 num_queries
- original_ainvoke = multi_query_retriever.llm_chain.ainvoke
- async def new_ainvoke(input_dict):
- input_dict["num_queries"] = num_queries
- return await original_ainvoke(input_dict)
- multi_query_retriever.llm_chain.ainvoke = new_ainvoke
-
- # 执行检索
- documents = await multi_query_retriever.ainvoke(query)
-
- # 去重并限制数量
- seen_content = set()
- unique_documents = []
- for doc in documents:
- content = doc.page_content
- if content not in seen_content:
- seen_content.add(content)
- unique_documents.append(doc)
- if len(unique_documents) >= k:
- break
-
- logger.info(f"RAG-Fusion 检索完成: 原始 {len(documents)} 个结果,去重后 {len(unique_documents)} 个结果")
- return unique_documents
-
- def get_retriever(self) -> ParentDocumentRetriever:
- """
- 直接获取 ParentDocumentRetriever 实例。
-
- 适用于需要在 IndexBuilder 外部访问检索器的高级用例。
- """
- if self.splitter_type != SplitterType.PARENT_CHILD:
- raise RuntimeError(
- "get_retriever() 仅在 PARENT_CHILD 切分器下可用。"
- "请使用 search() 或 search_with_parent_context() 进行标准检索。"
- )
- assert self.retriever is not None, "retriever 未初始化"
- return self.retriever
-
- def get_child_splitter(self) -> Union["RecursiveCharacterTextSplitter", "SemanticChunker", "SemanticChunkerAdapter"]:
- """获取子块切分器以便重新配置。"""
- if self.splitter_type != SplitterType.PARENT_CHILD:
- return self.splitter # type: ignore
- return self.child_splitter
-
- def get_parent_splitter(self) -> "RecursiveCharacterTextSplitter":
- """获取父块切分器以便重新配置。"""
- if self.splitter_type != SplitterType.PARENT_CHILD:
- raise RuntimeError(
- "父块切分器仅在 PARENT_CHILD 切分器下可用。"
- )
- return self.parent_splitter
-
- def get_docstore(self) -> BaseStore:
- """获取父块的文档存储。"""
- if self.splitter_type != SplitterType.PARENT_CHILD:
- raise RuntimeError(
- "文档存储仅在 PARENT_CHILD 切分器下可用。"
- )
- assert self.docstore is not None, "docstore 未初始化"
- return self.docstore
-
- def get_docstore_path(self) -> Optional[str]:
- """获取文档存储路径(已弃用,仅用于兼容性)。"""
- if self.splitter_type != SplitterType.PARENT_CHILD:
- raise RuntimeError(
- "文档存储路径仅在 PARENT_CHILD 切分器下可用。"
- )
- # PostgreSQL 存储没有 persist_path,返回 None
- return None
-
- def close(self):
- """关闭资源。"""
- if self.docstore is not None and hasattr(self.docstore, "aclose"):
- import asyncio
- asyncio.get_event_loop().run_until_complete(self.docstore.aclose()) # type: ignore
- logger.info("PostgreSQL 异步连接池已关闭")
-
- def __enter__(self):
- return self
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- self.close()
- return False
-
-
-# 需要导入 RecursiveCharacterTextSplitter
-from langchain_text_splitters import RecursiveCharacterTextSplitter
-
-
-# 示例用法已移除,请参考文档
diff --git a/rag_indexer/cli.py b/rag_indexer/cli.py
index 63014b8..6942506 100755
--- a/rag_indexer/cli.py
+++ b/rag_indexer/cli.py
@@ -1,85 +1,77 @@
"""
-Command-line interface for the RAG index builder.
+简易命令行入口,使用默认配置构建 RAG 索引。
"""
-import argparse
import asyncio
import logging
import sys
+from pathlib import Path
-from rag_indexer.builder import IndexBuilder
+from rag_indexer.IndexBuilder import IndexBuilder, IndexBuilderConfig
from rag_indexer.splitters import SplitterType
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
+logger = logging.getLogger(__name__)
-# 基础配置
+# 默认配置(所有连接参数从环境变量读取)
COLLECTION_NAME = "rag_documents"
-DB_URI = "postgresql://postgres:huang1998@115.190.121.151:5432/langgraph_db?sslmode=disable"
+SPLITTER_TYPE = SplitterType.PARENT_CHILD
+CHILD_SPLITTER_TYPE = SplitterType.SEMANTIC
-# 基础切分参数
-CHUNK_SIZE = 500
-CHUNK_OVERLAP = 50
-
-# 父子块切分参数
+# 父子块大小参数(可根据需要调整)
PARENT_CHUNK_SIZE = 1000
-CHILD_CHUNK_SIZE = 200
PARENT_CHUNK_OVERLAP = 100
+CHILD_CHUNK_SIZE = 200
CHILD_CHUNK_OVERLAP = 20
+SEARCH_K = 5
-# 切分策略:basic(基础)、semantic(语义)、parent-child(父子块)
-STRATEGY = "parent-child"
-# 存储类型:postgres(PostgreSQL)、local(本地文件)
-STORAGE_TYPE = "postgres"
+def get_input_path() -> Path:
+ """从命令行参数获取输入路径,若未提供则使用默认示例路径。"""
+ if len(sys.argv) > 1:
+ return Path(sys.argv[1])
+ # 默认测试路径(可按需修改)
+ return Path("data/user_docs/a.txt")
async def main():
- # 使用固定策略
- splitter_type = SplitterType.PARENT_CHILD
- child_splitter_type = SplitterType.SEMANTIC
+ input_path = get_input_path()
+ if not input_path.exists():
+ logger.error("路径不存在: %s", input_path)
+ sys.exit(1)
- splitter_kwargs = {}
-
- if splitter_type == SplitterType.RECURSIVE:
- splitter_kwargs["chunk_size"] = CHUNK_SIZE
- splitter_kwargs["chunk_overlap"] = CHUNK_OVERLAP
- elif splitter_type == SplitterType.PARENT_CHILD:
- splitter_kwargs["parent_chunk_size"] = PARENT_CHUNK_SIZE
- splitter_kwargs["child_chunk_size"] = CHILD_CHUNK_SIZE
- splitter_kwargs["parent_chunk_overlap"] = PARENT_CHUNK_OVERLAP
- splitter_kwargs["child_chunk_overlap"] = CHILD_CHUNK_OVERLAP
- splitter_kwargs["child_splitter_type"] = child_splitter_type
- if STORAGE_TYPE == "postgres":
- splitter_kwargs["docstore_conn_string"] = DB_URI
- elif STORAGE_TYPE == "local":
- splitter_kwargs["docstore_path"] = "./parent_docs"
- else:
- splitter_kwargs["docstore_conn_string"] = DB_URI
-
- builder = IndexBuilder(
+ # 构建配置(使用全部默认值)
+ config = IndexBuilderConfig(
collection_name=COLLECTION_NAME,
- splitter_type=splitter_type,
- **splitter_kwargs
+ splitter_type=SPLITTER_TYPE,
+ parent_chunk_size=PARENT_CHUNK_SIZE,
+ parent_chunk_overlap=PARENT_CHUNK_OVERLAP,
+ child_chunk_size=CHILD_CHUNK_SIZE,
+ child_chunk_overlap=CHILD_CHUNK_OVERLAP,
+ child_splitter_type=CHILD_SPLITTER_TYPE,
+ search_k=SEARCH_K,
+ # docstore 默认使用 create_docstore 从环境变量读取 PostgreSQL 连接
)
- is_file=False
- path="data/corpus/"
+ builder = IndexBuilder(config)
+ is_directory = input_path.is_dir()
try:
- if is_file:
- chunk_count = await builder.build_from_file(path)
- else:
- chunk_count = await builder.build_from_directory(path, recursive=True)
+ async with builder:
+ if is_directory:
+ chunk_count = await builder.build_from_directory(input_path, recursive=True)
+ else:
+ chunk_count = await builder.build_from_file(input_path)
- print(f"索引构建完成。共索引 {chunk_count} 个块")
+ print(f"\n索引构建完成。共索引 {chunk_count} 个块")
info = builder.get_collection_info()
print(f"集合 '{info['name']}' 包含 {info['vectors_count']} 个向量(维度:{info['vector_size']})")
except Exception as e:
- logging.exception(f"索引构建失败:{e}")
+ logger.exception("索引构建失败: %s", e)
sys.exit(1)
diff --git a/rag_indexer/loaders.py b/rag_indexer/loaders.py
index d0c16c4..c5c6e33 100644
--- a/rag_indexer/loaders.py
+++ b/rag_indexer/loaders.py
@@ -3,19 +3,27 @@
"""
import logging
+import os
from pathlib import Path
-from typing import Any, Dict, List, Mapping, Optional, Union
+from typing import Any, Dict, List, Optional, Union
from langchain_core.documents import Document
+from unstructured.documents.elements import Element
from unstructured.partition.auto import partition
logger = logging.getLogger(__name__)
+# 模块加载时设置一次环境变量,避免重复设置
+os.environ.setdefault("UNSTRUCTURED_LANGUAGE_CHECKS", "false")
+
class DocumentLoader:
"""从各种文件格式加载文档。"""
- SUPPORTED_EXTENSIONS = {".pdf", ".docx", ".doc", ".txt", ".md", ".html", ".pptx", ".xlsx", ".json"}
+ SUPPORTED_EXTENSIONS = {
+ ".pdf", ".docx", ".doc", ".txt", ".md",
+ ".html", ".pptx", ".xlsx", ".json"
+ }
def __init__(
self,
@@ -32,13 +40,11 @@ class DocumentLoader:
extract_images: 是否提取 PDF 中的图片
strategy: 解析策略 (auto, fast, hi_res, ocr_only)
ocr_languages: OCR 语言列表,如 ['chi_sim', 'eng']
- languages: 文档主语言,如 ['zh']
+ languages: 文档主语言,如 ['zh'](主要用于非 OCR 场景)
include_page_breaks: 是否包含分页符
- pdf_infer_table_structure: 是否识别表格结构 (需 hi_res 策略)
+ pdf_infer_table_structure: 是否识别表格结构(需 hi_res 策略)
partition_kwargs: 额外的 partition 参数字典(高级定制)
"""
- import os
- os.environ["UNSTRUCTURED_LANGUAGE_CHECKS"] = "false"
self.extract_images = extract_images
self.strategy = strategy
self.ocr_languages = ocr_languages or ["chi_sim", "eng"]
@@ -47,6 +53,52 @@ class DocumentLoader:
self.pdf_infer_table_structure = pdf_infer_table_structure
self.partition_kwargs = partition_kwargs or {}
+ def _build_partition_kwargs(self, file_path: Path) -> Dict[str, Any]:
+ """根据文件类型构建 partition 的参数。"""
+ kwargs: Dict[str, Any] = {
+ "include_page_breaks": self.include_page_breaks,
+ }
+
+ suffix = file_path.suffix.lower()
+
+ # PDF 专用参数
+ if suffix == ".pdf":
+ kwargs.update({
+ "strategy": self.strategy,
+ "ocr_languages": self.ocr_languages,
+ "extract_images_in_pdf": self.extract_images,
+ "pdf_infer_table_structure": self.pdf_infer_table_structure,
+ })
+
+ # 所有文件适用的语言参数
+ if self.languages:
+ kwargs["languages"] = self.languages
+
+ # 用户自定义参数覆盖默认值
+ kwargs.update(self.partition_kwargs)
+
+ return kwargs
+
+ def _element_to_document(self, element: Element, file_path: Path) -> Optional[Document]:
+ """将单个 Element 转换为 Document,同时保留关键元数据。"""
+ text = getattr(element, "text", "")
+ if not text or not text.strip():
+ return None
+
+ # 提取 unstructured 提供的元数据(根据实际需要选择)
+ metadata = {
+ "source": str(file_path),
+ "file_name": file_path.name,
+ "file_type": file_path.suffix.lower(),
+ # 以下元数据来自 Element 对象,可能为 None
+ "page_number": getattr(getattr(element, "metadata", None), "page_number", None),
+ "category": getattr(getattr(element, "metadata", None), "category", None),
+ }
+ # 过滤掉值为 None 的元数据
+ metadata = {k: v for k, v in metadata.items() if v is not None}
+
+ return Document(page_content=text, metadata=metadata)
+
def load_file(self, file_path: Union[str, Path]) -> List[Document]:
"""将单个文件加载为 LangChain Document 对象。"""
file_path = Path(file_path).resolve()
@@ -59,68 +111,58 @@ class DocumentLoader:
f"不支持的文件扩展名: {suffix}。支持的格式: {self.SUPPORTED_EXTENSIONS}"
)
- # 根据文件类型动态调整参数
- extra_kwargs = {}
- if suffix == ".pdf":
- extra_kwargs["strategy"] = self.strategy
- extra_kwargs["ocr_languages"] = self.ocr_languages
- extra_kwargs["extract_images_in_pdf"] = self.extract_images
- extra_kwargs["pdf_infer_table_structure"] = self.pdf_infer_table_structure
-
- # languages 参数适用于所有文件类型
- if self.languages:
- extra_kwargs["languages"] = self.languages
-
- extra_kwargs["include_page_breaks"] = self.include_page_breaks
+ kwargs = self._build_partition_kwargs(file_path)
- # 合并用户自定义的额外参数(优先级最高)
- extra_kwargs.update(self.partition_kwargs)
-
- # 使用 unstructured 解析
- elements = partition(
- filename=str(file_path),
-
- **extra_kwargs
- )
+ try:
+ elements = partition(filename=str(file_path), **kwargs)
+ except Exception as e:
+ logger.exception("解析文件 %s 失败", file_path)
+ raise RuntimeError(f"文件解析失败: {file_path}") from e
documents = []
for elem in elements:
- text = getattr(elem, "text", "")
- if not text or not text.strip():
- continue
-
- # 基础元数据
- metadata = {
- "source": str(file_path),
- "file_name": file_path.name,
- "file_type": suffix,
- }
-
- documents.append(Document(page_content=text, metadata=metadata))
+ doc = self._element_to_document(elem, file_path)
+ if doc:
+ documents.append(doc)
if not documents:
logger.warning("未从 %s 提取到文本内容", file_path)
- return []
return documents
def load_directory(
- self, directory_path: Union[str, Path], recursive: bool = True
+ self,
+ directory_path: Union[str, Path],
+ recursive: bool = True,
+ fail_fast: bool = False
) -> List[Document]:
- """从目录加载所有支持的文件。"""
+ """
+ 从目录加载所有支持的文件。
+
+ Args:
+ directory_path: 目录路径
+ recursive: 是否递归子目录
+ fail_fast: 遇到第一个失败时是否立即抛出异常
+ """
directory_path = Path(directory_path).resolve()
if not directory_path.is_dir():
raise NotADirectoryError(f"不是目录: {directory_path}")
- all_documents = []
+ all_documents: List[Document] = []
pattern = "**/*" if recursive else "*"
for file_path in directory_path.glob(pattern):
- if file_path.is_file() and file_path.suffix.lower() in self.SUPPORTED_EXTENSIONS:
- try:
- docs = self.load_file(file_path)
- all_documents.extend(docs)
- except Exception as e:
- logger.error("加载 %s 失败: %s", file_path, e)
+ if not file_path.is_file():
+ continue
+ if file_path.suffix.lower() not in self.SUPPORTED_EXTENSIONS:
+ continue
+
+ try:
+ docs = self.load_file(file_path)
+ all_documents.extend(docs)
+ except Exception as e:
+ logger.error("加载 %s 失败: %s", file_path, e)
+ if fail_fast:
+ raise
return all_documents
\ No newline at end of file
diff --git a/rag_indexer/splitters.py b/rag_indexer/splitters.py
index 45874d3..006e8ab 100644
--- a/rag_indexer/splitters.py
+++ b/rag_indexer/splitters.py
@@ -3,7 +3,8 @@
"""
from enum import Enum
-from typing import List, Optional
+from typing import List, Optional, Tuple, Dict, Any
+from dataclasses import dataclass, field
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
@@ -16,68 +17,195 @@ class SplitterType(str, Enum):
PARENT_CHILD = "parent_child"
-def get_splitter(splitter_type: SplitterType, **kwargs):
- """工厂函数,创建文本切分器。"""
- if splitter_type == SplitterType.RECURSIVE:
- chunk_size = kwargs.get("chunk_size", 500)
- chunk_overlap = kwargs.get("chunk_overlap", 50)
- return RecursiveCharacterTextSplitter(
- chunk_size=chunk_size,
- chunk_overlap=chunk_overlap,
- separators=["\n\n", "\n", "。", "!", "?", " ", ""],
- )
- elif splitter_type == SplitterType.SEMANTIC:
- embeddings = kwargs.pop("embeddings", None)
- if embeddings is None:
- raise ValueError("语义切分器需要提供 'embeddings' 参数")
- return SemanticChunkerAdapter(embeddings=embeddings, **kwargs)
- else:
- raise ValueError(f"不支持的切分器类型: {splitter_type}")
+# ---------- 配置数据类,统一参数 ----------
+@dataclass
+class RecursiveSplitterConfig:
+ """递归字符切分器配置"""
+ chunk_size: int = 500
+ chunk_overlap: int = 50
+ separators: List[str] = field(default_factory=lambda: ["\n\n", "\n", "。", "!", "?", " ", ""])
+ keep_separator: bool = True
+ strip_whitespace: bool = True
+@dataclass
+class SemanticSplitterConfig:
+ """语义切分器配置,仅包含 SemanticChunker 支持的参数。"""
+ embeddings: Any
+ buffer_size: int = 1
+ add_start_index: bool = False
+ breakpoint_threshold_type: str = "percentile"
+ breakpoint_threshold_amount: Optional[float] = None
+ number_of_chunks: Optional[int] = None
+ sentence_split_regex: str = r"(?<=[.?!。?!])\s+"
+ min_chunk_size: int = 100
+
+@dataclass
+class ParentChildSplitterConfig:
+ """父子切分器配置"""
+ embeddings: Any # 子块语义切分所需
+ parent_chunk_size: int = 1000
+ parent_chunk_overlap: int = 100
+ child_buffer_size: int = 1
+ child_breakpoint_threshold_type: str = "percentile"
+ child_breakpoint_threshold_amount: Optional[float] = None
+ child_min_chunk_size: int = 100
+ child_max_chunk_size: Optional[int] = 200
+
+
+# ---------- 适配器:让 SemanticChunker 实现 TextSplitter 接口 ----------
class SemanticChunkerAdapter(TextSplitter):
- """将 SemanticChunker 适配为 TextSplitter 接口。"""
+ """将 SemanticChunker 适配为 LangChain TextSplitter 接口。"""
- def __init__(self, embeddings, **kwargs):
+ def __init__(self, config: SemanticSplitterConfig, **kwargs):
super().__init__(**kwargs)
- chunk_size = kwargs.pop("chunk_size", None)
- chunk_overlap = kwargs.pop("chunk_overlap", None)
- self._chunker = SemanticChunker(embeddings=embeddings, **kwargs)
+ self._config = config
+ self._chunker = SemanticChunker(
+ embeddings=config.embeddings,
+ buffer_size=config.buffer_size,
+ add_start_index=config.add_start_index,
+ breakpoint_threshold_type=config.breakpoint_threshold_type,
+ breakpoint_threshold_amount=config.breakpoint_threshold_amount,
+ number_of_chunks=config.number_of_chunks,
+ sentence_split_regex=config.sentence_split_regex,
+ min_chunk_size=config.min_chunk_size,
+ )
def split_text(self, text: str) -> List[str]:
return self._chunker.split_text(text)
+ def split_documents(self, documents: List[Document]) -> List[Document]:
+ result = []
+ for doc in documents:
+ chunks = self.split_text(doc.page_content)
+ for i, chunk in enumerate(chunks):
+ result.append(Document(
+ page_content=chunk,
+ metadata={**doc.metadata, "chunk_index": i}
+ ))
+ return result
+
+# ---------- 工厂函数,统一创建切分器 ----------
+def get_splitter(splitter_type: SplitterType, **kwargs) -> TextSplitter:
+ """
+ 根据类型创建切分器。
+ 支持传入配置对象或直接参数。
+ """
+ if splitter_type == SplitterType.RECURSIVE:
+ config = RecursiveSplitterConfig(
+ chunk_size=kwargs.get("chunk_size", 500),
+ chunk_overlap=kwargs.get("chunk_overlap", 50),
+ separators=kwargs.get("separators", ["\n\n", "\n", "。", "!", "?", " ", ""]),
+ )
+ return RecursiveCharacterTextSplitter(
+ chunk_size=config.chunk_size,
+ chunk_overlap=config.chunk_overlap,
+ separators=config.separators,
+ keep_separator=config.keep_separator,
+ strip_whitespace=config.strip_whitespace,
+ )
+
+ elif splitter_type == SplitterType.SEMANTIC:
+ embeddings = kwargs.get("embeddings")
+ if embeddings is None:
+ raise ValueError("语义切分器需要提供 'embeddings' 参数")
+
+ if "config" in kwargs and isinstance(kwargs["config"], SemanticSplitterConfig):
+ config = kwargs["config"]
+ else:
+ # 过滤出 SemanticSplitterConfig 支持的字段
+ config_kwargs = {
+ "embeddings": embeddings,
+ "buffer_size": kwargs.get("buffer_size", 1),
+ "breakpoint_threshold_type": kwargs.get("breakpoint_threshold_type", "percentile"),
+ "breakpoint_threshold_amount": kwargs.get("breakpoint_threshold_amount"),
+ "number_of_chunks": kwargs.get("number_of_chunks"),
+ "min_chunk_size": kwargs.get("min_chunk_size", 100),
+ }
+ config = SemanticSplitterConfig(**config_kwargs)
+ return SemanticChunkerAdapter(config)
+
+ elif splitter_type == SplitterType.PARENT_CHILD:
+ # 父子切分器在 builder 中单独处理,不通过本函数创建
+ raise ValueError("父子切分器应通过 IndexBuilder 创建,不支持 get_splitter 直接构建")
+
+ else:
+ raise ValueError(f"不支持的切分器类型: {splitter_type}")
+
+# ---------- 父子切分器实现 ----------
class ParentChildSplitter:
"""
- 将文档切分为父块(大块)和子块(小块)。
- 子块用于索引检索,父块用于存储上下文。
+ 将文档切分为父块(大块,用于上下文)和子块(小块,用于索引检索)。
+ 内部维护父子块之间的映射关系。
"""
- def __init__(
- self,
- parent_chunk_size: int = 1000,
- child_chunk_size: int = 200,
- parent_chunk_overlap: int = 100,
- child_chunk_overlap: int = 20,
- ):
+ def __init__(self, config: ParentChildSplitterConfig):
+ self.config = config
+ # 父块使用递归字符切分
self.parent_splitter = RecursiveCharacterTextSplitter(
- chunk_size=parent_chunk_size,
- chunk_overlap=parent_chunk_overlap,
+ chunk_size=config.parent_chunk_size,
+ chunk_overlap=config.parent_chunk_overlap,
)
- self.child_splitter = RecursiveCharacterTextSplitter(
- chunk_size=child_chunk_size,
- chunk_overlap=child_chunk_overlap,
+ # 子块使用语义切分
+ semantic_config = SemanticSplitterConfig(
+ embeddings=config.embeddings,
+ buffer_size=config.child_buffer_size,
+ breakpoint_threshold_type=config.child_breakpoint_threshold_type,
+ breakpoint_threshold_amount=config.child_breakpoint_threshold_amount,
+ min_chunk_size=config.child_min_chunk_size,
)
+ self.child_splitter = SemanticChunkerAdapter(semantic_config)
- def split_documents(self, documents: List[Document]) -> tuple[List[Document], List[Document]]:
+ # 存储父子块映射关系(可选)
+ self.parent_to_children: Dict[str, List[str]] = {}
+ self.child_to_parent: Dict[str, str] = {}
+
+ def split_documents(self, documents: List[Document]) -> Tuple[List[Document], List[Document]]:
"""
返回:
(父块列表, 子块列表)
+ 同时填充内部映射字典。
"""
parent_chunks = self.parent_splitter.split_documents(documents)
child_chunks = self.child_splitter.split_documents(documents)
- # 将子块与父块 ID 关联(可选元数据)
- # 在实际实现中,需要将每个子块映射到对应的父块 ID。
- return parent_chunks, child_chunks
\ No newline at end of file
+ # 建立映射关系(简化示例:根据文本包含关系粗略匹配,实际需更精确的算法)
+ # 这里仅作示意,生产环境建议使用 embedding 相似度或精确子串定位
+ self._build_mappings(parent_chunks, child_chunks)
+
+ return parent_chunks, child_chunks
+
+ def _build_mappings(self, parents: List[Document], children: List[Document]) -> None:
+ """
+ 根据文本内容建立父子映射。
+ 本方法为简化实现,实际使用时请替换为更可靠的匹配逻辑。
+ """
+ self.parent_to_children.clear()
+ self.child_to_parent.clear()
+
+ # 为每个父块生成唯一 ID(若无则使用索引)
+ for p_idx, parent in enumerate(parents):
+ parent_id = parent.metadata.get("id", f"parent_{p_idx}")
+ parent.metadata["id"] = parent_id
+ self.parent_to_children[parent_id] = []
+
+ # 将每个子块分配给包含其文本的第一个父块
+ for c_idx, child in enumerate(children):
+ child_id = child.metadata.get("id", f"child_{c_idx}")
+ child.metadata["id"] = child_id
+ for parent in parents:
+ if child.page_content in parent.page_content:
+ parent_id = parent.metadata["id"]
+ self.parent_to_children[parent_id].append(child_id)
+ self.child_to_parent[child_id] = parent_id
+ child.metadata["parent_id"] = parent_id
+ break
+
+ def get_parent_for_child(self, child_id: str) -> Optional[str]:
+ """根据子块 ID 获取父块 ID"""
+ return self.child_to_parent.get(child_id)
+
+ def get_children_for_parent(self, parent_id: str) -> List[str]:
+ """根据父块 ID 获取所有子块 ID"""
+ return self.parent_to_children.get(parent_id, [])
\ No newline at end of file
diff --git a/rag_indexer/test/reset_index.py b/rag_indexer/test/reset_index.py
new file mode 100644
index 0000000..7c6a793
--- /dev/null
+++ b/rag_indexer/test/reset_index.py
@@ -0,0 +1,80 @@
+"""清理 RAG 索引数据。
+
+用法:
+ python reset_index.py # 清理全部
+ python reset_index.py --qdrant # 仅清理 Qdrant
+ python reset_index.py --postgres # 仅清理 PostgreSQL
+"""
+
+import asyncio
+import os
+import argparse
+
+from dotenv import load_dotenv
+load_dotenv()
+
+QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
+QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
+DB_URI = os.getenv("DB_URI", "postgresql://postgres:huang1998@115.190.121.151:5432/langgraph_db?sslmode=disable")
+COLLECTION_NAME = "rag_documents"
+TABLE_NAME = "parent_documents"
+
+
+def clear_qdrant():
+ """删除 Qdrant 集合。"""
+ from qdrant_client import QdrantClient
+
+ print("清理 Qdrant...")
+ client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
+
+ collections = client.get_collections().collections
+ if any(c.name == COLLECTION_NAME for c in collections):
+ client.delete_collection(COLLECTION_NAME)
+ print(f" 集合 '{COLLECTION_NAME}' 已删除")
+ else:
+ print(f" 集合 '{COLLECTION_NAME}' 不存在")
+
+
+async def clear_postgres():
+ """清空 PostgreSQL 表数据。"""
+ import asyncpg
+
+ print("清理 PostgreSQL...")
+ conn = await asyncpg.connect(dsn=DB_URI)
+
+ try:
+ exists = await conn.fetchval(
+ "SELECT EXISTS(SELECT 1 FROM information_schema.tables WHERE table_schema = 'public' AND table_name = $1)",
+ TABLE_NAME
+ )
+ if exists:
+ count = await conn.fetchval(f"SELECT COUNT(*) FROM {TABLE_NAME}")
+ await conn.execute(f"DELETE FROM {TABLE_NAME}")
+ print(f" 表 '{TABLE_NAME}' 已清空,删除 {count} 条记录")
+ else:
+ print(f" 表 '{TABLE_NAME}' 不存在")
+ finally:
+ await conn.close()
+
+
+async def main():
+ parser = argparse.ArgumentParser(description="清理 RAG 索引数据")
+ parser.add_argument("--qdrant", action="store_true", help="仅清理 Qdrant")
+ parser.add_argument("--postgres", action="store_true", help="仅清理 PostgreSQL")
+ args = parser.parse_args()
+
+ if not args.qdrant and not args.postgres:
+ args.qdrant = True
+ args.postgres = True
+
+ if args.qdrant:
+ clear_qdrant()
+
+ if args.postgres:
+ await clear_postgres()
+
+ print("\n完成。运行 `python -m rag_indexer.cli` 重建索引")
+
+
+if __name__ == "__main__":
+ asyncio.run(main())
diff --git a/rag_indexer/test/test_inspect_vectors.py b/rag_indexer/test/test_inspect_vectors.py
new file mode 100644
index 0000000..5e296c0
--- /dev/null
+++ b/rag_indexer/test/test_inspect_vectors.py
@@ -0,0 +1,63 @@
+"""检查 Qdrant 中存储的向量质量。"""
+
+import os
+import sys
+import numpy as np
+from dotenv import load_dotenv
+from qdrant_client import QdrantClient
+
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../.."))
+from rag_core import LlamaCppEmbedder
+
+load_dotenv()
+
+QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
+QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
+COLLECTION_NAME = "rag_documents"
+
+client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
+embedder = LlamaCppEmbedder()
+
+# 获取样本
+points, _ = client.scroll(
+ collection_name=COLLECTION_NAME,
+ limit=1,
+ with_vectors=True,
+ with_payload=True,
+)
+
+if not points:
+ print(f"集合 '{COLLECTION_NAME}' 为空")
+ exit()
+
+sample = points[0]
+raw_vec = sample.vector
+if isinstance(raw_vec, dict):
+ stored_vec = list(raw_vec.values())[0]
+elif isinstance(raw_vec, list):
+ stored_vec = raw_vec
+else:
+ stored_vec = []
+
+stored_payload = sample.payload or {}
+stored_text = str(stored_payload.get("page_content", ""))[:200]
+
+print(f"内容预览:\n{stored_text}...\n")
+print(f"向量维度: {len(stored_vec)}") # type: ignore
+print(f"前5个值: {stored_vec[:5]}") # type: ignore
+print(f"是否全零: {all(v == 0.0 for v in stored_vec)}") # type: ignore
+
+# 重新编码对比
+if stored_text:
+ new_vec = embedder.embed_query(stored_text)
+ similarity = np.dot(stored_vec, new_vec) / (np.linalg.norm(stored_vec) * np.linalg.norm(new_vec)) # type: ignore
+ print(f"\n重新编码前5个值: {new_vec[:5]}")
+ print(f"余弦相似度: {similarity:.4f}")
+
+ if similarity < 0.8:
+ print("\n⚠️ 相似度过低,建议删除集合并重建索引")
+ else:
+ print("\n✅ 向量一致")
+else:
+ print("\n⚠️ 样本无文本内容")
diff --git a/rag_indexer/test/test_refactored.py b/rag_indexer/test/test_refactored.py
new file mode 100644
index 0000000..ca681d9
--- /dev/null
+++ b/rag_indexer/test/test_refactored.py
@@ -0,0 +1,83 @@
+#!/usr/bin/env python3
+"""
+测试重构后的 IndexBuilder 和 RAGRetriever
+"""
+
+import asyncio
+import os
+import sys
+
+# 添加项目根目录到 Python 路径
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
+
+from rag_indexer.IndexBuilder import IndexBuilder
+from rag_indexer.splitters import SplitterType
+
+async def test_index_builder():
+ """测试索引构建功能"""
+ print("测试索引构建功能...")
+
+ # 创建 IndexBuilder 实例
+ builder = IndexBuilder(
+ collection_name="test_collection",
+ splitter_type=SplitterType.PARENT_CHILD,
+ parent_chunk_size=1000,
+ child_chunk_size=200
+ )
+
+ # 测试文档路径
+ test_file = os.path.join(os.path.dirname(__file__), "..", "data", "corpus", "三国演义.txt")
+
+ if os.path.exists(test_file):
+ # 构建索引
+ print(f"正在为文件 {test_file} 构建索引...")
+ processed = await builder.build_from_file(test_file)
+ print(f"索引构建完成,处理了 {processed} 个文档")
+
+ # 获取集合信息
+ info = builder.get_collection_info()
+ print(f"集合信息: {info}")
+ else:
+ print(f"测试文件不存在: {test_file}")
+
+ # 测试搜索功能
+ print("\n测试搜索功能...")
+ try:
+ results = builder.search("吕布", k=3)
+ print(f"搜索结果数量: {len(results)}")
+ for i, result in enumerate(results):
+ print(f"\n结果 {i+1}:")
+ print(f"内容: {result.page_content[:100]}...")
+ except Exception as e:
+ print(f"搜索测试失败: {e}")
+
+ # 测试带父块上下文的搜索
+ print("\n测试带父块上下文的搜索...")
+ try:
+ results = await builder.search_with_parent_context("吕布", k=3)
+ print(f"搜索结果数量: {len(results)}")
+ for i, result in enumerate(results):
+ print(f"\n结果 {i+1}:")
+ print(f"内容: {result.page_content[:100]}...")
+ except Exception as e:
+ print(f"带父块上下文的搜索测试失败: {e}")
+
+ # 测试统一检索接口
+ print("\n测试统一检索接口...")
+ try:
+ # 返回父块
+ results_parent = await builder.retrieve("吕布", return_parent=True)
+ print(f"返回父块的结果数量: {len(results_parent)}")
+
+ # 返回子块
+ results_child = await builder.retrieve("吕布", return_parent=False)
+ print(f"返回子块的结果数量: {len(results_child)}")
+ except Exception as e:
+ print(f"统一检索接口测试失败: {e}")
+
+ # 关闭资源
+ builder.close()
+ print("\n测试完成")
+
+if __name__ == "__main__":
+ asyncio.run(test_index_builder())
\ No newline at end of file
diff --git a/rag_indexer/test/test_validate_index.py b/rag_indexer/test/test_validate_index.py
new file mode 100644
index 0000000..072cd90
--- /dev/null
+++ b/rag_indexer/test/test_validate_index.py
@@ -0,0 +1,188 @@
+"""
+验证 RAG 索引完整性。
+
+检查 Qdrant 向量库、PostgreSQL 文档存储及检索功能。
+"""
+
+import asyncio
+import os
+import sys
+
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../.."))
+
+from dotenv import load_dotenv
+load_dotenv()
+
+QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
+QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
+DB_URI = os.getenv("DB_URI", "postgresql://postgres:huang1998@115.190.121.151:5432/langgraph_db?sslmode=disable")
+COLLECTION_NAME = "rag_documents"
+TABLE_NAME = "parent_documents"
+
+
+def check_qdrant():
+ """检查 Qdrant 向量库。"""
+ from qdrant_client import QdrantClient
+
+ print("=" * 60)
+ print("Qdrant 向量库")
+ print("=" * 60)
+
+ client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
+
+ # 集合列表
+ collections = client.get_collections().collections
+ print(f"\n集合数: {len(collections)}")
+ for c in collections:
+ print(f" - {c.name}")
+
+ # 目标集合信息
+ if not any(c.name == COLLECTION_NAME for c in collections):
+ print(f"\n集合 '{COLLECTION_NAME}' 不存在")
+ return
+
+ info = client.get_collection(COLLECTION_NAME)
+ print(f"\n集合 '{COLLECTION_NAME}':")
+ print(f" 状态: {info.status}")
+ print(f" 向量数: {info.points_count}")
+
+ vectors_config = info.config.params.vectors
+ if isinstance(vectors_config, dict):
+ for name, vc in vectors_config.items():
+ print(f" 向量 '{name}': 维度={vc.size}, 距离={vc.distance}")
+ else:
+ print(f" 向量维度: {vectors_config.size}")
+
+ # 抽样查看
+ print(f"\n前 3 个向量:")
+ points = client.scroll(
+ collection_name=COLLECTION_NAME,
+ limit=3,
+ with_payload=True,
+ with_vectors=False
+ )
+ for i, point in enumerate(points[0]):
+ print(f"\n {i+1}. ID: {point.id}")
+ payload = point.payload or {}
+ print(f" 内容: {payload.get('page_content', '')[:100]}...")
+
+
+async def check_postgres():
+ """检查 PostgreSQL 文档存储。"""
+ import asyncpg
+
+ print("\n" + "=" * 60)
+ print("PostgreSQL 文档存储")
+ print("=" * 60)
+
+ conn = await asyncpg.connect(dsn=DB_URI)
+
+ try:
+ # 表是否存在
+ tables = await conn.fetch(
+ "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'"
+ )
+ table_names = [t['table_name'] for t in tables]
+
+ if TABLE_NAME not in table_names:
+ print(f"\n表 '{TABLE_NAME}' 不存在")
+ return
+
+ # 统计
+ count = await conn.fetchval(f"SELECT COUNT(*) FROM {TABLE_NAME}")
+ print(f"\n表 '{TABLE_NAME}': {count} 条记录")
+
+ # 抽样
+ print(f"\n前 3 个文档:")
+ rows = await conn.fetch(
+ f"SELECT key, value FROM {TABLE_NAME} ORDER BY key LIMIT 3"
+ )
+ for i, row in enumerate(rows):
+ print(f"\n {i+1}. Key: {row['key']}")
+ val = row['value']
+ if isinstance(val, dict) and 'page_content' in val:
+ print(f" 内容: {val['page_content'][:100]}...")
+
+ # Key 前缀分布
+ key_prefixes = await conn.fetch(
+ f"""
+ SELECT
+ CASE
+ WHEN key LIKE '%:%' THEN split_part(key, ':', 1)
+ ELSE 'no_prefix'
+ END AS prefix,
+ COUNT(*) AS cnt
+ FROM {TABLE_NAME}
+ GROUP BY prefix
+ ORDER BY cnt DESC
+ LIMIT 10
+ """
+ )
+ print(f"\nKey 前缀分布:")
+ for row in key_prefixes:
+ print(f" {row['prefix']}: {row['cnt']}")
+
+ finally:
+ await conn.close()
+
+
+async def test_search():
+ """测试检索功能。"""
+ from rag_indexer.IndexBuilder import IndexBuilder, IndexBuilderConfig
+ from rag_indexer.splitters import SplitterType
+
+ print("\n" + "=" * 60)
+ print("检索测试")
+ print("=" * 60)
+
+ # 使用配置对象初始化(与默认构建方式一致)
+ config = IndexBuilderConfig(
+ collection_name=COLLECTION_NAME,
+ splitter_type=SplitterType.PARENT_CHILD,
+ )
+ builder = IndexBuilder(config)
+
+ # 确保检索器已初始化
+ if builder.retriever is None:
+ print("错误: 检索器未初始化,请检查切分策略")
+ return
+
+ query = input("\n查询 (回车使用默认): ").strip() or "你好"
+ print(f"\n查询: {query}")
+
+ # 标准检索(返回父块,因为 ParentDocumentRetriever 默认返回父块)
+ print("\n--- 标准检索 (返回父块) ---")
+ results = await builder.retriever.ainvoke(query)
+ for i, doc in enumerate(results):
+ content = doc.page_content[:200] if hasattr(doc, 'page_content') else str(doc)[:200]
+ print(f"\n {i+1}. {content}...")
+ if hasattr(doc, 'metadata'):
+ source = doc.metadata.get('source', '')
+ if source:
+ print(f" 来源: {source}")
+
+ # 若需要仅返回子块,可以临时修改检索器的 search_type
+ # (注意:ParentDocumentRetriever 的 search_type 默认为 "similarity")
+ print("\n--- 检索子块 (通过修改检索器参数) ---")
+ # 创建一个新的检索器副本,设置为返回子块
+ # 简单起见,直接调用 vectorstore 进行相似度搜索获取子块
+ vectorstore = builder.vector_store.get_langchain_vectorstore()
+ sub_results = await vectorstore.asimilarity_search(query, k=3)
+ for i, doc in enumerate(sub_results):
+ content = doc.page_content[:200] if hasattr(doc, 'page_content') else str(doc)[:200]
+ print(f"\n {i+1}. {content}...")
+ if hasattr(doc, 'metadata'):
+ parent_id = doc.metadata.get('parent_id', '')
+ if parent_id:
+ print(f" 父块 ID: {parent_id}")
+
+
+async def main():
+ check_qdrant()
+ await check_postgres()
+ await test_search()
+
+
+if __name__ == "__main__":
+ asyncio.run(main())
\ No newline at end of file