141 lines
4.1 KiB
Python
141 lines
4.1 KiB
Python
"""
|
||
Cross-Encoder 重排序器
|
||
|
||
使用 sentence-transformers 加载交叉编码器模型,对检索结果进行精排。
|
||
"""
|
||
|
||
import os
|
||
from typing import List, Dict, Any, Optional
|
||
from langchain.retrievers.document_compressors import CrossEncoderReranker
|
||
from langchain_core.documents import Document
|
||
from sentence_transformers import CrossEncoder
|
||
|
||
|
||
class CrossEncoderReranker:
|
||
"""
|
||
Cross-Encoder 重排序器包装类
|
||
|
||
支持 BAAI/bge-reranker-base 等模型。
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
model_name: str = "BAAI/bge-reranker-base",
|
||
top_n: int = 5,
|
||
device: Optional[str] = None,
|
||
cache_folder: Optional[str] = None,
|
||
):
|
||
"""
|
||
初始化重排序器
|
||
|
||
Args:
|
||
model_name: 模型名称或路径
|
||
top_n: 返回的顶部文档数量
|
||
device: 设备(cpu/cuda),如果为 None 则自动选择
|
||
cache_folder: 模型缓存目录
|
||
"""
|
||
self.model_name = model_name
|
||
self.top_n = top_n
|
||
self.device = device
|
||
self.cache_folder = cache_folder or os.path.join(
|
||
os.path.expanduser("~"), ".cache", "sentence_transformers"
|
||
)
|
||
|
||
# 延迟加载模型
|
||
self._model = None
|
||
self._langchain_reranker = None
|
||
|
||
def _load_model(self):
|
||
"""加载交叉编码器模型"""
|
||
if self._model is None:
|
||
try:
|
||
self._model = CrossEncoder(
|
||
self.model_name,
|
||
device=self.device,
|
||
cache_folder=self.cache_folder,
|
||
)
|
||
except Exception as e:
|
||
# 如果指定模型加载失败,尝试备用模型
|
||
print(f"加载模型 {self.model_name} 失败: {e}")
|
||
print("尝试加载备用模型 BAAI/bge-reranker-v2-m3...")
|
||
self._model = CrossEncoder(
|
||
"BAAI/bge-reranker-v2-m3",
|
||
device=self.device,
|
||
cache_folder=self.cache_folder,
|
||
)
|
||
|
||
def _create_langchain_reranker(self):
|
||
"""创建 LangChain 重排序器"""
|
||
if self._langchain_reranker is None:
|
||
self._load_model()
|
||
self._langchain_reranker = CrossEncoderReranker(
|
||
model=self._model,
|
||
top_n=self.top_n,
|
||
)
|
||
|
||
def rerank(
|
||
self,
|
||
query: str,
|
||
documents: List[Document],
|
||
) -> List[Document]:
|
||
"""
|
||
对文档进行重排序
|
||
|
||
Args:
|
||
query: 查询文本
|
||
documents: 待排序文档列表
|
||
|
||
Returns:
|
||
重排序后的文档列表
|
||
"""
|
||
self._create_langchain_reranker()
|
||
return self._langchain_reranker.compress_documents(
|
||
documents=documents,
|
||
query=query,
|
||
)
|
||
|
||
def create_contextual_compression_retriever(
|
||
self,
|
||
base_retriever: Any,
|
||
) -> Any:
|
||
"""
|
||
创建上下文压缩检索器
|
||
|
||
Args:
|
||
base_retriever: 基础检索器
|
||
|
||
Returns:
|
||
上下文压缩检索器
|
||
"""
|
||
from langchain.retrievers import ContextualCompressionRetriever
|
||
|
||
self._create_langchain_reranker()
|
||
|
||
compression_retriever = ContextualCompressionRetriever(
|
||
base_compressor=self._langchain_reranker,
|
||
base_retriever=base_retriever,
|
||
)
|
||
|
||
return compression_retriever
|
||
|
||
@classmethod
|
||
def create_from_config(
|
||
cls,
|
||
config: Optional[Dict[str, Any]] = None,
|
||
) -> "CrossEncoderReranker":
|
||
"""
|
||
从配置创建重排序器
|
||
|
||
Args:
|
||
config: 配置字典,包含 model_name, top_n, device 等
|
||
|
||
Returns:
|
||
CrossEncoderReranker 实例
|
||
"""
|
||
config = config or {}
|
||
return cls(
|
||
model_name=config.get("model_name", "BAAI/bge-reranker-base"),
|
||
top_n=config.get("top_n", 5),
|
||
device=config.get("device", None),
|
||
cache_folder=config.get("cache_folder", None),
|
||
) |