341 lines
11 KiB
Python
341 lines
11 KiB
Python
|
|
"""
|
|||
|
|
RAG 检索流水线
|
|||
|
|
|
|||
|
|
组合检索器、重排序器、查询改写器等组件,提供完整的 RAG 检索功能。
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import time
|
|||
|
|
from typing import List, Dict, Any, Optional, Union
|
|||
|
|
from dataclasses import dataclass, field
|
|||
|
|
from enum import Enum
|
|||
|
|
from langchain_core.documents import Document
|
|||
|
|
from langchain_core.embeddings import Embeddings
|
|||
|
|
from langchain_core.language_models import BaseLanguageModel
|
|||
|
|
|
|||
|
|
from .retriever import (
|
|||
|
|
create_base_retriever,
|
|||
|
|
create_hybrid_retriever,
|
|||
|
|
create_ensemble_retriever,
|
|||
|
|
create_qdrant_client,
|
|||
|
|
)
|
|||
|
|
from .reranker import CrossEncoderReranker
|
|||
|
|
from .query_transform import MultiQueryTransformer, create_rag_fusion_pipeline
|
|||
|
|
|
|||
|
|
|
|||
|
|
class RAGLevel(Enum):
|
|||
|
|
"""RAG 功能级别"""
|
|||
|
|
BASIC = 1 # 基础向量搜索
|
|||
|
|
HYBRID = 2 # 混合检索 + 重排序
|
|||
|
|
FUSION = 3 # RAG-Fusion
|
|||
|
|
AGENTIC = 4 # Agentic RAG
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class RAGConfig:
|
|||
|
|
"""RAG 配置"""
|
|||
|
|
# Qdrant 配置
|
|||
|
|
collection_name: str = "documents"
|
|||
|
|
qdrant_url: Optional[str] = None
|
|||
|
|
qdrant_api_key: Optional[str] = None
|
|||
|
|
|
|||
|
|
# 检索配置
|
|||
|
|
rag_level: RAGLevel = RAGLevel.FUSION
|
|||
|
|
dense_k: int = 10 # 向量检索数量
|
|||
|
|
sparse_k: int = 10 # BM25 检索数量
|
|||
|
|
total_k: int = 20 # 总检索数量
|
|||
|
|
rerank_top_n: int = 5 # 重排序返回数量
|
|||
|
|
|
|||
|
|
# 查询改写配置
|
|||
|
|
num_queries: int = 3 # RAG-Fusion 查询数量
|
|||
|
|
|
|||
|
|
# 模型配置
|
|||
|
|
reranker_model: str = "BAAI/bge-reranker-base"
|
|||
|
|
device: Optional[str] = None
|
|||
|
|
|
|||
|
|
# 性能配置
|
|||
|
|
enable_cache: bool = True
|
|||
|
|
verbose: bool = True
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class RetrievalResult:
|
|||
|
|
"""检索结果"""
|
|||
|
|
documents: List[Document]
|
|||
|
|
query_time: float
|
|||
|
|
level: RAGLevel
|
|||
|
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class RAGPipeline:
|
|||
|
|
"""
|
|||
|
|
RAG 检索流水线
|
|||
|
|
|
|||
|
|
支持从 Level 1 到 Level 4 的所有功能。
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
embeddings: Embeddings,
|
|||
|
|
llm: Optional[BaseLanguageModel] = None,
|
|||
|
|
config: Optional[RAGConfig] = None,
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
初始化 RAG 流水线
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
embeddings: 嵌入模型
|
|||
|
|
llm: 语言模型(用于查询改写,Level 3+ 需要)
|
|||
|
|
config: 配置
|
|||
|
|
"""
|
|||
|
|
self.embeddings = embeddings
|
|||
|
|
self.llm = llm
|
|||
|
|
self.config = config or RAGConfig()
|
|||
|
|
|
|||
|
|
# 初始化组件
|
|||
|
|
self._client = None
|
|||
|
|
self._reranker = None
|
|||
|
|
self._query_transformer = None
|
|||
|
|
self._retriever = None
|
|||
|
|
|
|||
|
|
# 缓存
|
|||
|
|
self._cache = {}
|
|||
|
|
|
|||
|
|
def _get_client(self):
|
|||
|
|
"""获取 Qdrant 客户端"""
|
|||
|
|
if self._client is None:
|
|||
|
|
self._client = create_qdrant_client(
|
|||
|
|
url=self.config.qdrant_url,
|
|||
|
|
api_key=self.config.qdrant_api_key,
|
|||
|
|
)
|
|||
|
|
return self._client
|
|||
|
|
|
|||
|
|
def _get_reranker(self):
|
|||
|
|
"""获取重排序器"""
|
|||
|
|
if self._reranker is None:
|
|||
|
|
self._reranker = CrossEncoderReranker(
|
|||
|
|
model_name=self.config.reranker_model,
|
|||
|
|
top_n=self.config.rerank_top_n,
|
|||
|
|
device=self.config.device,
|
|||
|
|
)
|
|||
|
|
return self._reranker
|
|||
|
|
|
|||
|
|
def _get_query_transformer(self):
|
|||
|
|
"""获取查询改写器"""
|
|||
|
|
if self._query_transformer is None and self.llm is not None:
|
|||
|
|
self._query_transformer = MultiQueryTransformer(
|
|||
|
|
llm=self.llm,
|
|||
|
|
num_queries=self.config.num_queries,
|
|||
|
|
)
|
|||
|
|
return self._query_transformer
|
|||
|
|
|
|||
|
|
def _create_basic_retriever(self):
|
|||
|
|
"""创建基础检索器(Level 1)"""
|
|||
|
|
return create_base_retriever(
|
|||
|
|
collection_name=self.config.collection_name,
|
|||
|
|
embeddings=self.embeddings,
|
|||
|
|
search_kwargs={"k": self.config.total_k},
|
|||
|
|
client=self._get_client(),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def _create_hybrid_retriever(self):
|
|||
|
|
"""创建混合检索器(Level 2)"""
|
|||
|
|
base_retriever = create_hybrid_retriever(
|
|||
|
|
collection_name=self.config.collection_name,
|
|||
|
|
embeddings=self.embeddings,
|
|||
|
|
dense_k=self.config.dense_k,
|
|||
|
|
sparse_k=self.config.sparse_k,
|
|||
|
|
client=self._get_client(),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 应用重排序
|
|||
|
|
reranker = self._get_reranker()
|
|||
|
|
return reranker.create_contextual_compression_retriever(base_retriever)
|
|||
|
|
|
|||
|
|
def _create_fusion_retriever(self):
|
|||
|
|
"""创建 RAG-Fusion 检索器(Level 3)"""
|
|||
|
|
if self.llm is None:
|
|||
|
|
raise ValueError("Level 3 (RAG-Fusion) 需要语言模型进行查询改写")
|
|||
|
|
|
|||
|
|
# 创建基础混合检索器
|
|||
|
|
base_retriever = create_hybrid_retriever(
|
|||
|
|
collection_name=self.config.collection_name,
|
|||
|
|
embeddings=self.embeddings,
|
|||
|
|
dense_k=self.config.dense_k,
|
|||
|
|
sparse_k=self.config.sparse_k,
|
|||
|
|
client=self._get_client(),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 创建 RAG-Fusion 流水线
|
|||
|
|
reranker = self._get_reranker()
|
|||
|
|
return create_rag_fusion_pipeline(
|
|||
|
|
base_retriever=base_retriever,
|
|||
|
|
llm=self.llm,
|
|||
|
|
reranker=reranker,
|
|||
|
|
num_queries=self.config.num_queries,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def _get_retriever(self):
|
|||
|
|
"""根据配置级别获取检索器"""
|
|||
|
|
if self._retriever is None:
|
|||
|
|
if self.config.rag_level == RAGLevel.BASIC:
|
|||
|
|
self._retriever = self._create_basic_retriever()
|
|||
|
|
elif self.config.rag_level == RAGLevel.HYBRID:
|
|||
|
|
self._retriever = self._create_hybrid_retriever()
|
|||
|
|
elif self.config.rag_level == RAGLevel.FUSION:
|
|||
|
|
self._retriever = self._create_fusion_retriever()
|
|||
|
|
elif self.config.rag_level == RAGLevel.AGENTIC:
|
|||
|
|
# Agentic RAG 使用 Fusion 作为基础,在 tools.py 中包装
|
|||
|
|
self._retriever = self._create_fusion_retriever()
|
|||
|
|
else:
|
|||
|
|
raise ValueError(f"不支持的 RAG 级别: {self.config.rag_level}")
|
|||
|
|
|
|||
|
|
return self._retriever
|
|||
|
|
|
|||
|
|
def retrieve(
|
|||
|
|
self,
|
|||
|
|
query: str,
|
|||
|
|
use_cache: Optional[bool] = None,
|
|||
|
|
**kwargs,
|
|||
|
|
) -> RetrievalResult:
|
|||
|
|
"""
|
|||
|
|
执行检索
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
query: 查询文本
|
|||
|
|
use_cache: 是否使用缓存
|
|||
|
|
**kwargs: 额外参数
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
检索结果
|
|||
|
|
"""
|
|||
|
|
start_time = time.time()
|
|||
|
|
|
|||
|
|
# 检查缓存
|
|||
|
|
if use_cache is None:
|
|||
|
|
use_cache = self.config.enable_cache
|
|||
|
|
|
|||
|
|
cache_key = f"{query}:{self.config.rag_level.value}"
|
|||
|
|
if use_cache and cache_key in self._cache:
|
|||
|
|
if self.config.verbose:
|
|||
|
|
print(f"使用缓存结果: {query}")
|
|||
|
|
return self._cache[cache_key]
|
|||
|
|
|
|||
|
|
# 获取检索器并执行检索
|
|||
|
|
retriever = self._get_retriever()
|
|||
|
|
documents = retriever.invoke(query, **kwargs)
|
|||
|
|
|
|||
|
|
# 计算查询时间
|
|||
|
|
query_time = time.time() - start_time
|
|||
|
|
|
|||
|
|
# 创建结果
|
|||
|
|
result = RetrievalResult(
|
|||
|
|
documents=documents,
|
|||
|
|
query_time=query_time,
|
|||
|
|
level=self.config.rag_level,
|
|||
|
|
metadata={
|
|||
|
|
"query": query,
|
|||
|
|
"collection": self.config.collection_name,
|
|||
|
|
"doc_count": len(documents),
|
|||
|
|
},
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 缓存结果
|
|||
|
|
if use_cache:
|
|||
|
|
self._cache[cache_key] = result
|
|||
|
|
|
|||
|
|
if self.config.verbose:
|
|||
|
|
print(f"检索完成: {len(documents)} 文档, 耗时: {query_time:.2f}s")
|
|||
|
|
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
def format_context(
|
|||
|
|
self,
|
|||
|
|
documents: List[Document],
|
|||
|
|
max_length: Optional[int] = None,
|
|||
|
|
) -> str:
|
|||
|
|
"""
|
|||
|
|
格式化检索到的文档为上下文文本
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
documents: 文档列表
|
|||
|
|
max_length: 最大长度(字符数)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
格式化后的上下文文本
|
|||
|
|
"""
|
|||
|
|
context_parts = []
|
|||
|
|
total_length = 0
|
|||
|
|
|
|||
|
|
for i, doc in enumerate(documents):
|
|||
|
|
# 提取内容和元数据
|
|||
|
|
content = doc.page_content.strip()
|
|||
|
|
metadata = doc.metadata
|
|||
|
|
|
|||
|
|
# 格式化文档
|
|||
|
|
doc_text = f"[文档 {i+1}]\n"
|
|||
|
|
if metadata.get("source"):
|
|||
|
|
doc_text += f"来源: {metadata['source']}\n"
|
|||
|
|
if metadata.get("page"):
|
|||
|
|
doc_text += f"页码: {metadata['page']}\n"
|
|||
|
|
|
|||
|
|
doc_text += f"内容: {content}\n\n"
|
|||
|
|
|
|||
|
|
# 检查长度限制
|
|||
|
|
if max_length is not None:
|
|||
|
|
if total_length + len(doc_text) > max_length:
|
|||
|
|
# 如果添加这个文档会超限,则截断并添加说明
|
|||
|
|
remaining = max_length - total_length
|
|||
|
|
if remaining > 100: # 至少保留100字符
|
|||
|
|
doc_text = doc_text[:remaining] + "...\n\n[内容已截断]"
|
|||
|
|
context_parts.append(doc_text)
|
|||
|
|
break
|
|||
|
|
else:
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
context_parts.append(doc_text)
|
|||
|
|
total_length += len(doc_text)
|
|||
|
|
|
|||
|
|
return "".join(context_parts).strip()
|
|||
|
|
|
|||
|
|
def clear_cache(self):
|
|||
|
|
"""清空缓存"""
|
|||
|
|
self._cache.clear()
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
def create_from_config(
|
|||
|
|
cls,
|
|||
|
|
embeddings: Embeddings,
|
|||
|
|
llm: Optional[BaseLanguageModel] = None,
|
|||
|
|
config_dict: Optional[Dict[str, Any]] = None,
|
|||
|
|
) -> "RAGPipeline":
|
|||
|
|
"""
|
|||
|
|
从配置字典创建流水线
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
embeddings: 嵌入模型
|
|||
|
|
llm: 语言模型
|
|||
|
|
config_dict: 配置字典
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
RAGPipeline 实例
|
|||
|
|
"""
|
|||
|
|
config_dict = config_dict or {}
|
|||
|
|
|
|||
|
|
# 创建配置对象
|
|||
|
|
config = RAGConfig(
|
|||
|
|
collection_name=config_dict.get("collection_name", "documents"),
|
|||
|
|
qdrant_url=config_dict.get("qdrant_url"),
|
|||
|
|
qdrant_api_key=config_dict.get("qdrant_api_key"),
|
|||
|
|
rag_level=RAGLevel(config_dict.get("rag_level", RAGLevel.FUSION.value)),
|
|||
|
|
dense_k=config_dict.get("dense_k", 10),
|
|||
|
|
sparse_k=config_dict.get("sparse_k", 10),
|
|||
|
|
total_k=config_dict.get("total_k", 20),
|
|||
|
|
rerank_top_n=config_dict.get("rerank_top_n", 5),
|
|||
|
|
num_queries=config_dict.get("num_queries", 3),
|
|||
|
|
reranker_model=config_dict.get("reranker_model", "BAAI/bge-reranker-base"),
|
|||
|
|
device=config_dict.get("device"),
|
|||
|
|
enable_cache=config_dict.get("enable_cache", True),
|
|||
|
|
verbose=config_dict.get("verbose", True),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return cls(embeddings=embeddings, llm=llm, config=config)
|