This commit is contained in:
91
backend/app/rag/pipeline.py
Normal file
91
backend/app/rag/pipeline.py
Normal file
@@ -0,0 +1,91 @@
|
||||
# rag/pipeline.py
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from ..config import LLAMACPP_RERANKER_URL, LLAMACPP_API_KEY
|
||||
from typing import List
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
|
||||
from .reranker import LLaMaCPPReranker
|
||||
from .query_transform import MultiQueryGenerator
|
||||
from .fusion import reciprocal_rank_fusion
|
||||
|
||||
class RAGPipeline:
|
||||
"""
|
||||
固定流程的 RAG 检索流水线:
|
||||
多路改写 → 并行检索 → RRF融合 → 重排序 → 返回父文档
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
retriever, # 基础检索器(应返回父文档,例如 ParentDocumentRetriever 实例)
|
||||
llm: BaseLanguageModel,
|
||||
num_queries: int = 3,
|
||||
rerank_top_n: int = 5,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
retriever: 基础检索器对象,需实现 ainvoke(query) 异步方法
|
||||
llm: 用于生成多路查询的语言模型
|
||||
num_queries: 生成的查询变体数量
|
||||
rerank_top_n: 最终返回的文档数量
|
||||
rerank_model: 重排序模型名称
|
||||
"""
|
||||
self.retriever = retriever
|
||||
self.llm = llm
|
||||
self.num_queries = num_queries
|
||||
self.rerank_top_n = rerank_top_n
|
||||
|
||||
# 初始化组件
|
||||
self.query_generator = MultiQueryGenerator(llm=llm, num_queries=num_queries)
|
||||
self.reranker = LLaMaCPPReranker(
|
||||
base_url=LLAMACPP_RERANKER_URL,
|
||||
api_key=LLAMACPP_API_KEY,
|
||||
top_n=rerank_top_n,
|
||||
)
|
||||
|
||||
async def aretrieve(self, query: str) -> List[Document]:
|
||||
"""
|
||||
异步执行完整检索流程
|
||||
"""
|
||||
# Step 1: 生成多路查询
|
||||
queries = await self.query_generator.agenerate(query)
|
||||
# 包含原始查询,确保至少有一条
|
||||
if query not in queries:
|
||||
queries.insert(0, query)
|
||||
else:
|
||||
# 如果原始查询已在列表中,将其移至首位
|
||||
queries.remove(query)
|
||||
queries.insert(0, query)
|
||||
|
||||
# Step 2: 并行检索(每个查询获取文档列表)
|
||||
tasks = [self.retriever.ainvoke(q) for q in queries]
|
||||
doc_lists = await asyncio.gather(*tasks)
|
||||
|
||||
# Step 3: RRF 融合
|
||||
fused_docs = reciprocal_rank_fusion(doc_lists)
|
||||
|
||||
# Step 4: 重排序
|
||||
try:
|
||||
final_docs = self.reranker.compress_documents(fused_docs, query)
|
||||
except Exception:
|
||||
# 若重排序器不可用,直接返回融合后的前 N 条
|
||||
final_docs = fused_docs[:self.rerank_top_n]
|
||||
|
||||
return final_docs
|
||||
|
||||
def retrieve(self, query: str) -> List[Document]:
|
||||
"""同步检索入口(内部调用异步方法)"""
|
||||
return asyncio.run(self.aretrieve(query))
|
||||
|
||||
def format_context(self, documents: List[Document]) -> str:
|
||||
"""将文档列表格式化为上下文字符串"""
|
||||
if not documents:
|
||||
return ""
|
||||
|
||||
parts = []
|
||||
for i, doc in enumerate(documents, 1):
|
||||
source = doc.metadata.get("source", "未知来源")
|
||||
parts.append(f"【资料 {i}】来源:{source}\n{doc.page_content}\n---\n")
|
||||
return "\n".join(parts)
|
||||
Reference in New Issue
Block a user