2026-04-18 16:31:48 +08:00
|
|
|
|
"""
|
|
|
|
|
|
RAG 检索流水线
|
|
|
|
|
|
|
2026-04-19 22:01:55 +08:00
|
|
|
|
整合基础检索、重排序和 RAG-Fusion 功能。
|
2026-04-18 16:31:48 +08:00
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
from enum import Enum
|
2026-04-19 22:01:55 +08:00
|
|
|
|
from typing import List, Optional, Dict, Any
|
2026-04-18 16:31:48 +08:00
|
|
|
|
from langchain_core.documents import Document
|
|
|
|
|
|
from langchain_core.language_models import BaseLanguageModel
|
|
|
|
|
|
|
|
|
|
|
|
from .retriever import (
|
|
|
|
|
|
create_base_retriever,
|
|
|
|
|
|
create_hybrid_retriever,
|
|
|
|
|
|
create_qdrant_client,
|
|
|
|
|
|
)
|
|
|
|
|
|
from .reranker import CrossEncoderReranker
|
2026-04-19 22:01:55 +08:00
|
|
|
|
from .query_transform import MultiQueryTransformer
|
|
|
|
|
|
from rag_core import QDRANT_URL, QDRANT_API_KEY
|
2026-04-18 16:31:48 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RAGLevel(Enum):
|
2026-04-19 22:01:55 +08:00
|
|
|
|
"""RAG 级别"""
|
|
|
|
|
|
BASIC = "basic" # 基础向量检索
|
|
|
|
|
|
RERANK = "rerank" # 基础检索 + 重排序
|
|
|
|
|
|
FUSION = "fusion" # RAG-Fusion(多路查询 + RRF)
|
2026-04-18 16:31:48 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RAGPipeline:
|
2026-04-19 22:01:55 +08:00
|
|
|
|
"""RAG 检索流水线"""
|
2026-04-18 16:31:48 +08:00
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
|
self,
|
2026-04-19 22:01:55 +08:00
|
|
|
|
embeddings,
|
2026-04-18 16:31:48 +08:00
|
|
|
|
llm: Optional[BaseLanguageModel] = None,
|
2026-04-19 22:01:55 +08:00
|
|
|
|
config: Optional[Dict[str, Any]] = None,
|
2026-04-18 16:31:48 +08:00
|
|
|
|
):
|
|
|
|
|
|
"""
|
|
|
|
|
|
初始化 RAG 流水线
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
embeddings: 嵌入模型
|
2026-04-19 22:01:55 +08:00
|
|
|
|
llm: 语言模型(用于 RAG-Fusion)
|
|
|
|
|
|
config: 配置参数
|
2026-04-18 16:31:48 +08:00
|
|
|
|
"""
|
|
|
|
|
|
self.embeddings = embeddings
|
|
|
|
|
|
self.llm = llm
|
2026-04-19 22:01:55 +08:00
|
|
|
|
self.config = config or {}
|
2026-04-18 16:31:48 +08:00
|
|
|
|
|
2026-04-19 22:01:55 +08:00
|
|
|
|
self.collection_name = self.config.get("collection_name", "rag_documents")
|
|
|
|
|
|
self.rag_level = self.config.get("rag_level", RAGLevel.RERANK.value)
|
|
|
|
|
|
self.num_queries = self.config.get("num_queries", 3)
|
|
|
|
|
|
self.rerank_top_n = self.config.get("rerank_top_n", 5)
|
2026-04-18 16:31:48 +08:00
|
|
|
|
|
2026-04-19 22:01:55 +08:00
|
|
|
|
# 初始化基础检索器
|
|
|
|
|
|
self.base_retriever = create_base_retriever(
|
|
|
|
|
|
collection_name=self.collection_name,
|
2026-04-18 16:31:48 +08:00
|
|
|
|
embeddings=self.embeddings,
|
2026-04-19 22:01:55 +08:00
|
|
|
|
search_kwargs={"k": 20}, # 召回 20 条
|
2026-04-18 16:31:48 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
2026-04-19 22:01:55 +08:00
|
|
|
|
# 初始化重排序器
|
|
|
|
|
|
try:
|
|
|
|
|
|
self.reranker = CrossEncoderReranker(top_n=self.rerank_top_n)
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"警告: 无法创建重排序器,将使用基础检索。错误: {e}")
|
|
|
|
|
|
self.reranker = None
|
|
|
|
|
|
|
|
|
|
|
|
# 根据 RAG 级别创建检索器
|
|
|
|
|
|
self.retriever = self._create_retriever()
|
|
|
|
|
|
|
|
|
|
|
|
def _create_retriever(self):
|
|
|
|
|
|
"""根据 RAG 级别创建检索器"""
|
|
|
|
|
|
if self.rag_level == RAGLevel.BASIC.value:
|
|
|
|
|
|
return self.base_retriever
|
|
|
|
|
|
|
|
|
|
|
|
# 基础检索 + 重排序
|
|
|
|
|
|
def rerank_retriever(query):
|
|
|
|
|
|
documents = self.base_retriever.invoke(query)
|
|
|
|
|
|
if self.reranker:
|
|
|
|
|
|
return self.reranker.compress_documents(documents, query)
|
|
|
|
|
|
else:
|
|
|
|
|
|
return documents[:self.rerank_top_n]
|
2026-04-18 16:31:48 +08:00
|
|
|
|
|
2026-04-19 22:01:55 +08:00
|
|
|
|
if self.rag_level == RAGLevel.RERANK.value:
|
|
|
|
|
|
return SimpleRetriever(rerank_retriever)
|
2026-04-18 16:31:48 +08:00
|
|
|
|
|
2026-04-19 22:01:55 +08:00
|
|
|
|
# RAG-Fusion
|
|
|
|
|
|
if self.rag_level == RAGLevel.FUSION.value:
|
|
|
|
|
|
if not self.llm:
|
|
|
|
|
|
raise ValueError("RAG-Fusion 需要提供 llm 参数")
|
|
|
|
|
|
|
|
|
|
|
|
# 创建多路查询检索器
|
|
|
|
|
|
transformer = MultiQueryTransformer(
|
|
|
|
|
|
llm=self.llm,
|
|
|
|
|
|
num_queries=self.num_queries
|
|
|
|
|
|
)
|
|
|
|
|
|
multi_query_retriever = transformer.create_multi_query_retriever(
|
|
|
|
|
|
base_retriever=SimpleRetriever(rerank_retriever)
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
return multi_query_retriever
|
2026-04-18 16:31:48 +08:00
|
|
|
|
|
2026-04-19 22:01:55 +08:00
|
|
|
|
return SimpleRetriever(rerank_retriever)
|
2026-04-18 16:31:48 +08:00
|
|
|
|
|
2026-04-19 22:01:55 +08:00
|
|
|
|
def retrieve(self, query: str) -> List[Document]:
|
2026-04-18 16:31:48 +08:00
|
|
|
|
"""
|
|
|
|
|
|
执行检索
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
2026-04-19 22:01:55 +08:00
|
|
|
|
query: 查询字符串
|
|
|
|
|
|
|
2026-04-18 16:31:48 +08:00
|
|
|
|
Returns:
|
2026-04-19 22:01:55 +08:00
|
|
|
|
相关文档列表
|
2026-04-18 16:31:48 +08:00
|
|
|
|
"""
|
2026-04-19 22:01:55 +08:00
|
|
|
|
return self.retriever.invoke(query)
|
2026-04-18 16:31:48 +08:00
|
|
|
|
|
2026-04-19 22:01:55 +08:00
|
|
|
|
async def aretrieve(self, query: str) -> List[Document]:
|
2026-04-18 16:31:48 +08:00
|
|
|
|
"""
|
2026-04-19 22:01:55 +08:00
|
|
|
|
异步执行检索
|
2026-04-18 16:31:48 +08:00
|
|
|
|
|
|
|
|
|
|
Args:
|
2026-04-19 22:01:55 +08:00
|
|
|
|
query: 查询字符串
|
|
|
|
|
|
|
2026-04-18 16:31:48 +08:00
|
|
|
|
Returns:
|
2026-04-19 22:01:55 +08:00
|
|
|
|
相关文档列表
|
2026-04-18 16:31:48 +08:00
|
|
|
|
"""
|
2026-04-19 22:01:55 +08:00
|
|
|
|
return await self.retriever.ainvoke(query)
|
2026-04-18 16:31:48 +08:00
|
|
|
|
|
2026-04-19 22:01:55 +08:00
|
|
|
|
def format_context(self, documents: List[Document]) -> str:
|
2026-04-18 16:31:48 +08:00
|
|
|
|
"""
|
2026-04-19 22:01:55 +08:00
|
|
|
|
格式化上下文
|
2026-04-18 16:31:48 +08:00
|
|
|
|
|
|
|
|
|
|
Args:
|
2026-04-19 22:01:55 +08:00
|
|
|
|
documents: 文档列表
|
|
|
|
|
|
|
2026-04-18 16:31:48 +08:00
|
|
|
|
Returns:
|
2026-04-19 22:01:55 +08:00
|
|
|
|
格式化后的上下文字符串
|
2026-04-18 16:31:48 +08:00
|
|
|
|
"""
|
2026-04-19 22:01:55 +08:00
|
|
|
|
if not documents:
|
|
|
|
|
|
return ""
|
2026-04-18 16:31:48 +08:00
|
|
|
|
|
2026-04-19 22:01:55 +08:00
|
|
|
|
context_parts = []
|
|
|
|
|
|
for i, doc in enumerate(documents, 1):
|
|
|
|
|
|
content = doc.page_content
|
|
|
|
|
|
metadata = doc.metadata or {}
|
|
|
|
|
|
source = metadata.get("source", "未知来源")
|
|
|
|
|
|
|
|
|
|
|
|
part = f"【资料 {i}】\n"
|
|
|
|
|
|
part += f"来源: {source}\n"
|
|
|
|
|
|
part += f"内容: {content}\n"
|
|
|
|
|
|
part += "---\n"
|
|
|
|
|
|
context_parts.append(part)
|
2026-04-18 16:31:48 +08:00
|
|
|
|
|
2026-04-19 22:01:55 +08:00
|
|
|
|
return "".join(context_parts)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SimpleRetriever:
|
|
|
|
|
|
"""简单检索器包装类"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, retrieve_func):
|
|
|
|
|
|
self.retrieve_func = retrieve_func
|
|
|
|
|
|
|
|
|
|
|
|
def invoke(self, query):
|
|
|
|
|
|
return self.retrieve_func(query)
|
|
|
|
|
|
|
|
|
|
|
|
async def ainvoke(self, query):
|
|
|
|
|
|
return self.retrieve_func(query)
|