Files
ailine/backend/app/rag/rerank.py
root 3ae9daa01a
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m44s
导入方式修改
2026-05-05 23:17:00 +08:00

105 lines
3.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
重排业务逻辑模块
本模块包含 RAG 相关的重排业务逻辑文档处理、排序、top_n
使用 model_services/rerank_services.py 提供的纯服务层
"""
import logging
from typing import List
from langchain_core.documents import Document
from ..model_services import get_rerank_service
logger = logging.getLogger(__name__)
class DocumentReranker:
"""
文档重排器 - 业务逻辑层
负责:
- 从 Document 提取内容
- 调用 rerank service 获取得分
- 根据得分排序
- 返回 top_n 文档
"""
def __init__(self, rerank_service=None):
"""
初始化文档重排器
Args:
rerank_service: 重排服务(可选,默认通过 get_rerank_service() 获取)
"""
self._rerank_service = rerank_service or get_rerank_service()
def compress_documents(
self,
documents: List[Document],
query: str,
top_n: int = 5
) -> List[Document]:
"""
对文档进行重排 - 业务逻辑
Args:
documents: 待排序的文档列表
query: 查询字符串
top_n: 返回前 N 个结果
Returns:
List[Document]: 排序后的文档列表
"""
if not documents:
return []
try:
# 1. 从 Document 提取内容(业务逻辑)
doc_contents = [doc.page_content for doc in documents]
logger.info(f"[Rerank] 收到 {len(documents)} 个文档待重排, query={query[:50]}")
total_chars = sum(len(c) for c in doc_contents)
logger.info(f"[Rerank] 各文档长度: {[len(c) for c in doc_contents]}, 总字符数: {total_chars}")
# 粗略估算 tokens (中文约 0.75 tokens/字符)
estimated_tokens = int(total_chars * 0.75)
logger.info(f"[Rerank] 估算总 tokens: ~{estimated_tokens} (假设中文)")
# 2. 调用纯服务层计算得分
logger.info(f"[Rerank] 正在调用 rerank service: {type(self._rerank_service).__name__}")
scores = self._rerank_service.compute_scores(query, doc_contents)
logger.info(f"[Rerank] 获取到 {len(scores)} 个得分: {scores}")
# 3. 根据得分排序(业务逻辑)
doc_score_pairs = list(zip(documents, scores))
doc_score_pairs_sorted = sorted(doc_score_pairs, key=lambda x: x[1], reverse=True)
logger.info(f"[Rerank] 排序后的结果:")
for i, (doc, score) in enumerate(doc_score_pairs_sorted):
logger.info(f" [{i}] score={score:.4f}, content={doc.page_content[:80]}...")
# 4. 取 top_n
top_docs = [pair[0] for pair in doc_score_pairs_sorted[:top_n]]
return top_docs
except Exception as e:
logger.warning(f"重排过程出错,返回原始前 {top_n} 个结果: {e}")
logger.warning(f"[Rerank] 异常详情: {type(e).__name__}: {e}")
import traceback
logger.warning(f"[Rerank] 堆栈: {traceback.format_exc()}")
return documents[:top_n]
def create_document_reranker(rerank_service=None) -> DocumentReranker:
"""
创建文档重排器的工厂函数
Args:
rerank_service: 重排服务(可选)
Returns:
DocumentReranker: 文档重排器实例
"""
return DocumentReranker(rerank_service)