Files
ailine/backend/app/model_services/rerank_services.py
root 1dc1ecad62
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 9m56s
优化: LLM降级重排一次调用给所有文档打分
2026-05-06 17:25:42 +08:00

422 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
重排模型服务模块
本模块提供统一的重排模型服务获取接口,支持自动降级:
1. 优先使用本地 llama.cpp 重排服务
2. 本地服务不可用时,自动降级到硅基流动(SiliconFlow) API 重排服务
3. 硅基流动服务不可用时,自动降级到智谱 API 重排服务
4. 所有API服务不可用时自动降级到 LLM 评分重排服务
主要功能:
- LocalLlamaCppRerankProvider本地 llama.cpp 重排服务提供者
- SiliconFlowRerankProvider硅基流动 API 重排服务提供者
- ZhipuRerankProvider智谱 API 重排服务提供者
- LLMFallbackRerankProviderLLM 评分降级重排服务提供者
- get_rerank_service():获取重排服务的统一接口
注意:本模块只负责调用 rerank server不包含业务逻辑文档处理、排序、top_n
业务逻辑放在 backend/app/rag/ 目录下
"""
import logging
from typing import List
import httpx
from .base import (
BaseServiceProvider,
FallbackServiceChain,
SingletonServiceManager
)
from backend.app.config import (
LLAMACPP_RERANKER_URL,
LLAMACPP_API_KEY,
ZHIPUAI_API_KEY,
ZHIPU_RERANK_MODEL,
ZHIPU_API_BASE,
SILICONFLOW_API_KEY,
SILICONFLOW_RERANK_MODEL,
SILICONFLOW_API_BASE
)
logger = logging.getLogger(__name__)
class BaseRerankService:
"""
重排服务基类 - 纯服务层,只负责调用 server
不包含业务逻辑文档处理、排序、top_n 等在 rag/ 目录下)
"""
def compute_scores(self, query: str, documents: List[str]) -> List[float]:
"""
计算每个文档与查询的相关性得分 - 纯 API 调用
Args:
query: 查询字符串
documents: 文档字符串列表
Returns:
List[float]: 每个文档的相关性得分列表
"""
raise NotImplementedError
class LocalLlamaCppRerankService(BaseRerankService):
"""
本地 llama.cpp 重排服务 - 纯服务层
"""
def __init__(self, base_url: str, api_key: str, model: str = "bge-reranker-v2-m3"):
self.base_url = base_url
self.api_key = api_key
self.model = model
def compute_scores(self, query: str, documents: List[str]) -> List[float]:
"""
调用 llama.cpp rerank API 计算得分 - 纯 API 调用
"""
if not documents:
return []
headers = {"Content-Type": "application/json"}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
base = self.base_url.rstrip("/")
if not base.endswith("/v1"):
base = base + "/v1"
payload = {
"model": self.model,
"query": query,
"documents": documents,
}
logger.info(f"[LocalLlamaCppRerank] 调用 rerank API: {base}/rerank")
logger.info(f"[LocalLlamaCppRerank] 请求 payload: query={query[:50]}, documents数量={len(documents)}")
with httpx.Client(timeout=120) as client:
response = client.post(
f"{base}/rerank",
headers=headers,
json=payload,
)
logger.info(f"[LocalLlamaCppRerank] 响应状态码: {response.status_code}")
if response.status_code != 200:
logger.error(f"[LocalLlamaCppRerank] 请求失败: {response.status_code}")
logger.error(f"[LocalLlamaCppRerank] 响应内容: {response.text[:500]}")
response.raise_for_status()
data = response.json()
logger.info(f"[LocalLlamaCppRerank] 响应数据类型: {type(data)}")
if isinstance(data, dict) and "results" in data:
results = data["results"]
results_sorted = sorted(results, key=lambda x: x["index"])
scores = [item["relevance_score"] for item in results_sorted]
logger.info(f"[LocalLlamaCppRerank] 返回 {len(scores)} 个得分")
return scores
else:
logger.error(f"[LocalLlamaCppRerank] 未知响应格式: {type(data)}")
raise ValueError(f"未知的 rerank API 响应格式: {data}")
class ZhipuRerankService(BaseRerankService):
"""
智谱 API 重排服务 - 纯服务层
"""
def __init__(self, model: str | None = None):
self.model = model or ZHIPU_RERANK_MODEL
self.api_key = ZHIPUAI_API_KEY
def compute_scores(self, query: str, documents: List[str]) -> List[float]:
"""
调用智谱 rerank API 计算得分 - 纯 API 调用
"""
if not documents:
return []
try:
from zhipuai import ZhipuAI
client = ZhipuAI(api_key=self.api_key)
response = client.rerank.create(
model=self.model,
query=query,
documents=documents,
)
results_sorted = sorted(response.results, key=lambda x: x.index)
return [item.relevance_score for item in results_sorted]
except Exception as e:
logger.warning(f"智谱 rerank 调用失败: {e}")
raise
class SiliconFlowRerankService(BaseRerankService):
"""
硅基流动(SiliconFlow) API 重排服务 - 纯服务层
"""
def __init__(self, model: str | None = None, api_key: str | None = None, api_base: str | None = None):
self.model = model or SILICONFLOW_RERANK_MODEL
self.api_key = api_key or SILICONFLOW_API_KEY
self.api_base = api_base or SILICONFLOW_API_BASE
def compute_scores(self, query: str, documents: List[str]) -> List[float]:
"""
调用 SiliconFlow rerank API 计算得分 - 纯 API 调用
"""
if not documents:
return []
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
base = self.api_base.rstrip("/")
payload = {
"model": self.model,
"query": query,
"documents": documents,
"return_documents": False
}
with httpx.Client(timeout=120) as client:
response = client.post(
f"{base}/rerank",
headers=headers,
json=payload,
)
response.raise_for_status()
data = response.json()
if isinstance(data, dict) and "results" in data:
results = data["results"]
results_sorted = sorted(results, key=lambda x: x["index"])
return [item["relevance_score"] for item in results_sorted]
else:
raise ValueError(f"未知的 SiliconFlow rerank API 响应格式: {data}")
class LLMFallbackRerankService(BaseRerankService):
"""
使用 LLM 作为最后的降级方案进行重排
通过让 LLM 评估文档相关性并给出分数
"""
def __init__(self, llm=None):
from .chat_services import get_chat_service
self.llm = llm or get_chat_service()
def compute_scores(self, query: str, documents: List[str]) -> List[float]:
"""
使用 LLM 评估文档相关性并打分 - 一次调用给所有文档打分!
"""
if not documents:
return []
logger.info(f"[LLMFallbackRerank] 开始为 {len(documents)} 个文档打分")
# 构建提示词,一次性给所有文档打分
docs_str = "\n".join([f"文档{i}: {doc[:500]}..." for i, doc in enumerate(documents)])
prompt = f"""你是一个文档相关性评分专家。请评估以下文档与查询的相关性为每个文档返回一个0到1之间的分数
- 1.0表示完全相关
- 0.0表示完全不相关
查询: {query}
{docs_str}
请按以下JSON格式返回不要解释
{{
"scores": [
0.95,
0.12,
...
]
}}"""
try:
result = self.llm.invoke(prompt)
content = result.content if hasattr(result, 'content') else str(result)
# 尝试提取 JSON
import json
import re
json_match = re.search(r'\{[\s\S]*\}', content)
if json_match:
data = json.loads(json_match.group())
scores = data.get("scores", [])
# 确保分数数量匹配且在0-1之间
if len(scores) == len(documents):
scores = [max(0.0, min(1.0, s)) for s in scores]
for i, score in enumerate(scores):
logger.info(f"[LLMFallbackRerank] doc[{i}] score={score:.4f}")
return scores
logger.warning(f"[LLMFallbackRerank] 无法从LLM响应提取分数返回默认值")
return [0.5 for _ in documents]
except Exception as e:
logger.warning(f"[LLMFallbackRerank] LLM 打分失败,返回默认分数 0.5: {e}")
return [0.5 for _ in documents]
class LLMFallbackRerankProvider(BaseServiceProvider[BaseRerankService]):
"""
LLM 降级重排服务提供者
"""
def __init__(self, llm=None):
super().__init__("llm_fallback_rerank")
self._llm = llm
def is_available(self) -> bool:
"""
LLM 降级方案总是可用(只要 LLM 服务可用)
"""
try:
from .chat_services import get_chat_service
get_chat_service()
logger.info("LLM 降级重排服务可用")
return True
except Exception as e:
logger.warning(f"LLM 降级重排服务不可用: {e}")
return False
def get_service(self) -> BaseRerankService:
"""
获取 LLM 降级重排服务
"""
if self._service_instance is None:
self._service_instance = LLMFallbackRerankService(self._llm)
return self._service_instance
class LocalLlamaCppRerankProvider(BaseServiceProvider[BaseRerankService]):
"""
本地 llama.cpp 重排服务提供者
"""
def __init__(self, model: str = "bge-reranker-v2-m3"):
super().__init__("local_llamacpp_rerank")
self._model = model
def is_available(self) -> bool:
"""
检查本地 llama.cpp 重排服务是否可用
"""
if not LLAMACPP_RERANKER_URL:
logger.warning("LLAMACPP_RERANKER_URL 未配置")
return False
try:
service = LocalLlamaCppRerankService(
base_url=LLAMACPP_RERANKER_URL,
api_key=LLAMACPP_API_KEY,
model=self._model
)
test_scores = service.compute_scores("test query", ["test document"])
logger.info(f"本地 llama.cpp 重排服务可用")
return True
except Exception as e:
logger.warning(f"本地 llama.cpp 重排服务不可用: {e}")
return False
def get_service(self) -> BaseRerankService:
"""
获取本地 llama.cpp 重排服务
"""
if self._service_instance is None:
self._service_instance = LocalLlamaCppRerankService(
base_url=LLAMACPP_RERANKER_URL,
api_key=LLAMACPP_API_KEY,
model=self._model
)
return self._service_instance
class ZhipuRerankProvider(BaseServiceProvider[BaseRerankService]):
"""
智谱 API 重排服务提供者
"""
def __init__(self, model: str | None = None):
super().__init__("zhipu_rerank")
self._model = model or ZHIPU_RERANK_MODEL
def is_available(self) -> bool:
"""
检查智谱 API 重排服务是否可用
注意zhipuai 库 2.0.1 版本没有 rerank API直接返回 False降级到 LLM 方案
"""
logger.warning("智谱 rerank API 在当前 zhipuai 库版本不可用,降级到 LLM 方案")
return False
def get_service(self) -> BaseRerankService:
"""
获取智谱 API 重排服务
"""
if self._service_instance is None:
self._service_instance = ZhipuRerankService(model=self._model)
return self._service_instance
class SiliconFlowRerankProvider(BaseServiceProvider[BaseRerankService]):
"""
硅基流动(SiliconFlow) API 重排服务提供者
"""
def __init__(self, model: str | None = None):
super().__init__("siliconflow_rerank")
self._model = model or SILICONFLOW_RERANK_MODEL
def is_available(self) -> bool:
"""
检查 SiliconFlow API 重排服务是否可用
"""
if not SILICONFLOW_API_KEY:
logger.warning("SILICONFLOW_API_KEY 未配置")
return False
try:
service = SiliconFlowRerankService(model=self._model)
test_scores = service.compute_scores("test query", ["test document"])
logger.info("SiliconFlow 重排服务可用")
return True
except Exception as e:
logger.warning(f"SiliconFlow 重排服务不可用: {e}")
return False
def get_service(self) -> BaseRerankService:
"""
获取 SiliconFlow API 重排服务
"""
if self._service_instance is None:
self._service_instance = SiliconFlowRerankService(model=self._model)
return self._service_instance
def get_rerank_service() -> BaseRerankService:
"""
获取重排服务(带自动降级)- 纯服务层
降级链: Local llama.cpp -> SiliconFlow Rerank -> Zhipu Rerank -> LLM Fallback
Returns:
BaseRerankService: 重排服务实例
"""
def _create_chain():
primary = LocalLlamaCppRerankProvider()
fallbacks = [SiliconFlowRerankProvider(), ZhipuRerankProvider(), LLMFallbackRerankProvider()]
return FallbackServiceChain(primary, fallbacks)
chain = SingletonServiceManager.get_or_create("rerank_service_chain", _create_chain)
return chain.get_available_service()