Files
ailine/app/rag/pipeline.py

92 lines
3.3 KiB
Python
Raw Normal View History

2026-04-20 01:10:18 +08:00
# rag/pipeline.py
2026-04-18 16:31:48 +08:00
2026-04-20 01:10:18 +08:00
import asyncio
2026-04-20 14:05:57 +08:00
import os
2026-04-20 01:10:18 +08:00
from typing import List, Optional
2026-04-18 16:31:48 +08:00
from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel
2026-04-20 01:10:18 +08:00
from .retriever import create_qdrant_client # 可能不需要直接使用
from .reranker import LLaMaCPPReranker
from .query_transform import MultiQueryGenerator
from .fusion import reciprocal_rank_fusion
2026-04-18 16:31:48 +08:00
class RAGPipeline:
2026-04-20 01:10:18 +08:00
"""
固定流程的 RAG 检索流水线
多路改写 并行检索 RRF融合 重排序 返回父文档
"""
2026-04-18 16:31:48 +08:00
def __init__(
self,
2026-04-20 01:10:18 +08:00
retriever, # 基础检索器(应返回父文档,例如 ParentDocumentRetriever 实例)
llm: BaseLanguageModel,
num_queries: int = 3,
rerank_top_n: int = 5,
2026-04-18 16:31:48 +08:00
):
"""
Args:
2026-04-20 01:10:18 +08:00
retriever: 基础检索器对象需实现 ainvoke(query) 异步方法
llm: 用于生成多路查询的语言模型
num_queries: 生成的查询变体数量
rerank_top_n: 最终返回的文档数量
rerank_model: 重排序模型名称
2026-04-18 16:31:48 +08:00
"""
2026-04-20 01:10:18 +08:00
self.retriever = retriever
2026-04-18 16:31:48 +08:00
self.llm = llm
2026-04-20 01:10:18 +08:00
self.num_queries = num_queries
self.rerank_top_n = rerank_top_n
# 初始化组件
self.query_generator = MultiQueryGenerator(llm=llm, num_queries=num_queries)
self.reranker = LLaMaCPPReranker(
2026-04-20 14:05:57 +08:00
base_url=os.getenv("LLAMACPP_RERANKER_URL", "http://127.0.0.1:8083"),
api_key=os.getenv("LLAMACPP_API_KEY", "huang1998"),
2026-04-20 01:10:18 +08:00
top_n=rerank_top_n,
2026-04-19 22:01:55 +08:00
)
2026-04-18 16:31:48 +08:00
2026-04-19 22:01:55 +08:00
async def aretrieve(self, query: str) -> List[Document]:
2026-04-18 16:31:48 +08:00
"""
2026-04-20 01:10:18 +08:00
异步执行完整检索流程
2026-04-18 16:31:48 +08:00
"""
2026-04-20 01:10:18 +08:00
# Step 1: 生成多路查询
queries = await self.query_generator.agenerate(query)
# 包含原始查询,确保至少有一条
if query not in queries:
queries.insert(0, query)
else:
# 如果原始查询已在列表中,将其移至首位
queries.remove(query)
queries.insert(0, query)
# Step 2: 并行检索(每个查询获取文档列表)
tasks = [self.retriever.ainvoke(q) for q in queries]
doc_lists = await asyncio.gather(*tasks)
# Step 3: RRF 融合
fused_docs = reciprocal_rank_fusion(doc_lists)
# Step 4: 重排序
2026-04-20 14:05:57 +08:00
try:
2026-04-20 01:10:18 +08:00
final_docs = self.reranker.compress_documents(fused_docs, query)
2026-04-20 14:05:57 +08:00
except Exception:
2026-04-20 01:10:18 +08:00
# 若重排序器不可用,直接返回融合后的前 N 条
final_docs = fused_docs[:self.rerank_top_n]
return final_docs
def retrieve(self, query: str) -> List[Document]:
"""同步检索入口(内部调用异步方法)"""
return asyncio.run(self.aretrieve(query))
2026-04-18 16:31:48 +08:00
2026-04-19 22:01:55 +08:00
def format_context(self, documents: List[Document]) -> str:
2026-04-20 01:10:18 +08:00
"""将文档列表格式化为上下文字符串"""
2026-04-19 22:01:55 +08:00
if not documents:
return ""
2026-04-18 16:31:48 +08:00
2026-04-20 01:10:18 +08:00
parts = []
2026-04-19 22:01:55 +08:00
for i, doc in enumerate(documents, 1):
2026-04-20 01:10:18 +08:00
source = doc.metadata.get("source", "未知来源")
parts.append(f"【资料 {i}】来源:{source}\n{doc.page_content}\n---\n")
return "\n".join(parts)