66 lines
2.0 KiB
Python
66 lines
2.0 KiB
Python
"""
|
|
重排序器模块
|
|
|
|
使用 Cross-Encoder 模型对检索结果进行重排序,提高检索精度。
|
|
"""
|
|
|
|
from typing import List
|
|
from langchain_core.documents import Document
|
|
|
|
|
|
class CrossEncoderReranker:
|
|
"""使用 Cross-Encoder 对检索结果重排序。"""
|
|
|
|
def __init__(self, model_name: str = "BAAI/bge-reranker-base", top_n: int = 5):
|
|
"""
|
|
初始化重排序器
|
|
|
|
Args:
|
|
model_name: 预训练模型名称
|
|
top_n: 返回前 N 个结果
|
|
"""
|
|
self.model_name = model_name
|
|
self.top_n = top_n
|
|
self.model = None
|
|
|
|
# 尝试加载 Cross-Encoder 模型
|
|
try:
|
|
from sentence_transformers import CrossEncoder
|
|
self.model = CrossEncoder(model_name)
|
|
except Exception as e:
|
|
print(f"警告: 无法加载 Cross-Encoder 模型 {model_name},将使用简单排序作为回退方案。错误: {e}")
|
|
|
|
def compress_documents(
|
|
self, documents: List[Document], query: str
|
|
) -> List[Document]:
|
|
"""
|
|
对文档进行重排序
|
|
|
|
Args:
|
|
documents: 待排序的文档列表
|
|
query: 查询字符串
|
|
|
|
Returns:
|
|
排序后的文档列表
|
|
"""
|
|
if not documents:
|
|
return []
|
|
|
|
# 如果模型加载失败,返回前 top_n 个文档
|
|
if self.model is None:
|
|
return documents[:self.top_n]
|
|
|
|
# 使用 Cross-Encoder 进行重排序
|
|
try:
|
|
pairs = [[query, doc.page_content] for doc in documents]
|
|
scores = self.model.predict(pairs)
|
|
|
|
# 按分数降序排序
|
|
scored_docs = sorted(
|
|
zip(documents, scores), key=lambda x: x[1], reverse=True
|
|
)
|
|
return [doc for doc, _ in scored_docs[:self.top_n]]
|
|
except Exception as e:
|
|
print(f"警告: 重排序过程出错,将使用原始排序。错误: {e}")
|
|
return documents[:self.top_n]
|