Files
ailine/backend/app/model_services/rerank_services.py
root 3bf0446ef8
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 5m37s
feat: 修复数据库持久化,完善服务降级机制
- 恢复使用 AsyncPostgresSaver 持久化短期记忆
- 添加 LLM 作为 Rerank 服务的最后降级方案
- 完善降级链:Local llama.cpp → Zhipu Rerank → LLM Fallback
2026-04-30 17:45:06 +08:00

323 lines
9.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
重排模型服务模块
本模块提供统一的重排模型服务获取接口,支持自动降级:
1. 优先使用本地 llama.cpp 重排服务
2. 本地服务不可用时,自动降级到智谱 API 重排服务
主要功能:
- LocalLlamaCppRerankProvider本地 llama.cpp 重排服务提供者
- ZhipuRerankProvider智谱 API 重排服务提供者
- get_rerank_service():获取重排服务的统一接口
注意:本模块只负责调用 rerank server不包含业务逻辑文档处理、排序、top_n
业务逻辑放在 backend/app/rag/ 目录下
"""
import logging
from typing import List
import httpx
from .base import (
BaseServiceProvider,
FallbackServiceChain,
SingletonServiceManager
)
from app.config import (
LLAMACPP_RERANKER_URL,
LLAMACPP_API_KEY,
ZHIPUAI_API_KEY,
ZHIPU_RERANK_MODEL,
ZHIPU_API_BASE
)
logger = logging.getLogger(__name__)
class BaseRerankService:
"""
重排服务基类 - 纯服务层,只负责调用 server
不包含业务逻辑文档处理、排序、top_n 等在 rag/ 目录下)
"""
def compute_scores(self, query: str, documents: List[str]) -> List[float]:
"""
计算每个文档与查询的相关性得分 - 纯 API 调用
Args:
query: 查询字符串
documents: 文档字符串列表
Returns:
List[float]: 每个文档的相关性得分列表
"""
raise NotImplementedError
class LocalLlamaCppRerankService(BaseRerankService):
"""
本地 llama.cpp 重排服务 - 纯服务层
"""
def __init__(self, base_url: str, api_key: str, model: str = "bge-reranker-v2-m3"):
self.base_url = base_url
self.api_key = api_key
self.model = model
def compute_scores(self, query: str, documents: List[str]) -> List[float]:
"""
调用 llama.cpp rerank API 计算得分 - 纯 API 调用
"""
if not documents:
return []
headers = {"Content-Type": "application/json"}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
base = self.base_url.rstrip("/")
if not base.endswith("/v1"):
base = base + "/v1"
payload = {
"model": self.model,
"query": query,
"documents": documents,
}
with httpx.Client(timeout=120) as client:
response = client.post(
f"{base}/rerank",
headers=headers,
json=payload,
)
response.raise_for_status()
data = response.json()
if isinstance(data, dict) and "results" in data:
results = data["results"]
results_sorted = sorted(results, key=lambda x: x["index"])
return [item["relevance_score"] for item in results_sorted]
else:
raise ValueError(f"未知的 rerank API 响应格式: {data}")
class ZhipuRerankService(BaseRerankService):
"""
智谱 API 重排服务 - 纯服务层
"""
def __init__(self, model: str | None = None):
self.model = model or ZHIPU_RERANK_MODEL
self.api_key = ZHIPUAI_API_KEY
def compute_scores(self, query: str, documents: List[str]) -> List[float]:
"""
调用智谱 rerank API 计算得分 - 纯 API 调用
"""
if not documents:
return []
try:
from zhipuai import ZhipuAI
client = ZhipuAI(api_key=self.api_key)
response = client.rerank.create(
model=self.model,
query=query,
documents=documents,
)
results_sorted = sorted(response.results, key=lambda x: x.index)
return [item.relevance_score for item in results_sorted]
except Exception as e:
logger.warning(f"智谱 rerank 调用失败: {e}")
raise
class LLMFallbackRerankService(BaseRerankService):
"""
使用 LLM 作为最后的降级方案进行重排
通过让 LLM 评估文档相关性并给出分数
"""
def __init__(self, llm=None):
from .chat_services import get_chat_service
self.llm = llm or get_chat_service()
def compute_scores(self, query: str, documents: List[str]) -> List[float]:
"""
使用 LLM 评估文档相关性并打分
"""
if not documents:
return []
scores = []
for doc in documents:
score = self._score_single_document(query, doc)
scores.append(score)
return scores
def _score_single_document(self, query: str, document: str) -> float:
"""
让 LLM 为单个文档的相关性打分 (0.0-1.0)
"""
prompt = f"""你是一个文档相关性评分专家。请评估以下文档与查询的相关性返回一个0到1之间的分数
- 1.0表示完全相关
- 0.0表示完全不相关
查询: {query}
文档: {document}
请只返回一个数字,不要解释。"""
try:
result = self.llm.invoke(prompt)
content = result.content if hasattr(result, 'content') else str(result)
# 尝试提取数字
import re
match = re.search(r'(\d+\.?\d*)', content)
if match:
score = float(match.group(1))
# 确保在 0-1 之间
return max(0.0, min(1.0, score))
# 如果没有找到数字返回0.5作为默认值
return 0.5
except Exception as e:
logger.warning(f"LLM 打分失败,返回默认分数 0.5: {e}")
return 0.5
class LLMFallbackRerankProvider(BaseServiceProvider[BaseRerankService]):
"""
LLM 降级重排服务提供者
"""
def __init__(self, llm=None):
super().__init__("llm_fallback_rerank")
self._llm = llm
def is_available(self) -> bool:
"""
LLM 降级方案总是可用(只要 LLM 服务可用)
"""
try:
from .chat_services import get_chat_service
get_chat_service()
logger.info("LLM 降级重排服务可用")
return True
except Exception as e:
logger.warning(f"LLM 降级重排服务不可用: {e}")
return False
def get_service(self) -> BaseRerankService:
"""
获取 LLM 降级重排服务
"""
if self._service_instance is None:
self._service_instance = LLMFallbackRerankService(self._llm)
return self._service_instance
class LocalLlamaCppRerankProvider(BaseServiceProvider[BaseRerankService]):
"""
本地 llama.cpp 重排服务提供者
"""
def __init__(self, model: str = "bge-reranker-v2-m3"):
super().__init__("local_llamacpp_rerank")
self._model = model
def is_available(self) -> bool:
"""
检查本地 llama.cpp 重排服务是否可用
"""
if not LLAMACPP_RERANKER_URL:
logger.warning("LLAMACPP_RERANKER_URL 未配置")
return False
try:
service = LocalLlamaCppRerankService(
base_url=LLAMACPP_RERANKER_URL,
api_key=LLAMACPP_API_KEY,
model=self._model
)
test_scores = service.compute_scores("test query", ["test document"])
logger.info(f"本地 llama.cpp 重排服务可用")
return True
except Exception as e:
logger.warning(f"本地 llama.cpp 重排服务不可用: {e}")
return False
def get_service(self) -> BaseRerankService:
"""
获取本地 llama.cpp 重排服务
"""
if self._service_instance is None:
self._service_instance = LocalLlamaCppRerankService(
base_url=LLAMACPP_RERANKER_URL,
api_key=LLAMACPP_API_KEY,
model=self._model
)
return self._service_instance
class ZhipuRerankProvider(BaseServiceProvider[BaseRerankService]):
"""
智谱 API 重排服务提供者
"""
def __init__(self, model: str | None = None):
super().__init__("zhipu_rerank")
self._model = model or ZHIPU_RERANK_MODEL
def is_available(self) -> bool:
"""
检查智谱 API 重排服务是否可用
"""
if not ZHIPUAI_API_KEY:
logger.warning("ZHIPUAI_API_KEY 未配置")
return False
try:
service = ZhipuRerankService(model=self._model)
test_scores = service.compute_scores("test query", ["test document"])
logger.info(f"智谱重排服务可用")
return True
except ImportError:
logger.warning("zhipuai 库未安装")
return False
except Exception as e:
logger.warning(f"智谱重排服务不可用: {e}")
return False
def get_service(self) -> BaseRerankService:
"""
获取智谱 API 重排服务
"""
if self._service_instance is None:
self._service_instance = ZhipuRerankService(model=self._model)
return self._service_instance
def get_rerank_service() -> BaseRerankService:
"""
获取重排服务(带自动降级)- 纯服务层
降级链: Local llama.cpp -> Zhipu Rerank -> LLM Fallback
Returns:
BaseRerankService: 重排服务实例
"""
def _create_chain():
primary = LocalLlamaCppRerankProvider()
fallbacks = [ZhipuRerankProvider(), LLMFallbackRerankProvider()]
return FallbackServiceChain(primary, fallbacks)
chain = SingletonServiceManager.get_or_create("rerank_service_chain", _create_chain)
return chain.get_available_service()