141 lines
4.1 KiB
Python
141 lines
4.1 KiB
Python
|
|
"""
|
|||
|
|
Cross-Encoder 重排序器
|
|||
|
|
|
|||
|
|
使用 sentence-transformers 加载交叉编码器模型,对检索结果进行精排。
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import os
|
|||
|
|
from typing import List, Dict, Any, Optional
|
|||
|
|
from langchain.retrievers.document_compressors import CrossEncoderReranker
|
|||
|
|
from langchain_core.documents import Document
|
|||
|
|
from sentence_transformers import CrossEncoder
|
|||
|
|
|
|||
|
|
|
|||
|
|
class CrossEncoderReranker:
|
|||
|
|
"""
|
|||
|
|
Cross-Encoder 重排序器包装类
|
|||
|
|
|
|||
|
|
支持 BAAI/bge-reranker-base 等模型。
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
model_name: str = "BAAI/bge-reranker-base",
|
|||
|
|
top_n: int = 5,
|
|||
|
|
device: Optional[str] = None,
|
|||
|
|
cache_folder: Optional[str] = None,
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
初始化重排序器
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
model_name: 模型名称或路径
|
|||
|
|
top_n: 返回的顶部文档数量
|
|||
|
|
device: 设备(cpu/cuda),如果为 None 则自动选择
|
|||
|
|
cache_folder: 模型缓存目录
|
|||
|
|
"""
|
|||
|
|
self.model_name = model_name
|
|||
|
|
self.top_n = top_n
|
|||
|
|
self.device = device
|
|||
|
|
self.cache_folder = cache_folder or os.path.join(
|
|||
|
|
os.path.expanduser("~"), ".cache", "sentence_transformers"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 延迟加载模型
|
|||
|
|
self._model = None
|
|||
|
|
self._langchain_reranker = None
|
|||
|
|
|
|||
|
|
def _load_model(self):
|
|||
|
|
"""加载交叉编码器模型"""
|
|||
|
|
if self._model is None:
|
|||
|
|
try:
|
|||
|
|
self._model = CrossEncoder(
|
|||
|
|
self.model_name,
|
|||
|
|
device=self.device,
|
|||
|
|
cache_folder=self.cache_folder,
|
|||
|
|
)
|
|||
|
|
except Exception as e:
|
|||
|
|
# 如果指定模型加载失败,尝试备用模型
|
|||
|
|
print(f"加载模型 {self.model_name} 失败: {e}")
|
|||
|
|
print("尝试加载备用模型 BAAI/bge-reranker-v2-m3...")
|
|||
|
|
self._model = CrossEncoder(
|
|||
|
|
"BAAI/bge-reranker-v2-m3",
|
|||
|
|
device=self.device,
|
|||
|
|
cache_folder=self.cache_folder,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def _create_langchain_reranker(self):
|
|||
|
|
"""创建 LangChain 重排序器"""
|
|||
|
|
if self._langchain_reranker is None:
|
|||
|
|
self._load_model()
|
|||
|
|
self._langchain_reranker = CrossEncoderReranker(
|
|||
|
|
model=self._model,
|
|||
|
|
top_n=self.top_n,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def rerank(
|
|||
|
|
self,
|
|||
|
|
query: str,
|
|||
|
|
documents: List[Document],
|
|||
|
|
) -> List[Document]:
|
|||
|
|
"""
|
|||
|
|
对文档进行重排序
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
query: 查询文本
|
|||
|
|
documents: 待排序文档列表
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
重排序后的文档列表
|
|||
|
|
"""
|
|||
|
|
self._create_langchain_reranker()
|
|||
|
|
return self._langchain_reranker.compress_documents(
|
|||
|
|
documents=documents,
|
|||
|
|
query=query,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def create_contextual_compression_retriever(
|
|||
|
|
self,
|
|||
|
|
base_retriever: Any,
|
|||
|
|
) -> Any:
|
|||
|
|
"""
|
|||
|
|
创建上下文压缩检索器
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
base_retriever: 基础检索器
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
上下文压缩检索器
|
|||
|
|
"""
|
|||
|
|
from langchain.retrievers import ContextualCompressionRetriever
|
|||
|
|
|
|||
|
|
self._create_langchain_reranker()
|
|||
|
|
|
|||
|
|
compression_retriever = ContextualCompressionRetriever(
|
|||
|
|
base_compressor=self._langchain_reranker,
|
|||
|
|
base_retriever=base_retriever,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return compression_retriever
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
def create_from_config(
|
|||
|
|
cls,
|
|||
|
|
config: Optional[Dict[str, Any]] = None,
|
|||
|
|
) -> "CrossEncoderReranker":
|
|||
|
|
"""
|
|||
|
|
从配置创建重排序器
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
config: 配置字典,包含 model_name, top_n, device 等
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
CrossEncoderReranker 实例
|
|||
|
|
"""
|
|||
|
|
config = config or {}
|
|||
|
|
return cls(
|
|||
|
|
model_name=config.get("model_name", "BAAI/bge-reranker-base"),
|
|||
|
|
top_n=config.get("top_n", 5),
|
|||
|
|
device=config.get("device", None),
|
|||
|
|
cache_folder=config.get("cache_folder", None),
|
|||
|
|
)
|