feat: 实现 BM25 稀疏 + 稠密向量混合检索功能
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Has been cancelled

This commit is contained in:
2026-05-04 02:01:22 +08:00
parent 2183c901b4
commit 60afa86ded
26 changed files with 905 additions and 656 deletions

View File

@@ -1,4 +1,11 @@
# rag/pipeline.py
"""
RAG 检索流水线模块
提供固定流程的 RAG 检索:
多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档
默认使用混合检索(稠密+稀疏)+ 父子文档模式。
"""
import asyncio
import os
@@ -6,61 +13,86 @@ from typing import List
from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel
from ..model_services import get_rerank_service
from .rerank import create_document_reranker
from .query_transform import MultiQueryGenerator
from .fusion import reciprocal_rank_fusion
from app.model_services import get_rerank_service
from app.rag.rerank import create_document_reranker
from app.rag.query_transform import MultiQueryGenerator
from app.rag.fusion import reciprocal_rank_fusion
from app.rag.retriever import create_parent_hybrid_retriever
class RAGPipeline:
"""
固定流程的 RAG 检索流水线:
多路改写 → 并行检索 → RRF融合 → 重排序 → 返回父文档
多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档
默认使用混合检索(稠密+BM25稀疏+ 父子文档模式。
"""
def __init__(
self,
retriever, # 基础检索器(应返回父文档,例如 ParentDocumentRetriever 实例)
llm: BaseLanguageModel,
retriever=None,
llm: Optional[BaseLanguageModel] = None,
num_queries: int = 3,
rerank_top_n: int = 5,
collection_name: str = "rag_documents",
):
"""
Args:
retriever: 基础检索器对象,需实现 ainvoke(query) 异步方法
llm: 用于生成多路查询的语言模型
num_queries: 生成的查询变体数量
rerank_top_n: 最终返回的文档数量
rerank_model: 重排序模型名称
retriever: 基础检索器对象,需实现 ainvoke(query) 异步方法
如果不提供,会自动创建默认的父子文档混合检索器。
llm: 用于生成多路查询的语言模型。
num_queries: 生成的查询变体数量
rerank_top_n: 最终返回的文档数量。
collection_name: Qdrant 集合名称(仅当 retriever 未提供时使用)。
"""
self.retriever = retriever
# 如果没有提供 retriever自动创建默认的混合检索器
if retriever is None:
self.retriever = create_parent_hybrid_retriever(
collection_name=collection_name,
search_k=rerank_top_n * 2 # 多取一些给重排序用
)
else:
self.retriever = retriever
self.llm = llm
self.num_queries = num_queries
self.rerank_top_n = rerank_top_n
# 初始化组件 - 使用统一的重排服务获取接口
self.query_generator = MultiQueryGenerator(llm=llm, num_queries=num_queries)
self.query_generator = MultiQueryGenerator(llm=llm, num_queries=num_queries) if llm else None
self.reranker = create_document_reranker()
async def aretrieve(self, query: str) -> List[Document]:
"""
异步执行完整检索流程
Args:
query: 用户查询
Returns:
检索到的相关文档列表
"""
# Step 1: 生成多路查询
queries = await self.query_generator.agenerate(query)
# 包含原始查询,确保至少有一条
if query not in queries:
queries.insert(0, query)
# 如果有 query_generator做多路改写
if self.query_generator and self.llm:
# Step 1: 生成多路查询
queries = await self.query_generator.agenerate(query)
# 包含原始查询,确保至少有一条
if query not in queries:
queries.insert(0, query)
else:
# 如果原始查询已在列表中,将其移至首位
queries.remove(query)
queries.insert(0, query)
# Step 2: 并行检索(每个查询获取文档列表)
tasks = [self.retriever.ainvoke(q) for q in queries]
doc_lists = await asyncio.gather(*tasks)
# Step 3: RRF 融合
fused_docs = reciprocal_rank_fusion(doc_lists)
else:
# 如果原始查询已在列表中,将其移至首位
queries.remove(query)
queries.insert(0, query)
# Step 2: 并行检索(每个查询获取文档列表)
tasks = [self.retriever.ainvoke(q) for q in queries]
doc_lists = await asyncio.gather(*tasks)
# Step 3: RRF 融合
fused_docs = reciprocal_rank_fusion(doc_lists)
# 没有 LLM 做查询改写,直接用原始查询检索
fused_docs = await self.retriever.ainvoke(query)
# Step 4: 重排序
try:
@@ -76,7 +108,15 @@ class RAGPipeline:
return asyncio.run(self.aretrieve(query))
def format_context(self, documents: List[Document]) -> str:
"""将文档列表格式化为上下文字符串"""
"""
将文档列表格式化为上下文字符串
Args:
documents: 文档列表
Returns:
格式化后的上下文字符串
"""
if not documents:
return ""
@@ -84,4 +124,30 @@ class RAGPipeline:
for i, doc in enumerate(documents, 1):
source = doc.metadata.get("source", "未知来源")
parts.append(f"【资料 {i}】来源:{source}\n{doc.page_content}\n---\n")
return "\n".join(parts)
return "\n".join(parts)
def create_rag_pipeline(
collection_name: str = "rag_documents",
llm: Optional[BaseLanguageModel] = None,
num_queries: int = 3,
rerank_top_n: int = 5,
) -> RAGPipeline:
"""
创建 RAG 检索流水线的便捷函数
Args:
collection_name: Qdrant 集合名称
llm: 用于生成多路查询的语言模型
num_queries: 生成的查询变体数量
rerank_top_n: 最终返回的文档数量
Returns:
RAGPipeline 实例
"""
return RAGPipeline(
llm=llm,
num_queries=num_queries,
rerank_top_n=rerank_top_n,
collection_name=collection_name
)