Compare commits

...

2 Commits

Author SHA1 Message Date
c18e8a9860 向量数据库
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 32m6s
2026-04-18 16:56:23 +08:00
0470afce13 本地RAG尝试 2026-04-18 16:31:48 +08:00
23 changed files with 2708 additions and 4 deletions

2
.gitignore vendored
View File

@@ -13,6 +13,8 @@
!frontend/**
!scripts/
!scripts/**
!rag_indexer/
!rag_indexer/**
!docker/
!docker/**
!.gitea/

View File

@@ -6,14 +6,40 @@ AI Agent 服务类 - 支持多模型动态切换
import os
import json
from dotenv import load_dotenv
from langchain_community.chat_models import ChatZhipuAI
from langchain_openai import ChatOpenAI
try:
from langchain_community.chat_models import ChatZhipuAI
HAS_ZHIPUAI = True
except ImportError:
HAS_ZHIPUAI = False
ChatZhipuAI = None
try:
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
HAS_OPENAI = True
except ImportError:
HAS_OPENAI = False
ChatOpenAI = None
OpenAIEmbeddings = None
from pydantic import SecretStr
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
try:
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
HAS_POSTGRES_CHECKPOINT = True
except ImportError:
HAS_POSTGRES_CHECKPOINT = False
AsyncPostgresSaver = None
# 本地模块
from app.graph_builder import GraphBuilder, GraphContext
from app.tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
try:
from app.rag import RAGPipeline
from app.rag.tools import RAGTool
HAS_RAG = True
except ImportError as e:
HAS_RAG = False
RAGPipeline = None
RAGTool = None
from app.logger import debug, info, warning, error
@@ -31,9 +57,13 @@ class AIAgentService:
"""
self.checkpointer = checkpointer
self.graphs = {} # 存储不同模型对应的 graph 实例
self.rag = None # RAG 检索实例
self.rag_tool = None # RAG 工具实例
def _create_zhipu_llm(self):
"""创建智谱在线 LLM"""
if not HAS_ZHIPUAI:
raise ImportError("智谱AI支持不可用请安装langchain-community包")
api_key = os.getenv("ZHIPUAI_API_KEY")
if not api_key:
raise ValueError("ZHIPUAI_API_KEY not set in environment")
@@ -49,6 +79,8 @@ class AIAgentService:
def _create_deepseek_llm(self):
"""创建 DeepSeek LLM使用 OpenAI 兼容 API"""
if not HAS_OPENAI:
raise ImportError("OpenAI兼容支持不可用请安装langchain-openai包")
api_key = os.getenv("DEEPSEEK_API_KEY")
if not api_key:
raise ValueError("DEEPSEEK_API_KEY not set in environment")
@@ -65,6 +97,8 @@ class AIAgentService:
def _create_local_llm(self):
"""创建本地 vLLM 服务 LLM"""
if not HAS_OPENAI:
raise ImportError("OpenAI兼容支持不可用请安装langchain-openai包")
# vLLM 服务地址:优先从环境变量读取,适配 Docker、FRP 穿透和本地开发
vllm_base_url = os.getenv(
"VLLM_BASE_URL",
@@ -80,8 +114,39 @@ class AIAgentService:
streaming=True, # 确保开启流式输出
)
def _create_embeddings(self):
"""创建嵌入模型"""
if not HAS_OPENAI:
raise ImportError("OpenAI兼容支持不可用请安装langchain-openai包")
embedding_url = os.getenv(
"LLAMACPP_EMBEDDING_URL",
"http://127.0.0.1:8082/v1"
)
return OpenAIEmbeddings(
openai_api_base=embedding_url,
openai_api_key=os.getenv("LLAMACPP_API_KEY", "token-abc123"),
model="text-embedding-ada-002", # 模型名称不重要,兼容即可
)
async def initialize(self):
"""预编译所有模型的 graph使用传入的 checkpointer"""
# 先初始化 RAG 检索系统
if HAS_RAG and RAGPipeline is not None and RAGTool is not None:
try:
info("🔄 正在初始化 RAG 检索系统...")
embeddings = self._create_embeddings()
self.rag = RAGPipeline(embeddings=embeddings)
self.rag_tool = RAGTool(self.rag).get_tool()
info("✅ RAG 检索系统初始化成功")
except Exception as e:
warning(f"⚠️ RAG 检索系统初始化失败: {e}")
self.rag = None
self.rag_tool = None
else:
info("⏭️ RAG 检索系统不可用,跳过初始化")
self.rag = None
self.rag_tool = None
model_configs = {
"local": self._create_local_llm, # 本地模型作为第一个
"deepseek": self._create_deepseek_llm, # DeepSeek 作为中间
@@ -92,7 +157,16 @@ class AIAgentService:
try:
info(f"🔄 正在初始化模型 '{model_name}'...")
llm = llm_creator()
builder = GraphBuilder(llm, AVAILABLE_TOOLS, TOOLS_BY_NAME).build()
# 构建工具列表:基础工具 + RAG工具如果可用
tools = AVAILABLE_TOOLS.copy()
tools_by_name = TOOLS_BY_NAME.copy()
if self.rag_tool is not None:
tools.append(self.rag_tool)
tools_by_name[self.rag_tool.name] = self.rag_tool
builder = GraphBuilder(llm, tools, tools_by_name).build()
graph = builder.compile(checkpointer=self.checkpointer)
self.graphs[model_name] = graph
info(f"✅ 模型 '{model_name}' 初始化成功")

136
app/rag/README.md Normal file
View File

@@ -0,0 +1,136 @@
# 在线 RAG 检索与生成系统 (Online RAG Retriever)
该模块负责 RAG 系统的阶段二:**在线检索与生成**。它接收用户提问,从知识库中检索出上下文,利用各种高级策略去噪、融合,并作为增强上下文输入给大语言模型 (LLM)。
## 📊 RAG-Fusion & 混合检索流水线示意图
```mermaid
graph TD
User((用户提问)) --> A[LLM 查询改写生成器]
subgraph RAG-Fusion 核心流程
A -->|改写为问题 1| B1[查询 1]
A -->|改写为问题 2| B2[查询 2]
A -->|原问题| B3[原始查询]
B1 & B2 & B3 --> C[混合检索器 Hybrid Retriever <br> Dense Vector + BM25 Sparse]
C --> D[多路召回结果合集 N=60条]
D --> E{RRF 倒数排名融合去重}
end
E -->|筛选出前 20 条| F[Cross-Encoder 重排器 Reranker]
F -->|精细打分排序 Top 5| G[最终纯净上下文 Context]
G --> H[将 Context 与原问题拼接输入大模型]
H --> I((LLM 生成最终回答))
```
---
## 🎯 演进路线与算法详解 (Roadmap)
### Level 1: 基础向量搜索 (Basic Similarity Search)
- **核心算法**: 近似最近邻搜索 (ANN, 常用 HNSW 算法)。将用户问题转化为向量后,计算它与库中向量的余弦相似度 (Cosine Similarity),取距离最近的 K 个块。
- **优缺点**: 速度极快。但只能捕捉“语义相似”,如果用户搜索特定专有名词、编号、订单号,纯向量检索往往会失效(产生“幻觉”匹配)。
### Level 2: 混合检索与重排序 (Hybrid Search + Reranker)
混合检索旨在结合向量的“语义泛化”与关键词的“精准匹配”,随后利用重排序模型过滤噪声。
**1. 基础召回 (混合检索)**
- **核心原理**: 结合基于 HNSW 的 Dense Vector 相似度搜索与基于 TF-IDF 的 BM25 稀疏检索 (Sparse Vector)。
- **实现指南**: 使用 `langchain_qdrant` 包中的 `Qdrant` 类连接数据库。通过调用 `Qdrant.from_existing_collection(...)` 实例化向量库,并使用 `.as_retriever(search_kwargs={"k": 20})` 方法生成基础检索器。Qdrant 底层会自动处理双路召回。
**2. 二次精排 (Cross-Encoder)**
- **核心原理**: 不同于双塔模型(分别算向量再求距离),交叉编码器将“用户问题 + 检索到的单例文档”拼接后整体输入 Transformer 模型,由模型直接输出 0~1 的相关性得分,精度极高。
- **实现指南**:
- 使用 `sentence-transformers` 库加载本地轻量级重排模型(如 `BAAI/bge-reranker-base`)。
- 引入 `langchain.retrievers.document_compressors` 包中的 `CrossEncoderReranker` 类包装该模型,设置参数 `top_n=5`
- 最后,使用 `langchain.retrievers` 包中的 `ContextualCompressionRetriever` 类,将 `base_compressor` (重排器) 和 `base_retriever` (基础检索器) 组合。
- **如何调用**: 业务逻辑中直接对组合后的检索器调用 `.invoke(query)` 方法,即可一键完成“大范围召回 20 条 -> 逐一打分精排选 5 条”的去噪流水线。
### Level 3: RAG-Fusion (多路改写与倒数排名融合)
RAG-Fusion 通过大模型发散思维,将单一问题改写为多个相似问题,扩大搜索面,再利用数学统计算法合并结果。
**1. 多路查询改写**
- **核心原理**: 克服用户初始提问词不达意或视角受限的问题。
- **实现指南**: 导入 `langchain.retrievers.multi_query` 包中的 `MultiQueryRetriever` 类。需向其提供一个已实例化的 LLM 对象(如基于 `ChatOpenAI` 封装的本地 VLLM 模型)。系统在底层会自动 Prompt 模型,将原始 `query` 转化为包含 3-5 个不同表述的查询列表。
**2. 倒数排名融合 (RRF)**
- **核心原理**: RRF (Reciprocal Rank Fusion) 是一种无需评分归一化的融合算法。公式为 $RRF\_score(d) = \sum_{q \in Q} \frac{1}{k + rank_q(d)}$,有效避免某一极端检索结果主导全局。
- **实现指南**:
- 针对每个改写后的查询 $q$,分别调用精排检索器的 `.invoke(q)` 获取文档列表。
- 使用 `langchain.retrievers` 中的 `EnsembleRetriever` 类(原生支持 RRF或在代码中遍历收集到的 `Document` 对象,基于其排名 `rank` 累加得分,最终通过 Python 的 `set` 去重并提取 `doc.page_content`
### Level 4: Agentic RAG / Self-RAG (智能体与自我反思)
- **核心原理**: 基于 LangGraph 的 ReAct (Reasoning and Acting) 状态机路由。大模型并非每次都去死板地执行检索,而是先判断问题:“这是闲聊?还是需要查知识库?”。如果是后者,模型输出一个 `ToolCall` 指令,触发检索。
- **实现指南**: 请参考下方的**与现有系统整合调用**章节。
- **示意图**:
```mermaid
sequenceDiagram
participant User
participant LangGraph Agent
participant RAG_Tool
participant Qdrant
User->>LangGraph Agent: "公司报销流程是什么?"
LangGraph Agent->>LangGraph Agent: 思考: 这是一个内部规章问题,需要查资料
LangGraph Agent->>RAG_Tool: ToolCall(search_knowledge_base, "公司报销流程")
RAG_Tool->>Qdrant: RAG-Fusion & 混合检索
Qdrant-->>RAG_Tool: 原始分块
RAG_Tool->>RAG_Tool: Cross-Encoder 重排过滤
RAG_Tool-->>LangGraph Agent: 返回最相关的5条报销规定
LangGraph Agent->>LangGraph Agent: 思考: 资料充分,开始撰写回答
LangGraph Agent-->>User: "根据知识库规定报销流程分为以下3步..."
```
---
## 📦 所需依赖与安装
除了基础的 LangChain 包外,在线检索模块为了支持重排和稀疏检索,还需要安装:
```bash
# 用于 Cross-Encoder 重排序模型 (如 BAAI/bge-reranker-base)
pip install sentence-transformers
# 用于 BM25 关键词混合检索
pip install rank_bm25
# 基础框架
pip install langchain langchain-core langchain-openai langchain-qdrant
```
---
## 📂 架构与文件结构设计
在 `app/rag/` 目录下,需创建以下文件来模块化上述功能:
```text
app/rag/
├── __init__.py
├── retriever.py # 负责 Qdrant 的基础召回与 ContextualCompressionRetriever
├── reranker.py # 负责加载 sentence-transformers 交叉编码器
├── query_transform.py # 负责基于 MultiQueryRetriever 的改写逻辑
├── pipeline.py # 组合上述组件,暴露出核心的 retrieve() 方法
└── tools.py # 将 Pipeline 包装成 LangChain Tool 供 Agent 调用
```
---
## <20> 与现有系统整合调用 (Agentic RAG 实现)
基于目前 LangGraph 系统的架构,我们将摒弃将代码堆砌在一起的旧方式,而是利用 **LangChain Tools** 的特性将 RAG 优雅地注入系统:
1. **封装检索工具 (Tool)**:
从 `langchain.tools` 导入 `@tool` 装饰器。定义一个名为 `search_knowledge_base(query: str)` 的函数。在函数内部,实例化并调用我们在 `pipeline.py` 中写好的多路召回与重排逻辑。
2. **模型绑定 (Bind)**:
在 `app/agent.py` 或 `app/nodes/tool_call.py` 中,将这个工具引入,并通过 `llm.bind_tools([search_knowledge_base])` 绑定到现有的本地大模型实例上。
3. **状态机路由 (Graph Routing)**:
你的 LangGraph 状态机会像处理普通对话一样自动接管:当模型判断需要调用查阅规章制度或专业资料时,它会输出 `ToolCall` 消息,流转到 `tool_node` 执行上述的 RAG 检索逻辑并返回上下文。
这让你无需修改任何前端 Streamlit 流式代码,就能平滑升级为具备超级知识库检索能力的智能体 (Agent)

22
app/rag/__init__.py Normal file
View File

@@ -0,0 +1,22 @@
"""
在线 RAG 检索与生成系统
提供高级RAG检索功能支持混合检索、重排序、RAG-Fusion和多路查询改写。
"""
from .pipeline import RAGPipeline
from .retriever import create_hybrid_retriever, create_base_retriever
from .reranker import CrossEncoderReranker
from .query_transform import MultiQueryTransformer
from .tools import search_knowledge_base_tool
__all__ = [
"RAGPipeline",
"create_hybrid_retriever",
"create_base_retriever",
"CrossEncoderReranker",
"MultiQueryTransformer",
"search_knowledge_base_tool",
]
__version__ = "0.1.0"

232
app/rag/example.py Normal file
View File

@@ -0,0 +1,232 @@
#!/usr/bin/env python3
"""
RAG 系统使用示例
演示如何使用 app/rag 模块进行知识检索。
"""
import sys
import os
# 添加项目根目录到路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
from langchain_openai import OpenAIEmbeddings
from langchain_community.llms import VLLMOpenAI
def setup_environment():
"""设置环境变量"""
# 设置 Qdrant 连接信息(根据实际情况修改)
os.environ.setdefault("QDRANT_URL", "http://115.190.121.151:6333")
# 如果需要 API 密钥,请设置 QDRANT_API_KEY
print("环境变量已设置")
print(f"QDRANT_URL: {os.environ.get('QDRANT_URL')}")
def demonstrate_basic_rag():
"""演示基础 RAG 功能"""
print("\n" + "="*60)
print("演示: 基础 RAG 检索 (Level 1)")
print("="*60)
# 创建嵌入模型(使用 OpenAI 兼容的本地模型)
embeddings = OpenAIEmbeddings(
openai_api_base="http://localhost:8000/v1", # 本地 VLLM 服务
openai_api_key="no-key-needed",
model="text-embedding-ada-002", # 假设的模型名称
)
# 创建 RAG 流水线
from app.rag import RAGPipeline, RAGConfig, RAGLevel
config = RAGConfig(
collection_name="documents", # 你的集合名称
rag_level=RAGLevel.BASIC,
verbose=True,
)
pipeline = RAGPipeline(
embeddings=embeddings,
config=config,
)
# 示例查询
query = "公司报销流程是什么?"
print(f"\n查询: {query}")
try:
result = pipeline.retrieve(query)
print(f"找到 {len(result.documents)} 个相关文档")
# 格式化上下文
context = pipeline.format_context(result.documents)
print(f"\n上下文预览:\n{context[:500]}...")
except Exception as e:
print(f"检索失败: {e}")
print("请确保 Qdrant 服务正常运行且集合存在")
def demonstrate_hybrid_rag():
"""演示混合 RAG 功能"""
print("\n" + "="*60)
print("演示: 混合 RAG 检索 (Level 2)")
print("="*60)
embeddings = OpenAIEmbeddings(
openai_api_base="http://localhost:8000/v1",
openai_api_key="no-key-needed",
model="text-embedding-ada-002",
)
from app.rag import RAGPipeline, RAGConfig, RAGLevel
config = RAGConfig(
collection_name="documents",
rag_level=RAGLevel.HYBRID,
dense_k=10,
sparse_k=10,
rerank_top_n=5,
verbose=True,
)
pipeline = RAGPipeline(
embeddings=embeddings,
config=config,
)
query = "如何申请年假?"
print(f"\n查询: {query}")
try:
result = pipeline.retrieve(query)
print(f"找到 {len(result.documents)} 个重排序后的文档")
except Exception as e:
print(f"检索失败: {e}")
def demonstrate_rag_fusion():
"""演示 RAG-Fusion 功能"""
print("\n" + "="*60)
print("演示: RAG-Fusion (Level 3)")
print("="*60)
embeddings = OpenAIEmbeddings(
openai_api_base="http://localhost:8000/v1",
openai_api_key="no-key-needed",
model="text-embedding-ada-002",
)
# 创建语言模型用于查询改写
llm = VLLMOpenAI(
openai_api_base="http://localhost:8000/v1",
openai_api_key="no-key-needed",
model_name="Qwen2.5-7B-Instruct", # 你的本地模型
temperature=0.3,
max_tokens=512,
)
from app.rag import RAGPipeline, RAGConfig, RAGLevel
config = RAGConfig(
collection_name="documents",
rag_level=RAGLevel.FUSION,
num_queries=3,
verbose=True,
)
pipeline = RAGPipeline(
embeddings=embeddings,
llm=llm,
config=config,
)
query = "项目上线需要哪些审批?"
print(f"\n查询: {query}")
try:
result = pipeline.retrieve(query)
print(f"找到 {len(result.documents)} 个文档 (经过多路查询改写和重排序)")
except Exception as e:
print(f"检索失败: {e}")
def demonstrate_agentic_rag():
"""演示 Agentic RAG 功能"""
print("\n" + "="*60)
print("演示: Agentic RAG (Level 4)")
print("="*60)
embeddings = OpenAIEmbeddings(
openai_api_base="http://localhost:8000/v1",
openai_api_key="no-key-needed",
model="text-embedding-ada-002",
)
llm = VLLMOpenAI(
openai_api_base="http://localhost:8000/v1",
openai_api_key="no-key-needed",
model_name="Qwen2.5-7B-Instruct",
temperature=0.3,
max_tokens=512,
)
from app.rag import create_agentic_rag_pipeline
try:
# 创建 Agentic RAG 流水线
agentic_rag = create_agentic_rag_pipeline(
embeddings=embeddings,
agent_llm=llm,
config={
"collection_name": "documents",
"verbose": True,
},
)
print("Agentic RAG 流水线创建成功!")
print(f"- 绑定的模型: {agentic_rag['llm']}")
print(f"- RAG 工具: {agentic_rag['tool'].name}")
# 演示工具调用
print("\n工具调用示例:")
response = agentic_rag["tool"].invoke({"query": "员工福利有哪些?"})
print(f"工具响应预览: {response[:200]}...")
except Exception as e:
print(f"创建 Agentic RAG 失败: {e}")
import traceback
traceback.print_exc()
def main():
"""主函数"""
print("RAG 系统演示")
print("="*60)
# 设置环境
setup_environment()
# 演示各级功能
demonstrate_basic_rag()
demonstrate_hybrid_rag()
demonstrate_rag_fusion()
demonstrate_agentic_rag()
print("\n" + "="*60)
print("演示完成!")
print("="*60)
print("\n使用说明:")
print("1. 确保 Qdrant 服务运行且集合已创建")
print("2. 根据需要修改 embeddings 和 llm 配置")
print("3. 在 Agent 系统中导入并使用 app.rag.tools.search_knowledge_base_tool")
print("4. 将工具绑定到你的 Agent 模型")
if __name__ == "__main__":
main()

341
app/rag/pipeline.py Normal file
View File

@@ -0,0 +1,341 @@
"""
RAG 检索流水线
组合检索器、重排序器、查询改写器等组件,提供完整的 RAG 检索功能。
"""
import time
from typing import List, Dict, Any, Optional, Union
from dataclasses import dataclass, field
from enum import Enum
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.language_models import BaseLanguageModel
from .retriever import (
create_base_retriever,
create_hybrid_retriever,
create_ensemble_retriever,
create_qdrant_client,
)
from .reranker import CrossEncoderReranker
from .query_transform import MultiQueryTransformer, create_rag_fusion_pipeline
class RAGLevel(Enum):
"""RAG 功能级别"""
BASIC = 1 # 基础向量搜索
HYBRID = 2 # 混合检索 + 重排序
FUSION = 3 # RAG-Fusion
AGENTIC = 4 # Agentic RAG
@dataclass
class RAGConfig:
"""RAG 配置"""
# Qdrant 配置
collection_name: str = "documents"
qdrant_url: Optional[str] = None
qdrant_api_key: Optional[str] = None
# 检索配置
rag_level: RAGLevel = RAGLevel.FUSION
dense_k: int = 10 # 向量检索数量
sparse_k: int = 10 # BM25 检索数量
total_k: int = 20 # 总检索数量
rerank_top_n: int = 5 # 重排序返回数量
# 查询改写配置
num_queries: int = 3 # RAG-Fusion 查询数量
# 模型配置
reranker_model: str = "BAAI/bge-reranker-base"
device: Optional[str] = None
# 性能配置
enable_cache: bool = True
verbose: bool = True
@dataclass
class RetrievalResult:
"""检索结果"""
documents: List[Document]
query_time: float
level: RAGLevel
metadata: Dict[str, Any] = field(default_factory=dict)
class RAGPipeline:
"""
RAG 检索流水线
支持从 Level 1 到 Level 4 的所有功能。
"""
def __init__(
self,
embeddings: Embeddings,
llm: Optional[BaseLanguageModel] = None,
config: Optional[RAGConfig] = None,
):
"""
初始化 RAG 流水线
Args:
embeddings: 嵌入模型
llm: 语言模型用于查询改写Level 3+ 需要)
config: 配置
"""
self.embeddings = embeddings
self.llm = llm
self.config = config or RAGConfig()
# 初始化组件
self._client = None
self._reranker = None
self._query_transformer = None
self._retriever = None
# 缓存
self._cache = {}
def _get_client(self):
"""获取 Qdrant 客户端"""
if self._client is None:
self._client = create_qdrant_client(
url=self.config.qdrant_url,
api_key=self.config.qdrant_api_key,
)
return self._client
def _get_reranker(self):
"""获取重排序器"""
if self._reranker is None:
self._reranker = CrossEncoderReranker(
model_name=self.config.reranker_model,
top_n=self.config.rerank_top_n,
device=self.config.device,
)
return self._reranker
def _get_query_transformer(self):
"""获取查询改写器"""
if self._query_transformer is None and self.llm is not None:
self._query_transformer = MultiQueryTransformer(
llm=self.llm,
num_queries=self.config.num_queries,
)
return self._query_transformer
def _create_basic_retriever(self):
"""创建基础检索器Level 1"""
return create_base_retriever(
collection_name=self.config.collection_name,
embeddings=self.embeddings,
search_kwargs={"k": self.config.total_k},
client=self._get_client(),
)
def _create_hybrid_retriever(self):
"""创建混合检索器Level 2"""
base_retriever = create_hybrid_retriever(
collection_name=self.config.collection_name,
embeddings=self.embeddings,
dense_k=self.config.dense_k,
sparse_k=self.config.sparse_k,
client=self._get_client(),
)
# 应用重排序
reranker = self._get_reranker()
return reranker.create_contextual_compression_retriever(base_retriever)
def _create_fusion_retriever(self):
"""创建 RAG-Fusion 检索器Level 3"""
if self.llm is None:
raise ValueError("Level 3 (RAG-Fusion) 需要语言模型进行查询改写")
# 创建基础混合检索器
base_retriever = create_hybrid_retriever(
collection_name=self.config.collection_name,
embeddings=self.embeddings,
dense_k=self.config.dense_k,
sparse_k=self.config.sparse_k,
client=self._get_client(),
)
# 创建 RAG-Fusion 流水线
reranker = self._get_reranker()
return create_rag_fusion_pipeline(
base_retriever=base_retriever,
llm=self.llm,
reranker=reranker,
num_queries=self.config.num_queries,
)
def _get_retriever(self):
"""根据配置级别获取检索器"""
if self._retriever is None:
if self.config.rag_level == RAGLevel.BASIC:
self._retriever = self._create_basic_retriever()
elif self.config.rag_level == RAGLevel.HYBRID:
self._retriever = self._create_hybrid_retriever()
elif self.config.rag_level == RAGLevel.FUSION:
self._retriever = self._create_fusion_retriever()
elif self.config.rag_level == RAGLevel.AGENTIC:
# Agentic RAG 使用 Fusion 作为基础,在 tools.py 中包装
self._retriever = self._create_fusion_retriever()
else:
raise ValueError(f"不支持的 RAG 级别: {self.config.rag_level}")
return self._retriever
def retrieve(
self,
query: str,
use_cache: Optional[bool] = None,
**kwargs,
) -> RetrievalResult:
"""
执行检索
Args:
query: 查询文本
use_cache: 是否使用缓存
**kwargs: 额外参数
Returns:
检索结果
"""
start_time = time.time()
# 检查缓存
if use_cache is None:
use_cache = self.config.enable_cache
cache_key = f"{query}:{self.config.rag_level.value}"
if use_cache and cache_key in self._cache:
if self.config.verbose:
print(f"使用缓存结果: {query}")
return self._cache[cache_key]
# 获取检索器并执行检索
retriever = self._get_retriever()
documents = retriever.invoke(query, **kwargs)
# 计算查询时间
query_time = time.time() - start_time
# 创建结果
result = RetrievalResult(
documents=documents,
query_time=query_time,
level=self.config.rag_level,
metadata={
"query": query,
"collection": self.config.collection_name,
"doc_count": len(documents),
},
)
# 缓存结果
if use_cache:
self._cache[cache_key] = result
if self.config.verbose:
print(f"检索完成: {len(documents)} 文档, 耗时: {query_time:.2f}s")
return result
def format_context(
self,
documents: List[Document],
max_length: Optional[int] = None,
) -> str:
"""
格式化检索到的文档为上下文文本
Args:
documents: 文档列表
max_length: 最大长度(字符数)
Returns:
格式化后的上下文文本
"""
context_parts = []
total_length = 0
for i, doc in enumerate(documents):
# 提取内容和元数据
content = doc.page_content.strip()
metadata = doc.metadata
# 格式化文档
doc_text = f"[文档 {i+1}]\n"
if metadata.get("source"):
doc_text += f"来源: {metadata['source']}\n"
if metadata.get("page"):
doc_text += f"页码: {metadata['page']}\n"
doc_text += f"内容: {content}\n\n"
# 检查长度限制
if max_length is not None:
if total_length + len(doc_text) > max_length:
# 如果添加这个文档会超限,则截断并添加说明
remaining = max_length - total_length
if remaining > 100: # 至少保留100字符
doc_text = doc_text[:remaining] + "...\n\n[内容已截断]"
context_parts.append(doc_text)
break
else:
break
context_parts.append(doc_text)
total_length += len(doc_text)
return "".join(context_parts).strip()
def clear_cache(self):
"""清空缓存"""
self._cache.clear()
@classmethod
def create_from_config(
cls,
embeddings: Embeddings,
llm: Optional[BaseLanguageModel] = None,
config_dict: Optional[Dict[str, Any]] = None,
) -> "RAGPipeline":
"""
从配置字典创建流水线
Args:
embeddings: 嵌入模型
llm: 语言模型
config_dict: 配置字典
Returns:
RAGPipeline 实例
"""
config_dict = config_dict or {}
# 创建配置对象
config = RAGConfig(
collection_name=config_dict.get("collection_name", "documents"),
qdrant_url=config_dict.get("qdrant_url"),
qdrant_api_key=config_dict.get("qdrant_api_key"),
rag_level=RAGLevel(config_dict.get("rag_level", RAGLevel.FUSION.value)),
dense_k=config_dict.get("dense_k", 10),
sparse_k=config_dict.get("sparse_k", 10),
total_k=config_dict.get("total_k", 20),
rerank_top_n=config_dict.get("rerank_top_n", 5),
num_queries=config_dict.get("num_queries", 3),
reranker_model=config_dict.get("reranker_model", "BAAI/bge-reranker-base"),
device=config_dict.get("device"),
enable_cache=config_dict.get("enable_cache", True),
verbose=config_dict.get("verbose", True),
)
return cls(embeddings=embeddings, llm=llm, config=config)

193
app/rag/query_transform.py Normal file
View File

@@ -0,0 +1,193 @@
"""
查询改写器
基于 MultiQueryRetriever 实现多路查询改写,扩大搜索范围。
"""
from typing import List, Optional, Any
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
class MultiQueryTransformer:
"""
多路查询改写器
将单个查询改写成多个相关查询,用于 RAG-Fusion。
"""
def __init__(
self,
llm: BaseLanguageModel,
num_queries: int = 3,
prompt_template: Optional[str] = None,
):
"""
初始化查询改写器
Args:
llm: 语言模型实例
num_queries: 生成的查询数量
prompt_template: 提示词模板
"""
self.llm = llm
self.num_queries = num_queries
# 默认提示词模板
self.prompt_template = prompt_template or """
你是一个专业的查询改写助手。你的任务是将用户的问题改写成 {num_queries} 个不同的版本。
这些版本应该从不同的角度、使用不同的关键词来表达相同或相关的意图。
原始问题: {question}
请生成 {num_queries} 个不同版本的查询,每个版本一行。
确保每个版本都是独立、完整的查询语句。
生成 {num_queries} 个查询:
"""
def transform_query(self, query: str) -> List[str]:
"""
将单个查询改写成多个查询
Args:
query: 原始查询
Returns:
改写后的查询列表
"""
prompt = PromptTemplate.from_template(self.prompt_template)
chain = prompt | self.llm | StrOutputParser()
response = chain.invoke({
"question": query,
"num_queries": self.num_queries,
})
# 解析响应,每行一个查询
queries = [
q.strip()
for q in response.strip().split('\n')
if q.strip()
]
# 确保数量正确,如果不够则添加原始查询
if len(queries) < self.num_queries:
queries.extend([query] * (self.num_queries - len(queries)))
elif len(queries) > self.num_queries:
queries = queries[:self.num_queries]
# 确保包含原始查询
if query not in queries:
queries = [query] + queries[:self.num_queries-1]
return queries
def create_multi_query_retriever(
self,
base_retriever: Any,
include_original: bool = True,
) -> MultiQueryRetriever:
"""
创建多路查询检索器
Args:
base_retriever: 基础检索器
include_original: 是否包含原始查询
Returns:
MultiQueryRetriever 实例
"""
retriever = MultiQueryRetriever.from_llm(
retriever=base_retriever,
llm=self.llm,
include_original=include_original,
)
# 设置生成的查询数量
retriever.llm_chain.prompt = PromptTemplate.from_template(
"你是一个专业的查询改写助手。你的任务是将用户的问题改写成 {num_queries} 个不同的版本。\n"
"这些版本应该从不同的角度、使用不同的关键词来表达相同或相关的意图。\n\n"
"原始问题: {question}\n\n"
"请生成 {num_queries} 个不同版本的查询,每个版本一行。\n"
"确保每个版本都是独立、完整的查询语句。\n\n"
"生成 {num_queries} 个查询:"
)
# 修改调用参数以包含 num_queries
original_invoke = retriever.llm_chain.invoke
def new_invoke(input_dict):
input_dict["num_queries"] = self.num_queries
return original_invoke(input_dict)
retriever.llm_chain.invoke = new_invoke
return retriever
@classmethod
def create_from_config(
cls,
llm: BaseLanguageModel,
config: Optional[dict] = None,
) -> "MultiQueryTransformer":
"""
从配置创建查询改写器
Args:
llm: 语言模型实例
config: 配置字典
Returns:
MultiQueryTransformer 实例
"""
config = config or {}
return cls(
llm=llm,
num_queries=config.get("num_queries", 3),
prompt_template=config.get("prompt_template", None),
)
def create_rag_fusion_pipeline(
base_retriever: Any,
llm: BaseLanguageModel,
reranker: Optional[Any] = None,
num_queries: int = 3,
) -> Any:
"""
创建完整的 RAG-Fusion 流水线
Args:
base_retriever: 基础检索器
llm: 语言模型(用于查询改写)
reranker: 重排序器(可选)
num_queries: 查询改写数量
Returns:
检索器实例
"""
# 创建多路查询改写器
query_transformer = MultiQueryTransformer(
llm=llm,
num_queries=num_queries,
)
# 创建多路查询检索器
multi_query_retriever = query_transformer.create_multi_query_retriever(
base_retriever=base_retriever,
include_original=True,
)
# 如果提供了重排序器,则应用重排序
if reranker is not None:
from langchain.retrievers import ContextualCompressionRetriever
return ContextualCompressionRetriever(
base_compressor=reranker,
base_retriever=multi_query_retriever,
)
return multi_query_retriever

23
app/rag/requirements.txt Normal file
View File

@@ -0,0 +1,23 @@
# RAG 系统依赖
# 基础框架
langchain>=0.1.0
langchain-core>=0.1.0
langchain-openai>=0.0.1
langchain-qdrant>=0.1.0
# 用于 Cross-Encoder 重排序模型
sentence-transformers>=2.2.0
# 用于 BM25 关键词混合检索
rank-bm25>=0.2.2
# Qdrant 客户端
qdrant-client>=1.6.0
# 可选的本地模型支持
# vllm>=0.5.0 # 如果需要本地模型推理
# transformers>=4.35.0 # 如果需要其他模型支持
# 开发依赖(测试用)
pytest>=7.0.0
pytest-asyncio>=0.21.0

141
app/rag/reranker.py Normal file
View File

@@ -0,0 +1,141 @@
"""
Cross-Encoder 重排序器
使用 sentence-transformers 加载交叉编码器模型,对检索结果进行精排。
"""
import os
from typing import List, Dict, Any, Optional
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain_core.documents import Document
from sentence_transformers import CrossEncoder
class CrossEncoderReranker:
"""
Cross-Encoder 重排序器包装类
支持 BAAI/bge-reranker-base 等模型。
"""
def __init__(
self,
model_name: str = "BAAI/bge-reranker-base",
top_n: int = 5,
device: Optional[str] = None,
cache_folder: Optional[str] = None,
):
"""
初始化重排序器
Args:
model_name: 模型名称或路径
top_n: 返回的顶部文档数量
device: 设备cpu/cuda如果为 None 则自动选择
cache_folder: 模型缓存目录
"""
self.model_name = model_name
self.top_n = top_n
self.device = device
self.cache_folder = cache_folder or os.path.join(
os.path.expanduser("~"), ".cache", "sentence_transformers"
)
# 延迟加载模型
self._model = None
self._langchain_reranker = None
def _load_model(self):
"""加载交叉编码器模型"""
if self._model is None:
try:
self._model = CrossEncoder(
self.model_name,
device=self.device,
cache_folder=self.cache_folder,
)
except Exception as e:
# 如果指定模型加载失败,尝试备用模型
print(f"加载模型 {self.model_name} 失败: {e}")
print("尝试加载备用模型 BAAI/bge-reranker-v2-m3...")
self._model = CrossEncoder(
"BAAI/bge-reranker-v2-m3",
device=self.device,
cache_folder=self.cache_folder,
)
def _create_langchain_reranker(self):
"""创建 LangChain 重排序器"""
if self._langchain_reranker is None:
self._load_model()
self._langchain_reranker = CrossEncoderReranker(
model=self._model,
top_n=self.top_n,
)
def rerank(
self,
query: str,
documents: List[Document],
) -> List[Document]:
"""
对文档进行重排序
Args:
query: 查询文本
documents: 待排序文档列表
Returns:
重排序后的文档列表
"""
self._create_langchain_reranker()
return self._langchain_reranker.compress_documents(
documents=documents,
query=query,
)
def create_contextual_compression_retriever(
self,
base_retriever: Any,
) -> Any:
"""
创建上下文压缩检索器
Args:
base_retriever: 基础检索器
Returns:
上下文压缩检索器
"""
from langchain.retrievers import ContextualCompressionRetriever
self._create_langchain_reranker()
compression_retriever = ContextualCompressionRetriever(
base_compressor=self._langchain_reranker,
base_retriever=base_retriever,
)
return compression_retriever
@classmethod
def create_from_config(
cls,
config: Optional[Dict[str, Any]] = None,
) -> "CrossEncoderReranker":
"""
从配置创建重排序器
Args:
config: 配置字典,包含 model_name, top_n, device 等
Returns:
CrossEncoderReranker 实例
"""
config = config or {}
return cls(
model_name=config.get("model_name", "BAAI/bge-reranker-base"),
top_n=config.get("top_n", 5),
device=config.get("device", None),
cache_folder=config.get("cache_folder", None),
)

144
app/rag/retriever.py Normal file
View File

@@ -0,0 +1,144 @@
"""
Qdrant 向量检索器
提供基础向量检索、混合检索Dense + BM25功能。
"""
import os
from typing import List, Dict, Any, Optional
from langchain_qdrant import Qdrant
from langchain.embeddings.base import Embeddings
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import DocumentCompressorPipeline
from langchain.retrievers import EnsembleRetriever
from qdrant_client import QdrantClient
from qdrant_client.http import models
def create_qdrant_client(
url: Optional[str] = None,
api_key: Optional[str] = None,
) -> QdrantClient:
"""
创建 Qdrant 客户端
Args:
url: Qdrant 服务地址,默认从环境变量 QDRANT_URL 读取
api_key: API 密钥,默认从环境变量 QDRANT_API_KEY 读取
Returns:
QdrantClient 实例
"""
url = url or os.getenv("QDRANT_URL", "http://localhost:6333")
api_key = api_key or os.getenv("QDRANT_API_KEY")
client_args = {"url": url}
if api_key:
client_args["api_key"] = api_key
return QdrantClient(**client_args)
def create_base_retriever(
collection_name: str,
embeddings: Embeddings,
search_kwargs: Optional[Dict[str, Any]] = None,
client: Optional[QdrantClient] = None,
) -> Qdrant:
"""
创建基础向量检索器
Args:
collection_name: Qdrant 集合名称
embeddings: 嵌入模型
search_kwargs: 搜索参数,默认 {"k": 20}
client: Qdrant 客户端,如果为 None 则自动创建
Returns:
Qdrant 检索器实例
"""
if client is None:
client = create_qdrant_client()
search_kwargs = search_kwargs or {"k": 20}
# 创建 Qdrant 检索器
retriever = Qdrant.from_existing_collection(
embedding=embeddings,
collection_name=collection_name,
client=client,
content_payload_key="content", # 假设存储的文本字段名为 "content"
metadata_payload_key="metadata", # 元数据字段名
)
return retriever.as_retriever(search_kwargs=search_kwargs)
def create_hybrid_retriever(
collection_name: str,
embeddings: Embeddings,
dense_k: int = 10,
sparse_k: int = 10,
client: Optional[QdrantClient] = None,
) -> ContextualCompressionRetriever:
"""
创建混合检索器Dense Vector + BM25
Args:
collection_name: Qdrant 集合名称
embeddings: 嵌入模型
dense_k: 向量检索返回数量
sparse_k: BM25 检索返回数量
client: Qdrant 客户端
Returns:
混合检索器
"""
if client is None:
client = create_qdrant_client()
# 基础检索器Qdrant 支持混合检索)
base_retriever = Qdrant.from_existing_collection(
embedding=embeddings,
collection_name=collection_name,
client=client,
content_payload_key="content",
metadata_payload_key="metadata",
)
# 配置混合检索参数
search_kwargs = {
"k": dense_k + sparse_k, # 总返回数量
"score_threshold": 0.3, # 相似度阈值
}
return base_retriever.as_retriever(search_kwargs=search_kwargs)
def create_ensemble_retriever(
retrievers: List[Any],
weights: Optional[List[float]] = None,
c: int = 60,
) -> EnsembleRetriever:
"""
创建集成检索器,支持倒数排名融合 (RRF)
Args:
retrievers: 检索器列表
weights: 检索器权重
c: RRF 常数通常为60
Returns:
集成检索器
"""
if weights is None:
weights = [1.0 / len(retrievers)] * len(retrievers)
ensemble = EnsembleRetriever(
retrievers=retrievers,
weights=weights,
c=c,
search_type="rrf", # 使用倒数排名融合
)
return ensemble

230
app/rag/tools.py Normal file
View File

@@ -0,0 +1,230 @@
"""
RAG 工具包装
将 RAG 流水线包装成 LangChain Tool供 Agent 调用。
"""
from typing import Optional, Dict, Any
from langchain.tools import tool
from langchain_core.tools import Tool
from langchain_core.embeddings import Embeddings
from langchain_core.language_models import BaseLanguageModel
from .pipeline import RAGPipeline, RAGConfig, RAGLevel
class RAGTool:
"""
RAG 工具包装器
将 RAG 流水线包装成 Agent 可调用的工具。
"""
def __init__(
self,
pipeline: RAGPipeline,
tool_name: str = "search_knowledge_base",
tool_description: str = None,
):
"""
初始化 RAG 工具
Args:
pipeline: RAG 流水线实例
tool_name: 工具名称
tool_description: 工具描述
"""
self.pipeline = pipeline
self.tool_name = tool_name
# 默认工具描述
self.tool_description = tool_description or (
"在知识库中搜索相关信息。当用户询问需要查阅文档、规章制度、"
"专业知识或需要基于已知信息回答的问题时使用此工具。"
"输入应为要搜索的查询文本。"
)
# 创建 LangChain 工具
self._tool = self._create_tool()
def _create_tool(self) -> Tool:
"""创建 LangChain 工具"""
@tool(self.tool_name, args_schema=None)
def search_knowledge_base(query: str) -> str:
"""
在知识库中搜索相关信息
Args:
query: 搜索查询
Returns:
格式化后的搜索结果
"""
try:
# 执行检索
result = self.pipeline.retrieve(query)
if not result.documents:
return "在知识库中未找到相关信息。"
# 格式化上下文
context = self.pipeline.format_context(
result.documents,
max_length=4000, # 限制上下文长度
)
# 构建响应
response = (
f"🔍 在知识库中找到了 {len(result.documents)} 条相关信息:\n\n"
f"{context}\n\n"
f"⏱️ 检索耗时: {result.query_time:.2f}"
)
return response
except Exception as e:
error_msg = f"检索过程中发生错误: {str(e)}"
if self.pipeline.config.verbose:
print(f"RAG 工具错误: {error_msg}")
return error_msg
# 设置工具描述
search_knowledge_base.description = self.tool_description
return search_knowledge_base
def get_tool(self) -> Tool:
"""获取 LangChain 工具"""
return self._tool
def __call__(self, query: str) -> str:
"""直接调用工具"""
return self._tool.invoke({"query": query})
def create_rag_tool(
embeddings: Embeddings,
llm: Optional[BaseLanguageModel] = None,
config: Optional[Dict[str, Any]] = None,
tool_name: str = "search_knowledge_base",
tool_description: Optional[str] = None,
) -> Tool:
"""
创建 RAG 工具(便捷函数)
Args:
embeddings: 嵌入模型
llm: 语言模型(用于高级 RAG 功能)
config: RAG 配置字典
tool_name: 工具名称
tool_description: 工具描述
Returns:
LangChain Tool 实例
"""
# 创建 RAG 流水线
pipeline = RAGPipeline.create_from_config(
embeddings=embeddings,
llm=llm,
config_dict=config,
)
# 创建工具包装器
rag_tool = RAGTool(
pipeline=pipeline,
tool_name=tool_name,
tool_description=tool_description,
)
return rag_tool.get_tool()
# 导出便捷函数
search_knowledge_base_tool = create_rag_tool
def bind_rag_to_agent(
agent_llm: BaseLanguageModel,
embeddings: Embeddings,
rag_llm: Optional[BaseLanguageModel] = None,
config: Optional[Dict[str, Any]] = None,
tool_name: str = "search_knowledge_base",
) -> BaseLanguageModel:
"""
将 RAG 工具绑定到 Agent 模型
Args:
agent_llm: Agent 使用的语言模型
embeddings: 嵌入模型
rag_llm: RAG 流水线使用的语言模型(如果与 agent_llm 不同)
config: RAG 配置
tool_name: 工具名称
Returns:
绑定工具后的模型
"""
# 如果未指定 RAG LLM使用 Agent LLM
if rag_llm is None:
rag_llm = agent_llm
# 创建 RAG 工具
rag_tool = create_rag_tool(
embeddings=embeddings,
llm=rag_llm,
config=config,
tool_name=tool_name,
)
# 绑定工具到模型
return agent_llm.bind_tools([rag_tool])
def create_agentic_rag_pipeline(
embeddings: Embeddings,
agent_llm: BaseLanguageModel,
rag_llm: Optional[BaseLanguageModel] = None,
config: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""
创建完整的 Agentic RAG 流水线Level 4
Args:
embeddings: 嵌入模型
agent_llm: Agent 模型
rag_llm: RAG 专用模型
config: 配置
Returns:
包含模型和工具的字典
"""
# 配置 Agentic RAG 级别
if config is None:
config = {}
config["rag_level"] = RAGLevel.AGENTIC.value
# 创建 RAG 工具
rag_tool = create_rag_tool(
embeddings=embeddings,
llm=rag_llm or agent_llm,
config=config,
tool_name="search_knowledge_base",
tool_description=(
"在知识库中搜索相关信息。当用户询问需要查阅文档、规章制度、"
"专业知识或需要基于已知信息回答的问题时使用此工具。"
"Agent 应该先判断是否需要使用此工具,然后调用它获取上下文。"
),
)
# 绑定工具到模型
bound_llm = agent_llm.bind_tools([rag_tool])
return {
"llm": bound_llm,
"tool": rag_tool,
"pipeline": RAGPipeline.create_from_config(
embeddings=embeddings,
llm=rag_llm or agent_llm,
config_dict=config,
),
}

View File

@@ -119,6 +119,7 @@ def _handle_ai_response():
api_thought = ""
display_text = ""
display_thought = ""
rag_sources = None # 存储 RAG 检索来源信息
# 调用流式 API
stream = api_client.chat_stream(
@@ -213,6 +214,25 @@ def _handle_ai_response():
last_msg = messages_update[-1] if messages_update else {}
if isinstance(last_msg, dict) and last_msg.get("role") == "tool":
tool_name = last_msg.get("name", "unknown")
tool_content = last_msg.get("content", "")
# 存储 RAG 检索结果
if tool_name == "search_knowledge_base":
# 尝试解析 tool_content它可能是 JSON 字符串
sources = []
try:
if isinstance(tool_content, str):
import json
data = json.loads(tool_content)
else:
data = tool_content
# 提取来源列表
if isinstance(data, dict) and "sources" in data:
sources = data["sources"]
else:
sources = [str(data)]
except Exception:
sources = [str(tool_content)]
rag_sources = sources
tool_status_placeholder.success(f"✅ 工具 {tool_name} 执行完成")
# 短暂显示后清除,保持界面清爽
import time
@@ -270,6 +290,31 @@ def _handle_ai_response():
# 移除光标
message_placeholder.markdown(display_text)
# 显示 RAG 检索来源(如果有)
if rag_sources:
with st.expander("🔍 检索来源", expanded=False):
# 格式化来源列表
if isinstance(rag_sources, list):
for i, source in enumerate(rag_sources, 1):
if isinstance(source, dict):
content = source.get("page_content", source.get("content", str(source)))
metadata = source.get("metadata", {})
filename = metadata.get("filename", metadata.get("source", "未知文件"))
page = metadata.get("page", metadata.get("page_number", ""))
if page:
source_info = f"**来源 {i}:** {filename} (第{page}页)"
else:
source_info = f"**来源 {i}:** {filename}"
st.markdown(source_info)
# 显示内容预览前200字符
preview = content[:200] + "..." if len(content) > 200 else content
st.markdown(f"> {preview}")
st.markdown("---")
else:
st.markdown(f"**来源 {i}:** {str(source)}")
else:
st.markdown(str(rag_sources))
# 拼装包含思考过程的完整内容,以便后续在历史中正确渲染
final_content = display_text
if display_thought:

109
rag_indexer/README.md Normal file
View File

@@ -0,0 +1,109 @@
# 离线 RAG 索引构建系统 (Offline RAG Indexer)
该模块负责 RAG 系统的阶段一:**离线索引构建**。它将外部的非结构化数据如文档、PDF、网页等清洗、切分并转化为向量最终存入向量数据库中。
## 📊 系统工作流示意图
```mermaid
graph TD
A[原始文档集合 <br> PDF / Word / Markdown] --> B(文档加载器 DocumentLoader)
B --> C{文本切分策略 Splitter}
C -->|基础策略| D1[固定字符长度切分 <br> Recursive Split]
C -->|进阶策略| D2[语义边界切分 <br> Semantic Chunking]
C -->|高级策略| D3[父子文档切分 <br> Parent-Child / Auto-merging]
D1 & D2 & D3 --> E[向量化 Embedder <br> llama.cpp: embeddinggemma]
E --> F[(Qdrant 向量数据库)]
subgraph "元数据管理"
G[提取作者、日期、页码等元数据 Metadata] -.附加.-> E
end
```
---
## 🎯 演进路线与核心算法 (Roadmap)
### Level 1: 基础暴力切分 (Basic Recursive Splitting)
- **核心算法**: 递归字符切分。它按照预定义的分隔符列表(如 `["\n\n", "\n", " ", ""]`)从大到小尝试切分文本,直到每块的大小满足最大长度限制。
- **优缺点**: 实现极简单,速度快。但非常容易将一句话拦腰截断,导致上下文语义丢失。
- **实现指南**:
-`langchain.text_splitter` 导入 `RecursiveCharacterTextSplitter`
- 实例化时设置 `chunk_size`(如 500`chunk_overlap`(如 50直接调用 `.split_documents(raw_docs)` 方法。
### Level 2: 语义动态切分 (Semantic Chunking)
- **核心算法**: 句子级相似度阈值算法。
1. 将文章按标点符号按句子拆分。
2. 使用轻量级 Embedding 模型将每一句向量化。
3. 计算相邻两句之间的余弦相似度 (Cosine Similarity)。
4. 当相似度低于设定阈值时(说明两句话讲的不是同一件事,语义发生了转折),在此处“切断”形成一个新的块。
- **优缺点**: 极大程度保留了段落内语义的连贯性,对 LLM 回答非常友好。但由于在切分阶段就需要调用向量模型,耗时略长。
- **实现指南**:
-`langchain_experimental.text_splitter` 导入 `SemanticChunker`
- 实例化时需要传入你已经配置好的 Embedding 模型实例(如基于 `OpenAIEmbeddings` 封装的 llama.cpp 本地模型),并设置 `breakpoint_threshold_type="percentile"` 等阈值参数。
### Level 3: 高级父子块策略 (Parent-Child / Auto-merging)
- **核心算法**: 层次化双重存储与映射。
- **切分机制**: 首先将文档粗切为较大的“父块 (Parent Chunk, 约 1000 词)”,随后将父块细切为较小的“子块 (Child Chunk, 约 200 词)”。
- **存储机制**: 仅仅将**子块**的向量存入 Qdrant 用于精准计算距离;将**父块**的原始内容存在内存或 Document Store (如 KV 数据库) 中,通过 UUID 相互映射。
- **核心思路**: 解决 RAG 领域经典的矛盾——检索时块越小越容易精确命中(去除噪声);但生成回答时,块越大越能给大模型提供充足的上下文背景。
- **实现指南**:
- 使用 `langchain.retrievers` 中的 `ParentDocumentRetriever` 模块。
- 在写入时,你需要同时准备一个底层的 `VectorStore` (即 Qdrant) 和一个 `BaseStore` (比如原生的 `InMemoryStore``Redis`)。
- 将两种不同的 `TextSplitter` 分别赋值给检索器的 `child_splitter``parent_splitter`,然后调用 `.add_documents()` 即可让系统自动完成映射。
### Level 4: GraphRAG 与 多模态 (Graph & Multi-modal)
- **核心算法**: LLM 实体关系抽取 (NER & Relation Extraction)。
- **核心思路**: 解决传统纯向量检索难以处理“跨文档复杂关系推理”的痛点A公司的CEO是谁他名下的B公司主要业务是什么这种需要横跨多页 PDF 的跳跃性问题)。
- **实现指南**:
- 使用本地的大模型(如 `Gemma-4-E2B`)配合 `langchain_community.graphs` 模块。
- 利用 `LLMGraphTransformer` 组件,在读取文档时,通过预设的 Prompt 强制大模型提取出实体Node和关系Edge直接写入诸如 Neo4j 这样的图数据库中,而非传统的 Qdrant 向量库。
---
## <20> 所需依赖与安装
为了支持完整的文档解析和 Qdrant 写入,需要安装以下 Python 包:
```bash
# 基础核心库
pip install langchain langchain-core langchain-openai langchain-qdrant
# 用于复杂文档解析 (PDF, Word, Excel 等)
pip install unstructured pdf2image pdfminer.six
# 用于语义分块 (可选)
pip install langchain-experimental
```
---
## 📂 架构与文件结构设计
`rag_indexer/` 目录下,需创建以下核心文件:
```text
rag_indexer/
├── __init__.py
├── loaders.py # 负责调用 unstructured 解析不同类型文件
├── splitters.py # 负责实现 Recursive、Semantic、Parent-Child 切分逻辑
├── embedders.py # 封装本地 llama.cpp 交互的 Embedding 接口
├── vector_store.py # 封装 Qdrant 写入、Upsert、Collection 初始化操作
└── builder.py # 核心编排文件,将上述模块串联成 Pipeline
```
---
### 串联与触发方式
在你的 LangGraph 系统外,创建一个执行脚本 `scripts/run_indexer.py`
```bash
# 终端执行,将本地的 PDF 手册刷入向量数据库
export QDRANT_URL="http://115.190.121.151:6333"
python scripts/run_indexer.py --file data/user_docs/tech_manual.pdf
```
这相当于系统后台的**“离线学习阶段”**,你可以随时挂载定时任务去扫描文件夹,增量更新知识库。

25
rag_indexer/__init__.py Normal file
View File

@@ -0,0 +1,25 @@
"""
Offline RAG Indexer module.
"""
from .loaders import DocumentLoader
from .splitters import (
RecursiveSplitter,
SemanticSplitter,
ParentChildSplitter,
SplitterType,
)
from .embedders import LlamaCppEmbedder
from .vector_store import QdrantVectorStore
from .builder import IndexBuilder
__all__ = [
"DocumentLoader",
"RecursiveSplitter",
"SemanticSplitter",
"ParentChildSplitter",
"SplitterType",
"LlamaCppEmbedder",
"QdrantVectorStore",
"IndexBuilder",
]

277
rag_indexer/builder.py Normal file
View File

@@ -0,0 +1,277 @@
"""
Core pipeline builder for offline RAG index construction.
Now supports LangChain's ParentDocumentRetriever for parent-child chunking.
"""
import logging
from pathlib import Path
from typing import List, Union, Optional, Tuple
from dataclasses import dataclass
from langchain_core.documents import Document
from langchain.retrievers import ParentDocumentRetriever
from langchain.storage import LocalFileStore, BaseStore
from .loaders import DocumentLoader
from .splitters import SplitterType, get_splitter, ParentChildSplitter
from .embedders import LlamaCppEmbedder
from .vector_store import QdrantVectorStore
from .docstore_manager import get_docstore, PostgresDocStore, create_docstore
logger = logging.getLogger(__name__)
@dataclass
class ParentChildConfig:
"""Configuration for parent-child splitting."""
parent_chunk_size: int = 1000
child_chunk_size: int = 200
parent_chunk_overlap: int = 100
child_chunk_overlap: int = 20
search_k: int = 5
docstore_path: str = None
docstore_type: str = "local"
docstore_conn_string: str = None
class IndexBuilder:
"""Main pipeline for RAG index construction."""
def __init__(
self,
collection_name: str = "rag_documents",
qdrant_url: str = None,
splitter_type: SplitterType = SplitterType.RECURSIVE,
**splitter_kwargs,
):
self.collection_name = collection_name
self.qdrant_url = qdrant_url
self.splitter_type = splitter_type
self.splitter_kwargs = splitter_kwargs
# Components
self.loader = DocumentLoader()
self.embedder = LlamaCppEmbedder()
self.embeddings = self.embedder.as_langchain_embeddings()
self.vector_store = QdrantVectorStore(
collection_name=collection_name,
embeddings=self.embeddings,
qdrant_url=qdrant_url,
)
# Splitter (except parent-child which is handled separately)
if splitter_type != SplitterType.PARENT_CHILD:
if splitter_type == SplitterType.SEMANTIC:
splitter_kwargs["embeddings"] = self.embeddings
self.splitter = get_splitter(splitter_type, **splitter_kwargs)
else:
self.splitter = None
# Initialize ParentDocumentRetriever for parent-child splitting
self._init_parent_child_retriever()
def _init_parent_child_retriever(self, **kwargs):
"""
Initialize ParentDocumentRetriever for parent-child chunking.
This replaces the custom ParentChildSplitter logic.
"""
# Parse kwargs for parent-child config
parent_size = kwargs.get("parent_chunk_size", 1000)
child_size = kwargs.get("child_chunk_size", 200)
parent_overlap = kwargs.get("parent_chunk_overlap", kwargs.get("chunk_overlap", 100))
child_overlap = kwargs.get("child_chunk_overlap", kwargs.get("chunk_overlap", 20))
# Define splitters
self.parent_splitter = RecursiveCharacterTextSplitter(
chunk_size=parent_size,
chunk_overlap=parent_overlap,
)
self.child_splitter = RecursiveCharacterTextSplitter(
chunk_size=child_size,
chunk_overlap=child_overlap,
)
# Vector store (for child chunks)
self.vector_store_obj = self.vector_store.get_langchain_vectorstore()
# Document store (for parent chunks)
docstore_path = kwargs.get("docstore_path")
docstore_type = kwargs.get("docstore_type", "local")
docstore_conn = kwargs.get("docstore_conn_string")
if docstore_type == "postgres" and docstore_conn:
self.docstore = PostgresDocStore(docstore_conn)
self._docstore_conn = docstore_conn
else:
self.docstore = get_docstore(docstore_path)
self._docstore_conn = None
# Create retriever
self.retriever = ParentDocumentRetriever(
vectorstore=self.vector_store_obj,
docstore=self.docstore,
child_splitter=self.child_splitter,
parent_splitter=self.parent_splitter,
search_kwargs={"k": kwargs.get("search_k", 5)},
)
def build_from_file(self, file_path: Union[str, Path]) -> int:
logger.info("Loading file: %s", file_path)
documents = self.loader.load_file(file_path)
logger.info("Loaded %d documents", len(documents))
return self._process_documents(documents)
def build_from_directory(self, directory_path: Union[str, Path], recursive: bool = True) -> int:
logger.info("Loading directory: %s (recursive=%s)", directory_path, recursive)
documents = self.loader.load_directory(directory_path, recursive=recursive)
logger.info("Loaded %d documents from directory", len(documents))
return self._process_documents(documents)
def _process_documents(self, documents: List[Document]) -> int:
if not documents:
logger.warning("No documents to process")
return 0
if self.splitter_type == SplitterType.PARENT_CHILD:
logger.info("Using LangChain ParentDocumentRetriever")
# Ensure collection exists for child chunks
self.vector_store.create_collection()
# Use ParentDocumentRetriever to add documents
# This automatically handles parent-child splitting, mapping, and retrieval
self.retriever.add_documents(documents)
# Log estimated chunk counts
estimated_parent_chunks = len(documents) * (self.parent_splitter._chunk_size // self.child_splitter._chunk_size)
logger.info(
"Indexed with ParentDocumentRetriever: "
f"~{len(documents)} parent chunks, ~{estimated_parent_chunks} child chunks"
)
return len(documents)
else:
logger.info("Splitting documents using %s", self.splitter_type)
chunks = self.splitter.split_documents(documents)
logger.info("Split into %d chunks", len(chunks))
self.vector_store.create_collection()
self.vector_store.add_documents(chunks)
return len(chunks)
def get_collection_info(self):
return self.vector_store.get_collection_info()
def search(self, query: str, k: int = 5) -> List[Document]:
"""Standard search - returns child chunks."""
return self.vector_store.similarity_search(query, k=k)
def search_with_parent_context(self, query: str, k: int = 5) -> List[Document]:
"""
Search with parent context - returns full parent chunks.
This is the main retrieval method when using parent-child splitting.
"""
if self.splitter_type != SplitterType.PARENT_CHILD:
raise RuntimeError(
"search_with_parent_context only available with PARENT_CHILD splitter. "
"Use search() for standard retrieval."
)
return self.retriever.get_relevant_documents(query, k=k)
def retrieve(self, query: str, return_parent: bool = True) -> List[Document]:
"""
Unified retrieval interface.
Args:
query: Search query
return_parent: If True and using parent-child splitter, return parent chunks
If False, always return child chunks
Returns:
List of relevant documents
"""
if self.splitter_type == SplitterType.PARENT_CHILD and return_parent:
return self.search_with_parent_context(query)
else:
return self.search(query)
def get_retriever(self) -> ParentDocumentRetriever:
"""
Get the ParentDocumentRetriever instance directly.
Useful for advanced use cases where you want to access the retriever
outside of IndexBuilder.
"""
if self.splitter_type != SplitterType.PARENT_CHILD:
raise RuntimeError(
"get_retriever() only available with PARENT_CHILD splitter. "
"Use search() or search_with_parent_context() for standard retrieval."
)
return self.retriever
def get_child_splitter(self) -> "RecursiveCharacterTextSplitter":
"""Get the child splitter for reconfiguration."""
if self.splitter_type != SplitterType.PARENT_CHILD:
return self.splitter
return self.child_splitter
def get_parent_splitter(self) -> "RecursiveCharacterTextSplitter":
"""Get the parent splitter for reconfiguration."""
if self.splitter_type != SplitterType.PARENT_CHILD:
raise RuntimeError(
"Parent splitter only available with PARENT_CHILD splitter."
)
return self.parent_splitter
def get_docstore(self) -> BaseStore:
"""Get the document store for parent chunks."""
if self.splitter_type != SplitterType.PARENT_CHILD:
raise RuntimeError(
"Docstore only available with PARENT_CHILD splitter."
)
return self.docstore
def get_docstore_path(self) -> str:
"""Get the document store path."""
if self.splitter_type != SplitterType.PARENT_CHILD:
raise RuntimeError(
"Docstore path only available with PARENT_CHILD splitter."
)
return self.docstore.persist_path
def close(self):
"""Close resources."""
if hasattr(self, "_docstore_conn") and self._docstore_conn:
import psycopg2
conn = psycopg2.connect(self._docstore_conn)
conn.close()
logger.info("Closed PostgreSQL connection")
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
return False
# RecursiveCharacterTextSplitter needs to be imported
from langchain_text_splitters import RecursiveCharacterTextSplitter
if __name__ == "__main__":
# Example usage
builder = IndexBuilder(
splitter_type=SplitterType.PARENT_CHILD,
parent_chunk_size=1000,
child_chunk_size=200,
docstore_path="./my_parent_docs",
)
print("Parent splitter:", builder.get_parent_splitter().chunk_size)
print("Child splitter:", builder.get_child_splitter().chunk_size)
print("Docstore path:", builder.get_docstore_path())
print("Retriever:", builder.get_retriever())

102
rag_indexer/cli.py Executable file
View File

@@ -0,0 +1,102 @@
"""
Command-line interface for the RAG index builder.
"""
import argparse
import logging
import sys
from builder import IndexBuilder
from splitters import SplitterType
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
def main():
parser = argparse.ArgumentParser(description="Offline RAG Index Builder")
parser.add_argument("--file", type=str, help="Path to file to index")
parser.add_argument("--dir", type=str, help="Path to directory to index")
parser.add_argument("--recursive", action="store_true", default=True,
help="Recursively process directories (default: True)")
parser.add_argument("--collection", type=str, default="rag_documents",
help="Qdrant collection name (default: rag_documents)")
parser.add_argument("--qdrant-url", type=str,
help="Qdrant server URL (default: http://127.0.0.1:6333)")
parser.add_argument("--splitter", type=str,
choices=["recursive", "semantic", "parent_child"],
default="recursive",
help="Text splitting strategy (default: recursive)")
parser.add_argument("--chunk-size", type=int, default=500,
help="Chunk size for recursive/parent splitter (default: 500)")
parser.add_argument("--chunk-overlap", type=int, default=50,
parser.add_argument("--docstore-path", type=str,
default=None,
help="Path to store parent documents for parent-child splitter (default: ./parent_docs or HERMES_HOME/parent_docs)")
parser.add_argument("--docstore-type", type=str,
choices=["local", "postgres"],
default="local",
help="Type of docstore: 'local' (default) or 'postgres' for PostgreSQL-backed storage")
parser.add_argument("--docstore-conn", type=str,
default=None,
help="PostgreSQL connection string for postgres docstore")
help="Chunk overlap (default: 50)")
parser.add_argument("--parent-size", type=int, default=1000,
help="Parent chunk size for parent-child splitter (default: 1000)")
parser.add_argument("--child-size", type=int, default=200,
help="Child chunk size for parent-child splitter (default: 200)")
args = parser.parse_args()
if not args.file and not args.dir:
print("Error: Either --file or --dir must be specified", file=sys.stderr)
parser.print_help()
sys.exit(1)
splitter_map = {
"recursive": SplitterType.RECURSIVE,
"semantic": SplitterType.SEMANTIC,
"parent_child": SplitterType.PARENT_CHILD,
}
splitter_type = splitter_map[args.splitter]
splitter_kwargs = {}
if splitter_type == SplitterType.RECURSIVE:
splitter_kwargs["chunk_size"] = args.chunk_size
splitter_kwargs["chunk_overlap"] = args.chunk_overlap
elif splitter_type == SplitterType.PARENT_CHILD:
splitter_kwargs["parent_chunk_size"] = args.parent_size
splitter_kwargs["child_chunk_size"] = args.child_size
splitter_kwargs["parent_chunk_overlap"] = args.chunk_overlap
splitter_kwargs["child_chunk_overlap"] = args.chunk_overlap // 2
splitter_kwargs["docstore_path"] = args.docstore_path
splitter_kwargs["docstore_type"] = args.docstore_type
splitter_kwargs["docstore_conn_string"] = args.docstore_conn
builder = IndexBuilder(
collection_name=args.collection,
qdrant_url=args.qdrant_url,
splitter_type=splitter_type,
**splitter_kwargs
)
try:
if args.file:
chunk_count = builder.build_from_file(args.file)
else:
chunk_count = builder.build_from_directory(args.dir, args.recursive)
print(f"Indexing completed. Total chunks indexed: {chunk_count}")
info = builder.get_collection_info()
print(f"Collection '{info['name']}' has {info['vectors_count']} vectors (dim={info['vector_size']})")
except Exception as e:
logging.exception("Indexing failed")
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,142 @@
"""
Document store manager for ParentDocumentRetriever.
Supports both LocalFileStore (default) and custom PostgreSQL-backed stores.
"""
import os
from typing import Optional
from langchain.storage import BaseStore, LocalFileStore
def get_docstore(persist_path: str = None) -> LocalFileStore:
"""
Create and return a document store for parent chunks.
Args:
persist_path: Path to store parent documents. Defaults to ./parent_docs
or HERMES_HOME/parent_docs if set.
"""
if persist_path is None:
# Use HERMES_HOME if available, otherwise default to current directory
persist_path = os.getenv("HERMES_HOME")
if persist_path:
persist_path = os.path.join(persist_path, "parent_docs")
else:
persist_path = "./parent_docs"
os.makedirs(persist_path, exist_ok=True)
return LocalFileStore(persist_path)
class PostgresDocStore(BaseStore):
"""
PostgreSQL-backed document store for parent chunks.
This is an optional advanced feature. For most use cases,
LocalFileStore is sufficient and simpler.
"""
def __init__(self, connection_string: str):
"""
Initialize PostgreSQL document store.
Args:
connection_string: PostgreSQL connection URL
"""
import psycopg2
from psycopg2 import sql
self.conn_string = connection_string
self._conn = None
# Create table if not exists
self._create_table()
def _create_table(self):
"""Create the parent documents table if not exists."""
try:
self._conn = psycopg2.connect(self.conn_string)
cursor = self._conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS parent_documents (
key TEXT PRIMARY KEY,
value JSONB NOT NULL,
created_at TIMESTAMPTZ DEFAULT NOW()
)
""")
self._conn.commit()
cursor.close()
except Exception as e:
raise RuntimeError(f"Failed to create PostgreSQL table: {e}")
def get(self, key: str) -> Optional[dict]:
"""Retrieve a document by key."""
try:
self._ensure_connection()
cursor = self._conn.cursor()
cursor.execute("SELECT value FROM parent_documents WHERE key = %s", (key,))
row = cursor.fetchone()
cursor.close()
if row:
import json
return json.loads(row[0])
return None
except Exception as e:
raise RuntimeError(f"Failed to retrieve document: {e}")
def set(self, key: str, value: dict) -> None:
"""Store a document."""
try:
self._ensure_connection()
cursor = self._conn.cursor()
# Upsert
insert_query = sql.SQL(
"INSERT INTO parent_documents (key, value) VALUES (%s, %s)"
)
update_query = sql.SQL(
"UPDATE parent_documents SET value = %s WHERE key = %s"
)
cursor.execute(insert_query, (key, json.dumps(value)))
try:
cursor.execute(update_query, (key, json.dumps(value)))
except psycopg2.IntegrityError:
pass # Key exists, ignore
self._conn.commit()
cursor.close()
except Exception as e:
raise RuntimeError(f"Failed to store document: {e}")
def _ensure_connection(self):
"""Ensure we have an open connection."""
if self._conn is None or self._conn.closed:
self._conn = psycopg2.connect(self.conn_string)
def close(self):
"""Close the connection."""
if self._conn and not self._conn.closed:
self._conn.close()
# Factory function for creating custom docstores
# Returns a tuple: (BaseStore instance, connection_string or None)
def create_docstore(
store_type: str = "local",
persist_path: str = None,
connection_string: str = None
) -> tuple:
"""
Factory function to create different types of document stores.
Args:
store_type: "local" (default), "postgres"
persist_path: Path for local file store
connection_string: PostgreSQL connection string
Returns:
Tuple of (BaseStore instance, connection_string or None)
"""
if store_type == "postgres" and connection_string:
return (PostgresDocStore(connection_string), connection_string)
else:
return (get_docstore(persist_path), None)

68
rag_indexer/embedders.py Normal file
View File

@@ -0,0 +1,68 @@
"""
Embedding model wrapper for llama.cpp service.
"""
import os
from typing import List, Optional
from urllib.parse import urljoin
from langchain_openai import OpenAIEmbeddings
class LlamaCppEmbedder:
"""Wrapper for llama.cpp embedding service via OpenAI-compatible API."""
def __init__(
self,
base_url: Optional[str] = None,
api_key: Optional[str] = None,
model: str = "embeddinggemma-300M-Q8_0",
):
self.base_url = base_url or os.getenv("LLAMACPP_EMBEDDING_URL", "http://127.0.0.1:8082")
self.api_key = api_key or os.getenv("LLAMACPP_API_KEY", "")
self.model = model
# Ensure URL ends with /v1
self.base_url = urljoin(self.base_url.rstrip("/") + "/", "v1")
def as_langchain_embeddings(self) -> OpenAIEmbeddings:
"""Create LangChain OpenAIEmbeddings instance."""
return OpenAIEmbeddings(
openai_api_base=self.base_url,
openai_api_key=self.api_key,
model=self.model,
)
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed a list of documents."""
emb = self.as_langchain_embeddings()
return emb.embed_documents(texts)
def embed_query(self, text: str) -> List[float]:
"""Embed a single query."""
emb = self.as_langchain_embeddings()
return emb.embed_query(text)
def get_embedding_dimension(self) -> int:
"""Get embedding dimension by embedding a test string."""
test_embedding = self.embed_query("test")
return len(test_embedding)
class MockEmbedder:
"""Mock embedder for testing without a real service."""
def __init__(self, dimension: int = 768):
self.dimension = dimension
def as_langchain_embeddings(self) -> OpenAIEmbeddings:
raise NotImplementedError("MockEmbedder cannot be used as LangChain embeddings")
def embed_documents(self, texts: List[str]) -> List[List[float]]:
return [[0.0] * self.dimension for _ in texts]
def embed_query(self, text: str) -> List[float]:
return [0.0] * self.dimension
def get_embedding_dimension(self) -> int:
return self.dimension

View File

@@ -0,0 +1,124 @@
"""
Example demonstrating ParentDocumentRetriever usage.
This script shows how to:
1. Build an index with parent-child chunking
2. Search with child chunks (fast, precise)
3. Search with parent context (large context)
4. Access the retriever directly for advanced use cases
"""
import logging
from pathlib import Path
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
from builder import IndexBuilder
from splitters import SplitterType
def main():
print("=" * 70)
print("ParentDocumentRetriever Example")
print("=" * 70)
# Step 1: Create IndexBuilder with parent-child splitting
print("\n1. Creating IndexBuilder with parent-child splitting...")
builder = IndexBuilder(
collection_name="parent_child_demo",
splitter_type=SplitterType.PARENT_CHILD,
parent_chunk_size=1000, # Parent chunks: larger context
child_chunk_size=200, # Child chunks: smaller for precision
docstore_path="./my_parent_docs", # Where to store parent chunks
search_k=5, # Number of child chunks to retrieve
)
print(f" Parent splitter: chunk_size={builder.get_parent_splitter().chunk_size}")
print(f" Child splitter: chunk_size={builder.get_child_splitter().chunk_size}")
print(f" Docstore path: {builder.get_docstore_path()}")
print(f" Search k: {builder.retriever.search_kwargs['k']}")
# Step 2: Build index from a sample file
print("\n2. Building index from sample file...")
# Create a test document
test_content = """
This is a test document for demonstrating ParentDocumentRetriever.
Parent chunks contain larger portions of text (1000 characters),
while child chunks are smaller (200 characters) for precise retrieval.
When you search with ParentDocumentRetriever:
- It first retrieves relevant child chunks
- Then replaces them with their corresponding parent chunks
- This gives you large context while maintaining precision
Example search queries:
- "ParentDocumentRetriever"
- "child chunks"
- "large context"
- "precise retrieval"
"""
test_file = Path("./test_document.txt")
test_file.write_text(test_content)
chunk_count = builder.build_from_file(str(test_file))
print(f" Indexed {chunk_count} documents")
# Step 3: Search with child chunks (fast, precise)
print("\n3. Searching with child chunks (fast, precise)...")
child_results = builder.search("ParentDocumentRetriever", k=3)
print(f" Found {len(child_results)} child chunks:")
for i, doc in enumerate(child_results, 1):
print(f" [{i}] {doc.page_content[:100]}...")
# Step 4: Search with parent context (large context)
print("\n4. Searching with parent context (large context)...")
parent_results = builder.search_with_parent_context("ParentDocumentRetriever", k=3)
print(f" Found {len(parent_results)} parent chunks:")
for i, doc in enumerate(parent_results, 1):
print(f" [{i}] {doc.page_content[:150]}...")
# Step 5: Compare results
print("\n5. Comparing child vs parent results...")
print(f" Child chunks total length: {sum(len(d.page_content) for d in child_results)}")
print(f" Parent chunks total length: {sum(len(d.page_content) for d in parent_results)}")
print(f" Ratio: parent/child = {sum(len(d.page_content) for d in parent_results) / max(sum(len(d.page_content) for d in child_results), 1):.2f}x larger")
# Step 6: Access retriever directly
print("\n6. Accessing retriever directly...")
retriever = builder.get_retriever()
print(f" Retriever type: {type(retriever).__name__}")
print(f" Vectorstore: {retriever.vectorstore}")
print(f" Docstore: {retriever.docstore}")
# Step 7: Unified retrieval interface
print("\n7. Using unified retrieval interface...")
unified_results = builder.retrieve("ParentDocumentRetriever", return_parent=True)
print(f" Retrieved {len(unified_results)} documents (with parent context)")
# Step 8: Collection info
print("\n8. Collection info...")
info = builder.get_collection_info()
print(f" Collection: {info['name']}")
print(f" Vectors: {info['vectors_count']}")
print(f" Vector size: {info['vector_size']}")
# Cleanup
print("\n9. Cleaning up...")
builder.close()
print("\n" + "=" * 70)
print("Example completed successfully!")
print("=" * 70)
return builder
if __name__ == "__main__":
builder = main()

91
rag_indexer/loaders.py Normal file
View File

@@ -0,0 +1,91 @@
"""
Document loaders using unstructured library.
"""
import logging
from pathlib import Path
from typing import List, Union
from langchain_core.documents import Document
from unstructured.partition.auto import partition
logger = logging.getLogger(__name__)
class DocumentLoader:
"""Load documents from various file formats."""
SUPPORTED_EXTENSIONS = {".pdf", ".docx", ".doc", ".txt", ".md", ".html", ".pptx", ".xlsx"}
def __init__(self, extract_images: bool = False):
"""
Args:
extract_images: Whether to extract images from documents (requires additional dependencies)
"""
self.extract_images = extract_images
def load_file(self, file_path: Union[str, Path]) -> List[Document]:
"""Load a single file into LangChain Document objects."""
file_path = Path(file_path).resolve()
if not file_path.exists():
raise FileNotFoundError(f"File not found: {file_path}")
suffix = file_path.suffix.lower()
if suffix not in self.SUPPORTED_EXTENSIONS:
raise ValueError(
f"Unsupported file extension: {suffix}. Supported: {self.SUPPORTED_EXTENSIONS}"
)
# Parse with unstructured
elements = partition(
filename=str(file_path),
extract_images_in_pdf=self.extract_images,
)
documents = []
for elem in elements:
text = getattr(elem, "text", "")
if not text or not text.strip():
continue
# Base metadata
metadata = {
"source": str(file_path),
"file_name": file_path.name,
"file_type": suffix,
}
# Merge element-specific metadata without overwriting base fields
elem_meta = getattr(elem, "metadata", {}) or {}
for key, value in elem_meta.items():
if value and key not in metadata:
metadata[key] = value
documents.append(Document(page_content=text, metadata=metadata))
if not documents:
logger.warning("No text content extracted from %s", file_path)
return []
return documents
def load_directory(
self, directory_path: Union[str, Path], recursive: bool = True
) -> List[Document]:
"""Load all supported files from a directory."""
directory_path = Path(directory_path).resolve()
if not directory_path.is_dir():
raise NotADirectoryError(f"Not a directory: {directory_path}")
all_documents = []
pattern = "**/*" if recursive else "*"
for file_path in directory_path.glob(pattern):
if file_path.is_file() and file_path.suffix.lower() in self.SUPPORTED_EXTENSIONS:
try:
docs = self.load_file(file_path)
all_documents.extend(docs)
except Exception as e:
logger.error("Failed to load %s: %s", file_path, e)
return all_documents

71
rag_indexer/splitters.py Normal file
View File

@@ -0,0 +1,71 @@
"""
Text splitters for chunking documents.
"""
from enum import Enum
from typing import List, Optional
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_experimental.text_splitter import SemanticChunker
class SplitterType(str, Enum):
RECURSIVE = "recursive"
SEMANTIC = "semantic"
PARENT_CHILD = "parent_child"
def get_splitter(splitter_type: SplitterType, **kwargs):
"""Factory function to create a text splitter."""
if splitter_type == SplitterType.RECURSIVE:
chunk_size = kwargs.get("chunk_size", 500)
chunk_overlap = kwargs.get("chunk_overlap", 50)
return RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
separators=["\n\n", "\n", "", "", "", " ", ""],
)
elif splitter_type == SplitterType.SEMANTIC:
# Requires embeddings for semantic splitting
embeddings = kwargs.get("embeddings")
if embeddings is None:
raise ValueError("Semantic splitter requires 'embeddings' parameter")
return SemanticChunker(embeddings=embeddings)
else:
raise ValueError(f"Unsupported splitter type: {splitter_type}")
class ParentChildSplitter:
"""
Splits documents into parent (large) and child (small) chunks.
Child chunks are indexed for retrieval, parent chunks are stored for context.
"""
def __init__(
self,
parent_chunk_size: int = 1000,
child_chunk_size: int = 200,
parent_chunk_overlap: int = 100,
child_chunk_overlap: int = 20,
):
self.parent_splitter = RecursiveCharacterTextSplitter(
chunk_size=parent_chunk_size,
chunk_overlap=parent_chunk_overlap,
)
self.child_splitter = RecursiveCharacterTextSplitter(
chunk_size=child_chunk_size,
chunk_overlap=child_chunk_overlap,
)
def split_documents(self, documents: List[Document]) -> tuple[List[Document], List[Document]]:
"""
Returns:
(parent_chunks, child_chunks)
"""
parent_chunks = self.parent_splitter.split_documents(documents)
child_chunks = self.child_splitter.split_documents(documents)
# Link child chunks to parent IDs (optional metadata)
# In a real implementation, you'd map each child to a parent chunk ID.
return parent_chunks, child_chunks

110
rag_indexer/vector_store.py Normal file
View File

@@ -0,0 +1,110 @@
"""
Qdrant vector store wrapper.
"""
import logging
import os
from typing import List, Optional, Dict, Any
from langchain_core.documents import Document
from langchain_qdrant import QdrantVectorStore as LangchainQdrantVS
from qdrant_client import QdrantClient
from qdrant_client.http import models
from qdrant_client.http.models import Distance, VectorParams
from .embedders import LlamaCppEmbedder
logger = logging.getLogger(__name__)
class QdrantVectorStore:
"""Wrapper for Qdrant vector database operations."""
def __init__(
self,
collection_name: str,
embeddings: Optional[Any] = None,
qdrant_url: Optional[str] = None,
api_key: Optional[str] = None,
):
self.collection_name = collection_name
self.qdrant_url = qdrant_url or os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
self.api_key = api_key
# Embeddings
if embeddings is None:
embedder = LlamaCppEmbedder()
self.embeddings = embedder.as_langchain_embeddings()
else:
self.embeddings = embeddings
# Qdrant client
self.client = QdrantClient(url=self.qdrant_url, api_key=self.api_key)
# LangChain vector store
self.vector_store = LangchainQdrantVS(
client=self.client,
collection_name=self.collection_name,
embeddings=self.embeddings,
)
def create_collection(self, vector_size: Optional[int] = None, force_recreate: bool = False):
"""Create collection with appropriate vector size."""
if vector_size is None:
embedder = LlamaCppEmbedder()
vector_size = embedder.get_embedding_dimension()
collections = self.client.get_collections().collections
exists = any(c.name == self.collection_name for c in collections)
if exists and force_recreate:
self.client.delete_collection(self.collection_name)
exists = False
if not exists:
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE),
)
logger.info("Collection '%s' created (dim=%d)", self.collection_name, vector_size)
else:
logger.info("Collection '%s' already exists", self.collection_name)
def add_documents(self, documents: List[Document], batch_size: int = 100):
"""Add documents to vector store."""
if not documents:
return []
self.create_collection()
ids = self.vector_store.add_documents(documents, batch_size=batch_size)
logger.info("Added %d documents to '%s'", len(ids), self.collection_name)
return ids
def similarity_search(self, query: str, k: int = 5) -> List[Document]:
return self.vector_store.similarity_search(query, k=k)
def similarity_search_with_score(self, query: str, k: int = 5) -> List[tuple[Document, float]]:
return self.vector_store.similarity_search_with_score(query, k=k)
def delete_collection(self):
self.client.delete_collection(self.collection_name)
logger.info("Collection '%s' deleted", self.collection_name)
def get_collection_info(self) -> Dict[str, Any]:
info = self.client.get_collection(self.collection_name)
return {
"name": info.name,
"vectors_count": info.vectors_count,
"status": info.status,
"vector_size": info.config.params.vectors.size,
}
def as_langchain_vectorstore(self):
return self.vector_store
def get_langchain_vectorstore(self):
"""返回 LangChain Qdrant 向量存储对象(别名)"""
return self.vector_store
def get_qdrant_client(self):
"""返回原生 Qdrant 客户端(如需手动管理 collection"""
return self.client

View File

@@ -48,6 +48,8 @@ pydantic==2.12.5
python-dotenv==1.2.2
typing-extensions==4.15.0
unstructured>=0.0.1
# ============================================================================
# 注意:
# 1. 此文件包含项目直接依赖的精确版本