Files
ailine/app/rag/reranker.py

141 lines
4.1 KiB
Python
Raw Normal View History

2026-04-18 16:31:48 +08:00
"""
Cross-Encoder 重排序器
使用 sentence-transformers 加载交叉编码器模型对检索结果进行精排
"""
import os
from typing import List, Dict, Any, Optional
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain_core.documents import Document
from sentence_transformers import CrossEncoder
class CrossEncoderReranker:
"""
Cross-Encoder 重排序器包装类
支持 BAAI/bge-reranker-base 等模型
"""
def __init__(
self,
model_name: str = "BAAI/bge-reranker-base",
top_n: int = 5,
device: Optional[str] = None,
cache_folder: Optional[str] = None,
):
"""
初始化重排序器
Args:
model_name: 模型名称或路径
top_n: 返回的顶部文档数量
device: 设备cpu/cuda如果为 None 则自动选择
cache_folder: 模型缓存目录
"""
self.model_name = model_name
self.top_n = top_n
self.device = device
self.cache_folder = cache_folder or os.path.join(
os.path.expanduser("~"), ".cache", "sentence_transformers"
)
# 延迟加载模型
self._model = None
self._langchain_reranker = None
def _load_model(self):
"""加载交叉编码器模型"""
if self._model is None:
try:
self._model = CrossEncoder(
self.model_name,
device=self.device,
cache_folder=self.cache_folder,
)
except Exception as e:
# 如果指定模型加载失败,尝试备用模型
print(f"加载模型 {self.model_name} 失败: {e}")
print("尝试加载备用模型 BAAI/bge-reranker-v2-m3...")
self._model = CrossEncoder(
"BAAI/bge-reranker-v2-m3",
device=self.device,
cache_folder=self.cache_folder,
)
def _create_langchain_reranker(self):
"""创建 LangChain 重排序器"""
if self._langchain_reranker is None:
self._load_model()
self._langchain_reranker = CrossEncoderReranker(
model=self._model,
top_n=self.top_n,
)
def rerank(
self,
query: str,
documents: List[Document],
) -> List[Document]:
"""
对文档进行重排序
Args:
query: 查询文本
documents: 待排序文档列表
Returns:
重排序后的文档列表
"""
self._create_langchain_reranker()
return self._langchain_reranker.compress_documents(
documents=documents,
query=query,
)
def create_contextual_compression_retriever(
self,
base_retriever: Any,
) -> Any:
"""
创建上下文压缩检索器
Args:
base_retriever: 基础检索器
Returns:
上下文压缩检索器
"""
from langchain.retrievers import ContextualCompressionRetriever
self._create_langchain_reranker()
compression_retriever = ContextualCompressionRetriever(
base_compressor=self._langchain_reranker,
base_retriever=base_retriever,
)
return compression_retriever
@classmethod
def create_from_config(
cls,
config: Optional[Dict[str, Any]] = None,
) -> "CrossEncoderReranker":
"""
从配置创建重排序器
Args:
config: 配置字典包含 model_name, top_n, device
Returns:
CrossEncoderReranker 实例
"""
config = config or {}
return cls(
model_name=config.get("model_name", "BAAI/bge-reranker-base"),
top_n=config.get("top_n", 5),
device=config.get("device", None),
cache_folder=config.get("cache_folder", None),
)