Files
ailine/app/rag/pipeline.py

169 lines
4.9 KiB
Python
Raw Normal View History

2026-04-18 16:31:48 +08:00
"""
RAG 检索流水线
2026-04-19 22:01:55 +08:00
整合基础检索重排序和 RAG-Fusion 功能
2026-04-18 16:31:48 +08:00
"""
from enum import Enum
2026-04-19 22:01:55 +08:00
from typing import List, Optional, Dict, Any
2026-04-18 16:31:48 +08:00
from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel
from .retriever import (
create_base_retriever,
create_hybrid_retriever,
create_qdrant_client,
)
from .reranker import CrossEncoderReranker
2026-04-19 22:01:55 +08:00
from .query_transform import MultiQueryTransformer
from rag_core import QDRANT_URL, QDRANT_API_KEY
2026-04-18 16:31:48 +08:00
class RAGLevel(Enum):
2026-04-19 22:01:55 +08:00
"""RAG 级别"""
BASIC = "basic" # 基础向量检索
RERANK = "rerank" # 基础检索 + 重排序
FUSION = "fusion" # RAG-Fusion多路查询 + RRF
2026-04-18 16:31:48 +08:00
class RAGPipeline:
2026-04-19 22:01:55 +08:00
"""RAG 检索流水线"""
2026-04-18 16:31:48 +08:00
def __init__(
self,
2026-04-19 22:01:55 +08:00
embeddings,
2026-04-18 16:31:48 +08:00
llm: Optional[BaseLanguageModel] = None,
2026-04-19 22:01:55 +08:00
config: Optional[Dict[str, Any]] = None,
2026-04-18 16:31:48 +08:00
):
"""
初始化 RAG 流水线
Args:
embeddings: 嵌入模型
2026-04-19 22:01:55 +08:00
llm: 语言模型用于 RAG-Fusion
config: 配置参数
2026-04-18 16:31:48 +08:00
"""
self.embeddings = embeddings
self.llm = llm
2026-04-19 22:01:55 +08:00
self.config = config or {}
2026-04-18 16:31:48 +08:00
2026-04-19 22:01:55 +08:00
self.collection_name = self.config.get("collection_name", "rag_documents")
self.rag_level = self.config.get("rag_level", RAGLevel.RERANK.value)
self.num_queries = self.config.get("num_queries", 3)
self.rerank_top_n = self.config.get("rerank_top_n", 5)
2026-04-18 16:31:48 +08:00
2026-04-19 22:01:55 +08:00
# 初始化基础检索器
self.base_retriever = create_base_retriever(
collection_name=self.collection_name,
2026-04-18 16:31:48 +08:00
embeddings=self.embeddings,
2026-04-19 22:01:55 +08:00
search_kwargs={"k": 20}, # 召回 20 条
2026-04-18 16:31:48 +08:00
)
2026-04-19 22:01:55 +08:00
# 初始化重排序器
try:
self.reranker = CrossEncoderReranker(top_n=self.rerank_top_n)
except Exception as e:
print(f"警告: 无法创建重排序器,将使用基础检索。错误: {e}")
self.reranker = None
# 根据 RAG 级别创建检索器
self.retriever = self._create_retriever()
def _create_retriever(self):
"""根据 RAG 级别创建检索器"""
if self.rag_level == RAGLevel.BASIC.value:
return self.base_retriever
# 基础检索 + 重排序
def rerank_retriever(query):
documents = self.base_retriever.invoke(query)
if self.reranker:
return self.reranker.compress_documents(documents, query)
else:
return documents[:self.rerank_top_n]
2026-04-18 16:31:48 +08:00
2026-04-19 22:01:55 +08:00
if self.rag_level == RAGLevel.RERANK.value:
return SimpleRetriever(rerank_retriever)
2026-04-18 16:31:48 +08:00
2026-04-19 22:01:55 +08:00
# RAG-Fusion
if self.rag_level == RAGLevel.FUSION.value:
if not self.llm:
raise ValueError("RAG-Fusion 需要提供 llm 参数")
# 创建多路查询检索器
transformer = MultiQueryTransformer(
llm=self.llm,
num_queries=self.num_queries
)
multi_query_retriever = transformer.create_multi_query_retriever(
base_retriever=SimpleRetriever(rerank_retriever)
)
return multi_query_retriever
2026-04-18 16:31:48 +08:00
2026-04-19 22:01:55 +08:00
return SimpleRetriever(rerank_retriever)
2026-04-18 16:31:48 +08:00
2026-04-19 22:01:55 +08:00
def retrieve(self, query: str) -> List[Document]:
2026-04-18 16:31:48 +08:00
"""
执行检索
Args:
2026-04-19 22:01:55 +08:00
query: 查询字符串
2026-04-18 16:31:48 +08:00
Returns:
2026-04-19 22:01:55 +08:00
相关文档列表
2026-04-18 16:31:48 +08:00
"""
2026-04-19 22:01:55 +08:00
return self.retriever.invoke(query)
2026-04-18 16:31:48 +08:00
2026-04-19 22:01:55 +08:00
async def aretrieve(self, query: str) -> List[Document]:
2026-04-18 16:31:48 +08:00
"""
2026-04-19 22:01:55 +08:00
异步执行检索
2026-04-18 16:31:48 +08:00
Args:
2026-04-19 22:01:55 +08:00
query: 查询字符串
2026-04-18 16:31:48 +08:00
Returns:
2026-04-19 22:01:55 +08:00
相关文档列表
2026-04-18 16:31:48 +08:00
"""
2026-04-19 22:01:55 +08:00
return await self.retriever.ainvoke(query)
2026-04-18 16:31:48 +08:00
2026-04-19 22:01:55 +08:00
def format_context(self, documents: List[Document]) -> str:
2026-04-18 16:31:48 +08:00
"""
2026-04-19 22:01:55 +08:00
格式化上下文
2026-04-18 16:31:48 +08:00
Args:
2026-04-19 22:01:55 +08:00
documents: 文档列表
2026-04-18 16:31:48 +08:00
Returns:
2026-04-19 22:01:55 +08:00
格式化后的上下文字符串
2026-04-18 16:31:48 +08:00
"""
2026-04-19 22:01:55 +08:00
if not documents:
return ""
2026-04-18 16:31:48 +08:00
2026-04-19 22:01:55 +08:00
context_parts = []
for i, doc in enumerate(documents, 1):
content = doc.page_content
metadata = doc.metadata or {}
source = metadata.get("source", "未知来源")
part = f"【资料 {i}\n"
part += f"来源: {source}\n"
part += f"内容: {content}\n"
part += "---\n"
context_parts.append(part)
2026-04-18 16:31:48 +08:00
2026-04-19 22:01:55 +08:00
return "".join(context_parts)
class SimpleRetriever:
"""简单检索器包装类"""
def __init__(self, retrieve_func):
self.retrieve_func = retrieve_func
def invoke(self, query):
return self.retrieve_func(query)
async def ainvoke(self, query):
return self.retrieve_func(query)