""" RAG 检索流水线 整合基础检索、重排序和 RAG-Fusion 功能。 """ from enum import Enum from typing import List, Optional, Dict, Any from langchain_core.documents import Document from langchain_core.language_models import BaseLanguageModel from .retriever import ( create_base_retriever, create_hybrid_retriever, create_qdrant_client, ) from .reranker import CrossEncoderReranker from .query_transform import MultiQueryTransformer from rag_core import QDRANT_URL, QDRANT_API_KEY class RAGLevel(Enum): """RAG 级别""" BASIC = "basic" # 基础向量检索 RERANK = "rerank" # 基础检索 + 重排序 FUSION = "fusion" # RAG-Fusion(多路查询 + RRF) class RAGPipeline: """RAG 检索流水线""" def __init__( self, embeddings, llm: Optional[BaseLanguageModel] = None, config: Optional[Dict[str, Any]] = None, ): """ 初始化 RAG 流水线 Args: embeddings: 嵌入模型 llm: 语言模型(用于 RAG-Fusion) config: 配置参数 """ self.embeddings = embeddings self.llm = llm self.config = config or {} self.collection_name = self.config.get("collection_name", "rag_documents") self.rag_level = self.config.get("rag_level", RAGLevel.RERANK.value) self.num_queries = self.config.get("num_queries", 3) self.rerank_top_n = self.config.get("rerank_top_n", 5) # 初始化基础检索器 self.base_retriever = create_base_retriever( collection_name=self.collection_name, embeddings=self.embeddings, search_kwargs={"k": 20}, # 召回 20 条 ) # 初始化重排序器 try: self.reranker = CrossEncoderReranker(top_n=self.rerank_top_n) except Exception as e: print(f"警告: 无法创建重排序器,将使用基础检索。错误: {e}") self.reranker = None # 根据 RAG 级别创建检索器 self.retriever = self._create_retriever() def _create_retriever(self): """根据 RAG 级别创建检索器""" if self.rag_level == RAGLevel.BASIC.value: return self.base_retriever # 基础检索 + 重排序 def rerank_retriever(query): documents = self.base_retriever.invoke(query) if self.reranker: return self.reranker.compress_documents(documents, query) else: return documents[:self.rerank_top_n] if self.rag_level == RAGLevel.RERANK.value: return SimpleRetriever(rerank_retriever) # RAG-Fusion if self.rag_level == RAGLevel.FUSION.value: if not self.llm: raise ValueError("RAG-Fusion 需要提供 llm 参数") # 创建多路查询检索器 transformer = MultiQueryTransformer( llm=self.llm, num_queries=self.num_queries ) multi_query_retriever = transformer.create_multi_query_retriever( base_retriever=SimpleRetriever(rerank_retriever) ) return multi_query_retriever return SimpleRetriever(rerank_retriever) def retrieve(self, query: str) -> List[Document]: """ 执行检索 Args: query: 查询字符串 Returns: 相关文档列表 """ return self.retriever.invoke(query) async def aretrieve(self, query: str) -> List[Document]: """ 异步执行检索 Args: query: 查询字符串 Returns: 相关文档列表 """ return await self.retriever.ainvoke(query) def format_context(self, documents: List[Document]) -> str: """ 格式化上下文 Args: documents: 文档列表 Returns: 格式化后的上下文字符串 """ if not documents: return "" context_parts = [] for i, doc in enumerate(documents, 1): content = doc.page_content metadata = doc.metadata or {} source = metadata.get("source", "未知来源") part = f"【资料 {i}】\n" part += f"来源: {source}\n" part += f"内容: {content}\n" part += "---\n" context_parts.append(part) return "".join(context_parts) class SimpleRetriever: """简单检索器包装类""" def __init__(self, retrieve_func): self.retrieve_func = retrieve_func def invoke(self, query): return self.retrieve_func(query) async def ainvoke(self, query): return self.retrieve_func(query)