""" RAG 检索流水线模块 提供固定流程的 RAG 检索: 多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档 默认使用混合检索(稠密+稀疏)+ 父子文档模式。 """ import asyncio import os from typing import List, Optional from langchain_core.documents import Document from langchain_core.language_models import BaseLanguageModel from app.model_services import get_rerank_service from app.rag.rerank import create_document_reranker from app.rag.query_transform import MultiQueryGenerator from app.rag.fusion import reciprocal_rank_fusion from app.rag.retriever import create_parent_hybrid_retriever class RAGPipeline: """ 固定流程的 RAG 检索流水线: 多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档 默认使用混合检索(稠密+BM25稀疏)+ 父子文档模式。 """ def __init__( self, retriever=None, llm: Optional[BaseLanguageModel] = None, num_queries: int = 3, rerank_top_n: int = 5, collection_name: str = "rag_documents", ): """ Args: retriever: 基础检索器对象,需实现 ainvoke(query) 异步方法。 如果不提供,会自动创建默认的父子文档混合检索器。 llm: 用于生成多路查询的语言模型。 num_queries: 生成的查询变体数量。 rerank_top_n: 最终返回的文档数量。 collection_name: Qdrant 集合名称(仅当 retriever 未提供时使用)。 """ # 如果没有提供 retriever,自动创建默认的混合检索器 if retriever is None: self.retriever = create_parent_hybrid_retriever( collection_name=collection_name, search_k=rerank_top_n * 2 # 多取一些给重排序用 ) else: self.retriever = retriever self.llm = llm self.num_queries = num_queries self.rerank_top_n = rerank_top_n # 初始化组件 - 使用统一的重排服务获取接口 self.query_generator = MultiQueryGenerator(llm=llm, num_queries=num_queries) if llm else None self.reranker = create_document_reranker() async def aretrieve(self, query: str) -> List[Document]: """ 异步执行完整检索流程 Args: query: 用户查询 Returns: 检索到的相关文档列表 """ # 如果有 query_generator,做多路改写 if self.query_generator and self.llm: # Step 1: 生成多路查询 queries = await self.query_generator.agenerate(query) # 包含原始查询,确保至少有一条 if query not in queries: queries.insert(0, query) else: # 如果原始查询已在列表中,将其移至首位 queries.remove(query) queries.insert(0, query) # Step 2: 并行检索(每个查询获取文档列表) tasks = [self.retriever.ainvoke(q) for q in queries] doc_lists = await asyncio.gather(*tasks) # Step 3: RRF 融合 fused_docs = reciprocal_rank_fusion(doc_lists) else: # 没有 LLM 做查询改写,直接用原始查询检索 fused_docs = await self.retriever.ainvoke(query) # Step 4: 重排序 try: final_docs = self.reranker.compress_documents(fused_docs, query, top_n=self.rerank_top_n) except Exception: # 若重排序器不可用,直接返回融合后的前 N 个结果 final_docs = fused_docs[:self.rerank_top_n] return final_docs def retrieve(self, query: str) -> List[Document]: """同步检索入口(内部调用异步方法)""" return asyncio.run(self.aretrieve(query)) def format_context(self, documents: List[Document]) -> str: """ 将文档列表格式化为上下文字符串 Args: documents: 文档列表 Returns: 格式化后的上下文字符串 """ if not documents: return "" parts = [] for i, doc in enumerate(documents, 1): source = doc.metadata.get("source", "未知来源") parts.append(f"【资料 {i}】来源:{source}\n{doc.page_content}\n---\n") return "\n".join(parts) def create_rag_pipeline( collection_name: str = "rag_documents", llm: Optional[BaseLanguageModel] = None, num_queries: int = 3, rerank_top_n: int = 5, ) -> RAGPipeline: """ 创建 RAG 检索流水线的便捷函数 Args: collection_name: Qdrant 集合名称 llm: 用于生成多路查询的语言模型 num_queries: 生成的查询变体数量 rerank_top_n: 最终返回的文档数量 Returns: RAGPipeline 实例 """ return RAGPipeline( llm=llm, num_queries=num_queries, rerank_top_n=rerank_top_n, collection_name=collection_name )