""" RAG 检索流水线 组合检索器、重排序器、查询改写器等组件,提供完整的 RAG 检索功能。 """ import time from typing import List, Dict, Any, Optional, Union from dataclasses import dataclass, field from enum import Enum from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.language_models import BaseLanguageModel from .retriever import ( create_base_retriever, create_hybrid_retriever, create_ensemble_retriever, create_qdrant_client, ) from .reranker import CrossEncoderReranker from .query_transform import MultiQueryTransformer, create_rag_fusion_pipeline class RAGLevel(Enum): """RAG 功能级别""" BASIC = 1 # 基础向量搜索 HYBRID = 2 # 混合检索 + 重排序 FUSION = 3 # RAG-Fusion AGENTIC = 4 # Agentic RAG @dataclass class RAGConfig: """RAG 配置""" # Qdrant 配置 collection_name: str = "documents" qdrant_url: Optional[str] = None qdrant_api_key: Optional[str] = None # 检索配置 rag_level: RAGLevel = RAGLevel.FUSION dense_k: int = 10 # 向量检索数量 sparse_k: int = 10 # BM25 检索数量 total_k: int = 20 # 总检索数量 rerank_top_n: int = 5 # 重排序返回数量 # 查询改写配置 num_queries: int = 3 # RAG-Fusion 查询数量 # 模型配置 reranker_model: str = "BAAI/bge-reranker-base" device: Optional[str] = None # 性能配置 enable_cache: bool = True verbose: bool = True @dataclass class RetrievalResult: """检索结果""" documents: List[Document] query_time: float level: RAGLevel metadata: Dict[str, Any] = field(default_factory=dict) class RAGPipeline: """ RAG 检索流水线 支持从 Level 1 到 Level 4 的所有功能。 """ def __init__( self, embeddings: Embeddings, llm: Optional[BaseLanguageModel] = None, config: Optional[RAGConfig] = None, ): """ 初始化 RAG 流水线 Args: embeddings: 嵌入模型 llm: 语言模型(用于查询改写,Level 3+ 需要) config: 配置 """ self.embeddings = embeddings self.llm = llm self.config = config or RAGConfig() # 初始化组件 self._client = None self._reranker = None self._query_transformer = None self._retriever = None # 缓存 self._cache = {} def _get_client(self): """获取 Qdrant 客户端""" if self._client is None: self._client = create_qdrant_client( url=self.config.qdrant_url, api_key=self.config.qdrant_api_key, ) return self._client def _get_reranker(self): """获取重排序器""" if self._reranker is None: self._reranker = CrossEncoderReranker( model_name=self.config.reranker_model, top_n=self.config.rerank_top_n, device=self.config.device, ) return self._reranker def _get_query_transformer(self): """获取查询改写器""" if self._query_transformer is None and self.llm is not None: self._query_transformer = MultiQueryTransformer( llm=self.llm, num_queries=self.config.num_queries, ) return self._query_transformer def _create_basic_retriever(self): """创建基础检索器(Level 1)""" return create_base_retriever( collection_name=self.config.collection_name, embeddings=self.embeddings, search_kwargs={"k": self.config.total_k}, client=self._get_client(), ) def _create_hybrid_retriever(self): """创建混合检索器(Level 2)""" base_retriever = create_hybrid_retriever( collection_name=self.config.collection_name, embeddings=self.embeddings, dense_k=self.config.dense_k, sparse_k=self.config.sparse_k, client=self._get_client(), ) # 应用重排序 reranker = self._get_reranker() return reranker.create_contextual_compression_retriever(base_retriever) def _create_fusion_retriever(self): """创建 RAG-Fusion 检索器(Level 3)""" if self.llm is None: raise ValueError("Level 3 (RAG-Fusion) 需要语言模型进行查询改写") # 创建基础混合检索器 base_retriever = create_hybrid_retriever( collection_name=self.config.collection_name, embeddings=self.embeddings, dense_k=self.config.dense_k, sparse_k=self.config.sparse_k, client=self._get_client(), ) # 创建 RAG-Fusion 流水线 reranker = self._get_reranker() return create_rag_fusion_pipeline( base_retriever=base_retriever, llm=self.llm, reranker=reranker, num_queries=self.config.num_queries, ) def _get_retriever(self): """根据配置级别获取检索器""" if self._retriever is None: if self.config.rag_level == RAGLevel.BASIC: self._retriever = self._create_basic_retriever() elif self.config.rag_level == RAGLevel.HYBRID: self._retriever = self._create_hybrid_retriever() elif self.config.rag_level == RAGLevel.FUSION: self._retriever = self._create_fusion_retriever() elif self.config.rag_level == RAGLevel.AGENTIC: # Agentic RAG 使用 Fusion 作为基础,在 tools.py 中包装 self._retriever = self._create_fusion_retriever() else: raise ValueError(f"不支持的 RAG 级别: {self.config.rag_level}") return self._retriever def retrieve( self, query: str, use_cache: Optional[bool] = None, **kwargs, ) -> RetrievalResult: """ 执行检索 Args: query: 查询文本 use_cache: 是否使用缓存 **kwargs: 额外参数 Returns: 检索结果 """ start_time = time.time() # 检查缓存 if use_cache is None: use_cache = self.config.enable_cache cache_key = f"{query}:{self.config.rag_level.value}" if use_cache and cache_key in self._cache: if self.config.verbose: print(f"使用缓存结果: {query}") return self._cache[cache_key] # 获取检索器并执行检索 retriever = self._get_retriever() documents = retriever.invoke(query, **kwargs) # 计算查询时间 query_time = time.time() - start_time # 创建结果 result = RetrievalResult( documents=documents, query_time=query_time, level=self.config.rag_level, metadata={ "query": query, "collection": self.config.collection_name, "doc_count": len(documents), }, ) # 缓存结果 if use_cache: self._cache[cache_key] = result if self.config.verbose: print(f"检索完成: {len(documents)} 文档, 耗时: {query_time:.2f}s") return result def format_context( self, documents: List[Document], max_length: Optional[int] = None, ) -> str: """ 格式化检索到的文档为上下文文本 Args: documents: 文档列表 max_length: 最大长度(字符数) Returns: 格式化后的上下文文本 """ context_parts = [] total_length = 0 for i, doc in enumerate(documents): # 提取内容和元数据 content = doc.page_content.strip() metadata = doc.metadata # 格式化文档 doc_text = f"[文档 {i+1}]\n" if metadata.get("source"): doc_text += f"来源: {metadata['source']}\n" if metadata.get("page"): doc_text += f"页码: {metadata['page']}\n" doc_text += f"内容: {content}\n\n" # 检查长度限制 if max_length is not None: if total_length + len(doc_text) > max_length: # 如果添加这个文档会超限,则截断并添加说明 remaining = max_length - total_length if remaining > 100: # 至少保留100字符 doc_text = doc_text[:remaining] + "...\n\n[内容已截断]" context_parts.append(doc_text) break else: break context_parts.append(doc_text) total_length += len(doc_text) return "".join(context_parts).strip() def clear_cache(self): """清空缓存""" self._cache.clear() @classmethod def create_from_config( cls, embeddings: Embeddings, llm: Optional[BaseLanguageModel] = None, config_dict: Optional[Dict[str, Any]] = None, ) -> "RAGPipeline": """ 从配置字典创建流水线 Args: embeddings: 嵌入模型 llm: 语言模型 config_dict: 配置字典 Returns: RAGPipeline 实例 """ config_dict = config_dict or {} # 创建配置对象 config = RAGConfig( collection_name=config_dict.get("collection_name", "documents"), qdrant_url=config_dict.get("qdrant_url"), qdrant_api_key=config_dict.get("qdrant_api_key"), rag_level=RAGLevel(config_dict.get("rag_level", RAGLevel.FUSION.value)), dense_k=config_dict.get("dense_k", 10), sparse_k=config_dict.get("sparse_k", 10), total_k=config_dict.get("total_k", 20), rerank_top_n=config_dict.get("rerank_top_n", 5), num_queries=config_dict.get("num_queries", 3), reranker_model=config_dict.get("reranker_model", "BAAI/bge-reranker-base"), device=config_dict.get("device"), enable_cache=config_dict.get("enable_cache", True), verbose=config_dict.get("verbose", True), ) return cls(embeddings=embeddings, llm=llm, config=config)