检索器重构
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 17m12s

This commit is contained in:
2026-04-19 22:01:55 +08:00
parent cc8ef41ef9
commit 933d418d77
26 changed files with 1694 additions and 1717 deletions

View File

@@ -1,341 +1,168 @@
"""
RAG 检索流水线
组合检索、重排序器、查询改写器等组件,提供完整的 RAG 检索功能。
整合基础检索、重排序和 RAG-Fusion 功能。
"""
import time
from typing import List, Dict, Any, Optional, Union
from dataclasses import dataclass, field
from enum import Enum
from typing import List, Optional, Dict, Any
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.language_models import BaseLanguageModel
from .retriever import (
create_base_retriever,
create_hybrid_retriever,
create_ensemble_retriever,
create_qdrant_client,
)
from .reranker import CrossEncoderReranker
from .query_transform import MultiQueryTransformer, create_rag_fusion_pipeline
from .query_transform import MultiQueryTransformer
from rag_core import QDRANT_URL, QDRANT_API_KEY
class RAGLevel(Enum):
"""RAG 功能级别"""
BASIC = 1 # 基础向量
HYBRID = 2 # 混合检索 + 重排序
FUSION = 3 # RAG-Fusion
AGENTIC = 4 # Agentic RAG
@dataclass
class RAGConfig:
"""RAG 配置"""
# Qdrant 配置
collection_name: str = "documents"
qdrant_url: Optional[str] = None
qdrant_api_key: Optional[str] = None
# 检索配置
rag_level: RAGLevel = RAGLevel.FUSION
dense_k: int = 10 # 向量检索数量
sparse_k: int = 10 # BM25 检索数量
total_k: int = 20 # 总检索数量
rerank_top_n: int = 5 # 重排序返回数量
# 查询改写配置
num_queries: int = 3 # RAG-Fusion 查询数量
# 模型配置
reranker_model: str = "BAAI/bge-reranker-base"
device: Optional[str] = None
# 性能配置
enable_cache: bool = True
verbose: bool = True
@dataclass
class RetrievalResult:
"""检索结果"""
documents: List[Document]
query_time: float
level: RAGLevel
metadata: Dict[str, Any] = field(default_factory=dict)
"""RAG 级别"""
BASIC = "basic" # 基础向量
RERANK = "rerank" # 基础检索 + 重排序
FUSION = "fusion" # RAG-Fusion(多路查询 + RRF
class RAGPipeline:
"""
RAG 检索流水线
支持从 Level 1 到 Level 4 的所有功能。
"""
"""RAG 检索流水线"""
def __init__(
self,
embeddings: Embeddings,
embeddings,
llm: Optional[BaseLanguageModel] = None,
config: Optional[RAGConfig] = None,
config: Optional[Dict[str, Any]] = None,
):
"""
初始化 RAG 流水线
Args:
embeddings: 嵌入模型
llm: 语言模型(用于查询改写Level 3+ 需要
config: 配置
llm: 语言模型(用于 RAG-Fusion
config: 配置参数
"""
self.embeddings = embeddings
self.llm = llm
self.config = config or RAGConfig()
self.config = config or {}
# 初始化组件
self._client = None
self._reranker = None
self._query_transformer = None
self._retriever = None
self.collection_name = self.config.get("collection_name", "rag_documents")
self.rag_level = self.config.get("rag_level", RAGLevel.RERANK.value)
self.num_queries = self.config.get("num_queries", 3)
self.rerank_top_n = self.config.get("rerank_top_n", 5)
# 缓存
self._cache = {}
def _get_client(self):
"""获取 Qdrant 客户端"""
if self._client is None:
self._client = create_qdrant_client(
url=self.config.qdrant_url,
api_key=self.config.qdrant_api_key,
)
return self._client
def _get_reranker(self):
"""获取重排序器"""
if self._reranker is None:
self._reranker = CrossEncoderReranker(
model_name=self.config.reranker_model,
top_n=self.config.rerank_top_n,
device=self.config.device,
)
return self._reranker
def _get_query_transformer(self):
"""获取查询改写器"""
if self._query_transformer is None and self.llm is not None:
self._query_transformer = MultiQueryTransformer(
llm=self.llm,
num_queries=self.config.num_queries,
)
return self._query_transformer
def _create_basic_retriever(self):
"""创建基础检索器Level 1"""
return create_base_retriever(
collection_name=self.config.collection_name,
# 初始化基础检索器
self.base_retriever = create_base_retriever(
collection_name=self.collection_name,
embeddings=self.embeddings,
search_kwargs={"k": self.config.total_k},
client=self._get_client(),
)
def _create_hybrid_retriever(self):
"""创建混合检索器Level 2"""
base_retriever = create_hybrid_retriever(
collection_name=self.config.collection_name,
embeddings=self.embeddings,
dense_k=self.config.dense_k,
sparse_k=self.config.sparse_k,
client=self._get_client(),
search_kwargs={"k": 20}, # 召回 20 条
)
# 应用重排序
reranker = self._get_reranker()
return reranker.create_contextual_compression_retriever(base_retriever)
def _create_fusion_retriever(self):
"""创建 RAG-Fusion 检索器Level 3"""
if self.llm is None:
raise ValueError("Level 3 (RAG-Fusion) 需要语言模型进行查询改写")
# 初始化重排序
try:
self.reranker = CrossEncoderReranker(top_n=self.rerank_top_n)
except Exception as e:
print(f"警告: 无法创建重排序器,将使用基础检索。错误: {e}")
self.reranker = None
# 创建基础混合检索器
base_retriever = create_hybrid_retriever(
collection_name=self.config.collection_name,
embeddings=self.embeddings,
dense_k=self.config.dense_k,
sparse_k=self.config.sparse_k,
client=self._get_client(),
)
# 创建 RAG-Fusion 流水线
reranker = self._get_reranker()
return create_rag_fusion_pipeline(
base_retriever=base_retriever,
llm=self.llm,
reranker=reranker,
num_queries=self.config.num_queries,
)
# 根据 RAG 级别创建检索器
self.retriever = self._create_retriever()
def _get_retriever(self):
"""根据配置级别获取检索器"""
if self._retriever is None:
if self.config.rag_level == RAGLevel.BASIC:
self._retriever = self._create_basic_retriever()
elif self.config.rag_level == RAGLevel.HYBRID:
self._retriever = self._create_hybrid_retriever()
elif self.config.rag_level == RAGLevel.FUSION:
self._retriever = self._create_fusion_retriever()
elif self.config.rag_level == RAGLevel.AGENTIC:
# Agentic RAG 使用 Fusion 作为基础,在 tools.py 中包装
self._retriever = self._create_fusion_retriever()
def _create_retriever(self):
"""根据 RAG 级别创建检索器"""
if self.rag_level == RAGLevel.BASIC.value:
return self.base_retriever
# 基础检索 + 重排序
def rerank_retriever(query):
documents = self.base_retriever.invoke(query)
if self.reranker:
return self.reranker.compress_documents(documents, query)
else:
raise ValueError(f"不支持的 RAG 级别: {self.config.rag_level}")
return documents[:self.rerank_top_n]
return self._retriever
if self.rag_level == RAGLevel.RERANK.value:
return SimpleRetriever(rerank_retriever)
# RAG-Fusion
if self.rag_level == RAGLevel.FUSION.value:
if not self.llm:
raise ValueError("RAG-Fusion 需要提供 llm 参数")
# 创建多路查询检索器
transformer = MultiQueryTransformer(
llm=self.llm,
num_queries=self.num_queries
)
multi_query_retriever = transformer.create_multi_query_retriever(
base_retriever=SimpleRetriever(rerank_retriever)
)
return multi_query_retriever
return SimpleRetriever(rerank_retriever)
def retrieve(
self,
query: str,
use_cache: Optional[bool] = None,
**kwargs,
) -> RetrievalResult:
def retrieve(self, query: str) -> List[Document]:
"""
执行检索
Args:
query: 查询文本
use_cache: 是否使用缓存
**kwargs: 额外参数
query: 查询字符串
Returns:
检索结果
相关文档列表
"""
start_time = time.time()
# 检查缓存
if use_cache is None:
use_cache = self.config.enable_cache
cache_key = f"{query}:{self.config.rag_level.value}"
if use_cache and cache_key in self._cache:
if self.config.verbose:
print(f"使用缓存结果: {query}")
return self._cache[cache_key]
# 获取检索器并执行检索
retriever = self._get_retriever()
documents = retriever.invoke(query, **kwargs)
# 计算查询时间
query_time = time.time() - start_time
# 创建结果
result = RetrievalResult(
documents=documents,
query_time=query_time,
level=self.config.rag_level,
metadata={
"query": query,
"collection": self.config.collection_name,
"doc_count": len(documents),
},
)
# 缓存结果
if use_cache:
self._cache[cache_key] = result
if self.config.verbose:
print(f"检索完成: {len(documents)} 文档, 耗时: {query_time:.2f}s")
return result
return self.retriever.invoke(query)
def format_context(
self,
documents: List[Document],
max_length: Optional[int] = None,
) -> str:
async def aretrieve(self, query: str) -> List[Document]:
"""
格式化检索到的文档为上下文文本
异步执行检索
Args:
query: 查询字符串
Returns:
相关文档列表
"""
return await self.retriever.ainvoke(query)
def format_context(self, documents: List[Document]) -> str:
"""
格式化上下文
Args:
documents: 文档列表
max_length: 最大长度(字符数)
Returns:
格式化后的上下文文本
格式化后的上下文字符串
"""
if not documents:
return ""
context_parts = []
total_length = 0
for i, doc in enumerate(documents, 1):
content = doc.page_content
metadata = doc.metadata or {}
source = metadata.get("source", "未知来源")
part = f"【资料 {i}\n"
part += f"来源: {source}\n"
part += f"内容: {content}\n"
part += "---\n"
context_parts.append(part)
for i, doc in enumerate(documents):
# 提取内容和元数据
content = doc.page_content.strip()
metadata = doc.metadata
# 格式化文档
doc_text = f"[文档 {i+1}]\n"
if metadata.get("source"):
doc_text += f"来源: {metadata['source']}\n"
if metadata.get("page"):
doc_text += f"页码: {metadata['page']}\n"
doc_text += f"内容: {content}\n\n"
# 检查长度限制
if max_length is not None:
if total_length + len(doc_text) > max_length:
# 如果添加这个文档会超限,则截断并添加说明
remaining = max_length - total_length
if remaining > 100: # 至少保留100字符
doc_text = doc_text[:remaining] + "...\n\n[内容已截断]"
context_parts.append(doc_text)
break
else:
break
context_parts.append(doc_text)
total_length += len(doc_text)
return "".join(context_parts).strip()
return "".join(context_parts)
class SimpleRetriever:
"""简单检索器包装类"""
def clear_cache(self):
"""清空缓存"""
self._cache.clear()
def __init__(self, retrieve_func):
self.retrieve_func = retrieve_func
@classmethod
def create_from_config(
cls,
embeddings: Embeddings,
llm: Optional[BaseLanguageModel] = None,
config_dict: Optional[Dict[str, Any]] = None,
) -> "RAGPipeline":
"""
从配置字典创建流水线
Args:
embeddings: 嵌入模型
llm: 语言模型
config_dict: 配置字典
Returns:
RAGPipeline 实例
"""
config_dict = config_dict or {}
# 创建配置对象
config = RAGConfig(
collection_name=config_dict.get("collection_name", "documents"),
qdrant_url=config_dict.get("qdrant_url"),
qdrant_api_key=config_dict.get("qdrant_api_key"),
rag_level=RAGLevel(config_dict.get("rag_level", RAGLevel.FUSION.value)),
dense_k=config_dict.get("dense_k", 10),
sparse_k=config_dict.get("sparse_k", 10),
total_k=config_dict.get("total_k", 20),
rerank_top_n=config_dict.get("rerank_top_n", 5),
num_queries=config_dict.get("num_queries", 3),
reranker_model=config_dict.get("reranker_model", "BAAI/bge-reranker-base"),
device=config_dict.get("device"),
enable_cache=config_dict.get("enable_cache", True),
verbose=config_dict.get("verbose", True),
)
return cls(embeddings=embeddings, llm=llm, config=config)
def invoke(self, query):
return self.retrieve_func(query)
async def ainvoke(self, query):
return self.retrieve_func(query)