Files
ailine/app/rag/pipeline.py
2026-04-18 16:31:48 +08:00

341 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
RAG 检索流水线
组合检索器、重排序器、查询改写器等组件,提供完整的 RAG 检索功能。
"""
import time
from typing import List, Dict, Any, Optional, Union
from dataclasses import dataclass, field
from enum import Enum
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.language_models import BaseLanguageModel
from .retriever import (
create_base_retriever,
create_hybrid_retriever,
create_ensemble_retriever,
create_qdrant_client,
)
from .reranker import CrossEncoderReranker
from .query_transform import MultiQueryTransformer, create_rag_fusion_pipeline
class RAGLevel(Enum):
"""RAG 功能级别"""
BASIC = 1 # 基础向量搜索
HYBRID = 2 # 混合检索 + 重排序
FUSION = 3 # RAG-Fusion
AGENTIC = 4 # Agentic RAG
@dataclass
class RAGConfig:
"""RAG 配置"""
# Qdrant 配置
collection_name: str = "documents"
qdrant_url: Optional[str] = None
qdrant_api_key: Optional[str] = None
# 检索配置
rag_level: RAGLevel = RAGLevel.FUSION
dense_k: int = 10 # 向量检索数量
sparse_k: int = 10 # BM25 检索数量
total_k: int = 20 # 总检索数量
rerank_top_n: int = 5 # 重排序返回数量
# 查询改写配置
num_queries: int = 3 # RAG-Fusion 查询数量
# 模型配置
reranker_model: str = "BAAI/bge-reranker-base"
device: Optional[str] = None
# 性能配置
enable_cache: bool = True
verbose: bool = True
@dataclass
class RetrievalResult:
"""检索结果"""
documents: List[Document]
query_time: float
level: RAGLevel
metadata: Dict[str, Any] = field(default_factory=dict)
class RAGPipeline:
"""
RAG 检索流水线
支持从 Level 1 到 Level 4 的所有功能。
"""
def __init__(
self,
embeddings: Embeddings,
llm: Optional[BaseLanguageModel] = None,
config: Optional[RAGConfig] = None,
):
"""
初始化 RAG 流水线
Args:
embeddings: 嵌入模型
llm: 语言模型用于查询改写Level 3+ 需要)
config: 配置
"""
self.embeddings = embeddings
self.llm = llm
self.config = config or RAGConfig()
# 初始化组件
self._client = None
self._reranker = None
self._query_transformer = None
self._retriever = None
# 缓存
self._cache = {}
def _get_client(self):
"""获取 Qdrant 客户端"""
if self._client is None:
self._client = create_qdrant_client(
url=self.config.qdrant_url,
api_key=self.config.qdrant_api_key,
)
return self._client
def _get_reranker(self):
"""获取重排序器"""
if self._reranker is None:
self._reranker = CrossEncoderReranker(
model_name=self.config.reranker_model,
top_n=self.config.rerank_top_n,
device=self.config.device,
)
return self._reranker
def _get_query_transformer(self):
"""获取查询改写器"""
if self._query_transformer is None and self.llm is not None:
self._query_transformer = MultiQueryTransformer(
llm=self.llm,
num_queries=self.config.num_queries,
)
return self._query_transformer
def _create_basic_retriever(self):
"""创建基础检索器Level 1"""
return create_base_retriever(
collection_name=self.config.collection_name,
embeddings=self.embeddings,
search_kwargs={"k": self.config.total_k},
client=self._get_client(),
)
def _create_hybrid_retriever(self):
"""创建混合检索器Level 2"""
base_retriever = create_hybrid_retriever(
collection_name=self.config.collection_name,
embeddings=self.embeddings,
dense_k=self.config.dense_k,
sparse_k=self.config.sparse_k,
client=self._get_client(),
)
# 应用重排序
reranker = self._get_reranker()
return reranker.create_contextual_compression_retriever(base_retriever)
def _create_fusion_retriever(self):
"""创建 RAG-Fusion 检索器Level 3"""
if self.llm is None:
raise ValueError("Level 3 (RAG-Fusion) 需要语言模型进行查询改写")
# 创建基础混合检索器
base_retriever = create_hybrid_retriever(
collection_name=self.config.collection_name,
embeddings=self.embeddings,
dense_k=self.config.dense_k,
sparse_k=self.config.sparse_k,
client=self._get_client(),
)
# 创建 RAG-Fusion 流水线
reranker = self._get_reranker()
return create_rag_fusion_pipeline(
base_retriever=base_retriever,
llm=self.llm,
reranker=reranker,
num_queries=self.config.num_queries,
)
def _get_retriever(self):
"""根据配置级别获取检索器"""
if self._retriever is None:
if self.config.rag_level == RAGLevel.BASIC:
self._retriever = self._create_basic_retriever()
elif self.config.rag_level == RAGLevel.HYBRID:
self._retriever = self._create_hybrid_retriever()
elif self.config.rag_level == RAGLevel.FUSION:
self._retriever = self._create_fusion_retriever()
elif self.config.rag_level == RAGLevel.AGENTIC:
# Agentic RAG 使用 Fusion 作为基础,在 tools.py 中包装
self._retriever = self._create_fusion_retriever()
else:
raise ValueError(f"不支持的 RAG 级别: {self.config.rag_level}")
return self._retriever
def retrieve(
self,
query: str,
use_cache: Optional[bool] = None,
**kwargs,
) -> RetrievalResult:
"""
执行检索
Args:
query: 查询文本
use_cache: 是否使用缓存
**kwargs: 额外参数
Returns:
检索结果
"""
start_time = time.time()
# 检查缓存
if use_cache is None:
use_cache = self.config.enable_cache
cache_key = f"{query}:{self.config.rag_level.value}"
if use_cache and cache_key in self._cache:
if self.config.verbose:
print(f"使用缓存结果: {query}")
return self._cache[cache_key]
# 获取检索器并执行检索
retriever = self._get_retriever()
documents = retriever.invoke(query, **kwargs)
# 计算查询时间
query_time = time.time() - start_time
# 创建结果
result = RetrievalResult(
documents=documents,
query_time=query_time,
level=self.config.rag_level,
metadata={
"query": query,
"collection": self.config.collection_name,
"doc_count": len(documents),
},
)
# 缓存结果
if use_cache:
self._cache[cache_key] = result
if self.config.verbose:
print(f"检索完成: {len(documents)} 文档, 耗时: {query_time:.2f}s")
return result
def format_context(
self,
documents: List[Document],
max_length: Optional[int] = None,
) -> str:
"""
格式化检索到的文档为上下文文本
Args:
documents: 文档列表
max_length: 最大长度(字符数)
Returns:
格式化后的上下文文本
"""
context_parts = []
total_length = 0
for i, doc in enumerate(documents):
# 提取内容和元数据
content = doc.page_content.strip()
metadata = doc.metadata
# 格式化文档
doc_text = f"[文档 {i+1}]\n"
if metadata.get("source"):
doc_text += f"来源: {metadata['source']}\n"
if metadata.get("page"):
doc_text += f"页码: {metadata['page']}\n"
doc_text += f"内容: {content}\n\n"
# 检查长度限制
if max_length is not None:
if total_length + len(doc_text) > max_length:
# 如果添加这个文档会超限,则截断并添加说明
remaining = max_length - total_length
if remaining > 100: # 至少保留100字符
doc_text = doc_text[:remaining] + "...\n\n[内容已截断]"
context_parts.append(doc_text)
break
else:
break
context_parts.append(doc_text)
total_length += len(doc_text)
return "".join(context_parts).strip()
def clear_cache(self):
"""清空缓存"""
self._cache.clear()
@classmethod
def create_from_config(
cls,
embeddings: Embeddings,
llm: Optional[BaseLanguageModel] = None,
config_dict: Optional[Dict[str, Any]] = None,
) -> "RAGPipeline":
"""
从配置字典创建流水线
Args:
embeddings: 嵌入模型
llm: 语言模型
config_dict: 配置字典
Returns:
RAGPipeline 实例
"""
config_dict = config_dict or {}
# 创建配置对象
config = RAGConfig(
collection_name=config_dict.get("collection_name", "documents"),
qdrant_url=config_dict.get("qdrant_url"),
qdrant_api_key=config_dict.get("qdrant_api_key"),
rag_level=RAGLevel(config_dict.get("rag_level", RAGLevel.FUSION.value)),
dense_k=config_dict.get("dense_k", 10),
sparse_k=config_dict.get("sparse_k", 10),
total_k=config_dict.get("total_k", 20),
rerank_top_n=config_dict.get("rerank_top_n", 5),
num_queries=config_dict.get("num_queries", 3),
reranker_model=config_dict.get("reranker_model", "BAAI/bge-reranker-base"),
device=config_dict.get("device"),
enable_cache=config_dict.get("enable_cache", True),
verbose=config_dict.get("verbose", True),
)
return cls(embeddings=embeddings, llm=llm, config=config)