2026-05-04 02:01:22 +08:00
|
|
|
|
"""
|
|
|
|
|
|
RAG 检索流水线模块
|
|
|
|
|
|
|
|
|
|
|
|
提供固定流程的 RAG 检索:
|
|
|
|
|
|
多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档
|
|
|
|
|
|
|
|
|
|
|
|
默认使用混合检索(稠密+稀疏)+ 父子文档模式。
|
|
|
|
|
|
"""
|
2026-04-21 11:02:16 +08:00
|
|
|
|
|
|
|
|
|
|
import asyncio
|
|
|
|
|
|
import os
|
|
|
|
|
|
from typing import List
|
|
|
|
|
|
from langchain_core.documents import Document
|
|
|
|
|
|
from langchain_core.language_models import BaseLanguageModel
|
|
|
|
|
|
|
2026-05-04 02:01:22 +08:00
|
|
|
|
from app.model_services import get_rerank_service
|
|
|
|
|
|
from app.rag.rerank import create_document_reranker
|
|
|
|
|
|
from app.rag.query_transform import MultiQueryGenerator
|
|
|
|
|
|
from app.rag.fusion import reciprocal_rank_fusion
|
|
|
|
|
|
from app.rag.retriever import create_parent_hybrid_retriever
|
|
|
|
|
|
|
2026-04-21 11:02:16 +08:00
|
|
|
|
|
|
|
|
|
|
class RAGPipeline:
|
|
|
|
|
|
"""
|
|
|
|
|
|
固定流程的 RAG 检索流水线:
|
2026-05-04 02:01:22 +08:00
|
|
|
|
多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档
|
|
|
|
|
|
|
|
|
|
|
|
默认使用混合检索(稠密+BM25稀疏)+ 父子文档模式。
|
2026-04-21 11:02:16 +08:00
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
|
self,
|
2026-05-04 02:01:22 +08:00
|
|
|
|
retriever=None,
|
|
|
|
|
|
llm: Optional[BaseLanguageModel] = None,
|
2026-04-21 11:02:16 +08:00
|
|
|
|
num_queries: int = 3,
|
|
|
|
|
|
rerank_top_n: int = 5,
|
2026-05-04 02:01:22 +08:00
|
|
|
|
collection_name: str = "rag_documents",
|
2026-04-21 11:02:16 +08:00
|
|
|
|
):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Args:
|
2026-05-04 02:01:22 +08:00
|
|
|
|
retriever: 基础检索器对象,需实现 ainvoke(query) 异步方法。
|
|
|
|
|
|
如果不提供,会自动创建默认的父子文档混合检索器。
|
|
|
|
|
|
llm: 用于生成多路查询的语言模型。
|
|
|
|
|
|
num_queries: 生成的查询变体数量。
|
|
|
|
|
|
rerank_top_n: 最终返回的文档数量。
|
|
|
|
|
|
collection_name: Qdrant 集合名称(仅当 retriever 未提供时使用)。
|
2026-04-21 11:02:16 +08:00
|
|
|
|
"""
|
2026-05-04 02:01:22 +08:00
|
|
|
|
# 如果没有提供 retriever,自动创建默认的混合检索器
|
|
|
|
|
|
if retriever is None:
|
|
|
|
|
|
self.retriever = create_parent_hybrid_retriever(
|
|
|
|
|
|
collection_name=collection_name,
|
|
|
|
|
|
search_k=rerank_top_n * 2 # 多取一些给重排序用
|
|
|
|
|
|
)
|
|
|
|
|
|
else:
|
|
|
|
|
|
self.retriever = retriever
|
|
|
|
|
|
|
2026-04-21 11:02:16 +08:00
|
|
|
|
self.llm = llm
|
|
|
|
|
|
self.num_queries = num_queries
|
|
|
|
|
|
self.rerank_top_n = rerank_top_n
|
|
|
|
|
|
|
2026-04-24 22:52:36 +08:00
|
|
|
|
# 初始化组件 - 使用统一的重排服务获取接口
|
2026-05-04 02:01:22 +08:00
|
|
|
|
self.query_generator = MultiQueryGenerator(llm=llm, num_queries=num_queries) if llm else None
|
2026-04-26 11:57:42 +08:00
|
|
|
|
self.reranker = create_document_reranker()
|
2026-04-21 11:02:16 +08:00
|
|
|
|
|
|
|
|
|
|
async def aretrieve(self, query: str) -> List[Document]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
异步执行完整检索流程
|
|
|
|
|
|
|
2026-05-04 02:01:22 +08:00
|
|
|
|
Args:
|
|
|
|
|
|
query: 用户查询
|
2026-04-21 11:02:16 +08:00
|
|
|
|
|
2026-05-04 02:01:22 +08:00
|
|
|
|
Returns:
|
|
|
|
|
|
检索到的相关文档列表
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 如果有 query_generator,做多路改写
|
|
|
|
|
|
if self.query_generator and self.llm:
|
|
|
|
|
|
# Step 1: 生成多路查询
|
|
|
|
|
|
queries = await self.query_generator.agenerate(query)
|
|
|
|
|
|
# 包含原始查询,确保至少有一条
|
|
|
|
|
|
if query not in queries:
|
|
|
|
|
|
queries.insert(0, query)
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 如果原始查询已在列表中,将其移至首位
|
|
|
|
|
|
queries.remove(query)
|
|
|
|
|
|
queries.insert(0, query)
|
|
|
|
|
|
|
|
|
|
|
|
# Step 2: 并行检索(每个查询获取文档列表)
|
|
|
|
|
|
tasks = [self.retriever.ainvoke(q) for q in queries]
|
|
|
|
|
|
doc_lists = await asyncio.gather(*tasks)
|
|
|
|
|
|
|
|
|
|
|
|
# Step 3: RRF 融合
|
|
|
|
|
|
fused_docs = reciprocal_rank_fusion(doc_lists)
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 没有 LLM 做查询改写,直接用原始查询检索
|
|
|
|
|
|
fused_docs = await self.retriever.ainvoke(query)
|
2026-04-21 11:02:16 +08:00
|
|
|
|
|
|
|
|
|
|
# Step 4: 重排序
|
|
|
|
|
|
try:
|
2026-04-24 22:52:36 +08:00
|
|
|
|
final_docs = self.reranker.compress_documents(fused_docs, query, top_n=self.rerank_top_n)
|
2026-04-21 11:02:16 +08:00
|
|
|
|
except Exception:
|
2026-04-24 22:52:36 +08:00
|
|
|
|
# 若重排序器不可用,直接返回融合后的前 N 个结果
|
2026-04-21 11:02:16 +08:00
|
|
|
|
final_docs = fused_docs[:self.rerank_top_n]
|
|
|
|
|
|
|
|
|
|
|
|
return final_docs
|
|
|
|
|
|
|
|
|
|
|
|
def retrieve(self, query: str) -> List[Document]:
|
|
|
|
|
|
"""同步检索入口(内部调用异步方法)"""
|
|
|
|
|
|
return asyncio.run(self.aretrieve(query))
|
|
|
|
|
|
|
|
|
|
|
|
def format_context(self, documents: List[Document]) -> str:
|
2026-05-04 02:01:22 +08:00
|
|
|
|
"""
|
|
|
|
|
|
将文档列表格式化为上下文字符串
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
documents: 文档列表
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
格式化后的上下文字符串
|
|
|
|
|
|
"""
|
2026-04-21 11:02:16 +08:00
|
|
|
|
if not documents:
|
|
|
|
|
|
return ""
|
|
|
|
|
|
|
|
|
|
|
|
parts = []
|
|
|
|
|
|
for i, doc in enumerate(documents, 1):
|
|
|
|
|
|
source = doc.metadata.get("source", "未知来源")
|
|
|
|
|
|
parts.append(f"【资料 {i}】来源:{source}\n{doc.page_content}\n---\n")
|
2026-05-04 02:01:22 +08:00
|
|
|
|
return "\n".join(parts)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_rag_pipeline(
|
|
|
|
|
|
collection_name: str = "rag_documents",
|
|
|
|
|
|
llm: Optional[BaseLanguageModel] = None,
|
|
|
|
|
|
num_queries: int = 3,
|
|
|
|
|
|
rerank_top_n: int = 5,
|
|
|
|
|
|
) -> RAGPipeline:
|
|
|
|
|
|
"""
|
|
|
|
|
|
创建 RAG 检索流水线的便捷函数
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
collection_name: Qdrant 集合名称
|
|
|
|
|
|
llm: 用于生成多路查询的语言模型
|
|
|
|
|
|
num_queries: 生成的查询变体数量
|
|
|
|
|
|
rerank_top_n: 最终返回的文档数量
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
RAGPipeline 实例
|
|
|
|
|
|
"""
|
|
|
|
|
|
return RAGPipeline(
|
|
|
|
|
|
llm=llm,
|
|
|
|
|
|
num_queries=num_queries,
|
|
|
|
|
|
rerank_top_n=rerank_top_n,
|
|
|
|
|
|
collection_name=collection_name
|
|
|
|
|
|
)
|