Files
ailine/backend/app/model_services/rerank_services.py
root f63c394fcd
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Has been cancelled
refactor: 重构 rerank 架构,分离服务层和业务逻辑
- rerank_services.py:纯服务层,只负责调用 rerank server
- rag/rerank.py:业务逻辑层,负责文档处理、排序、top_n
- 更新 pipeline.py 使用新架构
- 架构与 embedding_services.py 保持一致
2026-04-26 11:57:42 +08:00

235 lines
7.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
重排模型服务模块
本模块提供统一的重排模型服务获取接口,支持自动降级:
1. 优先使用本地 llama.cpp 重排服务
2. 本地服务不可用时,自动降级到智谱 API 重排服务
主要功能:
- LocalLlamaCppRerankProvider本地 llama.cpp 重排服务提供者
- ZhipuRerankProvider智谱 API 重排服务提供者
- get_rerank_service():获取重排服务的统一接口
注意:本模块只负责调用 rerank server不包含业务逻辑文档处理、排序、top_n
业务逻辑放在 backend/app/rag/ 目录下
"""
import logging
from typing import List
import httpx
from .base import (
BaseServiceProvider,
FallbackServiceChain,
SingletonServiceManager
)
from ..config import (
LLAMACPP_RERANKER_URL,
LLAMACPP_API_KEY,
ZHIPUAI_API_KEY,
ZHIPU_RERANK_MODEL,
ZHIPU_API_BASE
)
logger = logging.getLogger(__name__)
class BaseRerankService:
"""
重排服务基类 - 纯服务层,只负责调用 server
不包含业务逻辑文档处理、排序、top_n 等在 rag/ 目录下)
"""
def compute_scores(self, query: str, documents: List[str]) -> List[float]:
"""
计算每个文档与查询的相关性得分 - 纯 API 调用
Args:
query: 查询字符串
documents: 文档字符串列表
Returns:
List[float]: 每个文档的相关性得分列表
"""
raise NotImplementedError
class LocalLlamaCppRerankService(BaseRerankService):
"""
本地 llama.cpp 重排服务 - 纯服务层
"""
def __init__(self, base_url: str, api_key: str, model: str = "bge-reranker-v2-m3"):
self.base_url = base_url
self.api_key = api_key
self.model = model
def compute_scores(self, query: str, documents: List[str]) -> List[float]:
"""
调用 llama.cpp rerank API 计算得分 - 纯 API 调用
"""
if not documents:
return []
headers = {"Content-Type": "application/json"}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
base = self.base_url.rstrip("/")
if not base.endswith("/v1"):
base = base + "/v1"
payload = {
"model": self.model,
"query": query,
"documents": documents,
}
with httpx.Client(timeout=120) as client:
response = client.post(
f"{base}/rerank",
headers=headers,
json=payload,
)
response.raise_for_status()
data = response.json()
if isinstance(data, dict) and "results" in data:
results = data["results"]
results_sorted = sorted(results, key=lambda x: x["index"])
return [item["relevance_score"] for item in results_sorted]
else:
raise ValueError(f"未知的 rerank API 响应格式: {data}")
class ZhipuRerankService(BaseRerankService):
"""
智谱 API 重排服务 - 纯服务层
"""
def __init__(self, model: str | None = None):
self.model = model or ZHIPU_RERANK_MODEL
self.api_key = ZHIPUAI_API_KEY
def compute_scores(self, query: str, documents: List[str]) -> List[float]:
"""
调用智谱 rerank API 计算得分 - 纯 API 调用
"""
if not documents:
return []
try:
from zhipuai import ZhipuAI
client = ZhipuAI(api_key=self.api_key)
response = client.rerank.create(
model=self.model,
query=query,
documents=documents,
)
results_sorted = sorted(response.results, key=lambda x: x.index)
return [item.relevance_score for item in results_sorted]
except Exception as e:
logger.warning(f"智谱 rerank 调用失败: {e}")
raise
class LocalLlamaCppRerankProvider(BaseServiceProvider[BaseRerankService]):
"""
本地 llama.cpp 重排服务提供者
"""
def __init__(self, model: str = "bge-reranker-v2-m3"):
super().__init__("local_llamacpp_rerank")
self._model = model
def is_available(self) -> bool:
"""
检查本地 llama.cpp 重排服务是否可用
"""
if not LLAMACPP_RERANKER_URL:
logger.warning("LLAMACPP_RERANKER_URL 未配置")
return False
try:
service = LocalLlamaCppRerankService(
base_url=LLAMACPP_RERANKER_URL,
api_key=LLAMACPP_API_KEY,
model=self._model
)
test_scores = service.compute_scores("test query", ["test document"])
logger.info(f"本地 llama.cpp 重排服务可用")
return True
except Exception as e:
logger.warning(f"本地 llama.cpp 重排服务不可用: {e}")
return False
def get_service(self) -> BaseRerankService:
"""
获取本地 llama.cpp 重排服务
"""
if self._service_instance is None:
self._service_instance = LocalLlamaCppRerankService(
base_url=LLAMACPP_RERANKER_URL,
api_key=LLAMACPP_API_KEY,
model=self._model
)
return self._service_instance
class ZhipuRerankProvider(BaseServiceProvider[BaseRerankService]):
"""
智谱 API 重排服务提供者
"""
def __init__(self, model: str | None = None):
super().__init__("zhipu_rerank")
self._model = model or ZHIPU_RERANK_MODEL
def is_available(self) -> bool:
"""
检查智谱 API 重排服务是否可用
"""
if not ZHIPUAI_API_KEY:
logger.warning("ZHIPUAI_API_KEY 未配置")
return False
try:
service = ZhipuRerankService(model=self._model)
test_scores = service.compute_scores("test query", ["test document"])
logger.info(f"智谱重排服务可用")
return True
except ImportError:
logger.warning("zhipuai 库未安装")
return False
except Exception as e:
logger.warning(f"智谱重排服务不可用: {e}")
return False
def get_service(self) -> BaseRerankService:
"""
获取智谱 API 重排服务
"""
if self._service_instance is None:
self._service_instance = ZhipuRerankService(model=self._model)
return self._service_instance
def get_rerank_service() -> BaseRerankService:
"""
获取重排服务(带自动降级)- 纯服务层
Returns:
BaseRerankService: 重排服务实例
"""
def _create_chain():
primary = LocalLlamaCppRerankProvider()
fallback = ZhipuRerankProvider()
return FallbackServiceChain(primary, [fallback])
chain = SingletonServiceManager.get_or_create("rerank_service_chain", _create_chain)
return chain.get_available_service()