检索器重构
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 17m12s

This commit is contained in:
2026-04-19 22:01:55 +08:00
parent cc8ef41ef9
commit 933d418d77
26 changed files with 1694 additions and 1717 deletions

View File

@@ -1,141 +1,65 @@
"""
Cross-Encoder 重排序器
重排序器模块
使用 sentence-transformers 加载交叉编码器模型对检索结果进行精排
使用 Cross-Encoder 模型对检索结果进行重排序,提高检索精度
"""
import os
from typing import List, Dict, Any, Optional
from langchain.retrievers.document_compressors import CrossEncoderReranker
from typing import List
from langchain_core.documents import Document
from sentence_transformers import CrossEncoder
class CrossEncoderReranker:
"""
Cross-Encoder 重排序器包装类
"""使用 Cross-Encoder 对检索结果重排序。"""
支持 BAAI/bge-reranker-base 等模型。
"""
def __init__(
self,
model_name: str = "BAAI/bge-reranker-base",
top_n: int = 5,
device: Optional[str] = None,
cache_folder: Optional[str] = None,
):
def __init__(self, model_name: str = "BAAI/bge-reranker-base", top_n: int = 5):
"""
初始化重排序器
Args:
model_name: 模型名称或路径
top_n: 返回的顶部文档数量
device: 设备cpu/cuda如果为 None 则自动选择
cache_folder: 模型缓存目录
model_name: 预训练模型名称
top_n: 返回前 N 个结果
"""
self.model_name = model_name
self.top_n = top_n
self.device = device
self.cache_folder = cache_folder or os.path.join(
os.path.expanduser("~"), ".cache", "sentence_transformers"
)
self.model = None
# 延迟加载模型
self._model = None
self._langchain_reranker = None
# 尝试加载 Cross-Encoder 模型
try:
from sentence_transformers import CrossEncoder
self.model = CrossEncoder(model_name)
except Exception as e:
print(f"警告: 无法加载 Cross-Encoder 模型 {model_name},将使用简单排序作为回退方案。错误: {e}")
def _load_model(self):
"""加载交叉编码器模型"""
if self._model is None:
try:
self._model = CrossEncoder(
self.model_name,
device=self.device,
cache_folder=self.cache_folder,
)
except Exception as e:
# 如果指定模型加载失败,尝试备用模型
print(f"加载模型 {self.model_name} 失败: {e}")
print("尝试加载备用模型 BAAI/bge-reranker-v2-m3...")
self._model = CrossEncoder(
"BAAI/bge-reranker-v2-m3",
device=self.device,
cache_folder=self.cache_folder,
)
def _create_langchain_reranker(self):
"""创建 LangChain 重排序器"""
if self._langchain_reranker is None:
self._load_model()
self._langchain_reranker = CrossEncoderReranker(
model=self._model,
top_n=self.top_n,
)
def rerank(
self,
query: str,
documents: List[Document],
def compress_documents(
self, documents: List[Document], query: str
) -> List[Document]:
"""
对文档进行重排序
Args:
query: 查询文本
documents: 待排序文档列表
documents: 待排序的文档列表
query: 查询字符串
Returns:
排序后的文档列表
排序后的文档列表
"""
self._create_langchain_reranker()
return self._langchain_reranker.compress_documents(
documents=documents,
query=query,
)
def create_contextual_compression_retriever(
self,
base_retriever: Any,
) -> Any:
"""
创建上下文压缩检索器
if not documents:
return []
Args:
base_retriever: 基础检索器
# 如果模型加载失败,返回前 top_n 个文档
if self.model is None:
return documents[:self.top_n]
# 使用 Cross-Encoder 进行重排序
try:
pairs = [[query, doc.page_content] for doc in documents]
scores = self.model.predict(pairs)
Returns:
上下文压缩检索器
"""
from langchain.retrievers import ContextualCompressionRetriever
self._create_langchain_reranker()
compression_retriever = ContextualCompressionRetriever(
base_compressor=self._langchain_reranker,
base_retriever=base_retriever,
)
return compression_retriever
@classmethod
def create_from_config(
cls,
config: Optional[Dict[str, Any]] = None,
) -> "CrossEncoderReranker":
"""
从配置创建重排序器
Args:
config: 配置字典,包含 model_name, top_n, device 等
Returns:
CrossEncoderReranker 实例
"""
config = config or {}
return cls(
model_name=config.get("model_name", "BAAI/bge-reranker-base"),
top_n=config.get("top_n", 5),
device=config.get("device", None),
cache_folder=config.get("cache_folder", None),
)
# 按分数降序排序
scored_docs = sorted(
zip(documents, scores), key=lambda x: x[1], reverse=True
)
return [doc for doc, _ in scored_docs[:self.top_n]]
except Exception as e:
print(f"警告: 重排序过程出错,将使用原始排序。错误: {e}")
return documents[:self.top_n]