""" 重排序器模块 使用 Cross-Encoder 模型对检索结果进行重排序,提高检索精度。 """ from typing import List from langchain_core.documents import Document class CrossEncoderReranker: """使用 Cross-Encoder 对检索结果重排序。""" def __init__(self, model_name: str = "BAAI/bge-reranker-base", top_n: int = 5): """ 初始化重排序器 Args: model_name: 预训练模型名称 top_n: 返回前 N 个结果 """ self.model_name = model_name self.top_n = top_n self.model = None # 尝试加载 Cross-Encoder 模型 try: from sentence_transformers import CrossEncoder self.model = CrossEncoder(model_name) except Exception as e: print(f"警告: 无法加载 Cross-Encoder 模型 {model_name},将使用简单排序作为回退方案。错误: {e}") def compress_documents( self, documents: List[Document], query: str ) -> List[Document]: """ 对文档进行重排序 Args: documents: 待排序的文档列表 query: 查询字符串 Returns: 排序后的文档列表 """ if not documents: return [] # 如果模型加载失败,返回前 top_n 个文档 if self.model is None: return documents[:self.top_n] # 使用 Cross-Encoder 进行重排序 try: pairs = [[query, doc.page_content] for doc in documents] scores = self.model.predict(pairs) # 按分数降序排序 scored_docs = sorted( zip(documents, scores), key=lambda x: x[1], reverse=True ) return [doc for doc, _ in scored_docs[:self.top_n]] except Exception as e: print(f"警告: 重排序过程出错,将使用原始排序。错误: {e}") return documents[:self.top_n]