检索器重构
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 17m12s

This commit is contained in:
2026-04-19 22:01:55 +08:00
parent cc8ef41ef9
commit 933d418d77
26 changed files with 1694 additions and 1717 deletions

View File

@@ -2,71 +2,44 @@
该模块负责 RAG 系统的阶段二:**在线检索与生成**。它接收用户提问,从知识库中检索出上下文,利用各种高级策略去噪、融合,并作为增强上下文输入给大语言模型 (LLM)。
## 📊 RAG-Fusion & 混合检索流水线示意图
```mermaid
graph TD
User((用户提问)) --> A[LLM 查询改写生成器]
subgraph RAG-Fusion 核心流程
A -->|改写为问题 1| B1[查询 1]
A -->|改写为问题 2| B2[查询 2]
A -->|原问题| B3[原始查询]
B1 & B2 & B3 --> C[混合检索器 Hybrid Retriever <br> Dense Vector + BM25 Sparse]
C --> D[多路召回结果合集 N=60条]
D --> E{RRF 倒数排名融合去重}
end
E -->|筛选出前 20 条| F[Cross-Encoder 重排器 Reranker]
F -->|精细打分排序 Top 5| G[最终纯净上下文 Context]
G --> H[将 Context 与原问题拼接输入大模型]
H --> I((LLM 生成最终回答))
```
---
## 🎯 演进路线与算法详解 (Roadmap)
### Level 1: 基础向量搜索 (Basic Similarity Search)
- **核心算法**: 近似最近邻搜索 (ANN, 常用 HNSW 算法)。将用户问题转化为向量后,计算它与库中向量的余弦相似度 (Cosine Similarity),取距离最近的 K 个块。
- **优缺点**: 速度极快。但只能捕捉“语义相似”,如果用户搜索特定专有名词、编号、订单号,纯向量检索往往会失效(产生“幻觉”匹配)。
- **实现指南**:
- 使用 `rag_indexer.embedders.LlamaCppEmbedder` 作为嵌入模型
- 使用 `app/rag/retriever.py` 中的 `create_base_retriever` 创建基础检索器
- 配置 `search_kwargs={"k": 20}` 进行初步召回
### Level 2: 混合检索与重排序 (Hybrid Search + Reranker)
混合检索旨在结合向量的“语义泛化”与关键词的“精准匹配”,随后利用重排序模型过滤噪声。
**1. 基础召回 (混合检索)**
- **核心原理**: 结合基于 HNSW 的 Dense Vector 相似度搜索与基于 TF-IDF 的 BM25 稀疏检索 (Sparse Vector)。
- **实现指南**: 使用 `langchain_qdrant` 中的 `Qdrant` 类连接数据库。通过调用 `Qdrant.from_existing_collection(...)` 实例化向量库,并使用 `.as_retriever(search_kwargs={"k": 20})` 方法生成基础检索器。Qdrant 底层会自动处理双路召回
- **实现指南**: 使用 `app/rag/retriever.py` 中的 `create_hybrid_retriever` 函数,配置 `dense_k=10``sparse_k=10`,总召回 20 条结果
**2. 二次精排 (Cross-Encoder)**
- **核心原理**: 不同于双塔模型(分别算向量再求距离),交叉编码器将“用户问题 + 检索到的单例文档”拼接后整体输入 Transformer 模型,由模型直接输出 0~1 的相关性得分,精度极高。
- **实现指南**:
- 使用 `sentence-transformers` 库加载本地轻量级重排模型(如 `BAAI/bge-reranker-base`)。
- 引入 `langchain.retrievers.document_compressors` 包中的 `CrossEncoderReranker` 类包装该模型,设置参数 `top_n=5`
- 最后,使用 `langchain.retrievers` 包中的 `ContextualCompressionRetriever` 类,将 `base_compressor` (重排器) 和 `base_retriever` (基础检索器) 组合。
- **如何调用**: 业务逻辑中直接对组合后的检索器调用 `.invoke(query)` 方法,即可一键完成“大范围召回 20 条 -> 逐一打分精排选 5 条”的去噪流水线。
- **实现指南**:
- 使用 `app/rag/reranker.py` 中的 `CrossEncoderReranker` 类,加载 `BAAI/bge-reranker-base` 模型
- 设置 `top_n=5` 保留最相关的 5 条结果
- 使用 `ContextualCompressionRetriever` 组合基础检索器和重排序器
### Level 3: RAG-Fusion (多路改写与倒数排名融合)
RAG-Fusion 通过大模型发散思维,将单一问题改写为多个相似问题,扩大搜索面,再利用数学统计算法合并结果。
**1. 多路查询改写**
- **核心原理**: 克服用户初始提问词不达意或视角受限的问题。
- **实现指南**: 导入 `langchain.retrievers.multi_query` 中的 `MultiQueryRetriever` 类。需向其提供一个已实例化的 LLM 对象(如基于 `ChatOpenAI` 封装的本地 VLLM 模型)。系统在底层会自动 Prompt 模型,将原始 `query` 转化为包含 3-5 个不同表述的查询列表
- **实现指南**: 使用 `app/rag/query_transform.py` 中的 `MultiQueryTransformer` 类,配置 `num_queries=3` 生成 3 个不同角度的查询。
**2. 倒数排名融合 (RRF)**
- **核心原理**: RRF (Reciprocal Rank Fusion) 是一种无需评分归一化的融合算法。公式为 $RRF\_score(d) = \sum_{q \in Q} \frac{1}{k + rank_q(d)}$,有效避免某一极端检索结果主导全局。
- **实现指南**:
- 针对每个改写后的查询 $q$,分别调用精排检索器的 `.invoke(q)` 获取文档列表。
- 使用 `langchain.retrievers` 中的 `EnsembleRetriever` 类(原生支持 RRF或在代码中遍历收集到的 `Document` 对象,基于其排名 `rank` 累加得分,最终通过 Python 的 `set` 去重并提取 `doc.page_content`
- **实现指南**: 使用 `app/rag/retriever.py` 中的 `create_ensemble_retriever` 函数,配置 `search_type="rrf"` 实现倒数排名融合。
### Level 4: Agentic RAG / Self-RAG (智能体与自我反思)
- **核心原理**: 基于 LangGraph 的 ReAct (Reasoning and Acting) 状态机路由。大模型并非每次都去死板地执行检索,而是先判断问题:“这是闲聊?还是需要查知识库?”。如果是后者,模型输出一个 `ToolCall` 指令,触发检索。
- **实现指南**: 请参考下方的**与现有系统整合调用**章节
- **实现指南**: 使用 `app/rag/tools.py` 中的 `search_knowledge_base` 工具,将其绑定到 LangGraph 状态机中
- **示意图**:
```mermaid
@@ -87,6 +60,13 @@ RAG-Fusion 通过大模型发散思维,将单一问题改写为多个相似问
LangGraph Agent-->>User: "根据知识库规定报销流程分为以下3步..."
```
### Level 5: GraphRAG 集成 (基于图和关系的 RAG)
- **核心原理**: 结合知识图谱的结构化关系和向量检索的语义相似度,解决跨文档复杂关系推理问题。
- **实现指南**:
- 使用 `langchain_community.graphs` 模块构建知识图谱
- 配置本地大模型(如 `Gemma-4-E2B`)用于实体关系抽取
- 实现混合检索逻辑,结合向量相似度和图路径分析
---
## 📦 所需依赖与安装
@@ -102,18 +82,19 @@ pip install rank_bm25
# 基础框架
pip install langchain langchain-core langchain-openai langchain-qdrant
# 与 rag_indexer 共享的依赖
pip install qdrant-client httpx
```
---
## 📂 架构与文件结构设计
在 `app/rag/` 目录下,需创建以下文件来模块化上述功能:
```text
```
app/rag/
├── __init__.py
├── retriever.py # 负责 Qdrant 的基础召回与 ContextualCompressionRetriever
├── retriever.py # 负责 Qdrant 的基础召回与混合检索
├── reranker.py # 负责加载 sentence-transformers 交叉编码器
├── query_transform.py # 负责基于 MultiQueryRetriever 的改写逻辑
├── pipeline.py # 组合上述组件,暴露出核心的 retrieve() 方法
@@ -122,15 +103,69 @@ app/rag/
---
## <EFBFBD>现有系统整合调用 (Agentic RAG 实现)
## 🔄 rag_indexer 集成
### 数据结构兼容性
- **向量存储**: rag_indexer 使用 Qdrant 存储子块向量app/rag 直接从相同集合读取
- **文档存储**: rag_indexer 使用 PostgreSQL 存储父块app/rag 通过 `ParentDocumentRetriever` 关联
- **嵌入模型**: 共享 `LlamaCppEmbedder` 确保向量空间一致性
### 配置共享
- **环境变量**: QDRANT_URL、QDRANT_API_KEY、DB_URI 等配置在两个模块间共享
- **集合名称**: 默认使用 "rag_documents" 集合,确保数据一致性
---
## 🚀 与现有系统整合调用 (Agentic RAG 实现)
基于目前 LangGraph 系统的架构,我们将摒弃将代码堆砌在一起的旧方式,而是利用 **LangChain Tools** 的特性将 RAG 优雅地注入系统:
1. **封装检索工具 (Tool)**:
1. **封装检索工具 (Tool)**:
从 `langchain.tools` 导入 `@tool` 装饰器。定义一个名为 `search_knowledge_base(query: str)` 的函数。在函数内部,实例化并调用我们在 `pipeline.py` 中写好的多路召回与重排逻辑。
2. **模型绑定 (Bind)**:
2. **模型绑定 (Bind)**:
在 `app/agent.py` 或 `app/nodes/tool_call.py` 中,将这个工具引入,并通过 `llm.bind_tools([search_knowledge_base])` 绑定到现有的本地大模型实例上。
3. **状态机路由 (Graph Routing)**:
3. **状态机路由 (Graph Routing)**:
你的 LangGraph 状态机会像处理普通对话一样自动接管:当模型判断需要调用查阅规章制度或专业资料时,它会输出 `ToolCall` 消息,流转到 `tool_node` 执行上述的 RAG 检索逻辑并返回上下文。
这让你无需修改任何前端 Streamlit 流式代码,就能平滑升级为具备超级知识库检索能力的智能体 (Agent)
---
## 🎯 快速开始
```python
# 1. 初始化嵌入模型
from rag_indexer.embedders import LlamaCppEmbedder
embeddings = LlamaCppEmbedder()
# 2. 初始化语言模型(用于 RAG-Fusion
from langchain_openai import OpenAI
llm = OpenAI(
openai_api_base="http://localhost:8000/v1",
openai_api_key="no-key-needed",
model_name="Qwen2.5-7B-Instruct",
temperature=0.3,
)
# 3. 创建 RAG 流水线
from app.rag.pipeline import RAGPipeline, RAGLevel
pipeline = RAGPipeline(
embeddings=embeddings,
llm=llm,
config={
"collection_name": "rag_documents",
"rag_level": RAGLevel.FUSION.value,
"num_queries": 3,
"rerank_top_n": 5,
},
)
# 4. 执行检索
result = pipeline.retrieve("如何申请项目资金?")
# 5. 格式化上下文
context = pipeline.format_context(result.documents)
print(context)
```

View File

@@ -1,22 +1,53 @@
"""
在线 RAG 检索与生成系统
RAG 检索与生成模块
提供高级RAG检索功能支持混合检索、重排序、RAG-Fusion和多路查询改写。
提供在线检索与生成功能,包括:
- 基础向量检索
- 重排序
- RAG-Fusion
- Agentic RAG
示例用法:
>>> from app.rag import RAGPipeline, search_knowledge_base
>>> from rag_core import LlamaCppEmbedder
>>>
>>> embeddings = LlamaCppEmbedder()
>>> pipeline = RAGPipeline(embeddings=embeddings)
>>>
>>> documents = pipeline.retrieve("戏耍貂蝉美女")
>>> context = pipeline.format_context(documents)
"""
from .pipeline import RAGPipeline
from .retriever import create_hybrid_retriever, create_base_retriever
from .retriever import (
create_base_retriever,
create_hybrid_retriever,
# create_ensemble_retriever,
create_qdrant_client,
)
from .reranker import CrossEncoderReranker
from .query_transform import MultiQueryTransformer
from .tools import search_knowledge_base_tool
from .pipeline import RAGPipeline, RAGLevel
from .tools import search_knowledge_base, search_knowledge_base_sync
__all__ = [
"RAGPipeline",
"create_hybrid_retriever",
# 检索器
"create_base_retriever",
"create_hybrid_retriever",
# "create_ensemble_retriever",
"create_qdrant_client",
# 重排序器
"CrossEncoderReranker",
# 查询转换器
"MultiQueryTransformer",
"search_knowledge_base_tool",
# 流水线
"RAGPipeline",
"RAGLevel",
# 工具
"search_knowledge_base",
"search_knowledge_base_sync",
]
__version__ = "0.1.0"

View File

@@ -7,6 +7,10 @@ RAG 系统使用示例
import sys
import os
from dotenv import load_dotenv
# 加载环境变量
load_dotenv()
# 添加项目根目录到路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
@@ -19,10 +23,13 @@ def setup_environment():
"""设置环境变量"""
# 设置 Qdrant 连接信息(根据实际情况修改)
os.environ.setdefault("QDRANT_URL", "http://115.190.121.151:6333")
# 设置 Qdrant API 密钥(根据实际情况修改)
os.environ.setdefault("QDRANT_API_KEY", "your-api-key-here")
# 如果需要 API 密钥,请设置 QDRANT_API_KEY
print("环境变量已设置")
print(f"QDRANT_URL: {os.environ.get('QDRANT_URL')}")
print(f"QDRANT_API_KEY: {'***' if os.environ.get('QDRANT_API_KEY') else '未设置'}")
def demonstrate_basic_rag():
@@ -31,37 +38,32 @@ def demonstrate_basic_rag():
print("演示: 基础 RAG 检索 (Level 1)")
print("="*60)
# 创建嵌入模型(使用 OpenAI 兼容的本地模型)
embeddings = OpenAIEmbeddings(
openai_api_base="http://localhost:8000/v1", # 本地 VLLM 服务
openai_api_key="no-key-needed",
model="text-embedding-ada-002", # 假设的模型名称
)
# 创建嵌入模型(使用本地 LlamaCpp 模型)
from rag_core import LlamaCppEmbedder
embedder = LlamaCppEmbedder()
embeddings = embedder.as_langchain_embeddings()
# 创建 RAG 流水线
from app.rag import RAGPipeline, RAGConfig, RAGLevel
config = RAGConfig(
collection_name="documents", # 你的集合名称
rag_level=RAGLevel.BASIC,
verbose=True,
)
from app.rag import RAGPipeline, RAGLevel
pipeline = RAGPipeline(
embeddings=embeddings,
config=config,
config={
"collection_name": "rag_documents", # 你的集合名称
"rag_level": RAGLevel.BASIC.value,
}
)
# 示例查询
query = "公司报销流程是什么?"
query = "吕布"
print(f"\n查询: {query}")
try:
result = pipeline.retrieve(query)
print(f"找到 {len(result.documents)} 个相关文档")
documents = pipeline.retrieve(query)
print(f"找到 {len(documents)} 个相关文档")
# 格式化上下文
context = pipeline.format_context(result.documents)
context = pipeline.format_context(documents)
print(f"\n上下文预览:\n{context[:500]}...")
except Exception as e:
@@ -75,34 +77,31 @@ def demonstrate_hybrid_rag():
print("演示: 混合 RAG 检索 (Level 2)")
print("="*60)
embeddings = OpenAIEmbeddings(
openai_api_base="http://localhost:8000/v1",
openai_api_key="no-key-needed",
model="text-embedding-ada-002",
)
from rag_core import LlamaCppEmbedder
embedder = LlamaCppEmbedder()
embeddings = embedder.as_langchain_embeddings()
from app.rag import RAGPipeline, RAGConfig, RAGLevel
config = RAGConfig(
collection_name="documents",
rag_level=RAGLevel.HYBRID,
dense_k=10,
sparse_k=10,
rerank_top_n=5,
verbose=True,
)
from app.rag import RAGPipeline, RAGLevel
pipeline = RAGPipeline(
embeddings=embeddings,
config=config,
config={
"collection_name": "rag_documents",
"rag_level": RAGLevel.RERANK.value,
"rerank_top_n": 5,
}
)
query = "如何申请年假?"
query = "吕布"
print(f"\n查询: {query}")
try:
result = pipeline.retrieve(query)
print(f"找到 {len(result.documents)} 个重排序后的文档")
documents = pipeline.retrieve(query)
print(f"找到 {len(documents)} 个重排序后的文档")
# 格式化上下文
context = pipeline.format_context(documents)
print(f"\n上下文预览:\n{context[:500]}...")
except Exception as e:
print(f"检索失败: {e}")
@@ -114,42 +113,42 @@ def demonstrate_rag_fusion():
print("演示: RAG-Fusion (Level 3)")
print("="*60)
embeddings = OpenAIEmbeddings(
openai_api_base="http://localhost:8000/v1",
openai_api_key="no-key-needed",
model="text-embedding-ada-002",
)
from rag_core import LlamaCppEmbedder
embedder = LlamaCppEmbedder()
embeddings = embedder.as_langchain_embeddings()
# 创建语言模型用于查询改写
llm = VLLMOpenAI(
# 创建语言模型用于查询改写(使用 OpenAI 兼容的本地模型)
from langchain_openai import ChatOpenAI
llm = ChatOpenAI(
openai_api_base="http://localhost:8000/v1",
openai_api_key="no-key-needed",
model_name="Qwen2.5-7B-Instruct", # 你的本地模型
model="Qwen2.5-7B-Instruct", # 你的本地模型
temperature=0.3,
max_tokens=512,
)
from app.rag import RAGPipeline, RAGConfig, RAGLevel
config = RAGConfig(
collection_name="documents",
rag_level=RAGLevel.FUSION,
num_queries=3,
verbose=True,
)
from app.rag import RAGPipeline, RAGLevel
pipeline = RAGPipeline(
embeddings=embeddings,
llm=llm,
config=config,
config={
"collection_name": "rag_documents",
"rag_level": RAGLevel.FUSION.value,
"num_queries": 3,
}
)
query = "项目上线需要哪些审批?"
query = "吕布"
print(f"\n查询: {query}")
try:
result = pipeline.retrieve(query)
print(f"找到 {len(result.documents)} 个文档 (经过多路查询改写和重排序)")
documents = pipeline.retrieve(query)
print(f"找到 {len(documents)} 个文档 (经过多路查询改写和重排序)")
# 格式化上下文
context = pipeline.format_context(documents)
print(f"\n上下文预览:\n{context[:500]}...")
except Exception as e:
print(f"检索失败: {e}")
@@ -161,44 +160,16 @@ def demonstrate_agentic_rag():
print("演示: Agentic RAG (Level 4)")
print("="*60)
embeddings = OpenAIEmbeddings(
openai_api_base="http://localhost:8000/v1",
openai_api_key="no-key-needed",
model="text-embedding-ada-002",
)
llm = VLLMOpenAI(
openai_api_base="http://localhost:8000/v1",
openai_api_key="no-key-needed",
model_name="Qwen2.5-7B-Instruct",
temperature=0.3,
max_tokens=512,
)
from app.rag import create_agentic_rag_pipeline
from app.rag import search_knowledge_base_sync
try:
# 创建 Agentic RAG 流水线
agentic_rag = create_agentic_rag_pipeline(
embeddings=embeddings,
agent_llm=llm,
config={
"collection_name": "documents",
"verbose": True,
},
)
print("Agentic RAG 流水线创建成功!")
print(f"- 绑定的模型: {agentic_rag['llm']}")
print(f"- RAG 工具: {agentic_rag['tool'].name}")
# 演示工具调用
print("\n工具调用示例:")
response = agentic_rag["tool"].invoke({"query": "员工福利有哪些?"})
print("工具调用示例:")
response = search_knowledge_base_sync("吕布")
print(f"工具响应预览: {response[:200]}...")
except Exception as e:
print(f"创建 Agentic RAG 失败: {e}")
print(f"工具调用失败: {e}")
import traceback
traceback.print_exc()
@@ -211,11 +182,11 @@ def main():
# 设置环境
setup_environment()
# 演示各级功能
# 演示基础功能
demonstrate_basic_rag()
demonstrate_hybrid_rag()
demonstrate_rag_fusion()
demonstrate_agentic_rag()
# demonstrate_rag_fusion() # 需要本地 LLM 服务
# demonstrate_agentic_rag() # 需要本地 LLM 服务
print("\n" + "="*60)
print("演示完成!")
@@ -223,8 +194,8 @@ def main():
print("\n使用说明:")
print("1. 确保 Qdrant 服务运行且集合已创建")
print("2. 根据需要修改 embeddings 和 llm 配置")
print("3. 在 Agent 系统中导入并使用 app.rag.tools.search_knowledge_base_tool")
print("2. 已使用本地 LlamaCpp 嵌入模型")
print("3. 在 Agent 系统中导入并使用 app.rag.tools.search_knowledge_base")
print("4. 将工具绑定到你的 Agent 模型")

View File

@@ -1,341 +1,168 @@
"""
RAG 检索流水线
组合检索、重排序器、查询改写器等组件,提供完整的 RAG 检索功能。
整合基础检索、重排序和 RAG-Fusion 功能。
"""
import time
from typing import List, Dict, Any, Optional, Union
from dataclasses import dataclass, field
from enum import Enum
from typing import List, Optional, Dict, Any
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.language_models import BaseLanguageModel
from .retriever import (
create_base_retriever,
create_hybrid_retriever,
create_ensemble_retriever,
create_qdrant_client,
)
from .reranker import CrossEncoderReranker
from .query_transform import MultiQueryTransformer, create_rag_fusion_pipeline
from .query_transform import MultiQueryTransformer
from rag_core import QDRANT_URL, QDRANT_API_KEY
class RAGLevel(Enum):
"""RAG 功能级别"""
BASIC = 1 # 基础向量
HYBRID = 2 # 混合检索 + 重排序
FUSION = 3 # RAG-Fusion
AGENTIC = 4 # Agentic RAG
@dataclass
class RAGConfig:
"""RAG 配置"""
# Qdrant 配置
collection_name: str = "documents"
qdrant_url: Optional[str] = None
qdrant_api_key: Optional[str] = None
# 检索配置
rag_level: RAGLevel = RAGLevel.FUSION
dense_k: int = 10 # 向量检索数量
sparse_k: int = 10 # BM25 检索数量
total_k: int = 20 # 总检索数量
rerank_top_n: int = 5 # 重排序返回数量
# 查询改写配置
num_queries: int = 3 # RAG-Fusion 查询数量
# 模型配置
reranker_model: str = "BAAI/bge-reranker-base"
device: Optional[str] = None
# 性能配置
enable_cache: bool = True
verbose: bool = True
@dataclass
class RetrievalResult:
"""检索结果"""
documents: List[Document]
query_time: float
level: RAGLevel
metadata: Dict[str, Any] = field(default_factory=dict)
"""RAG 级别"""
BASIC = "basic" # 基础向量
RERANK = "rerank" # 基础检索 + 重排序
FUSION = "fusion" # RAG-Fusion(多路查询 + RRF
class RAGPipeline:
"""
RAG 检索流水线
支持从 Level 1 到 Level 4 的所有功能。
"""
"""RAG 检索流水线"""
def __init__(
self,
embeddings: Embeddings,
embeddings,
llm: Optional[BaseLanguageModel] = None,
config: Optional[RAGConfig] = None,
config: Optional[Dict[str, Any]] = None,
):
"""
初始化 RAG 流水线
Args:
embeddings: 嵌入模型
llm: 语言模型(用于查询改写Level 3+ 需要
config: 配置
llm: 语言模型(用于 RAG-Fusion
config: 配置参数
"""
self.embeddings = embeddings
self.llm = llm
self.config = config or RAGConfig()
self.config = config or {}
# 初始化组件
self._client = None
self._reranker = None
self._query_transformer = None
self._retriever = None
self.collection_name = self.config.get("collection_name", "rag_documents")
self.rag_level = self.config.get("rag_level", RAGLevel.RERANK.value)
self.num_queries = self.config.get("num_queries", 3)
self.rerank_top_n = self.config.get("rerank_top_n", 5)
# 缓存
self._cache = {}
def _get_client(self):
"""获取 Qdrant 客户端"""
if self._client is None:
self._client = create_qdrant_client(
url=self.config.qdrant_url,
api_key=self.config.qdrant_api_key,
)
return self._client
def _get_reranker(self):
"""获取重排序器"""
if self._reranker is None:
self._reranker = CrossEncoderReranker(
model_name=self.config.reranker_model,
top_n=self.config.rerank_top_n,
device=self.config.device,
)
return self._reranker
def _get_query_transformer(self):
"""获取查询改写器"""
if self._query_transformer is None and self.llm is not None:
self._query_transformer = MultiQueryTransformer(
llm=self.llm,
num_queries=self.config.num_queries,
)
return self._query_transformer
def _create_basic_retriever(self):
"""创建基础检索器Level 1"""
return create_base_retriever(
collection_name=self.config.collection_name,
# 初始化基础检索器
self.base_retriever = create_base_retriever(
collection_name=self.collection_name,
embeddings=self.embeddings,
search_kwargs={"k": self.config.total_k},
client=self._get_client(),
)
def _create_hybrid_retriever(self):
"""创建混合检索器Level 2"""
base_retriever = create_hybrid_retriever(
collection_name=self.config.collection_name,
embeddings=self.embeddings,
dense_k=self.config.dense_k,
sparse_k=self.config.sparse_k,
client=self._get_client(),
search_kwargs={"k": 20}, # 召回 20 条
)
# 应用重排序
reranker = self._get_reranker()
return reranker.create_contextual_compression_retriever(base_retriever)
def _create_fusion_retriever(self):
"""创建 RAG-Fusion 检索器Level 3"""
if self.llm is None:
raise ValueError("Level 3 (RAG-Fusion) 需要语言模型进行查询改写")
# 初始化重排序
try:
self.reranker = CrossEncoderReranker(top_n=self.rerank_top_n)
except Exception as e:
print(f"警告: 无法创建重排序器,将使用基础检索。错误: {e}")
self.reranker = None
# 创建基础混合检索器
base_retriever = create_hybrid_retriever(
collection_name=self.config.collection_name,
embeddings=self.embeddings,
dense_k=self.config.dense_k,
sparse_k=self.config.sparse_k,
client=self._get_client(),
)
# 创建 RAG-Fusion 流水线
reranker = self._get_reranker()
return create_rag_fusion_pipeline(
base_retriever=base_retriever,
llm=self.llm,
reranker=reranker,
num_queries=self.config.num_queries,
)
# 根据 RAG 级别创建检索器
self.retriever = self._create_retriever()
def _get_retriever(self):
"""根据配置级别获取检索器"""
if self._retriever is None:
if self.config.rag_level == RAGLevel.BASIC:
self._retriever = self._create_basic_retriever()
elif self.config.rag_level == RAGLevel.HYBRID:
self._retriever = self._create_hybrid_retriever()
elif self.config.rag_level == RAGLevel.FUSION:
self._retriever = self._create_fusion_retriever()
elif self.config.rag_level == RAGLevel.AGENTIC:
# Agentic RAG 使用 Fusion 作为基础,在 tools.py 中包装
self._retriever = self._create_fusion_retriever()
def _create_retriever(self):
"""根据 RAG 级别创建检索器"""
if self.rag_level == RAGLevel.BASIC.value:
return self.base_retriever
# 基础检索 + 重排序
def rerank_retriever(query):
documents = self.base_retriever.invoke(query)
if self.reranker:
return self.reranker.compress_documents(documents, query)
else:
raise ValueError(f"不支持的 RAG 级别: {self.config.rag_level}")
return documents[:self.rerank_top_n]
return self._retriever
if self.rag_level == RAGLevel.RERANK.value:
return SimpleRetriever(rerank_retriever)
# RAG-Fusion
if self.rag_level == RAGLevel.FUSION.value:
if not self.llm:
raise ValueError("RAG-Fusion 需要提供 llm 参数")
# 创建多路查询检索器
transformer = MultiQueryTransformer(
llm=self.llm,
num_queries=self.num_queries
)
multi_query_retriever = transformer.create_multi_query_retriever(
base_retriever=SimpleRetriever(rerank_retriever)
)
return multi_query_retriever
return SimpleRetriever(rerank_retriever)
def retrieve(
self,
query: str,
use_cache: Optional[bool] = None,
**kwargs,
) -> RetrievalResult:
def retrieve(self, query: str) -> List[Document]:
"""
执行检索
Args:
query: 查询文本
use_cache: 是否使用缓存
**kwargs: 额外参数
query: 查询字符串
Returns:
检索结果
相关文档列表
"""
start_time = time.time()
# 检查缓存
if use_cache is None:
use_cache = self.config.enable_cache
cache_key = f"{query}:{self.config.rag_level.value}"
if use_cache and cache_key in self._cache:
if self.config.verbose:
print(f"使用缓存结果: {query}")
return self._cache[cache_key]
# 获取检索器并执行检索
retriever = self._get_retriever()
documents = retriever.invoke(query, **kwargs)
# 计算查询时间
query_time = time.time() - start_time
# 创建结果
result = RetrievalResult(
documents=documents,
query_time=query_time,
level=self.config.rag_level,
metadata={
"query": query,
"collection": self.config.collection_name,
"doc_count": len(documents),
},
)
# 缓存结果
if use_cache:
self._cache[cache_key] = result
if self.config.verbose:
print(f"检索完成: {len(documents)} 文档, 耗时: {query_time:.2f}s")
return result
return self.retriever.invoke(query)
def format_context(
self,
documents: List[Document],
max_length: Optional[int] = None,
) -> str:
async def aretrieve(self, query: str) -> List[Document]:
"""
格式化检索到的文档为上下文文本
异步执行检索
Args:
query: 查询字符串
Returns:
相关文档列表
"""
return await self.retriever.ainvoke(query)
def format_context(self, documents: List[Document]) -> str:
"""
格式化上下文
Args:
documents: 文档列表
max_length: 最大长度(字符数)
Returns:
格式化后的上下文文本
格式化后的上下文字符串
"""
if not documents:
return ""
context_parts = []
total_length = 0
for i, doc in enumerate(documents, 1):
content = doc.page_content
metadata = doc.metadata or {}
source = metadata.get("source", "未知来源")
part = f"【资料 {i}\n"
part += f"来源: {source}\n"
part += f"内容: {content}\n"
part += "---\n"
context_parts.append(part)
for i, doc in enumerate(documents):
# 提取内容和元数据
content = doc.page_content.strip()
metadata = doc.metadata
# 格式化文档
doc_text = f"[文档 {i+1}]\n"
if metadata.get("source"):
doc_text += f"来源: {metadata['source']}\n"
if metadata.get("page"):
doc_text += f"页码: {metadata['page']}\n"
doc_text += f"内容: {content}\n\n"
# 检查长度限制
if max_length is not None:
if total_length + len(doc_text) > max_length:
# 如果添加这个文档会超限,则截断并添加说明
remaining = max_length - total_length
if remaining > 100: # 至少保留100字符
doc_text = doc_text[:remaining] + "...\n\n[内容已截断]"
context_parts.append(doc_text)
break
else:
break
context_parts.append(doc_text)
total_length += len(doc_text)
return "".join(context_parts).strip()
return "".join(context_parts)
class SimpleRetriever:
"""简单检索器包装类"""
def clear_cache(self):
"""清空缓存"""
self._cache.clear()
def __init__(self, retrieve_func):
self.retrieve_func = retrieve_func
@classmethod
def create_from_config(
cls,
embeddings: Embeddings,
llm: Optional[BaseLanguageModel] = None,
config_dict: Optional[Dict[str, Any]] = None,
) -> "RAGPipeline":
"""
从配置字典创建流水线
Args:
embeddings: 嵌入模型
llm: 语言模型
config_dict: 配置字典
Returns:
RAGPipeline 实例
"""
config_dict = config_dict or {}
# 创建配置对象
config = RAGConfig(
collection_name=config_dict.get("collection_name", "documents"),
qdrant_url=config_dict.get("qdrant_url"),
qdrant_api_key=config_dict.get("qdrant_api_key"),
rag_level=RAGLevel(config_dict.get("rag_level", RAGLevel.FUSION.value)),
dense_k=config_dict.get("dense_k", 10),
sparse_k=config_dict.get("sparse_k", 10),
total_k=config_dict.get("total_k", 20),
rerank_top_n=config_dict.get("rerank_top_n", 5),
num_queries=config_dict.get("num_queries", 3),
reranker_model=config_dict.get("reranker_model", "BAAI/bge-reranker-base"),
device=config_dict.get("device"),
enable_cache=config_dict.get("enable_cache", True),
verbose=config_dict.get("verbose", True),
)
return cls(embeddings=embeddings, llm=llm, config=config)
def invoke(self, query):
return self.retrieve_func(query)
async def ainvoke(self, query):
return self.retrieve_func(query)

View File

@@ -1,193 +1,62 @@
"""
查询改写器
查询转换器模块
基于 MultiQueryRetriever 实现多路查询改写,扩大搜索范围
实现多路查询改写功能,用于 RAG-Fusion
"""
from typing import List, Optional, Any
from langchain.retrievers.multi_query import MultiQueryRetriever
from typing import List, Optional
from langchain_core.language_models import BaseLanguageModel
# from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
class MultiQueryTransformer:
"""
多路查询改写器
"""多路查询改写器,用于 RAG-Fusion。"""
将单个查询改写成多个相关查询,用于 RAG-Fusion。
"""
def __init__(
self,
llm: BaseLanguageModel,
num_queries: int = 3,
prompt_template: Optional[str] = None,
):
def __init__(self, llm: BaseLanguageModel, num_queries: int = 3):
"""
初始化查询改写器
初始化多路查询改写器
Args:
llm: 语言模型实例
num_queries: 生成的查询数量
prompt_template: 提示词模板
"""
self.llm = llm
self.num_queries = num_queries
# 默认提示词模板
self.prompt_template = prompt_template or """
你是一个专业的查询改写助手。你的任务是将用户的问题改写成 {num_queries} 个不同的版本。
这些版本应该从不同的角度、使用不同的关键词来表达相同或相关的意图。
原始问题: {question}
请生成 {num_queries} 个不同版本的查询,每个版本一行。
确保每个版本都是独立、完整的查询语句。
生成 {num_queries} 个查询:
"""
def transform_query(self, query: str) -> List[str]:
def create_multi_query_retriever(self, base_retriever):
"""
将单个查询改写成多个查询
Args:
query: 原始查询
Returns:
改写后的查询列表
"""
prompt = PromptTemplate.from_template(self.prompt_template)
chain = prompt | self.llm | StrOutputParser()
response = chain.invoke({
"question": query,
"num_queries": self.num_queries,
})
# 解析响应,每行一个查询
queries = [
q.strip()
for q in response.strip().split('\n')
if q.strip()
]
# 确保数量正确,如果不够则添加原始查询
if len(queries) < self.num_queries:
queries.extend([query] * (self.num_queries - len(queries)))
elif len(queries) > self.num_queries:
queries = queries[:self.num_queries]
# 确保包含原始查询
if query not in queries:
queries = [query] + queries[:self.num_queries-1]
return queries
def create_multi_query_retriever(
self,
base_retriever: Any,
include_original: bool = True,
) -> MultiQueryRetriever:
"""
创建多路查询检索器
创建多路查询检索器。
Args:
base_retriever: 基础检索器
include_original: 是否包含原始查询
Returns:
MultiQueryRetriever 实例
"""
retriever = MultiQueryRetriever.from_llm(
retriever=base_retriever,
llm=self.llm,
include_original=include_original,
)
# 设置生成的查询数量
retriever.llm_chain.prompt = PromptTemplate.from_template(
"你是一个专业的查询改写助手。你的任务是将用户的问题改写成 {num_queries} 个不同的版本。\n"
"这些版本应该从不同的角度、使用不同的关键词来表达相同或相关的意图。\n\n"
"原始问题: {question}\n\n"
"请生成 {num_queries} 个不同版本的查询,每个版本一行。\n"
"确保每个版本都是独立、完整的查询语句。\n\n"
"生成 {num_queries} 个查询:"
)
# 修改调用参数以包含 num_queries
original_invoke = retriever.llm_chain.invoke
def new_invoke(input_dict):
input_dict["num_queries"] = self.num_queries
return original_invoke(input_dict)
retriever.llm_chain.invoke = new_invoke
return retriever
@classmethod
def create_from_config(
cls,
llm: BaseLanguageModel,
config: Optional[dict] = None,
) -> "MultiQueryTransformer":
"""
从配置创建查询改写器
Args:
llm: 语言模型实例
config: 配置字典
Returns:
MultiQueryTransformer 实例
"""
config = config or {}
return cls(
llm=llm,
num_queries=config.get("num_queries", 3),
prompt_template=config.get("prompt_template", None),
)
def create_rag_fusion_pipeline(
base_retriever: Any,
llm: BaseLanguageModel,
reranker: Optional[Any] = None,
num_queries: int = 3,
) -> Any:
"""
创建完整的 RAG-Fusion 流水线
Args:
base_retriever: 基础检索器
llm: 语言模型(用于查询改写)
reranker: 重排序器(可选)
num_queries: 查询改写数量
Returns:
检索器实例
"""
# 创建多路查询改写器
query_transformer = MultiQueryTransformer(
llm=llm,
num_queries=num_queries,
)
# 创建多路查询检索器
multi_query_retriever = query_transformer.create_multi_query_retriever(
base_retriever=base_retriever,
include_original=True,
)
# 如果提供了重排序器,则应用重排序
if reranker is not None:
from langchain.retrievers import ContextualCompressionRetriever
return ContextualCompressionRetriever(
base_compressor=reranker,
base_retriever=multi_query_retriever,
)
return multi_query_retriever
# 由于当前 LangChain 版本不支持 MultiQueryRetriever,暂时返回基础检索器
# retriever = MultiQueryRetriever.from_llm(
# retriever=base_retriever,
# llm=self.llm,
# include_original=True
# )
#
# # 自定义提示词
# retriever.llm_chain.prompt = PromptTemplate.from_template(
# "你是一个专业的查询改写助手。你的任务是将用户的问题改写成 {num_queries} 个不同的版本。\n"
# "这些版本应该从不同的角度、使用不同的关键词来表达相同或相关的意图。\n\n"
# "原始问题: {question}\n\n"
# "请生成 {num_queries} 个不同版本的查询,每个版本一行。\n"
# "确保每个版本都是独立、完整的查询语句。\n\n"
# "生成 {num_queries} 个查询:"
# )
#
# # 修改调用参数以包含 num_queries
# original_ainvoke = retriever.llm_chain.ainvoke
# async def new_ainvoke(input_dict):
# input_dict["num_queries"] = self.num_queries
# return await original_ainvoke(input_dict)
# retriever.llm_chain.ainvoke = new_ainvoke
#
# return retriever
return base_retriever

View File

@@ -1,141 +1,65 @@
"""
Cross-Encoder 重排序器
重排序器模块
使用 sentence-transformers 加载交叉编码器模型对检索结果进行精排
使用 Cross-Encoder 模型对检索结果进行重排序,提高检索精度
"""
import os
from typing import List, Dict, Any, Optional
from langchain.retrievers.document_compressors import CrossEncoderReranker
from typing import List
from langchain_core.documents import Document
from sentence_transformers import CrossEncoder
class CrossEncoderReranker:
"""
Cross-Encoder 重排序器包装类
"""使用 Cross-Encoder 对检索结果重排序。"""
支持 BAAI/bge-reranker-base 等模型。
"""
def __init__(
self,
model_name: str = "BAAI/bge-reranker-base",
top_n: int = 5,
device: Optional[str] = None,
cache_folder: Optional[str] = None,
):
def __init__(self, model_name: str = "BAAI/bge-reranker-base", top_n: int = 5):
"""
初始化重排序器
Args:
model_name: 模型名称或路径
top_n: 返回的顶部文档数量
device: 设备cpu/cuda如果为 None 则自动选择
cache_folder: 模型缓存目录
model_name: 预训练模型名称
top_n: 返回前 N 个结果
"""
self.model_name = model_name
self.top_n = top_n
self.device = device
self.cache_folder = cache_folder or os.path.join(
os.path.expanduser("~"), ".cache", "sentence_transformers"
)
self.model = None
# 延迟加载模型
self._model = None
self._langchain_reranker = None
# 尝试加载 Cross-Encoder 模型
try:
from sentence_transformers import CrossEncoder
self.model = CrossEncoder(model_name)
except Exception as e:
print(f"警告: 无法加载 Cross-Encoder 模型 {model_name},将使用简单排序作为回退方案。错误: {e}")
def _load_model(self):
"""加载交叉编码器模型"""
if self._model is None:
try:
self._model = CrossEncoder(
self.model_name,
device=self.device,
cache_folder=self.cache_folder,
)
except Exception as e:
# 如果指定模型加载失败,尝试备用模型
print(f"加载模型 {self.model_name} 失败: {e}")
print("尝试加载备用模型 BAAI/bge-reranker-v2-m3...")
self._model = CrossEncoder(
"BAAI/bge-reranker-v2-m3",
device=self.device,
cache_folder=self.cache_folder,
)
def _create_langchain_reranker(self):
"""创建 LangChain 重排序器"""
if self._langchain_reranker is None:
self._load_model()
self._langchain_reranker = CrossEncoderReranker(
model=self._model,
top_n=self.top_n,
)
def rerank(
self,
query: str,
documents: List[Document],
def compress_documents(
self, documents: List[Document], query: str
) -> List[Document]:
"""
对文档进行重排序
Args:
query: 查询文本
documents: 待排序文档列表
documents: 待排序的文档列表
query: 查询字符串
Returns:
排序后的文档列表
排序后的文档列表
"""
self._create_langchain_reranker()
return self._langchain_reranker.compress_documents(
documents=documents,
query=query,
)
def create_contextual_compression_retriever(
self,
base_retriever: Any,
) -> Any:
"""
创建上下文压缩检索器
if not documents:
return []
Args:
base_retriever: 基础检索器
# 如果模型加载失败,返回前 top_n 个文档
if self.model is None:
return documents[:self.top_n]
# 使用 Cross-Encoder 进行重排序
try:
pairs = [[query, doc.page_content] for doc in documents]
scores = self.model.predict(pairs)
Returns:
上下文压缩检索器
"""
from langchain.retrievers import ContextualCompressionRetriever
self._create_langchain_reranker()
compression_retriever = ContextualCompressionRetriever(
base_compressor=self._langchain_reranker,
base_retriever=base_retriever,
)
return compression_retriever
@classmethod
def create_from_config(
cls,
config: Optional[Dict[str, Any]] = None,
) -> "CrossEncoderReranker":
"""
从配置创建重排序器
Args:
config: 配置字典,包含 model_name, top_n, device 等
Returns:
CrossEncoderReranker 实例
"""
config = config or {}
return cls(
model_name=config.get("model_name", "BAAI/bge-reranker-base"),
top_n=config.get("top_n", 5),
device=config.get("device", None),
cache_folder=config.get("cache_folder", None),
)
# 按分数降序排序
scored_docs = sorted(
zip(documents, scores), key=lambda x: x[1], reverse=True
)
return [doc for doc, _ in scored_docs[:self.top_n]]
except Exception as e:
print(f"警告: 重排序过程出错,将使用原始排序。错误: {e}")
return documents[:self.top_n]

View File

@@ -4,15 +4,12 @@ Qdrant 向量检索器
提供基础向量检索、混合检索Dense + BM25功能。
"""
import os
from typing import List, Dict, Any, Optional
from langchain_qdrant import Qdrant
from langchain_qdrant import QdrantVectorStore
from langchain.embeddings.base import Embeddings
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import DocumentCompressorPipeline
from langchain.retrievers import EnsembleRetriever
# from langchain.retrievers import EnsembleRetriever
from qdrant_client import QdrantClient
from qdrant_client.http import models
from rag_core import QDRANT_URL, QDRANT_API_KEY
def create_qdrant_client(
@@ -21,21 +18,21 @@ def create_qdrant_client(
) -> QdrantClient:
"""
创建 Qdrant 客户端
Args:
url: Qdrant 服务地址,默认从环境变量 QDRANT_URL 读取
api_key: API 密钥,默认从环境变量 QDRANT_API_KEY 读取
Returns:
QdrantClient 实例
"""
url = url or os.getenv("QDRANT_URL", "http://localhost:6333")
api_key = api_key or os.getenv("QDRANT_API_KEY")
url = url or QDRANT_URL
api_key = api_key or QDRANT_API_KEY
client_args = {"url": url}
if api_key:
client_args["api_key"] = api_key
return QdrantClient(**client_args)
@@ -44,34 +41,33 @@ def create_base_retriever(
embeddings: Embeddings,
search_kwargs: Optional[Dict[str, Any]] = None,
client: Optional[QdrantClient] = None,
) -> Qdrant:
) -> QdrantVectorStore:
"""
创建基础向量检索器
Args:
collection_name: Qdrant 集合名称
embeddings: 嵌入模型
search_kwargs: 搜索参数,默认 {"k": 20}
client: Qdrant 客户端,如果为 None 则自动创建
Returns:
Qdrant 检索器实例
QdrantVectorStore 检索器实例
"""
search_kwargs = search_kwargs or {"k": 20}
# 创建 Qdrant 客户端
if client is None:
client = create_qdrant_client()
search_kwargs = search_kwargs or {"k": 20}
# 创建 Qdrant 检索器
retriever = Qdrant.from_existing_collection(
embedding=embeddings,
collection_name=collection_name,
# 使用 QdrantVectorStore 创建向量存储
vector_store = QdrantVectorStore(
client=client,
content_payload_key="content", # 假设存储的文本字段名为 "content"
metadata_payload_key="metadata", # 元数据字段名
collection_name=collection_name,
embedding=embeddings,
)
return retriever.as_retriever(search_kwargs=search_kwargs)
return vector_store.as_retriever(search_kwargs=search_kwargs)
def create_hybrid_retriever(
@@ -80,65 +76,63 @@ def create_hybrid_retriever(
dense_k: int = 10,
sparse_k: int = 10,
client: Optional[QdrantClient] = None,
) -> ContextualCompressionRetriever:
) -> QdrantVectorStore:
"""
创建混合检索器Dense Vector + BM25
Args:
collection_name: Qdrant 集合名称
embeddings: 嵌入模型
dense_k: 向量检索返回数量
sparse_k: BM25 检索返回数量
client: Qdrant 客户端
Returns:
混合检索器
"""
# 创建 Qdrant 客户端
if client is None:
client = create_qdrant_client()
# 基础检索器Qdrant 支持混合检索)
base_retriever = Qdrant.from_existing_collection(
embedding=embeddings,
collection_name=collection_name,
# 使用 QdrantVectorStore 创建向量存储
vector_store = QdrantVectorStore(
client=client,
content_payload_key="content",
metadata_payload_key="metadata",
collection_name=collection_name,
embedding=embeddings,
)
# 配置混合检索参数
search_kwargs = {
"k": dense_k + sparse_k, # 总返回数量
"score_threshold": 0.3, # 相似度阈值
"k": dense_k + sparse_k,
"score_threshold": 0.3,
}
return base_retriever.as_retriever(search_kwargs=search_kwargs)
return vector_store.as_retriever(search_kwargs=search_kwargs)
def create_ensemble_retriever(
retrievers: List[Any],
weights: Optional[List[float]] = None,
c: int = 60,
) -> EnsembleRetriever:
"""
创建集成检索器,支持倒数排名融合 (RRF)
Args:
retrievers: 检索器列表
weights: 检索器权重
c: RRF 常数通常为60
Returns:
集成检索器
"""
if weights is None:
weights = [1.0 / len(retrievers)] * len(retrievers)
ensemble = EnsembleRetriever(
retrievers=retrievers,
weights=weights,
c=c,
search_type="rrf", # 使用倒数排名融合
)
return ensemble
# def create_ensemble_retriever(
# retrievers: List[Any],
# weights: Optional[List[float]] = None,
# c: int = 60,
# ) -> EnsembleRetriever:
# """
# 创建集成检索器,支持倒数排名融合 (RRF)
#
# Args:
# retrievers: 检索器列表
# weights: 检索器权重
# c: RRF 常数通常为60
#
# Returns:
# 集成检索器
# """
# if weights is None:
# weights = [1.0 / len(retrievers)] * len(retrievers)
#
# ensemble = EnsembleRetriever(
# retrievers=retrievers,
# weights=weights,
# c=c,
# search_type="rrf",
# )
#
# return ensemble

View File

@@ -1,230 +1,89 @@
"""
RAG 工具包装
RAG 工具模块
RAG 流水线包装成 LangChain Tool供 Agent 调用。
检索功能封装为 LangChain Tool供 Agent 调用。
"""
from typing import Optional, Dict, Any
from langchain.tools import tool
from langchain_core.tools import Tool
from langchain_core.embeddings import Embeddings
from langchain_core.language_models import BaseLanguageModel
from .pipeline import RAGPipeline, RAGConfig, RAGLevel
from langchain_core.tools import tool
from rag_core import LlamaCppEmbedder, QDRANT_URL, QDRANT_API_KEY
from .pipeline import RAGPipeline, RAGLevel
class RAGTool:
"""
RAG 工具包装器
@tool
async def search_knowledge_base(query: str, rag_level: str = "rerank") -> str:
"""在知识库中搜索与查询相关的文档片段。
将 RAG 流水线包装成 Agent 可调用的工具
"""
def __init__(
self,
pipeline: RAGPipeline,
tool_name: str = "search_knowledge_base",
tool_description: str = None,
):
"""
初始化 RAG 工具
Args:
pipeline: RAG 流水线实例
tool_name: 工具名称
tool_description: 工具描述
"""
self.pipeline = pipeline
self.tool_name = tool_name
# 默认工具描述
self.tool_description = tool_description or (
"在知识库中搜索相关信息。当用户询问需要查阅文档、规章制度、"
"专业知识或需要基于已知信息回答的问题时使用此工具。"
"输入应为要搜索的查询文本。"
)
# 创建 LangChain 工具
self._tool = self._create_tool()
def _create_tool(self) -> Tool:
"""创建 LangChain 工具"""
@tool(self.tool_name, args_schema=None)
def search_knowledge_base(query: str) -> str:
"""
在知识库中搜索相关信息
Args:
query: 搜索查询
Returns:
格式化后的搜索结果
"""
try:
# 执行检索
result = self.pipeline.retrieve(query)
if not result.documents:
return "在知识库中未找到相关信息。"
# 格式化上下文
context = self.pipeline.format_context(
result.documents,
max_length=4000, # 限制上下文长度
)
# 构建响应
response = (
f"🔍 在知识库中找到了 {len(result.documents)} 条相关信息:\n\n"
f"{context}\n\n"
f"⏱️ 检索耗时: {result.query_time:.2f}"
)
return response
except Exception as e:
error_msg = f"检索过程中发生错误: {str(e)}"
if self.pipeline.config.verbose:
print(f"RAG 工具错误: {error_msg}")
return error_msg
# 设置工具描述
search_knowledge_base.description = self.tool_description
return search_knowledge_base
def get_tool(self) -> Tool:
"""获取 LangChain 工具"""
return self._tool
def __call__(self, query: str) -> str:
"""直接调用工具"""
return self._tool.invoke({"query": query})
def create_rag_tool(
embeddings: Embeddings,
llm: Optional[BaseLanguageModel] = None,
config: Optional[Dict[str, Any]] = None,
tool_name: str = "search_knowledge_base",
tool_description: Optional[str] = None,
) -> Tool:
"""
创建 RAG 工具(便捷函数)
适用于事实性问题、背景知识查询
Args:
embeddings: 嵌入模型
llm: 语言模型(用于高级 RAG 功能
config: RAG 配置字典
tool_name: 工具名称
tool_description: 工具描述
query: 查询字符串
rag_level: 检索级别可选值basic基础向量检索、rerank基础检索+重排序、fusionRAG-Fusion
Returns:
LangChain Tool 实例
检索到的相关文档内容
"""
# 初始化嵌入模型
embedder = LlamaCppEmbedder()
embeddings = embedder.as_langchain_embeddings()
# 创建 RAG 流水线
pipeline = RAGPipeline.create_from_config(
pipeline = RAGPipeline(
embeddings=embeddings,
llm=llm,
config_dict=config,
config={
"rag_level": rag_level,
"collection_name": "rag_documents",
"rerank_top_n": 5,
}
)
# 创建工具包装器
rag_tool = RAGTool(
pipeline=pipeline,
tool_name=tool_name,
tool_description=tool_description,
)
# 执行检索
try:
documents = await pipeline.aretrieve(query)
if not documents:
return "未找到相关信息。"
# 格式化结果
context = pipeline.format_context(documents)
return context
except Exception as e:
return f"检索过程中发生错误: {str(e)}"
@tool
def search_knowledge_base_sync(query: str, rag_level: str = "rerank") -> str:
"""同步版本的知识库搜索工具。
return rag_tool.get_tool()
# 导出便捷函数
search_knowledge_base_tool = create_rag_tool
def bind_rag_to_agent(
agent_llm: BaseLanguageModel,
embeddings: Embeddings,
rag_llm: Optional[BaseLanguageModel] = None,
config: Optional[Dict[str, Any]] = None,
tool_name: str = "search_knowledge_base",
) -> BaseLanguageModel:
"""
将 RAG 工具绑定到 Agent 模型
适用于事实性问题、背景知识查询。
Args:
agent_llm: Agent 使用的语言模型
embeddings: 嵌入模型
rag_llm: RAG 流水线使用的语言模型(如果与 agent_llm 不同)
config: RAG 配置
tool_name: 工具名称
Returns:
绑定工具后的模型
"""
# 如果未指定 RAG LLM使用 Agent LLM
if rag_llm is None:
rag_llm = agent_llm
query: 查询字符串
rag_level: 检索级别可选值basic基础向量检索、rerank基础检索+重排序、fusionRAG-Fusion
# 创建 RAG 工具
rag_tool = create_rag_tool(
Returns:
检索到的相关文档内容
"""
# 初始化嵌入模型
embedder = LlamaCppEmbedder()
embeddings = embedder.as_langchain_embeddings()
# 创建 RAG 流水线
pipeline = RAGPipeline(
embeddings=embeddings,
llm=rag_llm,
config=config,
tool_name=tool_name,
config={
"rag_level": rag_level,
"collection_name": "rag_documents",
"rerank_top_n": 5,
}
)
# 绑定工具到模型
return agent_llm.bind_tools([rag_tool])
def create_agentic_rag_pipeline(
embeddings: Embeddings,
agent_llm: BaseLanguageModel,
rag_llm: Optional[BaseLanguageModel] = None,
config: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""
创建完整的 Agentic RAG 流水线Level 4
Args:
embeddings: 嵌入模型
agent_llm: Agent 模型
rag_llm: RAG 专用模型
config: 配置
# 执行检索
try:
documents = pipeline.retrieve(query)
if not documents:
return "未找到相关信息。"
Returns:
包含模型和工具的字典
"""
# 配置 Agentic RAG 级别
if config is None:
config = {}
config["rag_level"] = RAGLevel.AGENTIC.value
# 创建 RAG 工具
rag_tool = create_rag_tool(
embeddings=embeddings,
llm=rag_llm or agent_llm,
config=config,
tool_name="search_knowledge_base",
tool_description=(
"在知识库中搜索相关信息。当用户询问需要查阅文档、规章制度、"
"专业知识或需要基于已知信息回答的问题时使用此工具。"
"Agent 应该先判断是否需要使用此工具,然后调用它获取上下文。"
),
)
# 绑定工具到模型
bound_llm = agent_llm.bind_tools([rag_tool])
return {
"llm": bound_llm,
"tool": rag_tool,
"pipeline": RAGPipeline.create_from_config(
embeddings=embeddings,
llm=rag_llm or agent_llm,
config_dict=config,
),
}
# 格式化结果
context = pipeline.format_context(documents)
return context
except Exception as e:
return f"检索过程中发生错误: {str(e)}"