Files
ailine/app/rag/reranker.py

66 lines
2.0 KiB
Python
Raw Normal View History

2026-04-18 16:31:48 +08:00
"""
2026-04-19 22:01:55 +08:00
重排序器模块
2026-04-18 16:31:48 +08:00
2026-04-19 22:01:55 +08:00
使用 Cross-Encoder 模型对检索结果进行重排序提高检索精度
2026-04-18 16:31:48 +08:00
"""
2026-04-19 22:01:55 +08:00
from typing import List
2026-04-18 16:31:48 +08:00
from langchain_core.documents import Document
class CrossEncoderReranker:
2026-04-19 22:01:55 +08:00
"""使用 Cross-Encoder 对检索结果重排序。"""
2026-04-18 16:31:48 +08:00
2026-04-19 22:01:55 +08:00
def __init__(self, model_name: str = "BAAI/bge-reranker-base", top_n: int = 5):
2026-04-18 16:31:48 +08:00
"""
初始化重排序器
Args:
2026-04-19 22:01:55 +08:00
model_name: 预训练模型名称
top_n: 返回前 N 个结果
2026-04-18 16:31:48 +08:00
"""
self.model_name = model_name
self.top_n = top_n
2026-04-19 22:01:55 +08:00
self.model = None
2026-04-18 16:31:48 +08:00
2026-04-19 22:01:55 +08:00
# 尝试加载 Cross-Encoder 模型
try:
from sentence_transformers import CrossEncoder
self.model = CrossEncoder(model_name)
except Exception as e:
print(f"警告: 无法加载 Cross-Encoder 模型 {model_name},将使用简单排序作为回退方案。错误: {e}")
2026-04-18 16:31:48 +08:00
2026-04-19 22:01:55 +08:00
def compress_documents(
self, documents: List[Document], query: str
2026-04-18 16:31:48 +08:00
) -> List[Document]:
"""
对文档进行重排序
Args:
2026-04-19 22:01:55 +08:00
documents: 待排序的文档列表
query: 查询字符串
2026-04-18 16:31:48 +08:00
Returns:
2026-04-19 22:01:55 +08:00
排序后的文档列表
2026-04-18 16:31:48 +08:00
"""
2026-04-19 22:01:55 +08:00
if not documents:
return []
2026-04-18 16:31:48 +08:00
2026-04-19 22:01:55 +08:00
# 如果模型加载失败,返回前 top_n 个文档
if self.model is None:
return documents[:self.top_n]
2026-04-18 16:31:48 +08:00
2026-04-19 22:01:55 +08:00
# 使用 Cross-Encoder 进行重排序
try:
pairs = [[query, doc.page_content] for doc in documents]
scores = self.model.predict(pairs)
2026-04-18 16:31:48 +08:00
2026-04-19 22:01:55 +08:00
# 按分数降序排序
scored_docs = sorted(
zip(documents, scores), key=lambda x: x[1], reverse=True
)
return [doc for doc, _ in scored_docs[:self.top_n]]
except Exception as e:
print(f"警告: 重排序过程出错,将使用原始排序。错误: {e}")
return documents[:self.top_n]