Files
ailine/app/rag/reranker.py
2026-04-18 16:31:48 +08:00

141 lines
4.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Cross-Encoder 重排序器
使用 sentence-transformers 加载交叉编码器模型,对检索结果进行精排。
"""
import os
from typing import List, Dict, Any, Optional
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain_core.documents import Document
from sentence_transformers import CrossEncoder
class CrossEncoderReranker:
"""
Cross-Encoder 重排序器包装类
支持 BAAI/bge-reranker-base 等模型。
"""
def __init__(
self,
model_name: str = "BAAI/bge-reranker-base",
top_n: int = 5,
device: Optional[str] = None,
cache_folder: Optional[str] = None,
):
"""
初始化重排序器
Args:
model_name: 模型名称或路径
top_n: 返回的顶部文档数量
device: 设备cpu/cuda如果为 None 则自动选择
cache_folder: 模型缓存目录
"""
self.model_name = model_name
self.top_n = top_n
self.device = device
self.cache_folder = cache_folder or os.path.join(
os.path.expanduser("~"), ".cache", "sentence_transformers"
)
# 延迟加载模型
self._model = None
self._langchain_reranker = None
def _load_model(self):
"""加载交叉编码器模型"""
if self._model is None:
try:
self._model = CrossEncoder(
self.model_name,
device=self.device,
cache_folder=self.cache_folder,
)
except Exception as e:
# 如果指定模型加载失败,尝试备用模型
print(f"加载模型 {self.model_name} 失败: {e}")
print("尝试加载备用模型 BAAI/bge-reranker-v2-m3...")
self._model = CrossEncoder(
"BAAI/bge-reranker-v2-m3",
device=self.device,
cache_folder=self.cache_folder,
)
def _create_langchain_reranker(self):
"""创建 LangChain 重排序器"""
if self._langchain_reranker is None:
self._load_model()
self._langchain_reranker = CrossEncoderReranker(
model=self._model,
top_n=self.top_n,
)
def rerank(
self,
query: str,
documents: List[Document],
) -> List[Document]:
"""
对文档进行重排序
Args:
query: 查询文本
documents: 待排序文档列表
Returns:
重排序后的文档列表
"""
self._create_langchain_reranker()
return self._langchain_reranker.compress_documents(
documents=documents,
query=query,
)
def create_contextual_compression_retriever(
self,
base_retriever: Any,
) -> Any:
"""
创建上下文压缩检索器
Args:
base_retriever: 基础检索器
Returns:
上下文压缩检索器
"""
from langchain.retrievers import ContextualCompressionRetriever
self._create_langchain_reranker()
compression_retriever = ContextualCompressionRetriever(
base_compressor=self._langchain_reranker,
base_retriever=base_retriever,
)
return compression_retriever
@classmethod
def create_from_config(
cls,
config: Optional[Dict[str, Any]] = None,
) -> "CrossEncoderReranker":
"""
从配置创建重排序器
Args:
config: 配置字典,包含 model_name, top_n, device 等
Returns:
CrossEncoderReranker 实例
"""
config = config or {}
return cls(
model_name=config.get("model_name", "BAAI/bge-reranker-base"),
top_n=config.get("top_n", 5),
device=config.get("device", None),
cache_folder=config.get("cache_folder", None),
)