本地RAG尝试

This commit is contained in:
2026-04-18 16:31:48 +08:00
parent 6042d4a476
commit 0470afce13
12 changed files with 1587 additions and 4 deletions

View File

@@ -6,14 +6,40 @@ AI Agent 服务类 - 支持多模型动态切换
import os
import json
from dotenv import load_dotenv
try:
from langchain_community.chat_models import ChatZhipuAI
from langchain_openai import ChatOpenAI
HAS_ZHIPUAI = True
except ImportError:
HAS_ZHIPUAI = False
ChatZhipuAI = None
try:
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
HAS_OPENAI = True
except ImportError:
HAS_OPENAI = False
ChatOpenAI = None
OpenAIEmbeddings = None
from pydantic import SecretStr
try:
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
HAS_POSTGRES_CHECKPOINT = True
except ImportError:
HAS_POSTGRES_CHECKPOINT = False
AsyncPostgresSaver = None
# 本地模块
from app.graph_builder import GraphBuilder, GraphContext
from app.tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
try:
from app.rag import RAGPipeline
from app.rag.tools import RAGTool
HAS_RAG = True
except ImportError as e:
HAS_RAG = False
RAGPipeline = None
RAGTool = None
from app.logger import debug, info, warning, error
@@ -31,9 +57,13 @@ class AIAgentService:
"""
self.checkpointer = checkpointer
self.graphs = {} # 存储不同模型对应的 graph 实例
self.rag = None # RAG 检索实例
self.rag_tool = None # RAG 工具实例
def _create_zhipu_llm(self):
"""创建智谱在线 LLM"""
if not HAS_ZHIPUAI:
raise ImportError("智谱AI支持不可用请安装langchain-community包")
api_key = os.getenv("ZHIPUAI_API_KEY")
if not api_key:
raise ValueError("ZHIPUAI_API_KEY not set in environment")
@@ -49,6 +79,8 @@ class AIAgentService:
def _create_deepseek_llm(self):
"""创建 DeepSeek LLM使用 OpenAI 兼容 API"""
if not HAS_OPENAI:
raise ImportError("OpenAI兼容支持不可用请安装langchain-openai包")
api_key = os.getenv("DEEPSEEK_API_KEY")
if not api_key:
raise ValueError("DEEPSEEK_API_KEY not set in environment")
@@ -65,6 +97,8 @@ class AIAgentService:
def _create_local_llm(self):
"""创建本地 vLLM 服务 LLM"""
if not HAS_OPENAI:
raise ImportError("OpenAI兼容支持不可用请安装langchain-openai包")
# vLLM 服务地址:优先从环境变量读取,适配 Docker、FRP 穿透和本地开发
vllm_base_url = os.getenv(
"VLLM_BASE_URL",
@@ -80,8 +114,39 @@ class AIAgentService:
streaming=True, # 确保开启流式输出
)
def _create_embeddings(self):
"""创建嵌入模型"""
if not HAS_OPENAI:
raise ImportError("OpenAI兼容支持不可用请安装langchain-openai包")
embedding_url = os.getenv(
"LLAMACPP_EMBEDDING_URL",
"http://127.0.0.1:8082/v1"
)
return OpenAIEmbeddings(
openai_api_base=embedding_url,
openai_api_key=os.getenv("LLAMACPP_API_KEY", "token-abc123"),
model="text-embedding-ada-002", # 模型名称不重要,兼容即可
)
async def initialize(self):
"""预编译所有模型的 graph使用传入的 checkpointer"""
# 先初始化 RAG 检索系统
if HAS_RAG and RAGPipeline is not None and RAGTool is not None:
try:
info("🔄 正在初始化 RAG 检索系统...")
embeddings = self._create_embeddings()
self.rag = RAGPipeline(embeddings=embeddings)
self.rag_tool = RAGTool(self.rag).get_tool()
info("✅ RAG 检索系统初始化成功")
except Exception as e:
warning(f"⚠️ RAG 检索系统初始化失败: {e}")
self.rag = None
self.rag_tool = None
else:
info("⏭️ RAG 检索系统不可用,跳过初始化")
self.rag = None
self.rag_tool = None
model_configs = {
"local": self._create_local_llm, # 本地模型作为第一个
"deepseek": self._create_deepseek_llm, # DeepSeek 作为中间
@@ -92,7 +157,16 @@ class AIAgentService:
try:
info(f"🔄 正在初始化模型 '{model_name}'...")
llm = llm_creator()
builder = GraphBuilder(llm, AVAILABLE_TOOLS, TOOLS_BY_NAME).build()
# 构建工具列表:基础工具 + RAG工具如果可用
tools = AVAILABLE_TOOLS.copy()
tools_by_name = TOOLS_BY_NAME.copy()
if self.rag_tool is not None:
tools.append(self.rag_tool)
tools_by_name[self.rag_tool.name] = self.rag_tool
builder = GraphBuilder(llm, tools, tools_by_name).build()
graph = builder.compile(checkpointer=self.checkpointer)
self.graphs[model_name] = graph
info(f"✅ 模型 '{model_name}' 初始化成功")

136
app/rag/README.md Normal file
View File

@@ -0,0 +1,136 @@
# 在线 RAG 检索与生成系统 (Online RAG Retriever)
该模块负责 RAG 系统的阶段二:**在线检索与生成**。它接收用户提问,从知识库中检索出上下文,利用各种高级策略去噪、融合,并作为增强上下文输入给大语言模型 (LLM)。
## 📊 RAG-Fusion & 混合检索流水线示意图
```mermaid
graph TD
User((用户提问)) --> A[LLM 查询改写生成器]
subgraph RAG-Fusion 核心流程
A -->|改写为问题 1| B1[查询 1]
A -->|改写为问题 2| B2[查询 2]
A -->|原问题| B3[原始查询]
B1 & B2 & B3 --> C[混合检索器 Hybrid Retriever <br> Dense Vector + BM25 Sparse]
C --> D[多路召回结果合集 N=60条]
D --> E{RRF 倒数排名融合去重}
end
E -->|筛选出前 20 条| F[Cross-Encoder 重排器 Reranker]
F -->|精细打分排序 Top 5| G[最终纯净上下文 Context]
G --> H[将 Context 与原问题拼接输入大模型]
H --> I((LLM 生成最终回答))
```
---
## 🎯 演进路线与算法详解 (Roadmap)
### Level 1: 基础向量搜索 (Basic Similarity Search)
- **核心算法**: 近似最近邻搜索 (ANN, 常用 HNSW 算法)。将用户问题转化为向量后,计算它与库中向量的余弦相似度 (Cosine Similarity),取距离最近的 K 个块。
- **优缺点**: 速度极快。但只能捕捉“语义相似”,如果用户搜索特定专有名词、编号、订单号,纯向量检索往往会失效(产生“幻觉”匹配)。
### Level 2: 混合检索与重排序 (Hybrid Search + Reranker)
混合检索旨在结合向量的“语义泛化”与关键词的“精准匹配”,随后利用重排序模型过滤噪声。
**1. 基础召回 (混合检索)**
- **核心原理**: 结合基于 HNSW 的 Dense Vector 相似度搜索与基于 TF-IDF 的 BM25 稀疏检索 (Sparse Vector)。
- **实现指南**: 使用 `langchain_qdrant` 包中的 `Qdrant` 类连接数据库。通过调用 `Qdrant.from_existing_collection(...)` 实例化向量库,并使用 `.as_retriever(search_kwargs={"k": 20})` 方法生成基础检索器。Qdrant 底层会自动处理双路召回。
**2. 二次精排 (Cross-Encoder)**
- **核心原理**: 不同于双塔模型(分别算向量再求距离),交叉编码器将“用户问题 + 检索到的单例文档”拼接后整体输入 Transformer 模型,由模型直接输出 0~1 的相关性得分,精度极高。
- **实现指南**:
- 使用 `sentence-transformers` 库加载本地轻量级重排模型(如 `BAAI/bge-reranker-base`)。
- 引入 `langchain.retrievers.document_compressors` 包中的 `CrossEncoderReranker` 类包装该模型,设置参数 `top_n=5`
- 最后,使用 `langchain.retrievers` 包中的 `ContextualCompressionRetriever` 类,将 `base_compressor` (重排器) 和 `base_retriever` (基础检索器) 组合。
- **如何调用**: 业务逻辑中直接对组合后的检索器调用 `.invoke(query)` 方法,即可一键完成“大范围召回 20 条 -> 逐一打分精排选 5 条”的去噪流水线。
### Level 3: RAG-Fusion (多路改写与倒数排名融合)
RAG-Fusion 通过大模型发散思维,将单一问题改写为多个相似问题,扩大搜索面,再利用数学统计算法合并结果。
**1. 多路查询改写**
- **核心原理**: 克服用户初始提问词不达意或视角受限的问题。
- **实现指南**: 导入 `langchain.retrievers.multi_query` 包中的 `MultiQueryRetriever` 类。需向其提供一个已实例化的 LLM 对象(如基于 `ChatOpenAI` 封装的本地 VLLM 模型)。系统在底层会自动 Prompt 模型,将原始 `query` 转化为包含 3-5 个不同表述的查询列表。
**2. 倒数排名融合 (RRF)**
- **核心原理**: RRF (Reciprocal Rank Fusion) 是一种无需评分归一化的融合算法。公式为 $RRF\_score(d) = \sum_{q \in Q} \frac{1}{k + rank_q(d)}$,有效避免某一极端检索结果主导全局。
- **实现指南**:
- 针对每个改写后的查询 $q$,分别调用精排检索器的 `.invoke(q)` 获取文档列表。
- 使用 `langchain.retrievers` 中的 `EnsembleRetriever` 类(原生支持 RRF或在代码中遍历收集到的 `Document` 对象,基于其排名 `rank` 累加得分,最终通过 Python 的 `set` 去重并提取 `doc.page_content`
### Level 4: Agentic RAG / Self-RAG (智能体与自我反思)
- **核心原理**: 基于 LangGraph 的 ReAct (Reasoning and Acting) 状态机路由。大模型并非每次都去死板地执行检索,而是先判断问题:“这是闲聊?还是需要查知识库?”。如果是后者,模型输出一个 `ToolCall` 指令,触发检索。
- **实现指南**: 请参考下方的**与现有系统整合调用**章节。
- **示意图**:
```mermaid
sequenceDiagram
participant User
participant LangGraph Agent
participant RAG_Tool
participant Qdrant
User->>LangGraph Agent: "公司报销流程是什么?"
LangGraph Agent->>LangGraph Agent: 思考: 这是一个内部规章问题,需要查资料
LangGraph Agent->>RAG_Tool: ToolCall(search_knowledge_base, "公司报销流程")
RAG_Tool->>Qdrant: RAG-Fusion & 混合检索
Qdrant-->>RAG_Tool: 原始分块
RAG_Tool->>RAG_Tool: Cross-Encoder 重排过滤
RAG_Tool-->>LangGraph Agent: 返回最相关的5条报销规定
LangGraph Agent->>LangGraph Agent: 思考: 资料充分,开始撰写回答
LangGraph Agent-->>User: "根据知识库规定报销流程分为以下3步..."
```
---
## 📦 所需依赖与安装
除了基础的 LangChain 包外,在线检索模块为了支持重排和稀疏检索,还需要安装:
```bash
# 用于 Cross-Encoder 重排序模型 (如 BAAI/bge-reranker-base)
pip install sentence-transformers
# 用于 BM25 关键词混合检索
pip install rank_bm25
# 基础框架
pip install langchain langchain-core langchain-openai langchain-qdrant
```
---
## 📂 架构与文件结构设计
在 `app/rag/` 目录下,需创建以下文件来模块化上述功能:
```text
app/rag/
├── __init__.py
├── retriever.py # 负责 Qdrant 的基础召回与 ContextualCompressionRetriever
├── reranker.py # 负责加载 sentence-transformers 交叉编码器
├── query_transform.py # 负责基于 MultiQueryRetriever 的改写逻辑
├── pipeline.py # 组合上述组件,暴露出核心的 retrieve() 方法
└── tools.py # 将 Pipeline 包装成 LangChain Tool 供 Agent 调用
```
---
## <20> 与现有系统整合调用 (Agentic RAG 实现)
基于目前 LangGraph 系统的架构,我们将摒弃将代码堆砌在一起的旧方式,而是利用 **LangChain Tools** 的特性将 RAG 优雅地注入系统:
1. **封装检索工具 (Tool)**:
从 `langchain.tools` 导入 `@tool` 装饰器。定义一个名为 `search_knowledge_base(query: str)` 的函数。在函数内部,实例化并调用我们在 `pipeline.py` 中写好的多路召回与重排逻辑。
2. **模型绑定 (Bind)**:
在 `app/agent.py` 或 `app/nodes/tool_call.py` 中,将这个工具引入,并通过 `llm.bind_tools([search_knowledge_base])` 绑定到现有的本地大模型实例上。
3. **状态机路由 (Graph Routing)**:
你的 LangGraph 状态机会像处理普通对话一样自动接管:当模型判断需要调用查阅规章制度或专业资料时,它会输出 `ToolCall` 消息,流转到 `tool_node` 执行上述的 RAG 检索逻辑并返回上下文。
这让你无需修改任何前端 Streamlit 流式代码,就能平滑升级为具备超级知识库检索能力的智能体 (Agent)

22
app/rag/__init__.py Normal file
View File

@@ -0,0 +1,22 @@
"""
在线 RAG 检索与生成系统
提供高级RAG检索功能支持混合检索、重排序、RAG-Fusion和多路查询改写。
"""
from .pipeline import RAGPipeline
from .retriever import create_hybrid_retriever, create_base_retriever
from .reranker import CrossEncoderReranker
from .query_transform import MultiQueryTransformer
from .tools import search_knowledge_base_tool
__all__ = [
"RAGPipeline",
"create_hybrid_retriever",
"create_base_retriever",
"CrossEncoderReranker",
"MultiQueryTransformer",
"search_knowledge_base_tool",
]
__version__ = "0.1.0"

232
app/rag/example.py Normal file
View File

@@ -0,0 +1,232 @@
#!/usr/bin/env python3
"""
RAG 系统使用示例
演示如何使用 app/rag 模块进行知识检索。
"""
import sys
import os
# 添加项目根目录到路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
from langchain_openai import OpenAIEmbeddings
from langchain_community.llms import VLLMOpenAI
def setup_environment():
"""设置环境变量"""
# 设置 Qdrant 连接信息(根据实际情况修改)
os.environ.setdefault("QDRANT_URL", "http://115.190.121.151:6333")
# 如果需要 API 密钥,请设置 QDRANT_API_KEY
print("环境变量已设置")
print(f"QDRANT_URL: {os.environ.get('QDRANT_URL')}")
def demonstrate_basic_rag():
"""演示基础 RAG 功能"""
print("\n" + "="*60)
print("演示: 基础 RAG 检索 (Level 1)")
print("="*60)
# 创建嵌入模型(使用 OpenAI 兼容的本地模型)
embeddings = OpenAIEmbeddings(
openai_api_base="http://localhost:8000/v1", # 本地 VLLM 服务
openai_api_key="no-key-needed",
model="text-embedding-ada-002", # 假设的模型名称
)
# 创建 RAG 流水线
from app.rag import RAGPipeline, RAGConfig, RAGLevel
config = RAGConfig(
collection_name="documents", # 你的集合名称
rag_level=RAGLevel.BASIC,
verbose=True,
)
pipeline = RAGPipeline(
embeddings=embeddings,
config=config,
)
# 示例查询
query = "公司报销流程是什么?"
print(f"\n查询: {query}")
try:
result = pipeline.retrieve(query)
print(f"找到 {len(result.documents)} 个相关文档")
# 格式化上下文
context = pipeline.format_context(result.documents)
print(f"\n上下文预览:\n{context[:500]}...")
except Exception as e:
print(f"检索失败: {e}")
print("请确保 Qdrant 服务正常运行且集合存在")
def demonstrate_hybrid_rag():
"""演示混合 RAG 功能"""
print("\n" + "="*60)
print("演示: 混合 RAG 检索 (Level 2)")
print("="*60)
embeddings = OpenAIEmbeddings(
openai_api_base="http://localhost:8000/v1",
openai_api_key="no-key-needed",
model="text-embedding-ada-002",
)
from app.rag import RAGPipeline, RAGConfig, RAGLevel
config = RAGConfig(
collection_name="documents",
rag_level=RAGLevel.HYBRID,
dense_k=10,
sparse_k=10,
rerank_top_n=5,
verbose=True,
)
pipeline = RAGPipeline(
embeddings=embeddings,
config=config,
)
query = "如何申请年假?"
print(f"\n查询: {query}")
try:
result = pipeline.retrieve(query)
print(f"找到 {len(result.documents)} 个重排序后的文档")
except Exception as e:
print(f"检索失败: {e}")
def demonstrate_rag_fusion():
"""演示 RAG-Fusion 功能"""
print("\n" + "="*60)
print("演示: RAG-Fusion (Level 3)")
print("="*60)
embeddings = OpenAIEmbeddings(
openai_api_base="http://localhost:8000/v1",
openai_api_key="no-key-needed",
model="text-embedding-ada-002",
)
# 创建语言模型用于查询改写
llm = VLLMOpenAI(
openai_api_base="http://localhost:8000/v1",
openai_api_key="no-key-needed",
model_name="Qwen2.5-7B-Instruct", # 你的本地模型
temperature=0.3,
max_tokens=512,
)
from app.rag import RAGPipeline, RAGConfig, RAGLevel
config = RAGConfig(
collection_name="documents",
rag_level=RAGLevel.FUSION,
num_queries=3,
verbose=True,
)
pipeline = RAGPipeline(
embeddings=embeddings,
llm=llm,
config=config,
)
query = "项目上线需要哪些审批?"
print(f"\n查询: {query}")
try:
result = pipeline.retrieve(query)
print(f"找到 {len(result.documents)} 个文档 (经过多路查询改写和重排序)")
except Exception as e:
print(f"检索失败: {e}")
def demonstrate_agentic_rag():
"""演示 Agentic RAG 功能"""
print("\n" + "="*60)
print("演示: Agentic RAG (Level 4)")
print("="*60)
embeddings = OpenAIEmbeddings(
openai_api_base="http://localhost:8000/v1",
openai_api_key="no-key-needed",
model="text-embedding-ada-002",
)
llm = VLLMOpenAI(
openai_api_base="http://localhost:8000/v1",
openai_api_key="no-key-needed",
model_name="Qwen2.5-7B-Instruct",
temperature=0.3,
max_tokens=512,
)
from app.rag import create_agentic_rag_pipeline
try:
# 创建 Agentic RAG 流水线
agentic_rag = create_agentic_rag_pipeline(
embeddings=embeddings,
agent_llm=llm,
config={
"collection_name": "documents",
"verbose": True,
},
)
print("Agentic RAG 流水线创建成功!")
print(f"- 绑定的模型: {agentic_rag['llm']}")
print(f"- RAG 工具: {agentic_rag['tool'].name}")
# 演示工具调用
print("\n工具调用示例:")
response = agentic_rag["tool"].invoke({"query": "员工福利有哪些?"})
print(f"工具响应预览: {response[:200]}...")
except Exception as e:
print(f"创建 Agentic RAG 失败: {e}")
import traceback
traceback.print_exc()
def main():
"""主函数"""
print("RAG 系统演示")
print("="*60)
# 设置环境
setup_environment()
# 演示各级功能
demonstrate_basic_rag()
demonstrate_hybrid_rag()
demonstrate_rag_fusion()
demonstrate_agentic_rag()
print("\n" + "="*60)
print("演示完成!")
print("="*60)
print("\n使用说明:")
print("1. 确保 Qdrant 服务运行且集合已创建")
print("2. 根据需要修改 embeddings 和 llm 配置")
print("3. 在 Agent 系统中导入并使用 app.rag.tools.search_knowledge_base_tool")
print("4. 将工具绑定到你的 Agent 模型")
if __name__ == "__main__":
main()

341
app/rag/pipeline.py Normal file
View File

@@ -0,0 +1,341 @@
"""
RAG 检索流水线
组合检索器、重排序器、查询改写器等组件,提供完整的 RAG 检索功能。
"""
import time
from typing import List, Dict, Any, Optional, Union
from dataclasses import dataclass, field
from enum import Enum
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.language_models import BaseLanguageModel
from .retriever import (
create_base_retriever,
create_hybrid_retriever,
create_ensemble_retriever,
create_qdrant_client,
)
from .reranker import CrossEncoderReranker
from .query_transform import MultiQueryTransformer, create_rag_fusion_pipeline
class RAGLevel(Enum):
"""RAG 功能级别"""
BASIC = 1 # 基础向量搜索
HYBRID = 2 # 混合检索 + 重排序
FUSION = 3 # RAG-Fusion
AGENTIC = 4 # Agentic RAG
@dataclass
class RAGConfig:
"""RAG 配置"""
# Qdrant 配置
collection_name: str = "documents"
qdrant_url: Optional[str] = None
qdrant_api_key: Optional[str] = None
# 检索配置
rag_level: RAGLevel = RAGLevel.FUSION
dense_k: int = 10 # 向量检索数量
sparse_k: int = 10 # BM25 检索数量
total_k: int = 20 # 总检索数量
rerank_top_n: int = 5 # 重排序返回数量
# 查询改写配置
num_queries: int = 3 # RAG-Fusion 查询数量
# 模型配置
reranker_model: str = "BAAI/bge-reranker-base"
device: Optional[str] = None
# 性能配置
enable_cache: bool = True
verbose: bool = True
@dataclass
class RetrievalResult:
"""检索结果"""
documents: List[Document]
query_time: float
level: RAGLevel
metadata: Dict[str, Any] = field(default_factory=dict)
class RAGPipeline:
"""
RAG 检索流水线
支持从 Level 1 到 Level 4 的所有功能。
"""
def __init__(
self,
embeddings: Embeddings,
llm: Optional[BaseLanguageModel] = None,
config: Optional[RAGConfig] = None,
):
"""
初始化 RAG 流水线
Args:
embeddings: 嵌入模型
llm: 语言模型用于查询改写Level 3+ 需要)
config: 配置
"""
self.embeddings = embeddings
self.llm = llm
self.config = config or RAGConfig()
# 初始化组件
self._client = None
self._reranker = None
self._query_transformer = None
self._retriever = None
# 缓存
self._cache = {}
def _get_client(self):
"""获取 Qdrant 客户端"""
if self._client is None:
self._client = create_qdrant_client(
url=self.config.qdrant_url,
api_key=self.config.qdrant_api_key,
)
return self._client
def _get_reranker(self):
"""获取重排序器"""
if self._reranker is None:
self._reranker = CrossEncoderReranker(
model_name=self.config.reranker_model,
top_n=self.config.rerank_top_n,
device=self.config.device,
)
return self._reranker
def _get_query_transformer(self):
"""获取查询改写器"""
if self._query_transformer is None and self.llm is not None:
self._query_transformer = MultiQueryTransformer(
llm=self.llm,
num_queries=self.config.num_queries,
)
return self._query_transformer
def _create_basic_retriever(self):
"""创建基础检索器Level 1"""
return create_base_retriever(
collection_name=self.config.collection_name,
embeddings=self.embeddings,
search_kwargs={"k": self.config.total_k},
client=self._get_client(),
)
def _create_hybrid_retriever(self):
"""创建混合检索器Level 2"""
base_retriever = create_hybrid_retriever(
collection_name=self.config.collection_name,
embeddings=self.embeddings,
dense_k=self.config.dense_k,
sparse_k=self.config.sparse_k,
client=self._get_client(),
)
# 应用重排序
reranker = self._get_reranker()
return reranker.create_contextual_compression_retriever(base_retriever)
def _create_fusion_retriever(self):
"""创建 RAG-Fusion 检索器Level 3"""
if self.llm is None:
raise ValueError("Level 3 (RAG-Fusion) 需要语言模型进行查询改写")
# 创建基础混合检索器
base_retriever = create_hybrid_retriever(
collection_name=self.config.collection_name,
embeddings=self.embeddings,
dense_k=self.config.dense_k,
sparse_k=self.config.sparse_k,
client=self._get_client(),
)
# 创建 RAG-Fusion 流水线
reranker = self._get_reranker()
return create_rag_fusion_pipeline(
base_retriever=base_retriever,
llm=self.llm,
reranker=reranker,
num_queries=self.config.num_queries,
)
def _get_retriever(self):
"""根据配置级别获取检索器"""
if self._retriever is None:
if self.config.rag_level == RAGLevel.BASIC:
self._retriever = self._create_basic_retriever()
elif self.config.rag_level == RAGLevel.HYBRID:
self._retriever = self._create_hybrid_retriever()
elif self.config.rag_level == RAGLevel.FUSION:
self._retriever = self._create_fusion_retriever()
elif self.config.rag_level == RAGLevel.AGENTIC:
# Agentic RAG 使用 Fusion 作为基础,在 tools.py 中包装
self._retriever = self._create_fusion_retriever()
else:
raise ValueError(f"不支持的 RAG 级别: {self.config.rag_level}")
return self._retriever
def retrieve(
self,
query: str,
use_cache: Optional[bool] = None,
**kwargs,
) -> RetrievalResult:
"""
执行检索
Args:
query: 查询文本
use_cache: 是否使用缓存
**kwargs: 额外参数
Returns:
检索结果
"""
start_time = time.time()
# 检查缓存
if use_cache is None:
use_cache = self.config.enable_cache
cache_key = f"{query}:{self.config.rag_level.value}"
if use_cache and cache_key in self._cache:
if self.config.verbose:
print(f"使用缓存结果: {query}")
return self._cache[cache_key]
# 获取检索器并执行检索
retriever = self._get_retriever()
documents = retriever.invoke(query, **kwargs)
# 计算查询时间
query_time = time.time() - start_time
# 创建结果
result = RetrievalResult(
documents=documents,
query_time=query_time,
level=self.config.rag_level,
metadata={
"query": query,
"collection": self.config.collection_name,
"doc_count": len(documents),
},
)
# 缓存结果
if use_cache:
self._cache[cache_key] = result
if self.config.verbose:
print(f"检索完成: {len(documents)} 文档, 耗时: {query_time:.2f}s")
return result
def format_context(
self,
documents: List[Document],
max_length: Optional[int] = None,
) -> str:
"""
格式化检索到的文档为上下文文本
Args:
documents: 文档列表
max_length: 最大长度(字符数)
Returns:
格式化后的上下文文本
"""
context_parts = []
total_length = 0
for i, doc in enumerate(documents):
# 提取内容和元数据
content = doc.page_content.strip()
metadata = doc.metadata
# 格式化文档
doc_text = f"[文档 {i+1}]\n"
if metadata.get("source"):
doc_text += f"来源: {metadata['source']}\n"
if metadata.get("page"):
doc_text += f"页码: {metadata['page']}\n"
doc_text += f"内容: {content}\n\n"
# 检查长度限制
if max_length is not None:
if total_length + len(doc_text) > max_length:
# 如果添加这个文档会超限,则截断并添加说明
remaining = max_length - total_length
if remaining > 100: # 至少保留100字符
doc_text = doc_text[:remaining] + "...\n\n[内容已截断]"
context_parts.append(doc_text)
break
else:
break
context_parts.append(doc_text)
total_length += len(doc_text)
return "".join(context_parts).strip()
def clear_cache(self):
"""清空缓存"""
self._cache.clear()
@classmethod
def create_from_config(
cls,
embeddings: Embeddings,
llm: Optional[BaseLanguageModel] = None,
config_dict: Optional[Dict[str, Any]] = None,
) -> "RAGPipeline":
"""
从配置字典创建流水线
Args:
embeddings: 嵌入模型
llm: 语言模型
config_dict: 配置字典
Returns:
RAGPipeline 实例
"""
config_dict = config_dict or {}
# 创建配置对象
config = RAGConfig(
collection_name=config_dict.get("collection_name", "documents"),
qdrant_url=config_dict.get("qdrant_url"),
qdrant_api_key=config_dict.get("qdrant_api_key"),
rag_level=RAGLevel(config_dict.get("rag_level", RAGLevel.FUSION.value)),
dense_k=config_dict.get("dense_k", 10),
sparse_k=config_dict.get("sparse_k", 10),
total_k=config_dict.get("total_k", 20),
rerank_top_n=config_dict.get("rerank_top_n", 5),
num_queries=config_dict.get("num_queries", 3),
reranker_model=config_dict.get("reranker_model", "BAAI/bge-reranker-base"),
device=config_dict.get("device"),
enable_cache=config_dict.get("enable_cache", True),
verbose=config_dict.get("verbose", True),
)
return cls(embeddings=embeddings, llm=llm, config=config)

193
app/rag/query_transform.py Normal file
View File

@@ -0,0 +1,193 @@
"""
查询改写器
基于 MultiQueryRetriever 实现多路查询改写,扩大搜索范围。
"""
from typing import List, Optional, Any
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
class MultiQueryTransformer:
"""
多路查询改写器
将单个查询改写成多个相关查询,用于 RAG-Fusion。
"""
def __init__(
self,
llm: BaseLanguageModel,
num_queries: int = 3,
prompt_template: Optional[str] = None,
):
"""
初始化查询改写器
Args:
llm: 语言模型实例
num_queries: 生成的查询数量
prompt_template: 提示词模板
"""
self.llm = llm
self.num_queries = num_queries
# 默认提示词模板
self.prompt_template = prompt_template or """
你是一个专业的查询改写助手。你的任务是将用户的问题改写成 {num_queries} 个不同的版本。
这些版本应该从不同的角度、使用不同的关键词来表达相同或相关的意图。
原始问题: {question}
请生成 {num_queries} 个不同版本的查询,每个版本一行。
确保每个版本都是独立、完整的查询语句。
生成 {num_queries} 个查询:
"""
def transform_query(self, query: str) -> List[str]:
"""
将单个查询改写成多个查询
Args:
query: 原始查询
Returns:
改写后的查询列表
"""
prompt = PromptTemplate.from_template(self.prompt_template)
chain = prompt | self.llm | StrOutputParser()
response = chain.invoke({
"question": query,
"num_queries": self.num_queries,
})
# 解析响应,每行一个查询
queries = [
q.strip()
for q in response.strip().split('\n')
if q.strip()
]
# 确保数量正确,如果不够则添加原始查询
if len(queries) < self.num_queries:
queries.extend([query] * (self.num_queries - len(queries)))
elif len(queries) > self.num_queries:
queries = queries[:self.num_queries]
# 确保包含原始查询
if query not in queries:
queries = [query] + queries[:self.num_queries-1]
return queries
def create_multi_query_retriever(
self,
base_retriever: Any,
include_original: bool = True,
) -> MultiQueryRetriever:
"""
创建多路查询检索器
Args:
base_retriever: 基础检索器
include_original: 是否包含原始查询
Returns:
MultiQueryRetriever 实例
"""
retriever = MultiQueryRetriever.from_llm(
retriever=base_retriever,
llm=self.llm,
include_original=include_original,
)
# 设置生成的查询数量
retriever.llm_chain.prompt = PromptTemplate.from_template(
"你是一个专业的查询改写助手。你的任务是将用户的问题改写成 {num_queries} 个不同的版本。\n"
"这些版本应该从不同的角度、使用不同的关键词来表达相同或相关的意图。\n\n"
"原始问题: {question}\n\n"
"请生成 {num_queries} 个不同版本的查询,每个版本一行。\n"
"确保每个版本都是独立、完整的查询语句。\n\n"
"生成 {num_queries} 个查询:"
)
# 修改调用参数以包含 num_queries
original_invoke = retriever.llm_chain.invoke
def new_invoke(input_dict):
input_dict["num_queries"] = self.num_queries
return original_invoke(input_dict)
retriever.llm_chain.invoke = new_invoke
return retriever
@classmethod
def create_from_config(
cls,
llm: BaseLanguageModel,
config: Optional[dict] = None,
) -> "MultiQueryTransformer":
"""
从配置创建查询改写器
Args:
llm: 语言模型实例
config: 配置字典
Returns:
MultiQueryTransformer 实例
"""
config = config or {}
return cls(
llm=llm,
num_queries=config.get("num_queries", 3),
prompt_template=config.get("prompt_template", None),
)
def create_rag_fusion_pipeline(
base_retriever: Any,
llm: BaseLanguageModel,
reranker: Optional[Any] = None,
num_queries: int = 3,
) -> Any:
"""
创建完整的 RAG-Fusion 流水线
Args:
base_retriever: 基础检索器
llm: 语言模型(用于查询改写)
reranker: 重排序器(可选)
num_queries: 查询改写数量
Returns:
检索器实例
"""
# 创建多路查询改写器
query_transformer = MultiQueryTransformer(
llm=llm,
num_queries=num_queries,
)
# 创建多路查询检索器
multi_query_retriever = query_transformer.create_multi_query_retriever(
base_retriever=base_retriever,
include_original=True,
)
# 如果提供了重排序器,则应用重排序
if reranker is not None:
from langchain.retrievers import ContextualCompressionRetriever
return ContextualCompressionRetriever(
base_compressor=reranker,
base_retriever=multi_query_retriever,
)
return multi_query_retriever

23
app/rag/requirements.txt Normal file
View File

@@ -0,0 +1,23 @@
# RAG 系统依赖
# 基础框架
langchain>=0.1.0
langchain-core>=0.1.0
langchain-openai>=0.0.1
langchain-qdrant>=0.1.0
# 用于 Cross-Encoder 重排序模型
sentence-transformers>=2.2.0
# 用于 BM25 关键词混合检索
rank-bm25>=0.2.2
# Qdrant 客户端
qdrant-client>=1.6.0
# 可选的本地模型支持
# vllm>=0.5.0 # 如果需要本地模型推理
# transformers>=4.35.0 # 如果需要其他模型支持
# 开发依赖(测试用)
pytest>=7.0.0
pytest-asyncio>=0.21.0

141
app/rag/reranker.py Normal file
View File

@@ -0,0 +1,141 @@
"""
Cross-Encoder 重排序器
使用 sentence-transformers 加载交叉编码器模型,对检索结果进行精排。
"""
import os
from typing import List, Dict, Any, Optional
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain_core.documents import Document
from sentence_transformers import CrossEncoder
class CrossEncoderReranker:
"""
Cross-Encoder 重排序器包装类
支持 BAAI/bge-reranker-base 等模型。
"""
def __init__(
self,
model_name: str = "BAAI/bge-reranker-base",
top_n: int = 5,
device: Optional[str] = None,
cache_folder: Optional[str] = None,
):
"""
初始化重排序器
Args:
model_name: 模型名称或路径
top_n: 返回的顶部文档数量
device: 设备cpu/cuda如果为 None 则自动选择
cache_folder: 模型缓存目录
"""
self.model_name = model_name
self.top_n = top_n
self.device = device
self.cache_folder = cache_folder or os.path.join(
os.path.expanduser("~"), ".cache", "sentence_transformers"
)
# 延迟加载模型
self._model = None
self._langchain_reranker = None
def _load_model(self):
"""加载交叉编码器模型"""
if self._model is None:
try:
self._model = CrossEncoder(
self.model_name,
device=self.device,
cache_folder=self.cache_folder,
)
except Exception as e:
# 如果指定模型加载失败,尝试备用模型
print(f"加载模型 {self.model_name} 失败: {e}")
print("尝试加载备用模型 BAAI/bge-reranker-v2-m3...")
self._model = CrossEncoder(
"BAAI/bge-reranker-v2-m3",
device=self.device,
cache_folder=self.cache_folder,
)
def _create_langchain_reranker(self):
"""创建 LangChain 重排序器"""
if self._langchain_reranker is None:
self._load_model()
self._langchain_reranker = CrossEncoderReranker(
model=self._model,
top_n=self.top_n,
)
def rerank(
self,
query: str,
documents: List[Document],
) -> List[Document]:
"""
对文档进行重排序
Args:
query: 查询文本
documents: 待排序文档列表
Returns:
重排序后的文档列表
"""
self._create_langchain_reranker()
return self._langchain_reranker.compress_documents(
documents=documents,
query=query,
)
def create_contextual_compression_retriever(
self,
base_retriever: Any,
) -> Any:
"""
创建上下文压缩检索器
Args:
base_retriever: 基础检索器
Returns:
上下文压缩检索器
"""
from langchain.retrievers import ContextualCompressionRetriever
self._create_langchain_reranker()
compression_retriever = ContextualCompressionRetriever(
base_compressor=self._langchain_reranker,
base_retriever=base_retriever,
)
return compression_retriever
@classmethod
def create_from_config(
cls,
config: Optional[Dict[str, Any]] = None,
) -> "CrossEncoderReranker":
"""
从配置创建重排序器
Args:
config: 配置字典,包含 model_name, top_n, device 等
Returns:
CrossEncoderReranker 实例
"""
config = config or {}
return cls(
model_name=config.get("model_name", "BAAI/bge-reranker-base"),
top_n=config.get("top_n", 5),
device=config.get("device", None),
cache_folder=config.get("cache_folder", None),
)

144
app/rag/retriever.py Normal file
View File

@@ -0,0 +1,144 @@
"""
Qdrant 向量检索器
提供基础向量检索、混合检索Dense + BM25功能。
"""
import os
from typing import List, Dict, Any, Optional
from langchain_qdrant import Qdrant
from langchain.embeddings.base import Embeddings
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import DocumentCompressorPipeline
from langchain.retrievers import EnsembleRetriever
from qdrant_client import QdrantClient
from qdrant_client.http import models
def create_qdrant_client(
url: Optional[str] = None,
api_key: Optional[str] = None,
) -> QdrantClient:
"""
创建 Qdrant 客户端
Args:
url: Qdrant 服务地址,默认从环境变量 QDRANT_URL 读取
api_key: API 密钥,默认从环境变量 QDRANT_API_KEY 读取
Returns:
QdrantClient 实例
"""
url = url or os.getenv("QDRANT_URL", "http://localhost:6333")
api_key = api_key or os.getenv("QDRANT_API_KEY")
client_args = {"url": url}
if api_key:
client_args["api_key"] = api_key
return QdrantClient(**client_args)
def create_base_retriever(
collection_name: str,
embeddings: Embeddings,
search_kwargs: Optional[Dict[str, Any]] = None,
client: Optional[QdrantClient] = None,
) -> Qdrant:
"""
创建基础向量检索器
Args:
collection_name: Qdrant 集合名称
embeddings: 嵌入模型
search_kwargs: 搜索参数,默认 {"k": 20}
client: Qdrant 客户端,如果为 None 则自动创建
Returns:
Qdrant 检索器实例
"""
if client is None:
client = create_qdrant_client()
search_kwargs = search_kwargs or {"k": 20}
# 创建 Qdrant 检索器
retriever = Qdrant.from_existing_collection(
embedding=embeddings,
collection_name=collection_name,
client=client,
content_payload_key="content", # 假设存储的文本字段名为 "content"
metadata_payload_key="metadata", # 元数据字段名
)
return retriever.as_retriever(search_kwargs=search_kwargs)
def create_hybrid_retriever(
collection_name: str,
embeddings: Embeddings,
dense_k: int = 10,
sparse_k: int = 10,
client: Optional[QdrantClient] = None,
) -> ContextualCompressionRetriever:
"""
创建混合检索器Dense Vector + BM25
Args:
collection_name: Qdrant 集合名称
embeddings: 嵌入模型
dense_k: 向量检索返回数量
sparse_k: BM25 检索返回数量
client: Qdrant 客户端
Returns:
混合检索器
"""
if client is None:
client = create_qdrant_client()
# 基础检索器Qdrant 支持混合检索)
base_retriever = Qdrant.from_existing_collection(
embedding=embeddings,
collection_name=collection_name,
client=client,
content_payload_key="content",
metadata_payload_key="metadata",
)
# 配置混合检索参数
search_kwargs = {
"k": dense_k + sparse_k, # 总返回数量
"score_threshold": 0.3, # 相似度阈值
}
return base_retriever.as_retriever(search_kwargs=search_kwargs)
def create_ensemble_retriever(
retrievers: List[Any],
weights: Optional[List[float]] = None,
c: int = 60,
) -> EnsembleRetriever:
"""
创建集成检索器,支持倒数排名融合 (RRF)
Args:
retrievers: 检索器列表
weights: 检索器权重
c: RRF 常数通常为60
Returns:
集成检索器
"""
if weights is None:
weights = [1.0 / len(retrievers)] * len(retrievers)
ensemble = EnsembleRetriever(
retrievers=retrievers,
weights=weights,
c=c,
search_type="rrf", # 使用倒数排名融合
)
return ensemble

230
app/rag/tools.py Normal file
View File

@@ -0,0 +1,230 @@
"""
RAG 工具包装
将 RAG 流水线包装成 LangChain Tool供 Agent 调用。
"""
from typing import Optional, Dict, Any
from langchain.tools import tool
from langchain_core.tools import Tool
from langchain_core.embeddings import Embeddings
from langchain_core.language_models import BaseLanguageModel
from .pipeline import RAGPipeline, RAGConfig, RAGLevel
class RAGTool:
"""
RAG 工具包装器
将 RAG 流水线包装成 Agent 可调用的工具。
"""
def __init__(
self,
pipeline: RAGPipeline,
tool_name: str = "search_knowledge_base",
tool_description: str = None,
):
"""
初始化 RAG 工具
Args:
pipeline: RAG 流水线实例
tool_name: 工具名称
tool_description: 工具描述
"""
self.pipeline = pipeline
self.tool_name = tool_name
# 默认工具描述
self.tool_description = tool_description or (
"在知识库中搜索相关信息。当用户询问需要查阅文档、规章制度、"
"专业知识或需要基于已知信息回答的问题时使用此工具。"
"输入应为要搜索的查询文本。"
)
# 创建 LangChain 工具
self._tool = self._create_tool()
def _create_tool(self) -> Tool:
"""创建 LangChain 工具"""
@tool(self.tool_name, args_schema=None)
def search_knowledge_base(query: str) -> str:
"""
在知识库中搜索相关信息
Args:
query: 搜索查询
Returns:
格式化后的搜索结果
"""
try:
# 执行检索
result = self.pipeline.retrieve(query)
if not result.documents:
return "在知识库中未找到相关信息。"
# 格式化上下文
context = self.pipeline.format_context(
result.documents,
max_length=4000, # 限制上下文长度
)
# 构建响应
response = (
f"🔍 在知识库中找到了 {len(result.documents)} 条相关信息:\n\n"
f"{context}\n\n"
f"⏱️ 检索耗时: {result.query_time:.2f}"
)
return response
except Exception as e:
error_msg = f"检索过程中发生错误: {str(e)}"
if self.pipeline.config.verbose:
print(f"RAG 工具错误: {error_msg}")
return error_msg
# 设置工具描述
search_knowledge_base.description = self.tool_description
return search_knowledge_base
def get_tool(self) -> Tool:
"""获取 LangChain 工具"""
return self._tool
def __call__(self, query: str) -> str:
"""直接调用工具"""
return self._tool.invoke({"query": query})
def create_rag_tool(
embeddings: Embeddings,
llm: Optional[BaseLanguageModel] = None,
config: Optional[Dict[str, Any]] = None,
tool_name: str = "search_knowledge_base",
tool_description: Optional[str] = None,
) -> Tool:
"""
创建 RAG 工具(便捷函数)
Args:
embeddings: 嵌入模型
llm: 语言模型(用于高级 RAG 功能)
config: RAG 配置字典
tool_name: 工具名称
tool_description: 工具描述
Returns:
LangChain Tool 实例
"""
# 创建 RAG 流水线
pipeline = RAGPipeline.create_from_config(
embeddings=embeddings,
llm=llm,
config_dict=config,
)
# 创建工具包装器
rag_tool = RAGTool(
pipeline=pipeline,
tool_name=tool_name,
tool_description=tool_description,
)
return rag_tool.get_tool()
# 导出便捷函数
search_knowledge_base_tool = create_rag_tool
def bind_rag_to_agent(
agent_llm: BaseLanguageModel,
embeddings: Embeddings,
rag_llm: Optional[BaseLanguageModel] = None,
config: Optional[Dict[str, Any]] = None,
tool_name: str = "search_knowledge_base",
) -> BaseLanguageModel:
"""
将 RAG 工具绑定到 Agent 模型
Args:
agent_llm: Agent 使用的语言模型
embeddings: 嵌入模型
rag_llm: RAG 流水线使用的语言模型(如果与 agent_llm 不同)
config: RAG 配置
tool_name: 工具名称
Returns:
绑定工具后的模型
"""
# 如果未指定 RAG LLM使用 Agent LLM
if rag_llm is None:
rag_llm = agent_llm
# 创建 RAG 工具
rag_tool = create_rag_tool(
embeddings=embeddings,
llm=rag_llm,
config=config,
tool_name=tool_name,
)
# 绑定工具到模型
return agent_llm.bind_tools([rag_tool])
def create_agentic_rag_pipeline(
embeddings: Embeddings,
agent_llm: BaseLanguageModel,
rag_llm: Optional[BaseLanguageModel] = None,
config: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""
创建完整的 Agentic RAG 流水线Level 4
Args:
embeddings: 嵌入模型
agent_llm: Agent 模型
rag_llm: RAG 专用模型
config: 配置
Returns:
包含模型和工具的字典
"""
# 配置 Agentic RAG 级别
if config is None:
config = {}
config["rag_level"] = RAGLevel.AGENTIC.value
# 创建 RAG 工具
rag_tool = create_rag_tool(
embeddings=embeddings,
llm=rag_llm or agent_llm,
config=config,
tool_name="search_knowledge_base",
tool_description=(
"在知识库中搜索相关信息。当用户询问需要查阅文档、规章制度、"
"专业知识或需要基于已知信息回答的问题时使用此工具。"
"Agent 应该先判断是否需要使用此工具,然后调用它获取上下文。"
),
)
# 绑定工具到模型
bound_llm = agent_llm.bind_tools([rag_tool])
return {
"llm": bound_llm,
"tool": rag_tool,
"pipeline": RAGPipeline.create_from_config(
embeddings=embeddings,
llm=rag_llm or agent_llm,
config_dict=config,
),
}

View File

@@ -119,6 +119,7 @@ def _handle_ai_response():
api_thought = ""
display_text = ""
display_thought = ""
rag_sources = None # 存储 RAG 检索来源信息
# 调用流式 API
stream = api_client.chat_stream(
@@ -213,6 +214,25 @@ def _handle_ai_response():
last_msg = messages_update[-1] if messages_update else {}
if isinstance(last_msg, dict) and last_msg.get("role") == "tool":
tool_name = last_msg.get("name", "unknown")
tool_content = last_msg.get("content", "")
# 存储 RAG 检索结果
if tool_name == "search_knowledge_base":
# 尝试解析 tool_content它可能是 JSON 字符串
sources = []
try:
if isinstance(tool_content, str):
import json
data = json.loads(tool_content)
else:
data = tool_content
# 提取来源列表
if isinstance(data, dict) and "sources" in data:
sources = data["sources"]
else:
sources = [str(data)]
except Exception:
sources = [str(tool_content)]
rag_sources = sources
tool_status_placeholder.success(f"✅ 工具 {tool_name} 执行完成")
# 短暂显示后清除,保持界面清爽
import time
@@ -270,6 +290,31 @@ def _handle_ai_response():
# 移除光标
message_placeholder.markdown(display_text)
# 显示 RAG 检索来源(如果有)
if rag_sources:
with st.expander("🔍 检索来源", expanded=False):
# 格式化来源列表
if isinstance(rag_sources, list):
for i, source in enumerate(rag_sources, 1):
if isinstance(source, dict):
content = source.get("page_content", source.get("content", str(source)))
metadata = source.get("metadata", {})
filename = metadata.get("filename", metadata.get("source", "未知文件"))
page = metadata.get("page", metadata.get("page_number", ""))
if page:
source_info = f"**来源 {i}:** {filename} (第{page}页)"
else:
source_info = f"**来源 {i}:** {filename}"
st.markdown(source_info)
# 显示内容预览前200字符
preview = content[:200] + "..." if len(content) > 200 else content
st.markdown(f"> {preview}")
st.markdown("---")
else:
st.markdown(f"**来源 {i}:** {str(source)}")
else:
st.markdown(str(rag_sources))
# 拼装包含思考过程的完整内容,以便后续在历史中正确渲染
final_content = display_text
if display_thought:

View File

@@ -48,6 +48,8 @@ pydantic==2.12.5
python-dotenv==1.2.2
typing-extensions==4.15.0
unstructured>=0.0.1
# ============================================================================
# 注意:
# 1. 此文件包含项目直接依赖的精确版本