105 lines
3.4 KiB
Python
105 lines
3.4 KiB
Python
"""
|
||
重排业务逻辑模块
|
||
|
||
本模块包含 RAG 相关的重排业务逻辑(文档处理、排序、top_n)
|
||
使用 model_services/rerank_services.py 提供的纯服务层
|
||
"""
|
||
|
||
import logging
|
||
from typing import List
|
||
from langchain_core.documents import Document
|
||
|
||
from ..model_services import get_rerank_service
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class DocumentReranker:
|
||
"""
|
||
文档重排器 - 业务逻辑层
|
||
|
||
负责:
|
||
- 从 Document 提取内容
|
||
- 调用 rerank service 获取得分
|
||
- 根据得分排序
|
||
- 返回 top_n 文档
|
||
"""
|
||
|
||
def __init__(self, rerank_service=None):
|
||
"""
|
||
初始化文档重排器
|
||
|
||
Args:
|
||
rerank_service: 重排服务(可选,默认通过 get_rerank_service() 获取)
|
||
"""
|
||
self._rerank_service = rerank_service or get_rerank_service()
|
||
|
||
def compress_documents(
|
||
self,
|
||
documents: List[Document],
|
||
query: str,
|
||
top_n: int = 5
|
||
) -> List[Document]:
|
||
"""
|
||
对文档进行重排 - 业务逻辑
|
||
|
||
Args:
|
||
documents: 待排序的文档列表
|
||
query: 查询字符串
|
||
top_n: 返回前 N 个结果
|
||
|
||
Returns:
|
||
List[Document]: 排序后的文档列表
|
||
"""
|
||
if not documents:
|
||
return []
|
||
|
||
try:
|
||
# 1. 从 Document 提取内容(业务逻辑)
|
||
doc_contents = [doc.page_content for doc in documents]
|
||
logger.info(f"[Rerank] 收到 {len(documents)} 个文档待重排, query={query[:50]}")
|
||
total_chars = sum(len(c) for c in doc_contents)
|
||
logger.info(f"[Rerank] 各文档长度: {[len(c) for c in doc_contents]}, 总字符数: {total_chars}")
|
||
# 粗略估算 tokens (中文约 0.75 tokens/字符)
|
||
estimated_tokens = int(total_chars * 0.75)
|
||
logger.info(f"[Rerank] 估算总 tokens: ~{estimated_tokens} (假设中文)")
|
||
|
||
# 2. 调用纯服务层计算得分
|
||
logger.info(f"[Rerank] 正在调用 rerank service: {type(self._rerank_service).__name__}")
|
||
scores = self._rerank_service.compute_scores(query, doc_contents)
|
||
logger.info(f"[Rerank] 获取到 {len(scores)} 个得分: {scores}")
|
||
|
||
# 3. 根据得分排序(业务逻辑)
|
||
doc_score_pairs = list(zip(documents, scores))
|
||
doc_score_pairs_sorted = sorted(doc_score_pairs, key=lambda x: x[1], reverse=True)
|
||
|
||
logger.info(f"[Rerank] 排序后的结果:")
|
||
for i, (doc, score) in enumerate(doc_score_pairs_sorted):
|
||
logger.info(f" [{i}] score={score:.4f}, content={doc.page_content[:80]}...")
|
||
|
||
# 4. 取 top_n
|
||
top_docs = [pair[0] for pair in doc_score_pairs_sorted[:top_n]]
|
||
|
||
return top_docs
|
||
|
||
except Exception as e:
|
||
logger.warning(f"重排过程出错,返回原始前 {top_n} 个结果: {e}")
|
||
logger.warning(f"[Rerank] 异常详情: {type(e).__name__}: {e}")
|
||
import traceback
|
||
logger.warning(f"[Rerank] 堆栈: {traceback.format_exc()}")
|
||
return documents[:top_n]
|
||
|
||
|
||
def create_document_reranker(rerank_service=None) -> DocumentReranker:
|
||
"""
|
||
创建文档重排器的工厂函数
|
||
|
||
Args:
|
||
rerank_service: 重排服务(可选)
|
||
|
||
Returns:
|
||
DocumentReranker: 文档重排器实例
|
||
"""
|
||
return DocumentReranker(rerank_service)
|
||
|