Files
ailine/backend/app/rag/rerank.py
root 1260bef5cb
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m31s
添加rag置信度判断
2026-05-06 01:15:52 +08:00

99 lines
2.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
重排业务逻辑模块
本模块包含 RAG 相关的重排业务逻辑文档处理、排序、top_n
使用 model_services/rerank_services.py 提供的纯服务层
"""
import logging
from typing import List
from langchain_core.documents import Document
from ..model_services import get_rerank_service
logger = logging.getLogger(__name__)
class DocumentReranker:
"""
文档重排器 - 业务逻辑层
负责:
- 从 Document 提取内容
- 调用 rerank service 获取得分
- 根据得分排序
- 返回 top_n 文档
"""
def __init__(self, rerank_service=None):
"""
初始化文档重排器
Args:
rerank_service: 重排服务(可选,默认通过 get_rerank_service() 获取)
"""
self._rerank_service = rerank_service or get_rerank_service()
def compress_documents(
self,
documents: List[Document],
query: str,
top_n: int = 5
) -> List[Document]:
"""
对文档进行重排 - 业务逻辑
Args:
documents: 待排序的文档列表
query: 查询字符串
top_n: 返回前 N 个结果
Returns:
List[Document]: 排序后的文档列表,每个文档的 metadata 中包含 rerank_score
"""
if not documents:
return []
try:
# 1. 从 Document 提取内容
doc_contents = [doc.page_content for doc in documents]
logger.info(f"[Rerank] 收到 {len(documents)} 个文档待重排")
# 2. 调用重排服务计算得分
scores = self._rerank_service.compute_scores(query, doc_contents)
logger.info(f"[Rerank] 获取到 {len(scores)} 个得分")
# 3. 构建 (文档, 分数) 对并排序
doc_score_pairs = list(zip(documents, scores))
doc_score_pairs_sorted = sorted(doc_score_pairs, key=lambda x: x[1], reverse=True)
# 4. 取 top_n并添加 rerank_score 到 metadata
top_docs = []
for doc, score in doc_score_pairs_sorted[:top_n]:
# 创建新文档,添加 rerank_score
new_doc = Document(
page_content=doc.page_content,
metadata={**doc.metadata, "rerank_score": score}
)
top_docs.append(new_doc)
return top_docs
except Exception as e:
logger.warning(f"[Rerank] 重排失败,返回原始结果: {e}")
return documents[:top_n]
def create_document_reranker(rerank_service=None) -> DocumentReranker:
"""
创建文档重排器的工厂函数
Args:
rerank_service: 重排服务(可选)
Returns:
DocumentReranker: 文档重排器实例
"""
return DocumentReranker(rerank_service)