Compare commits
2 Commits
6042d4a476
...
c18e8a9860
| Author | SHA1 | Date | |
|---|---|---|---|
| c18e8a9860 | |||
| 0470afce13 |
2
.gitignore
vendored
2
.gitignore
vendored
@@ -13,6 +13,8 @@
|
||||
!frontend/**
|
||||
!scripts/
|
||||
!scripts/**
|
||||
!rag_indexer/
|
||||
!rag_indexer/**
|
||||
!docker/
|
||||
!docker/**
|
||||
!.gitea/
|
||||
|
||||
82
app/agent.py
82
app/agent.py
@@ -6,14 +6,40 @@ AI Agent 服务类 - 支持多模型动态切换
|
||||
import os
|
||||
import json
|
||||
from dotenv import load_dotenv
|
||||
from langchain_community.chat_models import ChatZhipuAI
|
||||
from langchain_openai import ChatOpenAI
|
||||
try:
|
||||
from langchain_community.chat_models import ChatZhipuAI
|
||||
HAS_ZHIPUAI = True
|
||||
except ImportError:
|
||||
HAS_ZHIPUAI = False
|
||||
ChatZhipuAI = None
|
||||
|
||||
try:
|
||||
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
||||
HAS_OPENAI = True
|
||||
except ImportError:
|
||||
HAS_OPENAI = False
|
||||
ChatOpenAI = None
|
||||
OpenAIEmbeddings = None
|
||||
|
||||
from pydantic import SecretStr
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
try:
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
HAS_POSTGRES_CHECKPOINT = True
|
||||
except ImportError:
|
||||
HAS_POSTGRES_CHECKPOINT = False
|
||||
AsyncPostgresSaver = None
|
||||
|
||||
# 本地模块
|
||||
from app.graph_builder import GraphBuilder, GraphContext
|
||||
from app.tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
|
||||
try:
|
||||
from app.rag import RAGPipeline
|
||||
from app.rag.tools import RAGTool
|
||||
HAS_RAG = True
|
||||
except ImportError as e:
|
||||
HAS_RAG = False
|
||||
RAGPipeline = None
|
||||
RAGTool = None
|
||||
from app.logger import debug, info, warning, error
|
||||
|
||||
|
||||
@@ -31,9 +57,13 @@ class AIAgentService:
|
||||
"""
|
||||
self.checkpointer = checkpointer
|
||||
self.graphs = {} # 存储不同模型对应的 graph 实例
|
||||
self.rag = None # RAG 检索实例
|
||||
self.rag_tool = None # RAG 工具实例
|
||||
|
||||
def _create_zhipu_llm(self):
|
||||
"""创建智谱在线 LLM"""
|
||||
if not HAS_ZHIPUAI:
|
||||
raise ImportError("智谱AI支持不可用,请安装langchain-community包")
|
||||
api_key = os.getenv("ZHIPUAI_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError("ZHIPUAI_API_KEY not set in environment")
|
||||
@@ -49,6 +79,8 @@ class AIAgentService:
|
||||
|
||||
def _create_deepseek_llm(self):
|
||||
"""创建 DeepSeek LLM(使用 OpenAI 兼容 API)"""
|
||||
if not HAS_OPENAI:
|
||||
raise ImportError("OpenAI兼容支持不可用,请安装langchain-openai包")
|
||||
api_key = os.getenv("DEEPSEEK_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError("DEEPSEEK_API_KEY not set in environment")
|
||||
@@ -65,6 +97,8 @@ class AIAgentService:
|
||||
|
||||
def _create_local_llm(self):
|
||||
"""创建本地 vLLM 服务 LLM"""
|
||||
if not HAS_OPENAI:
|
||||
raise ImportError("OpenAI兼容支持不可用,请安装langchain-openai包")
|
||||
# vLLM 服务地址:优先从环境变量读取,适配 Docker、FRP 穿透和本地开发
|
||||
vllm_base_url = os.getenv(
|
||||
"VLLM_BASE_URL",
|
||||
@@ -80,8 +114,39 @@ class AIAgentService:
|
||||
streaming=True, # 确保开启流式输出
|
||||
)
|
||||
|
||||
def _create_embeddings(self):
|
||||
"""创建嵌入模型"""
|
||||
if not HAS_OPENAI:
|
||||
raise ImportError("OpenAI兼容支持不可用,请安装langchain-openai包")
|
||||
embedding_url = os.getenv(
|
||||
"LLAMACPP_EMBEDDING_URL",
|
||||
"http://127.0.0.1:8082/v1"
|
||||
)
|
||||
return OpenAIEmbeddings(
|
||||
openai_api_base=embedding_url,
|
||||
openai_api_key=os.getenv("LLAMACPP_API_KEY", "token-abc123"),
|
||||
model="text-embedding-ada-002", # 模型名称不重要,兼容即可
|
||||
)
|
||||
|
||||
async def initialize(self):
|
||||
"""预编译所有模型的 graph(使用传入的 checkpointer)"""
|
||||
# 先初始化 RAG 检索系统
|
||||
if HAS_RAG and RAGPipeline is not None and RAGTool is not None:
|
||||
try:
|
||||
info("🔄 正在初始化 RAG 检索系统...")
|
||||
embeddings = self._create_embeddings()
|
||||
self.rag = RAGPipeline(embeddings=embeddings)
|
||||
self.rag_tool = RAGTool(self.rag).get_tool()
|
||||
info("✅ RAG 检索系统初始化成功")
|
||||
except Exception as e:
|
||||
warning(f"⚠️ RAG 检索系统初始化失败: {e}")
|
||||
self.rag = None
|
||||
self.rag_tool = None
|
||||
else:
|
||||
info("⏭️ RAG 检索系统不可用,跳过初始化")
|
||||
self.rag = None
|
||||
self.rag_tool = None
|
||||
|
||||
model_configs = {
|
||||
"local": self._create_local_llm, # 本地模型作为第一个
|
||||
"deepseek": self._create_deepseek_llm, # DeepSeek 作为中间
|
||||
@@ -92,7 +157,16 @@ class AIAgentService:
|
||||
try:
|
||||
info(f"🔄 正在初始化模型 '{model_name}'...")
|
||||
llm = llm_creator()
|
||||
builder = GraphBuilder(llm, AVAILABLE_TOOLS, TOOLS_BY_NAME).build()
|
||||
|
||||
# 构建工具列表:基础工具 + RAG工具(如果可用)
|
||||
tools = AVAILABLE_TOOLS.copy()
|
||||
tools_by_name = TOOLS_BY_NAME.copy()
|
||||
|
||||
if self.rag_tool is not None:
|
||||
tools.append(self.rag_tool)
|
||||
tools_by_name[self.rag_tool.name] = self.rag_tool
|
||||
|
||||
builder = GraphBuilder(llm, tools, tools_by_name).build()
|
||||
graph = builder.compile(checkpointer=self.checkpointer)
|
||||
self.graphs[model_name] = graph
|
||||
info(f"✅ 模型 '{model_name}' 初始化成功")
|
||||
|
||||
136
app/rag/README.md
Normal file
136
app/rag/README.md
Normal file
@@ -0,0 +1,136 @@
|
||||
# 在线 RAG 检索与生成系统 (Online RAG Retriever)
|
||||
|
||||
该模块负责 RAG 系统的阶段二:**在线检索与生成**。它接收用户提问,从知识库中检索出上下文,利用各种高级策略去噪、融合,并作为增强上下文输入给大语言模型 (LLM)。
|
||||
|
||||
## 📊 RAG-Fusion & 混合检索流水线示意图
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
User((用户提问)) --> A[LLM 查询改写生成器]
|
||||
|
||||
subgraph RAG-Fusion 核心流程
|
||||
A -->|改写为问题 1| B1[查询 1]
|
||||
A -->|改写为问题 2| B2[查询 2]
|
||||
A -->|原问题| B3[原始查询]
|
||||
|
||||
B1 & B2 & B3 --> C[混合检索器 Hybrid Retriever <br> Dense Vector + BM25 Sparse]
|
||||
|
||||
C --> D[多路召回结果合集 N=60条]
|
||||
|
||||
D --> E{RRF 倒数排名融合去重}
|
||||
end
|
||||
|
||||
E -->|筛选出前 20 条| F[Cross-Encoder 重排器 Reranker]
|
||||
|
||||
F -->|精细打分排序 Top 5| G[最终纯净上下文 Context]
|
||||
|
||||
G --> H[将 Context 与原问题拼接输入大模型]
|
||||
H --> I((LLM 生成最终回答))
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🎯 演进路线与算法详解 (Roadmap)
|
||||
|
||||
### Level 1: 基础向量搜索 (Basic Similarity Search)
|
||||
- **核心算法**: 近似最近邻搜索 (ANN, 常用 HNSW 算法)。将用户问题转化为向量后,计算它与库中向量的余弦相似度 (Cosine Similarity),取距离最近的 K 个块。
|
||||
- **优缺点**: 速度极快。但只能捕捉“语义相似”,如果用户搜索特定专有名词、编号、订单号,纯向量检索往往会失效(产生“幻觉”匹配)。
|
||||
|
||||
### Level 2: 混合检索与重排序 (Hybrid Search + Reranker)
|
||||
混合检索旨在结合向量的“语义泛化”与关键词的“精准匹配”,随后利用重排序模型过滤噪声。
|
||||
|
||||
**1. 基础召回 (混合检索)**
|
||||
- **核心原理**: 结合基于 HNSW 的 Dense Vector 相似度搜索与基于 TF-IDF 的 BM25 稀疏检索 (Sparse Vector)。
|
||||
- **实现指南**: 使用 `langchain_qdrant` 包中的 `Qdrant` 类连接数据库。通过调用 `Qdrant.from_existing_collection(...)` 实例化向量库,并使用 `.as_retriever(search_kwargs={"k": 20})` 方法生成基础检索器。Qdrant 底层会自动处理双路召回。
|
||||
|
||||
**2. 二次精排 (Cross-Encoder)**
|
||||
- **核心原理**: 不同于双塔模型(分别算向量再求距离),交叉编码器将“用户问题 + 检索到的单例文档”拼接后整体输入 Transformer 模型,由模型直接输出 0~1 的相关性得分,精度极高。
|
||||
- **实现指南**:
|
||||
- 使用 `sentence-transformers` 库加载本地轻量级重排模型(如 `BAAI/bge-reranker-base`)。
|
||||
- 引入 `langchain.retrievers.document_compressors` 包中的 `CrossEncoderReranker` 类包装该模型,设置参数 `top_n=5`。
|
||||
- 最后,使用 `langchain.retrievers` 包中的 `ContextualCompressionRetriever` 类,将 `base_compressor` (重排器) 和 `base_retriever` (基础检索器) 组合。
|
||||
- **如何调用**: 业务逻辑中直接对组合后的检索器调用 `.invoke(query)` 方法,即可一键完成“大范围召回 20 条 -> 逐一打分精排选 5 条”的去噪流水线。
|
||||
|
||||
### Level 3: RAG-Fusion (多路改写与倒数排名融合)
|
||||
RAG-Fusion 通过大模型发散思维,将单一问题改写为多个相似问题,扩大搜索面,再利用数学统计算法合并结果。
|
||||
|
||||
**1. 多路查询改写**
|
||||
- **核心原理**: 克服用户初始提问词不达意或视角受限的问题。
|
||||
- **实现指南**: 导入 `langchain.retrievers.multi_query` 包中的 `MultiQueryRetriever` 类。需向其提供一个已实例化的 LLM 对象(如基于 `ChatOpenAI` 封装的本地 VLLM 模型)。系统在底层会自动 Prompt 模型,将原始 `query` 转化为包含 3-5 个不同表述的查询列表。
|
||||
|
||||
**2. 倒数排名融合 (RRF)**
|
||||
- **核心原理**: RRF (Reciprocal Rank Fusion) 是一种无需评分归一化的融合算法。公式为 $RRF\_score(d) = \sum_{q \in Q} \frac{1}{k + rank_q(d)}$,有效避免某一极端检索结果主导全局。
|
||||
- **实现指南**:
|
||||
- 针对每个改写后的查询 $q$,分别调用精排检索器的 `.invoke(q)` 获取文档列表。
|
||||
- 使用 `langchain.retrievers` 中的 `EnsembleRetriever` 类(原生支持 RRF),或在代码中遍历收集到的 `Document` 对象,基于其排名 `rank` 累加得分,最终通过 Python 的 `set` 去重并提取 `doc.page_content`。
|
||||
|
||||
### Level 4: Agentic RAG / Self-RAG (智能体与自我反思)
|
||||
- **核心原理**: 基于 LangGraph 的 ReAct (Reasoning and Acting) 状态机路由。大模型并非每次都去死板地执行检索,而是先判断问题:“这是闲聊?还是需要查知识库?”。如果是后者,模型输出一个 `ToolCall` 指令,触发检索。
|
||||
- **实现指南**: 请参考下方的**与现有系统整合调用**章节。
|
||||
|
||||
- **示意图**:
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant User
|
||||
participant LangGraph Agent
|
||||
participant RAG_Tool
|
||||
participant Qdrant
|
||||
|
||||
User->>LangGraph Agent: "公司报销流程是什么?"
|
||||
LangGraph Agent->>LangGraph Agent: 思考: 这是一个内部规章问题,需要查资料
|
||||
LangGraph Agent->>RAG_Tool: ToolCall(search_knowledge_base, "公司报销流程")
|
||||
RAG_Tool->>Qdrant: RAG-Fusion & 混合检索
|
||||
Qdrant-->>RAG_Tool: 原始分块
|
||||
RAG_Tool->>RAG_Tool: Cross-Encoder 重排过滤
|
||||
RAG_Tool-->>LangGraph Agent: 返回最相关的5条报销规定
|
||||
LangGraph Agent->>LangGraph Agent: 思考: 资料充分,开始撰写回答
|
||||
LangGraph Agent-->>User: "根据知识库规定,报销流程分为以下3步..."
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 📦 所需依赖与安装
|
||||
|
||||
除了基础的 LangChain 包外,在线检索模块为了支持重排和稀疏检索,还需要安装:
|
||||
|
||||
```bash
|
||||
# 用于 Cross-Encoder 重排序模型 (如 BAAI/bge-reranker-base)
|
||||
pip install sentence-transformers
|
||||
|
||||
# 用于 BM25 关键词混合检索
|
||||
pip install rank_bm25
|
||||
|
||||
# 基础框架
|
||||
pip install langchain langchain-core langchain-openai langchain-qdrant
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 📂 架构与文件结构设计
|
||||
|
||||
在 `app/rag/` 目录下,需创建以下文件来模块化上述功能:
|
||||
|
||||
```text
|
||||
app/rag/
|
||||
├── __init__.py
|
||||
├── retriever.py # 负责 Qdrant 的基础召回与 ContextualCompressionRetriever
|
||||
├── reranker.py # 负责加载 sentence-transformers 交叉编码器
|
||||
├── query_transform.py # 负责基于 MultiQueryRetriever 的改写逻辑
|
||||
├── pipeline.py # 组合上述组件,暴露出核心的 retrieve() 方法
|
||||
└── tools.py # 将 Pipeline 包装成 LangChain Tool 供 Agent 调用
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## <20> 与现有系统整合调用 (Agentic RAG 实现)
|
||||
|
||||
基于目前 LangGraph 系统的架构,我们将摒弃将代码堆砌在一起的旧方式,而是利用 **LangChain Tools** 的特性将 RAG 优雅地注入系统:
|
||||
|
||||
1. **封装检索工具 (Tool)**:
|
||||
从 `langchain.tools` 导入 `@tool` 装饰器。定义一个名为 `search_knowledge_base(query: str)` 的函数。在函数内部,实例化并调用我们在 `pipeline.py` 中写好的多路召回与重排逻辑。
|
||||
2. **模型绑定 (Bind)**:
|
||||
在 `app/agent.py` 或 `app/nodes/tool_call.py` 中,将这个工具引入,并通过 `llm.bind_tools([search_knowledge_base])` 绑定到现有的本地大模型实例上。
|
||||
3. **状态机路由 (Graph Routing)**:
|
||||
你的 LangGraph 状态机会像处理普通对话一样自动接管:当模型判断需要调用查阅规章制度或专业资料时,它会输出 `ToolCall` 消息,流转到 `tool_node` 执行上述的 RAG 检索逻辑并返回上下文。
|
||||
|
||||
这让你无需修改任何前端 Streamlit 流式代码,就能平滑升级为具备超级知识库检索能力的智能体 (Agent)!
|
||||
22
app/rag/__init__.py
Normal file
22
app/rag/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""
|
||||
在线 RAG 检索与生成系统
|
||||
|
||||
提供高级RAG检索功能,支持混合检索、重排序、RAG-Fusion和多路查询改写。
|
||||
"""
|
||||
|
||||
from .pipeline import RAGPipeline
|
||||
from .retriever import create_hybrid_retriever, create_base_retriever
|
||||
from .reranker import CrossEncoderReranker
|
||||
from .query_transform import MultiQueryTransformer
|
||||
from .tools import search_knowledge_base_tool
|
||||
|
||||
__all__ = [
|
||||
"RAGPipeline",
|
||||
"create_hybrid_retriever",
|
||||
"create_base_retriever",
|
||||
"CrossEncoderReranker",
|
||||
"MultiQueryTransformer",
|
||||
"search_knowledge_base_tool",
|
||||
]
|
||||
|
||||
__version__ = "0.1.0"
|
||||
232
app/rag/example.py
Normal file
232
app/rag/example.py
Normal file
@@ -0,0 +1,232 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
RAG 系统使用示例
|
||||
|
||||
演示如何使用 app/rag 模块进行知识检索。
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
# 添加项目根目录到路径
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
from langchain_community.llms import VLLMOpenAI
|
||||
|
||||
|
||||
def setup_environment():
|
||||
"""设置环境变量"""
|
||||
# 设置 Qdrant 连接信息(根据实际情况修改)
|
||||
os.environ.setdefault("QDRANT_URL", "http://115.190.121.151:6333")
|
||||
# 如果需要 API 密钥,请设置 QDRANT_API_KEY
|
||||
|
||||
print("环境变量已设置")
|
||||
print(f"QDRANT_URL: {os.environ.get('QDRANT_URL')}")
|
||||
|
||||
|
||||
def demonstrate_basic_rag():
|
||||
"""演示基础 RAG 功能"""
|
||||
print("\n" + "="*60)
|
||||
print("演示: 基础 RAG 检索 (Level 1)")
|
||||
print("="*60)
|
||||
|
||||
# 创建嵌入模型(使用 OpenAI 兼容的本地模型)
|
||||
embeddings = OpenAIEmbeddings(
|
||||
openai_api_base="http://localhost:8000/v1", # 本地 VLLM 服务
|
||||
openai_api_key="no-key-needed",
|
||||
model="text-embedding-ada-002", # 假设的模型名称
|
||||
)
|
||||
|
||||
# 创建 RAG 流水线
|
||||
from app.rag import RAGPipeline, RAGConfig, RAGLevel
|
||||
|
||||
config = RAGConfig(
|
||||
collection_name="documents", # 你的集合名称
|
||||
rag_level=RAGLevel.BASIC,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
pipeline = RAGPipeline(
|
||||
embeddings=embeddings,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# 示例查询
|
||||
query = "公司报销流程是什么?"
|
||||
print(f"\n查询: {query}")
|
||||
|
||||
try:
|
||||
result = pipeline.retrieve(query)
|
||||
print(f"找到 {len(result.documents)} 个相关文档")
|
||||
|
||||
# 格式化上下文
|
||||
context = pipeline.format_context(result.documents)
|
||||
print(f"\n上下文预览:\n{context[:500]}...")
|
||||
|
||||
except Exception as e:
|
||||
print(f"检索失败: {e}")
|
||||
print("请确保 Qdrant 服务正常运行且集合存在")
|
||||
|
||||
|
||||
def demonstrate_hybrid_rag():
|
||||
"""演示混合 RAG 功能"""
|
||||
print("\n" + "="*60)
|
||||
print("演示: 混合 RAG 检索 (Level 2)")
|
||||
print("="*60)
|
||||
|
||||
embeddings = OpenAIEmbeddings(
|
||||
openai_api_base="http://localhost:8000/v1",
|
||||
openai_api_key="no-key-needed",
|
||||
model="text-embedding-ada-002",
|
||||
)
|
||||
|
||||
from app.rag import RAGPipeline, RAGConfig, RAGLevel
|
||||
|
||||
config = RAGConfig(
|
||||
collection_name="documents",
|
||||
rag_level=RAGLevel.HYBRID,
|
||||
dense_k=10,
|
||||
sparse_k=10,
|
||||
rerank_top_n=5,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
pipeline = RAGPipeline(
|
||||
embeddings=embeddings,
|
||||
config=config,
|
||||
)
|
||||
|
||||
query = "如何申请年假?"
|
||||
print(f"\n查询: {query}")
|
||||
|
||||
try:
|
||||
result = pipeline.retrieve(query)
|
||||
print(f"找到 {len(result.documents)} 个重排序后的文档")
|
||||
|
||||
except Exception as e:
|
||||
print(f"检索失败: {e}")
|
||||
|
||||
|
||||
def demonstrate_rag_fusion():
|
||||
"""演示 RAG-Fusion 功能"""
|
||||
print("\n" + "="*60)
|
||||
print("演示: RAG-Fusion (Level 3)")
|
||||
print("="*60)
|
||||
|
||||
embeddings = OpenAIEmbeddings(
|
||||
openai_api_base="http://localhost:8000/v1",
|
||||
openai_api_key="no-key-needed",
|
||||
model="text-embedding-ada-002",
|
||||
)
|
||||
|
||||
# 创建语言模型用于查询改写
|
||||
llm = VLLMOpenAI(
|
||||
openai_api_base="http://localhost:8000/v1",
|
||||
openai_api_key="no-key-needed",
|
||||
model_name="Qwen2.5-7B-Instruct", # 你的本地模型
|
||||
temperature=0.3,
|
||||
max_tokens=512,
|
||||
)
|
||||
|
||||
from app.rag import RAGPipeline, RAGConfig, RAGLevel
|
||||
|
||||
config = RAGConfig(
|
||||
collection_name="documents",
|
||||
rag_level=RAGLevel.FUSION,
|
||||
num_queries=3,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
pipeline = RAGPipeline(
|
||||
embeddings=embeddings,
|
||||
llm=llm,
|
||||
config=config,
|
||||
)
|
||||
|
||||
query = "项目上线需要哪些审批?"
|
||||
print(f"\n查询: {query}")
|
||||
|
||||
try:
|
||||
result = pipeline.retrieve(query)
|
||||
print(f"找到 {len(result.documents)} 个文档 (经过多路查询改写和重排序)")
|
||||
|
||||
except Exception as e:
|
||||
print(f"检索失败: {e}")
|
||||
|
||||
|
||||
def demonstrate_agentic_rag():
|
||||
"""演示 Agentic RAG 功能"""
|
||||
print("\n" + "="*60)
|
||||
print("演示: Agentic RAG (Level 4)")
|
||||
print("="*60)
|
||||
|
||||
embeddings = OpenAIEmbeddings(
|
||||
openai_api_base="http://localhost:8000/v1",
|
||||
openai_api_key="no-key-needed",
|
||||
model="text-embedding-ada-002",
|
||||
)
|
||||
|
||||
llm = VLLMOpenAI(
|
||||
openai_api_base="http://localhost:8000/v1",
|
||||
openai_api_key="no-key-needed",
|
||||
model_name="Qwen2.5-7B-Instruct",
|
||||
temperature=0.3,
|
||||
max_tokens=512,
|
||||
)
|
||||
|
||||
from app.rag import create_agentic_rag_pipeline
|
||||
|
||||
try:
|
||||
# 创建 Agentic RAG 流水线
|
||||
agentic_rag = create_agentic_rag_pipeline(
|
||||
embeddings=embeddings,
|
||||
agent_llm=llm,
|
||||
config={
|
||||
"collection_name": "documents",
|
||||
"verbose": True,
|
||||
},
|
||||
)
|
||||
|
||||
print("Agentic RAG 流水线创建成功!")
|
||||
print(f"- 绑定的模型: {agentic_rag['llm']}")
|
||||
print(f"- RAG 工具: {agentic_rag['tool'].name}")
|
||||
|
||||
# 演示工具调用
|
||||
print("\n工具调用示例:")
|
||||
response = agentic_rag["tool"].invoke({"query": "员工福利有哪些?"})
|
||||
print(f"工具响应预览: {response[:200]}...")
|
||||
|
||||
except Exception as e:
|
||||
print(f"创建 Agentic RAG 失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
print("RAG 系统演示")
|
||||
print("="*60)
|
||||
|
||||
# 设置环境
|
||||
setup_environment()
|
||||
|
||||
# 演示各级功能
|
||||
demonstrate_basic_rag()
|
||||
demonstrate_hybrid_rag()
|
||||
demonstrate_rag_fusion()
|
||||
demonstrate_agentic_rag()
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("演示完成!")
|
||||
print("="*60)
|
||||
|
||||
print("\n使用说明:")
|
||||
print("1. 确保 Qdrant 服务运行且集合已创建")
|
||||
print("2. 根据需要修改 embeddings 和 llm 配置")
|
||||
print("3. 在 Agent 系统中导入并使用 app.rag.tools.search_knowledge_base_tool")
|
||||
print("4. 将工具绑定到你的 Agent 模型")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
341
app/rag/pipeline.py
Normal file
341
app/rag/pipeline.py
Normal file
@@ -0,0 +1,341 @@
|
||||
"""
|
||||
RAG 检索流水线
|
||||
|
||||
组合检索器、重排序器、查询改写器等组件,提供完整的 RAG 检索功能。
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import List, Dict, Any, Optional, Union
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
|
||||
from .retriever import (
|
||||
create_base_retriever,
|
||||
create_hybrid_retriever,
|
||||
create_ensemble_retriever,
|
||||
create_qdrant_client,
|
||||
)
|
||||
from .reranker import CrossEncoderReranker
|
||||
from .query_transform import MultiQueryTransformer, create_rag_fusion_pipeline
|
||||
|
||||
|
||||
class RAGLevel(Enum):
|
||||
"""RAG 功能级别"""
|
||||
BASIC = 1 # 基础向量搜索
|
||||
HYBRID = 2 # 混合检索 + 重排序
|
||||
FUSION = 3 # RAG-Fusion
|
||||
AGENTIC = 4 # Agentic RAG
|
||||
|
||||
|
||||
@dataclass
|
||||
class RAGConfig:
|
||||
"""RAG 配置"""
|
||||
# Qdrant 配置
|
||||
collection_name: str = "documents"
|
||||
qdrant_url: Optional[str] = None
|
||||
qdrant_api_key: Optional[str] = None
|
||||
|
||||
# 检索配置
|
||||
rag_level: RAGLevel = RAGLevel.FUSION
|
||||
dense_k: int = 10 # 向量检索数量
|
||||
sparse_k: int = 10 # BM25 检索数量
|
||||
total_k: int = 20 # 总检索数量
|
||||
rerank_top_n: int = 5 # 重排序返回数量
|
||||
|
||||
# 查询改写配置
|
||||
num_queries: int = 3 # RAG-Fusion 查询数量
|
||||
|
||||
# 模型配置
|
||||
reranker_model: str = "BAAI/bge-reranker-base"
|
||||
device: Optional[str] = None
|
||||
|
||||
# 性能配置
|
||||
enable_cache: bool = True
|
||||
verbose: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrievalResult:
|
||||
"""检索结果"""
|
||||
documents: List[Document]
|
||||
query_time: float
|
||||
level: RAGLevel
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class RAGPipeline:
|
||||
"""
|
||||
RAG 检索流水线
|
||||
|
||||
支持从 Level 1 到 Level 4 的所有功能。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embeddings: Embeddings,
|
||||
llm: Optional[BaseLanguageModel] = None,
|
||||
config: Optional[RAGConfig] = None,
|
||||
):
|
||||
"""
|
||||
初始化 RAG 流水线
|
||||
|
||||
Args:
|
||||
embeddings: 嵌入模型
|
||||
llm: 语言模型(用于查询改写,Level 3+ 需要)
|
||||
config: 配置
|
||||
"""
|
||||
self.embeddings = embeddings
|
||||
self.llm = llm
|
||||
self.config = config or RAGConfig()
|
||||
|
||||
# 初始化组件
|
||||
self._client = None
|
||||
self._reranker = None
|
||||
self._query_transformer = None
|
||||
self._retriever = None
|
||||
|
||||
# 缓存
|
||||
self._cache = {}
|
||||
|
||||
def _get_client(self):
|
||||
"""获取 Qdrant 客户端"""
|
||||
if self._client is None:
|
||||
self._client = create_qdrant_client(
|
||||
url=self.config.qdrant_url,
|
||||
api_key=self.config.qdrant_api_key,
|
||||
)
|
||||
return self._client
|
||||
|
||||
def _get_reranker(self):
|
||||
"""获取重排序器"""
|
||||
if self._reranker is None:
|
||||
self._reranker = CrossEncoderReranker(
|
||||
model_name=self.config.reranker_model,
|
||||
top_n=self.config.rerank_top_n,
|
||||
device=self.config.device,
|
||||
)
|
||||
return self._reranker
|
||||
|
||||
def _get_query_transformer(self):
|
||||
"""获取查询改写器"""
|
||||
if self._query_transformer is None and self.llm is not None:
|
||||
self._query_transformer = MultiQueryTransformer(
|
||||
llm=self.llm,
|
||||
num_queries=self.config.num_queries,
|
||||
)
|
||||
return self._query_transformer
|
||||
|
||||
def _create_basic_retriever(self):
|
||||
"""创建基础检索器(Level 1)"""
|
||||
return create_base_retriever(
|
||||
collection_name=self.config.collection_name,
|
||||
embeddings=self.embeddings,
|
||||
search_kwargs={"k": self.config.total_k},
|
||||
client=self._get_client(),
|
||||
)
|
||||
|
||||
def _create_hybrid_retriever(self):
|
||||
"""创建混合检索器(Level 2)"""
|
||||
base_retriever = create_hybrid_retriever(
|
||||
collection_name=self.config.collection_name,
|
||||
embeddings=self.embeddings,
|
||||
dense_k=self.config.dense_k,
|
||||
sparse_k=self.config.sparse_k,
|
||||
client=self._get_client(),
|
||||
)
|
||||
|
||||
# 应用重排序
|
||||
reranker = self._get_reranker()
|
||||
return reranker.create_contextual_compression_retriever(base_retriever)
|
||||
|
||||
def _create_fusion_retriever(self):
|
||||
"""创建 RAG-Fusion 检索器(Level 3)"""
|
||||
if self.llm is None:
|
||||
raise ValueError("Level 3 (RAG-Fusion) 需要语言模型进行查询改写")
|
||||
|
||||
# 创建基础混合检索器
|
||||
base_retriever = create_hybrid_retriever(
|
||||
collection_name=self.config.collection_name,
|
||||
embeddings=self.embeddings,
|
||||
dense_k=self.config.dense_k,
|
||||
sparse_k=self.config.sparse_k,
|
||||
client=self._get_client(),
|
||||
)
|
||||
|
||||
# 创建 RAG-Fusion 流水线
|
||||
reranker = self._get_reranker()
|
||||
return create_rag_fusion_pipeline(
|
||||
base_retriever=base_retriever,
|
||||
llm=self.llm,
|
||||
reranker=reranker,
|
||||
num_queries=self.config.num_queries,
|
||||
)
|
||||
|
||||
def _get_retriever(self):
|
||||
"""根据配置级别获取检索器"""
|
||||
if self._retriever is None:
|
||||
if self.config.rag_level == RAGLevel.BASIC:
|
||||
self._retriever = self._create_basic_retriever()
|
||||
elif self.config.rag_level == RAGLevel.HYBRID:
|
||||
self._retriever = self._create_hybrid_retriever()
|
||||
elif self.config.rag_level == RAGLevel.FUSION:
|
||||
self._retriever = self._create_fusion_retriever()
|
||||
elif self.config.rag_level == RAGLevel.AGENTIC:
|
||||
# Agentic RAG 使用 Fusion 作为基础,在 tools.py 中包装
|
||||
self._retriever = self._create_fusion_retriever()
|
||||
else:
|
||||
raise ValueError(f"不支持的 RAG 级别: {self.config.rag_level}")
|
||||
|
||||
return self._retriever
|
||||
|
||||
def retrieve(
|
||||
self,
|
||||
query: str,
|
||||
use_cache: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> RetrievalResult:
|
||||
"""
|
||||
执行检索
|
||||
|
||||
Args:
|
||||
query: 查询文本
|
||||
use_cache: 是否使用缓存
|
||||
**kwargs: 额外参数
|
||||
|
||||
Returns:
|
||||
检索结果
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# 检查缓存
|
||||
if use_cache is None:
|
||||
use_cache = self.config.enable_cache
|
||||
|
||||
cache_key = f"{query}:{self.config.rag_level.value}"
|
||||
if use_cache and cache_key in self._cache:
|
||||
if self.config.verbose:
|
||||
print(f"使用缓存结果: {query}")
|
||||
return self._cache[cache_key]
|
||||
|
||||
# 获取检索器并执行检索
|
||||
retriever = self._get_retriever()
|
||||
documents = retriever.invoke(query, **kwargs)
|
||||
|
||||
# 计算查询时间
|
||||
query_time = time.time() - start_time
|
||||
|
||||
# 创建结果
|
||||
result = RetrievalResult(
|
||||
documents=documents,
|
||||
query_time=query_time,
|
||||
level=self.config.rag_level,
|
||||
metadata={
|
||||
"query": query,
|
||||
"collection": self.config.collection_name,
|
||||
"doc_count": len(documents),
|
||||
},
|
||||
)
|
||||
|
||||
# 缓存结果
|
||||
if use_cache:
|
||||
self._cache[cache_key] = result
|
||||
|
||||
if self.config.verbose:
|
||||
print(f"检索完成: {len(documents)} 文档, 耗时: {query_time:.2f}s")
|
||||
|
||||
return result
|
||||
|
||||
def format_context(
|
||||
self,
|
||||
documents: List[Document],
|
||||
max_length: Optional[int] = None,
|
||||
) -> str:
|
||||
"""
|
||||
格式化检索到的文档为上下文文本
|
||||
|
||||
Args:
|
||||
documents: 文档列表
|
||||
max_length: 最大长度(字符数)
|
||||
|
||||
Returns:
|
||||
格式化后的上下文文本
|
||||
"""
|
||||
context_parts = []
|
||||
total_length = 0
|
||||
|
||||
for i, doc in enumerate(documents):
|
||||
# 提取内容和元数据
|
||||
content = doc.page_content.strip()
|
||||
metadata = doc.metadata
|
||||
|
||||
# 格式化文档
|
||||
doc_text = f"[文档 {i+1}]\n"
|
||||
if metadata.get("source"):
|
||||
doc_text += f"来源: {metadata['source']}\n"
|
||||
if metadata.get("page"):
|
||||
doc_text += f"页码: {metadata['page']}\n"
|
||||
|
||||
doc_text += f"内容: {content}\n\n"
|
||||
|
||||
# 检查长度限制
|
||||
if max_length is not None:
|
||||
if total_length + len(doc_text) > max_length:
|
||||
# 如果添加这个文档会超限,则截断并添加说明
|
||||
remaining = max_length - total_length
|
||||
if remaining > 100: # 至少保留100字符
|
||||
doc_text = doc_text[:remaining] + "...\n\n[内容已截断]"
|
||||
context_parts.append(doc_text)
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
context_parts.append(doc_text)
|
||||
total_length += len(doc_text)
|
||||
|
||||
return "".join(context_parts).strip()
|
||||
|
||||
def clear_cache(self):
|
||||
"""清空缓存"""
|
||||
self._cache.clear()
|
||||
|
||||
@classmethod
|
||||
def create_from_config(
|
||||
cls,
|
||||
embeddings: Embeddings,
|
||||
llm: Optional[BaseLanguageModel] = None,
|
||||
config_dict: Optional[Dict[str, Any]] = None,
|
||||
) -> "RAGPipeline":
|
||||
"""
|
||||
从配置字典创建流水线
|
||||
|
||||
Args:
|
||||
embeddings: 嵌入模型
|
||||
llm: 语言模型
|
||||
config_dict: 配置字典
|
||||
|
||||
Returns:
|
||||
RAGPipeline 实例
|
||||
"""
|
||||
config_dict = config_dict or {}
|
||||
|
||||
# 创建配置对象
|
||||
config = RAGConfig(
|
||||
collection_name=config_dict.get("collection_name", "documents"),
|
||||
qdrant_url=config_dict.get("qdrant_url"),
|
||||
qdrant_api_key=config_dict.get("qdrant_api_key"),
|
||||
rag_level=RAGLevel(config_dict.get("rag_level", RAGLevel.FUSION.value)),
|
||||
dense_k=config_dict.get("dense_k", 10),
|
||||
sparse_k=config_dict.get("sparse_k", 10),
|
||||
total_k=config_dict.get("total_k", 20),
|
||||
rerank_top_n=config_dict.get("rerank_top_n", 5),
|
||||
num_queries=config_dict.get("num_queries", 3),
|
||||
reranker_model=config_dict.get("reranker_model", "BAAI/bge-reranker-base"),
|
||||
device=config_dict.get("device"),
|
||||
enable_cache=config_dict.get("enable_cache", True),
|
||||
verbose=config_dict.get("verbose", True),
|
||||
)
|
||||
|
||||
return cls(embeddings=embeddings, llm=llm, config=config)
|
||||
193
app/rag/query_transform.py
Normal file
193
app/rag/query_transform.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""
|
||||
查询改写器
|
||||
|
||||
基于 MultiQueryRetriever 实现多路查询改写,扩大搜索范围。
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Any
|
||||
from langchain.retrievers.multi_query import MultiQueryRetriever
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
|
||||
|
||||
class MultiQueryTransformer:
|
||||
"""
|
||||
多路查询改写器
|
||||
|
||||
将单个查询改写成多个相关查询,用于 RAG-Fusion。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm: BaseLanguageModel,
|
||||
num_queries: int = 3,
|
||||
prompt_template: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
初始化查询改写器
|
||||
|
||||
Args:
|
||||
llm: 语言模型实例
|
||||
num_queries: 生成的查询数量
|
||||
prompt_template: 提示词模板
|
||||
"""
|
||||
self.llm = llm
|
||||
self.num_queries = num_queries
|
||||
|
||||
# 默认提示词模板
|
||||
self.prompt_template = prompt_template or """
|
||||
你是一个专业的查询改写助手。你的任务是将用户的问题改写成 {num_queries} 个不同的版本。
|
||||
这些版本应该从不同的角度、使用不同的关键词来表达相同或相关的意图。
|
||||
|
||||
原始问题: {question}
|
||||
|
||||
请生成 {num_queries} 个不同版本的查询,每个版本一行。
|
||||
确保每个版本都是独立、完整的查询语句。
|
||||
|
||||
生成 {num_queries} 个查询:
|
||||
"""
|
||||
|
||||
def transform_query(self, query: str) -> List[str]:
|
||||
"""
|
||||
将单个查询改写成多个查询
|
||||
|
||||
Args:
|
||||
query: 原始查询
|
||||
|
||||
Returns:
|
||||
改写后的查询列表
|
||||
"""
|
||||
prompt = PromptTemplate.from_template(self.prompt_template)
|
||||
|
||||
chain = prompt | self.llm | StrOutputParser()
|
||||
|
||||
response = chain.invoke({
|
||||
"question": query,
|
||||
"num_queries": self.num_queries,
|
||||
})
|
||||
|
||||
# 解析响应,每行一个查询
|
||||
queries = [
|
||||
q.strip()
|
||||
for q in response.strip().split('\n')
|
||||
if q.strip()
|
||||
]
|
||||
|
||||
# 确保数量正确,如果不够则添加原始查询
|
||||
if len(queries) < self.num_queries:
|
||||
queries.extend([query] * (self.num_queries - len(queries)))
|
||||
elif len(queries) > self.num_queries:
|
||||
queries = queries[:self.num_queries]
|
||||
|
||||
# 确保包含原始查询
|
||||
if query not in queries:
|
||||
queries = [query] + queries[:self.num_queries-1]
|
||||
|
||||
return queries
|
||||
|
||||
def create_multi_query_retriever(
|
||||
self,
|
||||
base_retriever: Any,
|
||||
include_original: bool = True,
|
||||
) -> MultiQueryRetriever:
|
||||
"""
|
||||
创建多路查询检索器
|
||||
|
||||
Args:
|
||||
base_retriever: 基础检索器
|
||||
include_original: 是否包含原始查询
|
||||
|
||||
Returns:
|
||||
MultiQueryRetriever 实例
|
||||
"""
|
||||
retriever = MultiQueryRetriever.from_llm(
|
||||
retriever=base_retriever,
|
||||
llm=self.llm,
|
||||
include_original=include_original,
|
||||
)
|
||||
|
||||
# 设置生成的查询数量
|
||||
retriever.llm_chain.prompt = PromptTemplate.from_template(
|
||||
"你是一个专业的查询改写助手。你的任务是将用户的问题改写成 {num_queries} 个不同的版本。\n"
|
||||
"这些版本应该从不同的角度、使用不同的关键词来表达相同或相关的意图。\n\n"
|
||||
"原始问题: {question}\n\n"
|
||||
"请生成 {num_queries} 个不同版本的查询,每个版本一行。\n"
|
||||
"确保每个版本都是独立、完整的查询语句。\n\n"
|
||||
"生成 {num_queries} 个查询:"
|
||||
)
|
||||
|
||||
# 修改调用参数以包含 num_queries
|
||||
original_invoke = retriever.llm_chain.invoke
|
||||
|
||||
def new_invoke(input_dict):
|
||||
input_dict["num_queries"] = self.num_queries
|
||||
return original_invoke(input_dict)
|
||||
|
||||
retriever.llm_chain.invoke = new_invoke
|
||||
|
||||
return retriever
|
||||
|
||||
@classmethod
|
||||
def create_from_config(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
config: Optional[dict] = None,
|
||||
) -> "MultiQueryTransformer":
|
||||
"""
|
||||
从配置创建查询改写器
|
||||
|
||||
Args:
|
||||
llm: 语言模型实例
|
||||
config: 配置字典
|
||||
|
||||
Returns:
|
||||
MultiQueryTransformer 实例
|
||||
"""
|
||||
config = config or {}
|
||||
return cls(
|
||||
llm=llm,
|
||||
num_queries=config.get("num_queries", 3),
|
||||
prompt_template=config.get("prompt_template", None),
|
||||
)
|
||||
|
||||
|
||||
def create_rag_fusion_pipeline(
|
||||
base_retriever: Any,
|
||||
llm: BaseLanguageModel,
|
||||
reranker: Optional[Any] = None,
|
||||
num_queries: int = 3,
|
||||
) -> Any:
|
||||
"""
|
||||
创建完整的 RAG-Fusion 流水线
|
||||
|
||||
Args:
|
||||
base_retriever: 基础检索器
|
||||
llm: 语言模型(用于查询改写)
|
||||
reranker: 重排序器(可选)
|
||||
num_queries: 查询改写数量
|
||||
|
||||
Returns:
|
||||
检索器实例
|
||||
"""
|
||||
# 创建多路查询改写器
|
||||
query_transformer = MultiQueryTransformer(
|
||||
llm=llm,
|
||||
num_queries=num_queries,
|
||||
)
|
||||
|
||||
# 创建多路查询检索器
|
||||
multi_query_retriever = query_transformer.create_multi_query_retriever(
|
||||
base_retriever=base_retriever,
|
||||
include_original=True,
|
||||
)
|
||||
|
||||
# 如果提供了重排序器,则应用重排序
|
||||
if reranker is not None:
|
||||
from langchain.retrievers import ContextualCompressionRetriever
|
||||
return ContextualCompressionRetriever(
|
||||
base_compressor=reranker,
|
||||
base_retriever=multi_query_retriever,
|
||||
)
|
||||
|
||||
return multi_query_retriever
|
||||
23
app/rag/requirements.txt
Normal file
23
app/rag/requirements.txt
Normal file
@@ -0,0 +1,23 @@
|
||||
# RAG 系统依赖
|
||||
# 基础框架
|
||||
langchain>=0.1.0
|
||||
langchain-core>=0.1.0
|
||||
langchain-openai>=0.0.1
|
||||
langchain-qdrant>=0.1.0
|
||||
|
||||
# 用于 Cross-Encoder 重排序模型
|
||||
sentence-transformers>=2.2.0
|
||||
|
||||
# 用于 BM25 关键词混合检索
|
||||
rank-bm25>=0.2.2
|
||||
|
||||
# Qdrant 客户端
|
||||
qdrant-client>=1.6.0
|
||||
|
||||
# 可选的本地模型支持
|
||||
# vllm>=0.5.0 # 如果需要本地模型推理
|
||||
# transformers>=4.35.0 # 如果需要其他模型支持
|
||||
|
||||
# 开发依赖(测试用)
|
||||
pytest>=7.0.0
|
||||
pytest-asyncio>=0.21.0
|
||||
141
app/rag/reranker.py
Normal file
141
app/rag/reranker.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""
|
||||
Cross-Encoder 重排序器
|
||||
|
||||
使用 sentence-transformers 加载交叉编码器模型,对检索结果进行精排。
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import List, Dict, Any, Optional
|
||||
from langchain.retrievers.document_compressors import CrossEncoderReranker
|
||||
from langchain_core.documents import Document
|
||||
from sentence_transformers import CrossEncoder
|
||||
|
||||
|
||||
class CrossEncoderReranker:
|
||||
"""
|
||||
Cross-Encoder 重排序器包装类
|
||||
|
||||
支持 BAAI/bge-reranker-base 等模型。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "BAAI/bge-reranker-base",
|
||||
top_n: int = 5,
|
||||
device: Optional[str] = None,
|
||||
cache_folder: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
初始化重排序器
|
||||
|
||||
Args:
|
||||
model_name: 模型名称或路径
|
||||
top_n: 返回的顶部文档数量
|
||||
device: 设备(cpu/cuda),如果为 None 则自动选择
|
||||
cache_folder: 模型缓存目录
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self.top_n = top_n
|
||||
self.device = device
|
||||
self.cache_folder = cache_folder or os.path.join(
|
||||
os.path.expanduser("~"), ".cache", "sentence_transformers"
|
||||
)
|
||||
|
||||
# 延迟加载模型
|
||||
self._model = None
|
||||
self._langchain_reranker = None
|
||||
|
||||
def _load_model(self):
|
||||
"""加载交叉编码器模型"""
|
||||
if self._model is None:
|
||||
try:
|
||||
self._model = CrossEncoder(
|
||||
self.model_name,
|
||||
device=self.device,
|
||||
cache_folder=self.cache_folder,
|
||||
)
|
||||
except Exception as e:
|
||||
# 如果指定模型加载失败,尝试备用模型
|
||||
print(f"加载模型 {self.model_name} 失败: {e}")
|
||||
print("尝试加载备用模型 BAAI/bge-reranker-v2-m3...")
|
||||
self._model = CrossEncoder(
|
||||
"BAAI/bge-reranker-v2-m3",
|
||||
device=self.device,
|
||||
cache_folder=self.cache_folder,
|
||||
)
|
||||
|
||||
def _create_langchain_reranker(self):
|
||||
"""创建 LangChain 重排序器"""
|
||||
if self._langchain_reranker is None:
|
||||
self._load_model()
|
||||
self._langchain_reranker = CrossEncoderReranker(
|
||||
model=self._model,
|
||||
top_n=self.top_n,
|
||||
)
|
||||
|
||||
def rerank(
|
||||
self,
|
||||
query: str,
|
||||
documents: List[Document],
|
||||
) -> List[Document]:
|
||||
"""
|
||||
对文档进行重排序
|
||||
|
||||
Args:
|
||||
query: 查询文本
|
||||
documents: 待排序文档列表
|
||||
|
||||
Returns:
|
||||
重排序后的文档列表
|
||||
"""
|
||||
self._create_langchain_reranker()
|
||||
return self._langchain_reranker.compress_documents(
|
||||
documents=documents,
|
||||
query=query,
|
||||
)
|
||||
|
||||
def create_contextual_compression_retriever(
|
||||
self,
|
||||
base_retriever: Any,
|
||||
) -> Any:
|
||||
"""
|
||||
创建上下文压缩检索器
|
||||
|
||||
Args:
|
||||
base_retriever: 基础检索器
|
||||
|
||||
Returns:
|
||||
上下文压缩检索器
|
||||
"""
|
||||
from langchain.retrievers import ContextualCompressionRetriever
|
||||
|
||||
self._create_langchain_reranker()
|
||||
|
||||
compression_retriever = ContextualCompressionRetriever(
|
||||
base_compressor=self._langchain_reranker,
|
||||
base_retriever=base_retriever,
|
||||
)
|
||||
|
||||
return compression_retriever
|
||||
|
||||
@classmethod
|
||||
def create_from_config(
|
||||
cls,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
) -> "CrossEncoderReranker":
|
||||
"""
|
||||
从配置创建重排序器
|
||||
|
||||
Args:
|
||||
config: 配置字典,包含 model_name, top_n, device 等
|
||||
|
||||
Returns:
|
||||
CrossEncoderReranker 实例
|
||||
"""
|
||||
config = config or {}
|
||||
return cls(
|
||||
model_name=config.get("model_name", "BAAI/bge-reranker-base"),
|
||||
top_n=config.get("top_n", 5),
|
||||
device=config.get("device", None),
|
||||
cache_folder=config.get("cache_folder", None),
|
||||
)
|
||||
144
app/rag/retriever.py
Normal file
144
app/rag/retriever.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""
|
||||
Qdrant 向量检索器
|
||||
|
||||
提供基础向量检索、混合检索(Dense + BM25)功能。
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import List, Dict, Any, Optional
|
||||
from langchain_qdrant import Qdrant
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.retrievers import ContextualCompressionRetriever
|
||||
from langchain.retrievers.document_compressors import DocumentCompressorPipeline
|
||||
from langchain.retrievers import EnsembleRetriever
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.http import models
|
||||
|
||||
|
||||
def create_qdrant_client(
|
||||
url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
) -> QdrantClient:
|
||||
"""
|
||||
创建 Qdrant 客户端
|
||||
|
||||
Args:
|
||||
url: Qdrant 服务地址,默认从环境变量 QDRANT_URL 读取
|
||||
api_key: API 密钥,默认从环境变量 QDRANT_API_KEY 读取
|
||||
|
||||
Returns:
|
||||
QdrantClient 实例
|
||||
"""
|
||||
url = url or os.getenv("QDRANT_URL", "http://localhost:6333")
|
||||
api_key = api_key or os.getenv("QDRANT_API_KEY")
|
||||
|
||||
client_args = {"url": url}
|
||||
if api_key:
|
||||
client_args["api_key"] = api_key
|
||||
|
||||
return QdrantClient(**client_args)
|
||||
|
||||
|
||||
def create_base_retriever(
|
||||
collection_name: str,
|
||||
embeddings: Embeddings,
|
||||
search_kwargs: Optional[Dict[str, Any]] = None,
|
||||
client: Optional[QdrantClient] = None,
|
||||
) -> Qdrant:
|
||||
"""
|
||||
创建基础向量检索器
|
||||
|
||||
Args:
|
||||
collection_name: Qdrant 集合名称
|
||||
embeddings: 嵌入模型
|
||||
search_kwargs: 搜索参数,默认 {"k": 20}
|
||||
client: Qdrant 客户端,如果为 None 则自动创建
|
||||
|
||||
Returns:
|
||||
Qdrant 检索器实例
|
||||
"""
|
||||
if client is None:
|
||||
client = create_qdrant_client()
|
||||
|
||||
search_kwargs = search_kwargs or {"k": 20}
|
||||
|
||||
# 创建 Qdrant 检索器
|
||||
retriever = Qdrant.from_existing_collection(
|
||||
embedding=embeddings,
|
||||
collection_name=collection_name,
|
||||
client=client,
|
||||
content_payload_key="content", # 假设存储的文本字段名为 "content"
|
||||
metadata_payload_key="metadata", # 元数据字段名
|
||||
)
|
||||
|
||||
return retriever.as_retriever(search_kwargs=search_kwargs)
|
||||
|
||||
|
||||
def create_hybrid_retriever(
|
||||
collection_name: str,
|
||||
embeddings: Embeddings,
|
||||
dense_k: int = 10,
|
||||
sparse_k: int = 10,
|
||||
client: Optional[QdrantClient] = None,
|
||||
) -> ContextualCompressionRetriever:
|
||||
"""
|
||||
创建混合检索器(Dense Vector + BM25)
|
||||
|
||||
Args:
|
||||
collection_name: Qdrant 集合名称
|
||||
embeddings: 嵌入模型
|
||||
dense_k: 向量检索返回数量
|
||||
sparse_k: BM25 检索返回数量
|
||||
client: Qdrant 客户端
|
||||
|
||||
Returns:
|
||||
混合检索器
|
||||
"""
|
||||
if client is None:
|
||||
client = create_qdrant_client()
|
||||
|
||||
# 基础检索器(Qdrant 支持混合检索)
|
||||
base_retriever = Qdrant.from_existing_collection(
|
||||
embedding=embeddings,
|
||||
collection_name=collection_name,
|
||||
client=client,
|
||||
content_payload_key="content",
|
||||
metadata_payload_key="metadata",
|
||||
)
|
||||
|
||||
# 配置混合检索参数
|
||||
search_kwargs = {
|
||||
"k": dense_k + sparse_k, # 总返回数量
|
||||
"score_threshold": 0.3, # 相似度阈值
|
||||
}
|
||||
|
||||
return base_retriever.as_retriever(search_kwargs=search_kwargs)
|
||||
|
||||
|
||||
def create_ensemble_retriever(
|
||||
retrievers: List[Any],
|
||||
weights: Optional[List[float]] = None,
|
||||
c: int = 60,
|
||||
) -> EnsembleRetriever:
|
||||
"""
|
||||
创建集成检索器,支持倒数排名融合 (RRF)
|
||||
|
||||
Args:
|
||||
retrievers: 检索器列表
|
||||
weights: 检索器权重
|
||||
c: RRF 常数(通常为60)
|
||||
|
||||
Returns:
|
||||
集成检索器
|
||||
"""
|
||||
if weights is None:
|
||||
weights = [1.0 / len(retrievers)] * len(retrievers)
|
||||
|
||||
ensemble = EnsembleRetriever(
|
||||
retrievers=retrievers,
|
||||
weights=weights,
|
||||
c=c,
|
||||
search_type="rrf", # 使用倒数排名融合
|
||||
)
|
||||
|
||||
return ensemble
|
||||
230
app/rag/tools.py
Normal file
230
app/rag/tools.py
Normal file
@@ -0,0 +1,230 @@
|
||||
"""
|
||||
RAG 工具包装
|
||||
|
||||
将 RAG 流水线包装成 LangChain Tool,供 Agent 调用。
|
||||
"""
|
||||
|
||||
from typing import Optional, Dict, Any
|
||||
from langchain.tools import tool
|
||||
from langchain_core.tools import Tool
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
|
||||
from .pipeline import RAGPipeline, RAGConfig, RAGLevel
|
||||
|
||||
|
||||
class RAGTool:
|
||||
"""
|
||||
RAG 工具包装器
|
||||
|
||||
将 RAG 流水线包装成 Agent 可调用的工具。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pipeline: RAGPipeline,
|
||||
tool_name: str = "search_knowledge_base",
|
||||
tool_description: str = None,
|
||||
):
|
||||
"""
|
||||
初始化 RAG 工具
|
||||
|
||||
Args:
|
||||
pipeline: RAG 流水线实例
|
||||
tool_name: 工具名称
|
||||
tool_description: 工具描述
|
||||
"""
|
||||
self.pipeline = pipeline
|
||||
self.tool_name = tool_name
|
||||
|
||||
# 默认工具描述
|
||||
self.tool_description = tool_description or (
|
||||
"在知识库中搜索相关信息。当用户询问需要查阅文档、规章制度、"
|
||||
"专业知识或需要基于已知信息回答的问题时使用此工具。"
|
||||
"输入应为要搜索的查询文本。"
|
||||
)
|
||||
|
||||
# 创建 LangChain 工具
|
||||
self._tool = self._create_tool()
|
||||
|
||||
def _create_tool(self) -> Tool:
|
||||
"""创建 LangChain 工具"""
|
||||
|
||||
@tool(self.tool_name, args_schema=None)
|
||||
def search_knowledge_base(query: str) -> str:
|
||||
"""
|
||||
在知识库中搜索相关信息
|
||||
|
||||
Args:
|
||||
query: 搜索查询
|
||||
|
||||
Returns:
|
||||
格式化后的搜索结果
|
||||
"""
|
||||
try:
|
||||
# 执行检索
|
||||
result = self.pipeline.retrieve(query)
|
||||
|
||||
if not result.documents:
|
||||
return "在知识库中未找到相关信息。"
|
||||
|
||||
# 格式化上下文
|
||||
context = self.pipeline.format_context(
|
||||
result.documents,
|
||||
max_length=4000, # 限制上下文长度
|
||||
)
|
||||
|
||||
# 构建响应
|
||||
response = (
|
||||
f"🔍 在知识库中找到了 {len(result.documents)} 条相关信息:\n\n"
|
||||
f"{context}\n\n"
|
||||
f"⏱️ 检索耗时: {result.query_time:.2f}秒"
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"检索过程中发生错误: {str(e)}"
|
||||
if self.pipeline.config.verbose:
|
||||
print(f"RAG 工具错误: {error_msg}")
|
||||
return error_msg
|
||||
|
||||
# 设置工具描述
|
||||
search_knowledge_base.description = self.tool_description
|
||||
|
||||
return search_knowledge_base
|
||||
|
||||
def get_tool(self) -> Tool:
|
||||
"""获取 LangChain 工具"""
|
||||
return self._tool
|
||||
|
||||
def __call__(self, query: str) -> str:
|
||||
"""直接调用工具"""
|
||||
return self._tool.invoke({"query": query})
|
||||
|
||||
|
||||
def create_rag_tool(
|
||||
embeddings: Embeddings,
|
||||
llm: Optional[BaseLanguageModel] = None,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
tool_name: str = "search_knowledge_base",
|
||||
tool_description: Optional[str] = None,
|
||||
) -> Tool:
|
||||
"""
|
||||
创建 RAG 工具(便捷函数)
|
||||
|
||||
Args:
|
||||
embeddings: 嵌入模型
|
||||
llm: 语言模型(用于高级 RAG 功能)
|
||||
config: RAG 配置字典
|
||||
tool_name: 工具名称
|
||||
tool_description: 工具描述
|
||||
|
||||
Returns:
|
||||
LangChain Tool 实例
|
||||
"""
|
||||
# 创建 RAG 流水线
|
||||
pipeline = RAGPipeline.create_from_config(
|
||||
embeddings=embeddings,
|
||||
llm=llm,
|
||||
config_dict=config,
|
||||
)
|
||||
|
||||
# 创建工具包装器
|
||||
rag_tool = RAGTool(
|
||||
pipeline=pipeline,
|
||||
tool_name=tool_name,
|
||||
tool_description=tool_description,
|
||||
)
|
||||
|
||||
return rag_tool.get_tool()
|
||||
|
||||
|
||||
# 导出便捷函数
|
||||
search_knowledge_base_tool = create_rag_tool
|
||||
|
||||
|
||||
def bind_rag_to_agent(
|
||||
agent_llm: BaseLanguageModel,
|
||||
embeddings: Embeddings,
|
||||
rag_llm: Optional[BaseLanguageModel] = None,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
tool_name: str = "search_knowledge_base",
|
||||
) -> BaseLanguageModel:
|
||||
"""
|
||||
将 RAG 工具绑定到 Agent 模型
|
||||
|
||||
Args:
|
||||
agent_llm: Agent 使用的语言模型
|
||||
embeddings: 嵌入模型
|
||||
rag_llm: RAG 流水线使用的语言模型(如果与 agent_llm 不同)
|
||||
config: RAG 配置
|
||||
tool_name: 工具名称
|
||||
|
||||
Returns:
|
||||
绑定工具后的模型
|
||||
"""
|
||||
# 如果未指定 RAG LLM,使用 Agent LLM
|
||||
if rag_llm is None:
|
||||
rag_llm = agent_llm
|
||||
|
||||
# 创建 RAG 工具
|
||||
rag_tool = create_rag_tool(
|
||||
embeddings=embeddings,
|
||||
llm=rag_llm,
|
||||
config=config,
|
||||
tool_name=tool_name,
|
||||
)
|
||||
|
||||
# 绑定工具到模型
|
||||
return agent_llm.bind_tools([rag_tool])
|
||||
|
||||
|
||||
def create_agentic_rag_pipeline(
|
||||
embeddings: Embeddings,
|
||||
agent_llm: BaseLanguageModel,
|
||||
rag_llm: Optional[BaseLanguageModel] = None,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
创建完整的 Agentic RAG 流水线(Level 4)
|
||||
|
||||
Args:
|
||||
embeddings: 嵌入模型
|
||||
agent_llm: Agent 模型
|
||||
rag_llm: RAG 专用模型
|
||||
config: 配置
|
||||
|
||||
Returns:
|
||||
包含模型和工具的字典
|
||||
"""
|
||||
# 配置 Agentic RAG 级别
|
||||
if config is None:
|
||||
config = {}
|
||||
config["rag_level"] = RAGLevel.AGENTIC.value
|
||||
|
||||
# 创建 RAG 工具
|
||||
rag_tool = create_rag_tool(
|
||||
embeddings=embeddings,
|
||||
llm=rag_llm or agent_llm,
|
||||
config=config,
|
||||
tool_name="search_knowledge_base",
|
||||
tool_description=(
|
||||
"在知识库中搜索相关信息。当用户询问需要查阅文档、规章制度、"
|
||||
"专业知识或需要基于已知信息回答的问题时使用此工具。"
|
||||
"Agent 应该先判断是否需要使用此工具,然后调用它获取上下文。"
|
||||
),
|
||||
)
|
||||
|
||||
# 绑定工具到模型
|
||||
bound_llm = agent_llm.bind_tools([rag_tool])
|
||||
|
||||
return {
|
||||
"llm": bound_llm,
|
||||
"tool": rag_tool,
|
||||
"pipeline": RAGPipeline.create_from_config(
|
||||
embeddings=embeddings,
|
||||
llm=rag_llm or agent_llm,
|
||||
config_dict=config,
|
||||
),
|
||||
}
|
||||
@@ -119,6 +119,7 @@ def _handle_ai_response():
|
||||
api_thought = ""
|
||||
display_text = ""
|
||||
display_thought = ""
|
||||
rag_sources = None # 存储 RAG 检索来源信息
|
||||
|
||||
# 调用流式 API
|
||||
stream = api_client.chat_stream(
|
||||
@@ -213,6 +214,25 @@ def _handle_ai_response():
|
||||
last_msg = messages_update[-1] if messages_update else {}
|
||||
if isinstance(last_msg, dict) and last_msg.get("role") == "tool":
|
||||
tool_name = last_msg.get("name", "unknown")
|
||||
tool_content = last_msg.get("content", "")
|
||||
# 存储 RAG 检索结果
|
||||
if tool_name == "search_knowledge_base":
|
||||
# 尝试解析 tool_content,它可能是 JSON 字符串
|
||||
sources = []
|
||||
try:
|
||||
if isinstance(tool_content, str):
|
||||
import json
|
||||
data = json.loads(tool_content)
|
||||
else:
|
||||
data = tool_content
|
||||
# 提取来源列表
|
||||
if isinstance(data, dict) and "sources" in data:
|
||||
sources = data["sources"]
|
||||
else:
|
||||
sources = [str(data)]
|
||||
except Exception:
|
||||
sources = [str(tool_content)]
|
||||
rag_sources = sources
|
||||
tool_status_placeholder.success(f"✅ 工具 {tool_name} 执行完成")
|
||||
# 短暂显示后清除,保持界面清爽
|
||||
import time
|
||||
@@ -270,6 +290,31 @@ def _handle_ai_response():
|
||||
# 移除光标
|
||||
message_placeholder.markdown(display_text)
|
||||
|
||||
# 显示 RAG 检索来源(如果有)
|
||||
if rag_sources:
|
||||
with st.expander("🔍 检索来源", expanded=False):
|
||||
# 格式化来源列表
|
||||
if isinstance(rag_sources, list):
|
||||
for i, source in enumerate(rag_sources, 1):
|
||||
if isinstance(source, dict):
|
||||
content = source.get("page_content", source.get("content", str(source)))
|
||||
metadata = source.get("metadata", {})
|
||||
filename = metadata.get("filename", metadata.get("source", "未知文件"))
|
||||
page = metadata.get("page", metadata.get("page_number", ""))
|
||||
if page:
|
||||
source_info = f"**来源 {i}:** {filename} (第{page}页)"
|
||||
else:
|
||||
source_info = f"**来源 {i}:** {filename}"
|
||||
st.markdown(source_info)
|
||||
# 显示内容预览(前200字符)
|
||||
preview = content[:200] + "..." if len(content) > 200 else content
|
||||
st.markdown(f"> {preview}")
|
||||
st.markdown("---")
|
||||
else:
|
||||
st.markdown(f"**来源 {i}:** {str(source)}")
|
||||
else:
|
||||
st.markdown(str(rag_sources))
|
||||
|
||||
# 拼装包含思考过程的完整内容,以便后续在历史中正确渲染
|
||||
final_content = display_text
|
||||
if display_thought:
|
||||
|
||||
109
rag_indexer/README.md
Normal file
109
rag_indexer/README.md
Normal file
@@ -0,0 +1,109 @@
|
||||
# 离线 RAG 索引构建系统 (Offline RAG Indexer)
|
||||
|
||||
该模块负责 RAG 系统的阶段一:**离线索引构建**。它将外部的非结构化数据(如文档、PDF、网页等)清洗、切分并转化为向量,最终存入向量数据库中。
|
||||
|
||||
## 📊 系统工作流示意图
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
A[原始文档集合 <br> PDF / Word / Markdown] --> B(文档加载器 DocumentLoader)
|
||||
B --> C{文本切分策略 Splitter}
|
||||
|
||||
C -->|基础策略| D1[固定字符长度切分 <br> Recursive Split]
|
||||
C -->|进阶策略| D2[语义边界切分 <br> Semantic Chunking]
|
||||
C -->|高级策略| D3[父子文档切分 <br> Parent-Child / Auto-merging]
|
||||
|
||||
D1 & D2 & D3 --> E[向量化 Embedder <br> llama.cpp: embeddinggemma]
|
||||
|
||||
E --> F[(Qdrant 向量数据库)]
|
||||
|
||||
subgraph "元数据管理"
|
||||
G[提取作者、日期、页码等元数据 Metadata] -.附加.-> E
|
||||
end
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🎯 演进路线与核心算法 (Roadmap)
|
||||
|
||||
### Level 1: 基础暴力切分 (Basic Recursive Splitting)
|
||||
- **核心算法**: 递归字符切分。它按照预定义的分隔符列表(如 `["\n\n", "\n", " ", ""]`)从大到小尝试切分文本,直到每块的大小满足最大长度限制。
|
||||
- **优缺点**: 实现极简单,速度快。但非常容易将一句话拦腰截断,导致上下文语义丢失。
|
||||
- **实现指南**:
|
||||
- 从 `langchain.text_splitter` 导入 `RecursiveCharacterTextSplitter`。
|
||||
- 实例化时设置 `chunk_size`(如 500)和 `chunk_overlap`(如 50),直接调用 `.split_documents(raw_docs)` 方法。
|
||||
|
||||
### Level 2: 语义动态切分 (Semantic Chunking)
|
||||
- **核心算法**: 句子级相似度阈值算法。
|
||||
1. 将文章按标点符号按句子拆分。
|
||||
2. 使用轻量级 Embedding 模型将每一句向量化。
|
||||
3. 计算相邻两句之间的余弦相似度 (Cosine Similarity)。
|
||||
4. 当相似度低于设定阈值时(说明两句话讲的不是同一件事,语义发生了转折),在此处“切断”形成一个新的块。
|
||||
- **优缺点**: 极大程度保留了段落内语义的连贯性,对 LLM 回答非常友好。但由于在切分阶段就需要调用向量模型,耗时略长。
|
||||
- **实现指南**:
|
||||
- 从 `langchain_experimental.text_splitter` 导入 `SemanticChunker`。
|
||||
- 实例化时需要传入你已经配置好的 Embedding 模型实例(如基于 `OpenAIEmbeddings` 封装的 llama.cpp 本地模型),并设置 `breakpoint_threshold_type="percentile"` 等阈值参数。
|
||||
|
||||
### Level 3: 高级父子块策略 (Parent-Child / Auto-merging)
|
||||
- **核心算法**: 层次化双重存储与映射。
|
||||
- **切分机制**: 首先将文档粗切为较大的“父块 (Parent Chunk, 约 1000 词)”,随后将父块细切为较小的“子块 (Child Chunk, 约 200 词)”。
|
||||
- **存储机制**: 仅仅将**子块**的向量存入 Qdrant 用于精准计算距离;将**父块**的原始内容存在内存或 Document Store (如 KV 数据库) 中,通过 UUID 相互映射。
|
||||
- **核心思路**: 解决 RAG 领域经典的矛盾——检索时块越小越容易精确命中(去除噪声);但生成回答时,块越大越能给大模型提供充足的上下文背景。
|
||||
- **实现指南**:
|
||||
- 使用 `langchain.retrievers` 中的 `ParentDocumentRetriever` 模块。
|
||||
- 在写入时,你需要同时准备一个底层的 `VectorStore` (即 Qdrant) 和一个 `BaseStore` (比如原生的 `InMemoryStore` 或 `Redis`)。
|
||||
- 将两种不同的 `TextSplitter` 分别赋值给检索器的 `child_splitter` 和 `parent_splitter`,然后调用 `.add_documents()` 即可让系统自动完成映射。
|
||||
|
||||
### Level 4: GraphRAG 与 多模态 (Graph & Multi-modal)
|
||||
- **核心算法**: LLM 实体关系抽取 (NER & Relation Extraction)。
|
||||
- **核心思路**: 解决传统纯向量检索难以处理“跨文档复杂关系推理”的痛点(如:A公司的CEO是谁?他名下的B公司主要业务是什么?这种需要横跨多页 PDF 的跳跃性问题)。
|
||||
- **实现指南**:
|
||||
- 使用本地的大模型(如 `Gemma-4-E2B`)配合 `langchain_community.graphs` 模块。
|
||||
- 利用 `LLMGraphTransformer` 组件,在读取文档时,通过预设的 Prompt 强制大模型提取出实体(Node)和关系(Edge),直接写入诸如 Neo4j 这样的图数据库中,而非传统的 Qdrant 向量库。
|
||||
|
||||
---
|
||||
|
||||
## <20> 所需依赖与安装
|
||||
|
||||
为了支持完整的文档解析和 Qdrant 写入,需要安装以下 Python 包:
|
||||
|
||||
```bash
|
||||
# 基础核心库
|
||||
pip install langchain langchain-core langchain-openai langchain-qdrant
|
||||
|
||||
# 用于复杂文档解析 (PDF, Word, Excel 等)
|
||||
pip install unstructured pdf2image pdfminer.six
|
||||
|
||||
# 用于语义分块 (可选)
|
||||
pip install langchain-experimental
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 📂 架构与文件结构设计
|
||||
|
||||
在 `rag_indexer/` 目录下,需创建以下核心文件:
|
||||
|
||||
```text
|
||||
rag_indexer/
|
||||
├── __init__.py
|
||||
├── loaders.py # 负责调用 unstructured 解析不同类型文件
|
||||
├── splitters.py # 负责实现 Recursive、Semantic、Parent-Child 切分逻辑
|
||||
├── embedders.py # 封装本地 llama.cpp 交互的 Embedding 接口
|
||||
├── vector_store.py # 封装 Qdrant 写入、Upsert、Collection 初始化操作
|
||||
└── builder.py # 核心编排文件,将上述模块串联成 Pipeline
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
|
||||
|
||||
### 串联与触发方式
|
||||
在你的 LangGraph 系统外,创建一个执行脚本 `scripts/run_indexer.py`:
|
||||
|
||||
```bash
|
||||
# 终端执行,将本地的 PDF 手册刷入向量数据库
|
||||
export QDRANT_URL="http://115.190.121.151:6333"
|
||||
python scripts/run_indexer.py --file data/user_docs/tech_manual.pdf
|
||||
```
|
||||
这相当于系统后台的**“离线学习阶段”**,你可以随时挂载定时任务去扫描文件夹,增量更新知识库。
|
||||
25
rag_indexer/__init__.py
Normal file
25
rag_indexer/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""
|
||||
Offline RAG Indexer module.
|
||||
"""
|
||||
|
||||
from .loaders import DocumentLoader
|
||||
from .splitters import (
|
||||
RecursiveSplitter,
|
||||
SemanticSplitter,
|
||||
ParentChildSplitter,
|
||||
SplitterType,
|
||||
)
|
||||
from .embedders import LlamaCppEmbedder
|
||||
from .vector_store import QdrantVectorStore
|
||||
from .builder import IndexBuilder
|
||||
|
||||
__all__ = [
|
||||
"DocumentLoader",
|
||||
"RecursiveSplitter",
|
||||
"SemanticSplitter",
|
||||
"ParentChildSplitter",
|
||||
"SplitterType",
|
||||
"LlamaCppEmbedder",
|
||||
"QdrantVectorStore",
|
||||
"IndexBuilder",
|
||||
]
|
||||
277
rag_indexer/builder.py
Normal file
277
rag_indexer/builder.py
Normal file
@@ -0,0 +1,277 @@
|
||||
"""
|
||||
Core pipeline builder for offline RAG index construction.
|
||||
|
||||
Now supports LangChain's ParentDocumentRetriever for parent-child chunking.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Union, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain.retrievers import ParentDocumentRetriever
|
||||
from langchain.storage import LocalFileStore, BaseStore
|
||||
|
||||
from .loaders import DocumentLoader
|
||||
from .splitters import SplitterType, get_splitter, ParentChildSplitter
|
||||
from .embedders import LlamaCppEmbedder
|
||||
from .vector_store import QdrantVectorStore
|
||||
from .docstore_manager import get_docstore, PostgresDocStore, create_docstore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParentChildConfig:
|
||||
"""Configuration for parent-child splitting."""
|
||||
parent_chunk_size: int = 1000
|
||||
child_chunk_size: int = 200
|
||||
parent_chunk_overlap: int = 100
|
||||
child_chunk_overlap: int = 20
|
||||
search_k: int = 5
|
||||
docstore_path: str = None
|
||||
docstore_type: str = "local"
|
||||
docstore_conn_string: str = None
|
||||
|
||||
|
||||
class IndexBuilder:
|
||||
"""Main pipeline for RAG index construction."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str = "rag_documents",
|
||||
qdrant_url: str = None,
|
||||
splitter_type: SplitterType = SplitterType.RECURSIVE,
|
||||
**splitter_kwargs,
|
||||
):
|
||||
self.collection_name = collection_name
|
||||
self.qdrant_url = qdrant_url
|
||||
self.splitter_type = splitter_type
|
||||
self.splitter_kwargs = splitter_kwargs
|
||||
|
||||
# Components
|
||||
self.loader = DocumentLoader()
|
||||
self.embedder = LlamaCppEmbedder()
|
||||
self.embeddings = self.embedder.as_langchain_embeddings()
|
||||
|
||||
self.vector_store = QdrantVectorStore(
|
||||
collection_name=collection_name,
|
||||
embeddings=self.embeddings,
|
||||
qdrant_url=qdrant_url,
|
||||
)
|
||||
|
||||
# Splitter (except parent-child which is handled separately)
|
||||
if splitter_type != SplitterType.PARENT_CHILD:
|
||||
if splitter_type == SplitterType.SEMANTIC:
|
||||
splitter_kwargs["embeddings"] = self.embeddings
|
||||
self.splitter = get_splitter(splitter_type, **splitter_kwargs)
|
||||
else:
|
||||
self.splitter = None
|
||||
# Initialize ParentDocumentRetriever for parent-child splitting
|
||||
self._init_parent_child_retriever()
|
||||
|
||||
def _init_parent_child_retriever(self, **kwargs):
|
||||
"""
|
||||
Initialize ParentDocumentRetriever for parent-child chunking.
|
||||
|
||||
This replaces the custom ParentChildSplitter logic.
|
||||
"""
|
||||
# Parse kwargs for parent-child config
|
||||
parent_size = kwargs.get("parent_chunk_size", 1000)
|
||||
child_size = kwargs.get("child_chunk_size", 200)
|
||||
parent_overlap = kwargs.get("parent_chunk_overlap", kwargs.get("chunk_overlap", 100))
|
||||
child_overlap = kwargs.get("child_chunk_overlap", kwargs.get("chunk_overlap", 20))
|
||||
|
||||
# Define splitters
|
||||
self.parent_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=parent_size,
|
||||
chunk_overlap=parent_overlap,
|
||||
)
|
||||
self.child_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=child_size,
|
||||
chunk_overlap=child_overlap,
|
||||
)
|
||||
|
||||
# Vector store (for child chunks)
|
||||
self.vector_store_obj = self.vector_store.get_langchain_vectorstore()
|
||||
|
||||
# Document store (for parent chunks)
|
||||
docstore_path = kwargs.get("docstore_path")
|
||||
docstore_type = kwargs.get("docstore_type", "local")
|
||||
docstore_conn = kwargs.get("docstore_conn_string")
|
||||
|
||||
if docstore_type == "postgres" and docstore_conn:
|
||||
self.docstore = PostgresDocStore(docstore_conn)
|
||||
self._docstore_conn = docstore_conn
|
||||
else:
|
||||
self.docstore = get_docstore(docstore_path)
|
||||
self._docstore_conn = None
|
||||
|
||||
# Create retriever
|
||||
self.retriever = ParentDocumentRetriever(
|
||||
vectorstore=self.vector_store_obj,
|
||||
docstore=self.docstore,
|
||||
child_splitter=self.child_splitter,
|
||||
parent_splitter=self.parent_splitter,
|
||||
search_kwargs={"k": kwargs.get("search_k", 5)},
|
||||
)
|
||||
|
||||
def build_from_file(self, file_path: Union[str, Path]) -> int:
|
||||
logger.info("Loading file: %s", file_path)
|
||||
documents = self.loader.load_file(file_path)
|
||||
logger.info("Loaded %d documents", len(documents))
|
||||
return self._process_documents(documents)
|
||||
|
||||
def build_from_directory(self, directory_path: Union[str, Path], recursive: bool = True) -> int:
|
||||
logger.info("Loading directory: %s (recursive=%s)", directory_path, recursive)
|
||||
documents = self.loader.load_directory(directory_path, recursive=recursive)
|
||||
logger.info("Loaded %d documents from directory", len(documents))
|
||||
return self._process_documents(documents)
|
||||
|
||||
def _process_documents(self, documents: List[Document]) -> int:
|
||||
if not documents:
|
||||
logger.warning("No documents to process")
|
||||
return 0
|
||||
|
||||
if self.splitter_type == SplitterType.PARENT_CHILD:
|
||||
logger.info("Using LangChain ParentDocumentRetriever")
|
||||
|
||||
# Ensure collection exists for child chunks
|
||||
self.vector_store.create_collection()
|
||||
|
||||
# Use ParentDocumentRetriever to add documents
|
||||
# This automatically handles parent-child splitting, mapping, and retrieval
|
||||
self.retriever.add_documents(documents)
|
||||
|
||||
# Log estimated chunk counts
|
||||
estimated_parent_chunks = len(documents) * (self.parent_splitter._chunk_size // self.child_splitter._chunk_size)
|
||||
logger.info(
|
||||
"Indexed with ParentDocumentRetriever: "
|
||||
f"~{len(documents)} parent chunks, ~{estimated_parent_chunks} child chunks"
|
||||
)
|
||||
return len(documents)
|
||||
|
||||
else:
|
||||
logger.info("Splitting documents using %s", self.splitter_type)
|
||||
chunks = self.splitter.split_documents(documents)
|
||||
logger.info("Split into %d chunks", len(chunks))
|
||||
|
||||
self.vector_store.create_collection()
|
||||
self.vector_store.add_documents(chunks)
|
||||
return len(chunks)
|
||||
|
||||
def get_collection_info(self):
|
||||
return self.vector_store.get_collection_info()
|
||||
|
||||
def search(self, query: str, k: int = 5) -> List[Document]:
|
||||
"""Standard search - returns child chunks."""
|
||||
return self.vector_store.similarity_search(query, k=k)
|
||||
|
||||
def search_with_parent_context(self, query: str, k: int = 5) -> List[Document]:
|
||||
"""
|
||||
Search with parent context - returns full parent chunks.
|
||||
|
||||
This is the main retrieval method when using parent-child splitting.
|
||||
"""
|
||||
if self.splitter_type != SplitterType.PARENT_CHILD:
|
||||
raise RuntimeError(
|
||||
"search_with_parent_context only available with PARENT_CHILD splitter. "
|
||||
"Use search() for standard retrieval."
|
||||
)
|
||||
return self.retriever.get_relevant_documents(query, k=k)
|
||||
|
||||
def retrieve(self, query: str, return_parent: bool = True) -> List[Document]:
|
||||
"""
|
||||
Unified retrieval interface.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
return_parent: If True and using parent-child splitter, return parent chunks
|
||||
If False, always return child chunks
|
||||
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
if self.splitter_type == SplitterType.PARENT_CHILD and return_parent:
|
||||
return self.search_with_parent_context(query)
|
||||
else:
|
||||
return self.search(query)
|
||||
|
||||
def get_retriever(self) -> ParentDocumentRetriever:
|
||||
"""
|
||||
Get the ParentDocumentRetriever instance directly.
|
||||
|
||||
Useful for advanced use cases where you want to access the retriever
|
||||
outside of IndexBuilder.
|
||||
"""
|
||||
if self.splitter_type != SplitterType.PARENT_CHILD:
|
||||
raise RuntimeError(
|
||||
"get_retriever() only available with PARENT_CHILD splitter. "
|
||||
"Use search() or search_with_parent_context() for standard retrieval."
|
||||
)
|
||||
return self.retriever
|
||||
|
||||
def get_child_splitter(self) -> "RecursiveCharacterTextSplitter":
|
||||
"""Get the child splitter for reconfiguration."""
|
||||
if self.splitter_type != SplitterType.PARENT_CHILD:
|
||||
return self.splitter
|
||||
return self.child_splitter
|
||||
|
||||
def get_parent_splitter(self) -> "RecursiveCharacterTextSplitter":
|
||||
"""Get the parent splitter for reconfiguration."""
|
||||
if self.splitter_type != SplitterType.PARENT_CHILD:
|
||||
raise RuntimeError(
|
||||
"Parent splitter only available with PARENT_CHILD splitter."
|
||||
)
|
||||
return self.parent_splitter
|
||||
|
||||
def get_docstore(self) -> BaseStore:
|
||||
"""Get the document store for parent chunks."""
|
||||
if self.splitter_type != SplitterType.PARENT_CHILD:
|
||||
raise RuntimeError(
|
||||
"Docstore only available with PARENT_CHILD splitter."
|
||||
)
|
||||
return self.docstore
|
||||
|
||||
def get_docstore_path(self) -> str:
|
||||
"""Get the document store path."""
|
||||
if self.splitter_type != SplitterType.PARENT_CHILD:
|
||||
raise RuntimeError(
|
||||
"Docstore path only available with PARENT_CHILD splitter."
|
||||
)
|
||||
return self.docstore.persist_path
|
||||
|
||||
def close(self):
|
||||
"""Close resources."""
|
||||
if hasattr(self, "_docstore_conn") and self._docstore_conn:
|
||||
import psycopg2
|
||||
conn = psycopg2.connect(self._docstore_conn)
|
||||
conn.close()
|
||||
logger.info("Closed PostgreSQL connection")
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
return False
|
||||
|
||||
|
||||
# RecursiveCharacterTextSplitter needs to be imported
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example usage
|
||||
builder = IndexBuilder(
|
||||
splitter_type=SplitterType.PARENT_CHILD,
|
||||
parent_chunk_size=1000,
|
||||
child_chunk_size=200,
|
||||
docstore_path="./my_parent_docs",
|
||||
)
|
||||
|
||||
print("Parent splitter:", builder.get_parent_splitter().chunk_size)
|
||||
print("Child splitter:", builder.get_child_splitter().chunk_size)
|
||||
print("Docstore path:", builder.get_docstore_path())
|
||||
print("Retriever:", builder.get_retriever())
|
||||
102
rag_indexer/cli.py
Executable file
102
rag_indexer/cli.py
Executable file
@@ -0,0 +1,102 @@
|
||||
"""
|
||||
Command-line interface for the RAG index builder.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from builder import IndexBuilder
|
||||
from splitters import SplitterType
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Offline RAG Index Builder")
|
||||
parser.add_argument("--file", type=str, help="Path to file to index")
|
||||
parser.add_argument("--dir", type=str, help="Path to directory to index")
|
||||
parser.add_argument("--recursive", action="store_true", default=True,
|
||||
help="Recursively process directories (default: True)")
|
||||
parser.add_argument("--collection", type=str, default="rag_documents",
|
||||
help="Qdrant collection name (default: rag_documents)")
|
||||
parser.add_argument("--qdrant-url", type=str,
|
||||
help="Qdrant server URL (default: http://127.0.0.1:6333)")
|
||||
parser.add_argument("--splitter", type=str,
|
||||
choices=["recursive", "semantic", "parent_child"],
|
||||
default="recursive",
|
||||
help="Text splitting strategy (default: recursive)")
|
||||
parser.add_argument("--chunk-size", type=int, default=500,
|
||||
help="Chunk size for recursive/parent splitter (default: 500)")
|
||||
parser.add_argument("--chunk-overlap", type=int, default=50,
|
||||
parser.add_argument("--docstore-path", type=str,
|
||||
default=None,
|
||||
help="Path to store parent documents for parent-child splitter (default: ./parent_docs or HERMES_HOME/parent_docs)")
|
||||
parser.add_argument("--docstore-type", type=str,
|
||||
choices=["local", "postgres"],
|
||||
default="local",
|
||||
help="Type of docstore: 'local' (default) or 'postgres' for PostgreSQL-backed storage")
|
||||
parser.add_argument("--docstore-conn", type=str,
|
||||
default=None,
|
||||
help="PostgreSQL connection string for postgres docstore")
|
||||
|
||||
help="Chunk overlap (default: 50)")
|
||||
parser.add_argument("--parent-size", type=int, default=1000,
|
||||
help="Parent chunk size for parent-child splitter (default: 1000)")
|
||||
parser.add_argument("--child-size", type=int, default=200,
|
||||
help="Child chunk size for parent-child splitter (default: 200)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.file and not args.dir:
|
||||
print("Error: Either --file or --dir must be specified", file=sys.stderr)
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
|
||||
splitter_map = {
|
||||
"recursive": SplitterType.RECURSIVE,
|
||||
"semantic": SplitterType.SEMANTIC,
|
||||
"parent_child": SplitterType.PARENT_CHILD,
|
||||
}
|
||||
splitter_type = splitter_map[args.splitter]
|
||||
|
||||
splitter_kwargs = {}
|
||||
if splitter_type == SplitterType.RECURSIVE:
|
||||
splitter_kwargs["chunk_size"] = args.chunk_size
|
||||
splitter_kwargs["chunk_overlap"] = args.chunk_overlap
|
||||
elif splitter_type == SplitterType.PARENT_CHILD:
|
||||
splitter_kwargs["parent_chunk_size"] = args.parent_size
|
||||
splitter_kwargs["child_chunk_size"] = args.child_size
|
||||
splitter_kwargs["parent_chunk_overlap"] = args.chunk_overlap
|
||||
splitter_kwargs["child_chunk_overlap"] = args.chunk_overlap // 2
|
||||
splitter_kwargs["docstore_path"] = args.docstore_path
|
||||
splitter_kwargs["docstore_type"] = args.docstore_type
|
||||
splitter_kwargs["docstore_conn_string"] = args.docstore_conn
|
||||
|
||||
builder = IndexBuilder(
|
||||
collection_name=args.collection,
|
||||
qdrant_url=args.qdrant_url,
|
||||
splitter_type=splitter_type,
|
||||
**splitter_kwargs
|
||||
)
|
||||
|
||||
try:
|
||||
if args.file:
|
||||
chunk_count = builder.build_from_file(args.file)
|
||||
else:
|
||||
chunk_count = builder.build_from_directory(args.dir, args.recursive)
|
||||
|
||||
print(f"Indexing completed. Total chunks indexed: {chunk_count}")
|
||||
info = builder.get_collection_info()
|
||||
print(f"Collection '{info['name']}' has {info['vectors_count']} vectors (dim={info['vector_size']})")
|
||||
|
||||
except Exception as e:
|
||||
logging.exception("Indexing failed")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
142
rag_indexer/docstore_manager.py
Normal file
142
rag_indexer/docstore_manager.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""
|
||||
Document store manager for ParentDocumentRetriever.
|
||||
|
||||
Supports both LocalFileStore (default) and custom PostgreSQL-backed stores.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
from langchain.storage import BaseStore, LocalFileStore
|
||||
|
||||
|
||||
def get_docstore(persist_path: str = None) -> LocalFileStore:
|
||||
"""
|
||||
Create and return a document store for parent chunks.
|
||||
|
||||
Args:
|
||||
persist_path: Path to store parent documents. Defaults to ./parent_docs
|
||||
or HERMES_HOME/parent_docs if set.
|
||||
"""
|
||||
if persist_path is None:
|
||||
# Use HERMES_HOME if available, otherwise default to current directory
|
||||
persist_path = os.getenv("HERMES_HOME")
|
||||
if persist_path:
|
||||
persist_path = os.path.join(persist_path, "parent_docs")
|
||||
else:
|
||||
persist_path = "./parent_docs"
|
||||
|
||||
os.makedirs(persist_path, exist_ok=True)
|
||||
return LocalFileStore(persist_path)
|
||||
|
||||
|
||||
class PostgresDocStore(BaseStore):
|
||||
"""
|
||||
PostgreSQL-backed document store for parent chunks.
|
||||
|
||||
This is an optional advanced feature. For most use cases,
|
||||
LocalFileStore is sufficient and simpler.
|
||||
"""
|
||||
|
||||
def __init__(self, connection_string: str):
|
||||
"""
|
||||
Initialize PostgreSQL document store.
|
||||
|
||||
Args:
|
||||
connection_string: PostgreSQL connection URL
|
||||
"""
|
||||
import psycopg2
|
||||
from psycopg2 import sql
|
||||
|
||||
self.conn_string = connection_string
|
||||
self._conn = None
|
||||
|
||||
# Create table if not exists
|
||||
self._create_table()
|
||||
|
||||
def _create_table(self):
|
||||
"""Create the parent documents table if not exists."""
|
||||
try:
|
||||
self._conn = psycopg2.connect(self.conn_string)
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS parent_documents (
|
||||
key TEXT PRIMARY KEY,
|
||||
value JSONB NOT NULL,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW()
|
||||
)
|
||||
""")
|
||||
self._conn.commit()
|
||||
cursor.close()
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to create PostgreSQL table: {e}")
|
||||
|
||||
def get(self, key: str) -> Optional[dict]:
|
||||
"""Retrieve a document by key."""
|
||||
try:
|
||||
self._ensure_connection()
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute("SELECT value FROM parent_documents WHERE key = %s", (key,))
|
||||
row = cursor.fetchone()
|
||||
cursor.close()
|
||||
if row:
|
||||
import json
|
||||
return json.loads(row[0])
|
||||
return None
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to retrieve document: {e}")
|
||||
|
||||
def set(self, key: str, value: dict) -> None:
|
||||
"""Store a document."""
|
||||
try:
|
||||
self._ensure_connection()
|
||||
cursor = self._conn.cursor()
|
||||
# Upsert
|
||||
insert_query = sql.SQL(
|
||||
"INSERT INTO parent_documents (key, value) VALUES (%s, %s)"
|
||||
)
|
||||
update_query = sql.SQL(
|
||||
"UPDATE parent_documents SET value = %s WHERE key = %s"
|
||||
)
|
||||
cursor.execute(insert_query, (key, json.dumps(value)))
|
||||
try:
|
||||
cursor.execute(update_query, (key, json.dumps(value)))
|
||||
except psycopg2.IntegrityError:
|
||||
pass # Key exists, ignore
|
||||
self._conn.commit()
|
||||
cursor.close()
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to store document: {e}")
|
||||
|
||||
def _ensure_connection(self):
|
||||
"""Ensure we have an open connection."""
|
||||
if self._conn is None or self._conn.closed:
|
||||
self._conn = psycopg2.connect(self.conn_string)
|
||||
|
||||
def close(self):
|
||||
"""Close the connection."""
|
||||
if self._conn and not self._conn.closed:
|
||||
self._conn.close()
|
||||
|
||||
|
||||
# Factory function for creating custom docstores
|
||||
# Returns a tuple: (BaseStore instance, connection_string or None)
|
||||
def create_docstore(
|
||||
store_type: str = "local",
|
||||
persist_path: str = None,
|
||||
connection_string: str = None
|
||||
) -> tuple:
|
||||
"""
|
||||
Factory function to create different types of document stores.
|
||||
|
||||
Args:
|
||||
store_type: "local" (default), "postgres"
|
||||
persist_path: Path for local file store
|
||||
connection_string: PostgreSQL connection string
|
||||
|
||||
Returns:
|
||||
Tuple of (BaseStore instance, connection_string or None)
|
||||
"""
|
||||
if store_type == "postgres" and connection_string:
|
||||
return (PostgresDocStore(connection_string), connection_string)
|
||||
else:
|
||||
return (get_docstore(persist_path), None)
|
||||
68
rag_indexer/embedders.py
Normal file
68
rag_indexer/embedders.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""
|
||||
Embedding model wrapper for llama.cpp service.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import List, Optional
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
|
||||
|
||||
class LlamaCppEmbedder:
|
||||
"""Wrapper for llama.cpp embedding service via OpenAI-compatible API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
model: str = "embeddinggemma-300M-Q8_0",
|
||||
):
|
||||
self.base_url = base_url or os.getenv("LLAMACPP_EMBEDDING_URL", "http://127.0.0.1:8082")
|
||||
self.api_key = api_key or os.getenv("LLAMACPP_API_KEY", "")
|
||||
self.model = model
|
||||
|
||||
# Ensure URL ends with /v1
|
||||
self.base_url = urljoin(self.base_url.rstrip("/") + "/", "v1")
|
||||
|
||||
def as_langchain_embeddings(self) -> OpenAIEmbeddings:
|
||||
"""Create LangChain OpenAIEmbeddings instance."""
|
||||
return OpenAIEmbeddings(
|
||||
openai_api_base=self.base_url,
|
||||
openai_api_key=self.api_key,
|
||||
model=self.model,
|
||||
)
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed a list of documents."""
|
||||
emb = self.as_langchain_embeddings()
|
||||
return emb.embed_documents(texts)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Embed a single query."""
|
||||
emb = self.as_langchain_embeddings()
|
||||
return emb.embed_query(text)
|
||||
|
||||
def get_embedding_dimension(self) -> int:
|
||||
"""Get embedding dimension by embedding a test string."""
|
||||
test_embedding = self.embed_query("test")
|
||||
return len(test_embedding)
|
||||
|
||||
|
||||
class MockEmbedder:
|
||||
"""Mock embedder for testing without a real service."""
|
||||
|
||||
def __init__(self, dimension: int = 768):
|
||||
self.dimension = dimension
|
||||
|
||||
def as_langchain_embeddings(self) -> OpenAIEmbeddings:
|
||||
raise NotImplementedError("MockEmbedder cannot be used as LangChain embeddings")
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
return [[0.0] * self.dimension for _ in texts]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
return [0.0] * self.dimension
|
||||
|
||||
def get_embedding_dimension(self) -> int:
|
||||
return self.dimension
|
||||
124
rag_indexer/example_parent_child.py
Normal file
124
rag_indexer/example_parent_child.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""
|
||||
Example demonstrating ParentDocumentRetriever usage.
|
||||
|
||||
This script shows how to:
|
||||
1. Build an index with parent-child chunking
|
||||
2. Search with child chunks (fast, precise)
|
||||
3. Search with parent context (large context)
|
||||
4. Access the retriever directly for advanced use cases
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
|
||||
from builder import IndexBuilder
|
||||
from splitters import SplitterType
|
||||
|
||||
|
||||
def main():
|
||||
print("=" * 70)
|
||||
print("ParentDocumentRetriever Example")
|
||||
print("=" * 70)
|
||||
|
||||
# Step 1: Create IndexBuilder with parent-child splitting
|
||||
print("\n1. Creating IndexBuilder with parent-child splitting...")
|
||||
builder = IndexBuilder(
|
||||
collection_name="parent_child_demo",
|
||||
splitter_type=SplitterType.PARENT_CHILD,
|
||||
parent_chunk_size=1000, # Parent chunks: larger context
|
||||
child_chunk_size=200, # Child chunks: smaller for precision
|
||||
docstore_path="./my_parent_docs", # Where to store parent chunks
|
||||
search_k=5, # Number of child chunks to retrieve
|
||||
)
|
||||
|
||||
print(f" Parent splitter: chunk_size={builder.get_parent_splitter().chunk_size}")
|
||||
print(f" Child splitter: chunk_size={builder.get_child_splitter().chunk_size}")
|
||||
print(f" Docstore path: {builder.get_docstore_path()}")
|
||||
print(f" Search k: {builder.retriever.search_kwargs['k']}")
|
||||
|
||||
# Step 2: Build index from a sample file
|
||||
print("\n2. Building index from sample file...")
|
||||
|
||||
# Create a test document
|
||||
test_content = """
|
||||
This is a test document for demonstrating ParentDocumentRetriever.
|
||||
|
||||
Parent chunks contain larger portions of text (1000 characters),
|
||||
while child chunks are smaller (200 characters) for precise retrieval.
|
||||
|
||||
When you search with ParentDocumentRetriever:
|
||||
- It first retrieves relevant child chunks
|
||||
- Then replaces them with their corresponding parent chunks
|
||||
- This gives you large context while maintaining precision
|
||||
|
||||
Example search queries:
|
||||
- "ParentDocumentRetriever"
|
||||
- "child chunks"
|
||||
- "large context"
|
||||
- "precise retrieval"
|
||||
"""
|
||||
|
||||
test_file = Path("./test_document.txt")
|
||||
test_file.write_text(test_content)
|
||||
|
||||
chunk_count = builder.build_from_file(str(test_file))
|
||||
print(f" Indexed {chunk_count} documents")
|
||||
|
||||
# Step 3: Search with child chunks (fast, precise)
|
||||
print("\n3. Searching with child chunks (fast, precise)...")
|
||||
child_results = builder.search("ParentDocumentRetriever", k=3)
|
||||
print(f" Found {len(child_results)} child chunks:")
|
||||
for i, doc in enumerate(child_results, 1):
|
||||
print(f" [{i}] {doc.page_content[:100]}...")
|
||||
|
||||
# Step 4: Search with parent context (large context)
|
||||
print("\n4. Searching with parent context (large context)...")
|
||||
parent_results = builder.search_with_parent_context("ParentDocumentRetriever", k=3)
|
||||
print(f" Found {len(parent_results)} parent chunks:")
|
||||
for i, doc in enumerate(parent_results, 1):
|
||||
print(f" [{i}] {doc.page_content[:150]}...")
|
||||
|
||||
# Step 5: Compare results
|
||||
print("\n5. Comparing child vs parent results...")
|
||||
print(f" Child chunks total length: {sum(len(d.page_content) for d in child_results)}")
|
||||
print(f" Parent chunks total length: {sum(len(d.page_content) for d in parent_results)}")
|
||||
print(f" Ratio: parent/child = {sum(len(d.page_content) for d in parent_results) / max(sum(len(d.page_content) for d in child_results), 1):.2f}x larger")
|
||||
|
||||
# Step 6: Access retriever directly
|
||||
print("\n6. Accessing retriever directly...")
|
||||
retriever = builder.get_retriever()
|
||||
print(f" Retriever type: {type(retriever).__name__}")
|
||||
print(f" Vectorstore: {retriever.vectorstore}")
|
||||
print(f" Docstore: {retriever.docstore}")
|
||||
|
||||
# Step 7: Unified retrieval interface
|
||||
print("\n7. Using unified retrieval interface...")
|
||||
unified_results = builder.retrieve("ParentDocumentRetriever", return_parent=True)
|
||||
print(f" Retrieved {len(unified_results)} documents (with parent context)")
|
||||
|
||||
# Step 8: Collection info
|
||||
print("\n8. Collection info...")
|
||||
info = builder.get_collection_info()
|
||||
print(f" Collection: {info['name']}")
|
||||
print(f" Vectors: {info['vectors_count']}")
|
||||
print(f" Vector size: {info['vector_size']}")
|
||||
|
||||
# Cleanup
|
||||
print("\n9. Cleaning up...")
|
||||
builder.close()
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("Example completed successfully!")
|
||||
print("=" * 70)
|
||||
|
||||
return builder
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
builder = main()
|
||||
91
rag_indexer/loaders.py
Normal file
91
rag_indexer/loaders.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""
|
||||
Document loaders using unstructured library.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from unstructured.partition.auto import partition
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DocumentLoader:
|
||||
"""Load documents from various file formats."""
|
||||
|
||||
SUPPORTED_EXTENSIONS = {".pdf", ".docx", ".doc", ".txt", ".md", ".html", ".pptx", ".xlsx"}
|
||||
|
||||
def __init__(self, extract_images: bool = False):
|
||||
"""
|
||||
Args:
|
||||
extract_images: Whether to extract images from documents (requires additional dependencies)
|
||||
"""
|
||||
self.extract_images = extract_images
|
||||
|
||||
def load_file(self, file_path: Union[str, Path]) -> List[Document]:
|
||||
"""Load a single file into LangChain Document objects."""
|
||||
file_path = Path(file_path).resolve()
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
suffix = file_path.suffix.lower()
|
||||
if suffix not in self.SUPPORTED_EXTENSIONS:
|
||||
raise ValueError(
|
||||
f"Unsupported file extension: {suffix}. Supported: {self.SUPPORTED_EXTENSIONS}"
|
||||
)
|
||||
|
||||
# Parse with unstructured
|
||||
elements = partition(
|
||||
filename=str(file_path),
|
||||
extract_images_in_pdf=self.extract_images,
|
||||
)
|
||||
|
||||
documents = []
|
||||
for elem in elements:
|
||||
text = getattr(elem, "text", "")
|
||||
if not text or not text.strip():
|
||||
continue
|
||||
|
||||
# Base metadata
|
||||
metadata = {
|
||||
"source": str(file_path),
|
||||
"file_name": file_path.name,
|
||||
"file_type": suffix,
|
||||
}
|
||||
|
||||
# Merge element-specific metadata without overwriting base fields
|
||||
elem_meta = getattr(elem, "metadata", {}) or {}
|
||||
for key, value in elem_meta.items():
|
||||
if value and key not in metadata:
|
||||
metadata[key] = value
|
||||
|
||||
documents.append(Document(page_content=text, metadata=metadata))
|
||||
|
||||
if not documents:
|
||||
logger.warning("No text content extracted from %s", file_path)
|
||||
return []
|
||||
|
||||
return documents
|
||||
|
||||
def load_directory(
|
||||
self, directory_path: Union[str, Path], recursive: bool = True
|
||||
) -> List[Document]:
|
||||
"""Load all supported files from a directory."""
|
||||
directory_path = Path(directory_path).resolve()
|
||||
if not directory_path.is_dir():
|
||||
raise NotADirectoryError(f"Not a directory: {directory_path}")
|
||||
|
||||
all_documents = []
|
||||
pattern = "**/*" if recursive else "*"
|
||||
|
||||
for file_path in directory_path.glob(pattern):
|
||||
if file_path.is_file() and file_path.suffix.lower() in self.SUPPORTED_EXTENSIONS:
|
||||
try:
|
||||
docs = self.load_file(file_path)
|
||||
all_documents.extend(docs)
|
||||
except Exception as e:
|
||||
logger.error("Failed to load %s: %s", file_path, e)
|
||||
|
||||
return all_documents
|
||||
71
rag_indexer/splitters.py
Normal file
71
rag_indexer/splitters.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""
|
||||
Text splitters for chunking documents.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
from langchain_experimental.text_splitter import SemanticChunker
|
||||
|
||||
|
||||
class SplitterType(str, Enum):
|
||||
RECURSIVE = "recursive"
|
||||
SEMANTIC = "semantic"
|
||||
PARENT_CHILD = "parent_child"
|
||||
|
||||
|
||||
def get_splitter(splitter_type: SplitterType, **kwargs):
|
||||
"""Factory function to create a text splitter."""
|
||||
if splitter_type == SplitterType.RECURSIVE:
|
||||
chunk_size = kwargs.get("chunk_size", 500)
|
||||
chunk_overlap = kwargs.get("chunk_overlap", 50)
|
||||
return RecursiveCharacterTextSplitter(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
separators=["\n\n", "\n", "。", "!", "?", " ", ""],
|
||||
)
|
||||
elif splitter_type == SplitterType.SEMANTIC:
|
||||
# Requires embeddings for semantic splitting
|
||||
embeddings = kwargs.get("embeddings")
|
||||
if embeddings is None:
|
||||
raise ValueError("Semantic splitter requires 'embeddings' parameter")
|
||||
return SemanticChunker(embeddings=embeddings)
|
||||
else:
|
||||
raise ValueError(f"Unsupported splitter type: {splitter_type}")
|
||||
|
||||
|
||||
class ParentChildSplitter:
|
||||
"""
|
||||
Splits documents into parent (large) and child (small) chunks.
|
||||
Child chunks are indexed for retrieval, parent chunks are stored for context.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
parent_chunk_size: int = 1000,
|
||||
child_chunk_size: int = 200,
|
||||
parent_chunk_overlap: int = 100,
|
||||
child_chunk_overlap: int = 20,
|
||||
):
|
||||
self.parent_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=parent_chunk_size,
|
||||
chunk_overlap=parent_chunk_overlap,
|
||||
)
|
||||
self.child_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=child_chunk_size,
|
||||
chunk_overlap=child_chunk_overlap,
|
||||
)
|
||||
|
||||
def split_documents(self, documents: List[Document]) -> tuple[List[Document], List[Document]]:
|
||||
"""
|
||||
Returns:
|
||||
(parent_chunks, child_chunks)
|
||||
"""
|
||||
parent_chunks = self.parent_splitter.split_documents(documents)
|
||||
child_chunks = self.child_splitter.split_documents(documents)
|
||||
|
||||
# Link child chunks to parent IDs (optional metadata)
|
||||
# In a real implementation, you'd map each child to a parent chunk ID.
|
||||
return parent_chunks, child_chunks
|
||||
110
rag_indexer/vector_store.py
Normal file
110
rag_indexer/vector_store.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""
|
||||
Qdrant vector store wrapper.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Optional, Dict, Any
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_qdrant import QdrantVectorStore as LangchainQdrantVS
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.http import models
|
||||
from qdrant_client.http.models import Distance, VectorParams
|
||||
|
||||
from .embedders import LlamaCppEmbedder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class QdrantVectorStore:
|
||||
"""Wrapper for Qdrant vector database operations."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str,
|
||||
embeddings: Optional[Any] = None,
|
||||
qdrant_url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
):
|
||||
self.collection_name = collection_name
|
||||
self.qdrant_url = qdrant_url or os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
|
||||
self.api_key = api_key
|
||||
|
||||
# Embeddings
|
||||
if embeddings is None:
|
||||
embedder = LlamaCppEmbedder()
|
||||
self.embeddings = embedder.as_langchain_embeddings()
|
||||
else:
|
||||
self.embeddings = embeddings
|
||||
|
||||
# Qdrant client
|
||||
self.client = QdrantClient(url=self.qdrant_url, api_key=self.api_key)
|
||||
|
||||
# LangChain vector store
|
||||
self.vector_store = LangchainQdrantVS(
|
||||
client=self.client,
|
||||
collection_name=self.collection_name,
|
||||
embeddings=self.embeddings,
|
||||
)
|
||||
|
||||
def create_collection(self, vector_size: Optional[int] = None, force_recreate: bool = False):
|
||||
"""Create collection with appropriate vector size."""
|
||||
if vector_size is None:
|
||||
embedder = LlamaCppEmbedder()
|
||||
vector_size = embedder.get_embedding_dimension()
|
||||
|
||||
collections = self.client.get_collections().collections
|
||||
exists = any(c.name == self.collection_name for c in collections)
|
||||
|
||||
if exists and force_recreate:
|
||||
self.client.delete_collection(self.collection_name)
|
||||
exists = False
|
||||
|
||||
if not exists:
|
||||
self.client.create_collection(
|
||||
collection_name=self.collection_name,
|
||||
vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE),
|
||||
)
|
||||
logger.info("Collection '%s' created (dim=%d)", self.collection_name, vector_size)
|
||||
else:
|
||||
logger.info("Collection '%s' already exists", self.collection_name)
|
||||
|
||||
def add_documents(self, documents: List[Document], batch_size: int = 100):
|
||||
"""Add documents to vector store."""
|
||||
if not documents:
|
||||
return []
|
||||
self.create_collection()
|
||||
ids = self.vector_store.add_documents(documents, batch_size=batch_size)
|
||||
logger.info("Added %d documents to '%s'", len(ids), self.collection_name)
|
||||
return ids
|
||||
|
||||
def similarity_search(self, query: str, k: int = 5) -> List[Document]:
|
||||
return self.vector_store.similarity_search(query, k=k)
|
||||
|
||||
def similarity_search_with_score(self, query: str, k: int = 5) -> List[tuple[Document, float]]:
|
||||
return self.vector_store.similarity_search_with_score(query, k=k)
|
||||
|
||||
def delete_collection(self):
|
||||
self.client.delete_collection(self.collection_name)
|
||||
logger.info("Collection '%s' deleted", self.collection_name)
|
||||
|
||||
def get_collection_info(self) -> Dict[str, Any]:
|
||||
info = self.client.get_collection(self.collection_name)
|
||||
return {
|
||||
"name": info.name,
|
||||
"vectors_count": info.vectors_count,
|
||||
"status": info.status,
|
||||
"vector_size": info.config.params.vectors.size,
|
||||
}
|
||||
|
||||
def as_langchain_vectorstore(self):
|
||||
return self.vector_store
|
||||
|
||||
def get_langchain_vectorstore(self):
|
||||
"""返回 LangChain Qdrant 向量存储对象(别名)"""
|
||||
return self.vector_store
|
||||
|
||||
def get_qdrant_client(self):
|
||||
"""返回原生 Qdrant 客户端(如需手动管理 collection)"""
|
||||
return self.client
|
||||
@@ -48,6 +48,8 @@ pydantic==2.12.5
|
||||
python-dotenv==1.2.2
|
||||
typing-extensions==4.15.0
|
||||
|
||||
unstructured>=0.0.1
|
||||
|
||||
# ============================================================================
|
||||
# 注意:
|
||||
# 1. 此文件包含项目直接依赖的精确版本
|
||||
|
||||
Reference in New Issue
Block a user