""" Cross-Encoder 重排序器 使用 sentence-transformers 加载交叉编码器模型,对检索结果进行精排。 """ import os from typing import List, Dict, Any, Optional from langchain.retrievers.document_compressors import CrossEncoderReranker from langchain_core.documents import Document from sentence_transformers import CrossEncoder class CrossEncoderReranker: """ Cross-Encoder 重排序器包装类 支持 BAAI/bge-reranker-base 等模型。 """ def __init__( self, model_name: str = "BAAI/bge-reranker-base", top_n: int = 5, device: Optional[str] = None, cache_folder: Optional[str] = None, ): """ 初始化重排序器 Args: model_name: 模型名称或路径 top_n: 返回的顶部文档数量 device: 设备(cpu/cuda),如果为 None 则自动选择 cache_folder: 模型缓存目录 """ self.model_name = model_name self.top_n = top_n self.device = device self.cache_folder = cache_folder or os.path.join( os.path.expanduser("~"), ".cache", "sentence_transformers" ) # 延迟加载模型 self._model = None self._langchain_reranker = None def _load_model(self): """加载交叉编码器模型""" if self._model is None: try: self._model = CrossEncoder( self.model_name, device=self.device, cache_folder=self.cache_folder, ) except Exception as e: # 如果指定模型加载失败,尝试备用模型 print(f"加载模型 {self.model_name} 失败: {e}") print("尝试加载备用模型 BAAI/bge-reranker-v2-m3...") self._model = CrossEncoder( "BAAI/bge-reranker-v2-m3", device=self.device, cache_folder=self.cache_folder, ) def _create_langchain_reranker(self): """创建 LangChain 重排序器""" if self._langchain_reranker is None: self._load_model() self._langchain_reranker = CrossEncoderReranker( model=self._model, top_n=self.top_n, ) def rerank( self, query: str, documents: List[Document], ) -> List[Document]: """ 对文档进行重排序 Args: query: 查询文本 documents: 待排序文档列表 Returns: 重排序后的文档列表 """ self._create_langchain_reranker() return self._langchain_reranker.compress_documents( documents=documents, query=query, ) def create_contextual_compression_retriever( self, base_retriever: Any, ) -> Any: """ 创建上下文压缩检索器 Args: base_retriever: 基础检索器 Returns: 上下文压缩检索器 """ from langchain.retrievers import ContextualCompressionRetriever self._create_langchain_reranker() compression_retriever = ContextualCompressionRetriever( base_compressor=self._langchain_reranker, base_retriever=base_retriever, ) return compression_retriever @classmethod def create_from_config( cls, config: Optional[Dict[str, Any]] = None, ) -> "CrossEncoderReranker": """ 从配置创建重排序器 Args: config: 配置字典,包含 model_name, top_n, device 等 Returns: CrossEncoderReranker 实例 """ config = config or {} return cls( model_name=config.get("model_name", "BAAI/bge-reranker-base"), top_n=config.get("top_n", 5), device=config.get("device", None), cache_folder=config.get("cache_folder", None), )