重排,多路查询
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 35m37s

This commit is contained in:
2026-04-20 01:10:18 +08:00
parent 933d418d77
commit 3c906e91d9
21 changed files with 728 additions and 635 deletions

View File

@@ -1,168 +1,92 @@
"""
RAG 检索流水线
# rag/pipeline.py
整合基础检索、重排序和 RAG-Fusion 功能。
"""
from enum import Enum
from typing import List, Optional, Dict, Any
import asyncio
from typing import List, Optional
from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel
from .retriever import (
create_base_retriever,
create_hybrid_retriever,
create_qdrant_client,
)
from .reranker import CrossEncoderReranker
from .query_transform import MultiQueryTransformer
from rag_core import QDRANT_URL, QDRANT_API_KEY
class RAGLevel(Enum):
"""RAG 级别"""
BASIC = "basic" # 基础向量检索
RERANK = "rerank" # 基础检索 + 重排序
FUSION = "fusion" # RAG-Fusion多路查询 + RRF
from .retriever import create_qdrant_client # 可能不需要直接使用
from .reranker import LLaMaCPPReranker
from .query_transform import MultiQueryGenerator
from .fusion import reciprocal_rank_fusion
class RAGPipeline:
"""RAG 检索流水线"""
"""
固定流程的 RAG 检索流水线:
多路改写 → 并行检索 → RRF融合 → 重排序 → 返回父文档
"""
def __init__(
self,
embeddings,
llm: Optional[BaseLanguageModel] = None,
config: Optional[Dict[str, Any]] = None,
retriever, # 基础检索器(应返回父文档,例如 ParentDocumentRetriever 实例)
llm: BaseLanguageModel,
num_queries: int = 3,
rerank_top_n: int = 5,
rerank_model: str = "BAAI/bge-reranker-base",
):
"""
初始化 RAG 流水线
Args:
embeddings: 嵌入模型
llm: 语言模型(用于 RAG-Fusion
config: 配置参数
retriever: 基础检索器对象,需实现 ainvoke(query) 异步方法
llm: 用于生成多路查询的语言模型
num_queries: 生成的查询变体数量
rerank_top_n: 最终返回的文档数量
rerank_model: 重排序模型名称
"""
self.embeddings = embeddings
self.retriever = retriever
self.llm = llm
self.config = config or {}
self.num_queries = num_queries
self.rerank_top_n = rerank_top_n
self.collection_name = self.config.get("collection_name", "rag_documents")
self.rag_level = self.config.get("rag_level", RAGLevel.RERANK.value)
self.num_queries = self.config.get("num_queries", 3)
self.rerank_top_n = self.config.get("rerank_top_n", 5)
# 初始化基础检索器
self.base_retriever = create_base_retriever(
collection_name=self.collection_name,
embeddings=self.embeddings,
search_kwargs={"k": 20}, # 召回 20 条
)
# 初始化重排序器
try:
self.reranker = CrossEncoderReranker(top_n=self.rerank_top_n)
except Exception as e:
print(f"警告: 无法创建重排序器,将使用基础检索。错误: {e}")
self.reranker = None
# 根据 RAG 级别创建检索器
self.retriever = self._create_retriever()
def _create_retriever(self):
"""根据 RAG 级别创建检索器"""
if self.rag_level == RAGLevel.BASIC.value:
return self.base_retriever
# 基础检索 + 重排序
def rerank_retriever(query):
documents = self.base_retriever.invoke(query)
if self.reranker:
return self.reranker.compress_documents(documents, query)
else:
return documents[:self.rerank_top_n]
if self.rag_level == RAGLevel.RERANK.value:
return SimpleRetriever(rerank_retriever)
# RAG-Fusion
if self.rag_level == RAGLevel.FUSION.value:
if not self.llm:
raise ValueError("RAG-Fusion 需要提供 llm 参数")
# 创建多路查询检索器
transformer = MultiQueryTransformer(
llm=self.llm,
num_queries=self.num_queries
# 初始化组件
self.query_generator = MultiQueryGenerator(llm=llm, num_queries=num_queries)
self.reranker = LLaMaCPPReranker(
base_url="http://127.0.0.1:8083",
top_n=rerank_top_n,
api_key="huang1998"
)
multi_query_retriever = transformer.create_multi_query_retriever(
base_retriever=SimpleRetriever(rerank_retriever)
)
return multi_query_retriever
return SimpleRetriever(rerank_retriever)
def retrieve(self, query: str) -> List[Document]:
"""
执行检索
Args:
query: 查询字符串
Returns:
相关文档列表
"""
return self.retriever.invoke(query)
async def aretrieve(self, query: str) -> List[Document]:
"""
异步执行检索
Args:
query: 查询字符串
Returns:
相关文档列表
异步执行完整检索流程
"""
return await self.retriever.ainvoke(query)
# Step 1: 生成多路查询
queries = await self.query_generator.agenerate(query)
# 包含原始查询,确保至少有一条
if query not in queries:
queries.insert(0, query)
else:
# 如果原始查询已在列表中,将其移至首位
queries.remove(query)
queries.insert(0, query)
# Step 2: 并行检索(每个查询获取文档列表)
tasks = [self.retriever.ainvoke(q) for q in queries]
doc_lists = await asyncio.gather(*tasks)
# Step 3: RRF 融合
fused_docs = reciprocal_rank_fusion(doc_lists)
# Step 4: 重排序
if self.reranker.model is not None:
final_docs = self.reranker.compress_documents(fused_docs, query)
else:
# 若重排序器不可用,直接返回融合后的前 N 条
final_docs = fused_docs[:self.rerank_top_n]
return final_docs
def retrieve(self, query: str) -> List[Document]:
"""同步检索入口(内部调用异步方法)"""
return asyncio.run(self.aretrieve(query))
def format_context(self, documents: List[Document]) -> str:
"""
格式化上下文
Args:
documents: 文档列表
Returns:
格式化后的上下文字符串
"""
"""将文档列表格式化为上下文字符串"""
if not documents:
return ""
context_parts = []
parts = []
for i, doc in enumerate(documents, 1):
content = doc.page_content
metadata = doc.metadata or {}
source = metadata.get("source", "未知来源")
part = f"【资料 {i}\n"
part += f"来源: {source}\n"
part += f"内容: {content}\n"
part += "---\n"
context_parts.append(part)
return "".join(context_parts)
class SimpleRetriever:
"""简单检索器包装类"""
def __init__(self, retrieve_func):
self.retrieve_func = retrieve_func
def invoke(self, query):
return self.retrieve_func(query)
async def ainvoke(self, query):
return self.retrieve_func(query)
source = doc.metadata.get("source", "未知来源")
parts.append(f"【资料 {i}】来源:{source}\n{doc.page_content}\n---\n")
return "\n".join(parts)