169 lines
4.9 KiB
Python
169 lines
4.9 KiB
Python
"""
|
||
RAG 检索流水线
|
||
|
||
整合基础检索、重排序和 RAG-Fusion 功能。
|
||
"""
|
||
|
||
from enum import Enum
|
||
from typing import List, Optional, Dict, Any
|
||
from langchain_core.documents import Document
|
||
from langchain_core.language_models import BaseLanguageModel
|
||
|
||
from .retriever import (
|
||
create_base_retriever,
|
||
create_hybrid_retriever,
|
||
create_qdrant_client,
|
||
)
|
||
from .reranker import CrossEncoderReranker
|
||
from .query_transform import MultiQueryTransformer
|
||
from rag_core import QDRANT_URL, QDRANT_API_KEY
|
||
|
||
|
||
class RAGLevel(Enum):
|
||
"""RAG 级别"""
|
||
BASIC = "basic" # 基础向量检索
|
||
RERANK = "rerank" # 基础检索 + 重排序
|
||
FUSION = "fusion" # RAG-Fusion(多路查询 + RRF)
|
||
|
||
|
||
class RAGPipeline:
|
||
"""RAG 检索流水线"""
|
||
|
||
def __init__(
|
||
self,
|
||
embeddings,
|
||
llm: Optional[BaseLanguageModel] = None,
|
||
config: Optional[Dict[str, Any]] = None,
|
||
):
|
||
"""
|
||
初始化 RAG 流水线
|
||
|
||
Args:
|
||
embeddings: 嵌入模型
|
||
llm: 语言模型(用于 RAG-Fusion)
|
||
config: 配置参数
|
||
"""
|
||
self.embeddings = embeddings
|
||
self.llm = llm
|
||
self.config = config or {}
|
||
|
||
self.collection_name = self.config.get("collection_name", "rag_documents")
|
||
self.rag_level = self.config.get("rag_level", RAGLevel.RERANK.value)
|
||
self.num_queries = self.config.get("num_queries", 3)
|
||
self.rerank_top_n = self.config.get("rerank_top_n", 5)
|
||
|
||
# 初始化基础检索器
|
||
self.base_retriever = create_base_retriever(
|
||
collection_name=self.collection_name,
|
||
embeddings=self.embeddings,
|
||
search_kwargs={"k": 20}, # 召回 20 条
|
||
)
|
||
|
||
# 初始化重排序器
|
||
try:
|
||
self.reranker = CrossEncoderReranker(top_n=self.rerank_top_n)
|
||
except Exception as e:
|
||
print(f"警告: 无法创建重排序器,将使用基础检索。错误: {e}")
|
||
self.reranker = None
|
||
|
||
# 根据 RAG 级别创建检索器
|
||
self.retriever = self._create_retriever()
|
||
|
||
def _create_retriever(self):
|
||
"""根据 RAG 级别创建检索器"""
|
||
if self.rag_level == RAGLevel.BASIC.value:
|
||
return self.base_retriever
|
||
|
||
# 基础检索 + 重排序
|
||
def rerank_retriever(query):
|
||
documents = self.base_retriever.invoke(query)
|
||
if self.reranker:
|
||
return self.reranker.compress_documents(documents, query)
|
||
else:
|
||
return documents[:self.rerank_top_n]
|
||
|
||
if self.rag_level == RAGLevel.RERANK.value:
|
||
return SimpleRetriever(rerank_retriever)
|
||
|
||
# RAG-Fusion
|
||
if self.rag_level == RAGLevel.FUSION.value:
|
||
if not self.llm:
|
||
raise ValueError("RAG-Fusion 需要提供 llm 参数")
|
||
|
||
# 创建多路查询检索器
|
||
transformer = MultiQueryTransformer(
|
||
llm=self.llm,
|
||
num_queries=self.num_queries
|
||
)
|
||
multi_query_retriever = transformer.create_multi_query_retriever(
|
||
base_retriever=SimpleRetriever(rerank_retriever)
|
||
)
|
||
|
||
return multi_query_retriever
|
||
|
||
return SimpleRetriever(rerank_retriever)
|
||
|
||
def retrieve(self, query: str) -> List[Document]:
|
||
"""
|
||
执行检索
|
||
|
||
Args:
|
||
query: 查询字符串
|
||
|
||
Returns:
|
||
相关文档列表
|
||
"""
|
||
return self.retriever.invoke(query)
|
||
|
||
async def aretrieve(self, query: str) -> List[Document]:
|
||
"""
|
||
异步执行检索
|
||
|
||
Args:
|
||
query: 查询字符串
|
||
|
||
Returns:
|
||
相关文档列表
|
||
"""
|
||
return await self.retriever.ainvoke(query)
|
||
|
||
def format_context(self, documents: List[Document]) -> str:
|
||
"""
|
||
格式化上下文
|
||
|
||
Args:
|
||
documents: 文档列表
|
||
|
||
Returns:
|
||
格式化后的上下文字符串
|
||
"""
|
||
if not documents:
|
||
return ""
|
||
|
||
context_parts = []
|
||
for i, doc in enumerate(documents, 1):
|
||
content = doc.page_content
|
||
metadata = doc.metadata or {}
|
||
source = metadata.get("source", "未知来源")
|
||
|
||
part = f"【资料 {i}】\n"
|
||
part += f"来源: {source}\n"
|
||
part += f"内容: {content}\n"
|
||
part += "---\n"
|
||
context_parts.append(part)
|
||
|
||
return "".join(context_parts)
|
||
|
||
|
||
class SimpleRetriever:
|
||
"""简单检索器包装类"""
|
||
|
||
def __init__(self, retrieve_func):
|
||
self.retrieve_func = retrieve_func
|
||
|
||
def invoke(self, query):
|
||
return self.retrieve_func(query)
|
||
|
||
async def ainvoke(self, query):
|
||
return self.retrieve_func(query)
|