""" 重排模型服务模块 本模块提供统一的重排模型服务获取接口,支持自动降级: 1. 优先使用本地 llama.cpp 重排服务 2. 本地服务不可用时,自动降级到智谱 API 重排服务 主要功能: - LocalLlamaCppRerankProvider:本地 llama.cpp 重排服务提供者 - ZhipuRerankProvider:智谱 API 重排服务提供者 - get_rerank_service():获取重排服务的统一接口 注意:本模块只负责调用 rerank server,不包含业务逻辑(文档处理、排序、top_n) 业务逻辑放在 backend/app/rag/ 目录下 """ import logging from typing import List import httpx from .base import ( BaseServiceProvider, FallbackServiceChain, SingletonServiceManager ) from app.config import ( LLAMACPP_RERANKER_URL, LLAMACPP_API_KEY, ZHIPUAI_API_KEY, ZHIPU_RERANK_MODEL, ZHIPU_API_BASE ) logger = logging.getLogger(__name__) class BaseRerankService: """ 重排服务基类 - 纯服务层,只负责调用 server 不包含业务逻辑(文档处理、排序、top_n 等在 rag/ 目录下) """ def compute_scores(self, query: str, documents: List[str]) -> List[float]: """ 计算每个文档与查询的相关性得分 - 纯 API 调用 Args: query: 查询字符串 documents: 文档字符串列表 Returns: List[float]: 每个文档的相关性得分列表 """ raise NotImplementedError class LocalLlamaCppRerankService(BaseRerankService): """ 本地 llama.cpp 重排服务 - 纯服务层 """ def __init__(self, base_url: str, api_key: str, model: str = "bge-reranker-v2-m3"): self.base_url = base_url self.api_key = api_key self.model = model def compute_scores(self, query: str, documents: List[str]) -> List[float]: """ 调用 llama.cpp rerank API 计算得分 - 纯 API 调用 """ if not documents: return [] headers = {"Content-Type": "application/json"} if self.api_key: headers["Authorization"] = f"Bearer {self.api_key}" base = self.base_url.rstrip("/") if not base.endswith("/v1"): base = base + "/v1" payload = { "model": self.model, "query": query, "documents": documents, } with httpx.Client(timeout=120) as client: response = client.post( f"{base}/rerank", headers=headers, json=payload, ) response.raise_for_status() data = response.json() if isinstance(data, dict) and "results" in data: results = data["results"] results_sorted = sorted(results, key=lambda x: x["index"]) return [item["relevance_score"] for item in results_sorted] else: raise ValueError(f"未知的 rerank API 响应格式: {data}") class ZhipuRerankService(BaseRerankService): """ 智谱 API 重排服务 - 纯服务层 """ def __init__(self, model: str | None = None): self.model = model or ZHIPU_RERANK_MODEL self.api_key = ZHIPUAI_API_KEY def compute_scores(self, query: str, documents: List[str]) -> List[float]: """ 调用智谱 rerank API 计算得分 - 纯 API 调用 """ if not documents: return [] try: from zhipuai import ZhipuAI client = ZhipuAI(api_key=self.api_key) response = client.rerank.create( model=self.model, query=query, documents=documents, ) results_sorted = sorted(response.results, key=lambda x: x.index) return [item.relevance_score for item in results_sorted] except Exception as e: logger.warning(f"智谱 rerank 调用失败: {e}") raise class LLMFallbackRerankService(BaseRerankService): """ 使用 LLM 作为最后的降级方案进行重排 通过让 LLM 评估文档相关性并给出分数 """ def __init__(self, llm=None): from .chat_services import get_chat_service self.llm = llm or get_chat_service() def compute_scores(self, query: str, documents: List[str]) -> List[float]: """ 使用 LLM 评估文档相关性并打分 """ if not documents: return [] scores = [] for doc in documents: score = self._score_single_document(query, doc) scores.append(score) return scores def _score_single_document(self, query: str, document: str) -> float: """ 让 LLM 为单个文档的相关性打分 (0.0-1.0) """ prompt = f"""你是一个文档相关性评分专家。请评估以下文档与查询的相关性,返回一个0到1之间的分数: - 1.0表示完全相关 - 0.0表示完全不相关 查询: {query} 文档: {document} 请只返回一个数字,不要解释。""" try: result = self.llm.invoke(prompt) content = result.content if hasattr(result, 'content') else str(result) # 尝试提取数字 import re match = re.search(r'(\d+\.?\d*)', content) if match: score = float(match.group(1)) # 确保在 0-1 之间 return max(0.0, min(1.0, score)) # 如果没有找到数字,返回0.5作为默认值 return 0.5 except Exception as e: logger.warning(f"LLM 打分失败,返回默认分数 0.5: {e}") return 0.5 class LLMFallbackRerankProvider(BaseServiceProvider[BaseRerankService]): """ LLM 降级重排服务提供者 """ def __init__(self, llm=None): super().__init__("llm_fallback_rerank") self._llm = llm def is_available(self) -> bool: """ LLM 降级方案总是可用(只要 LLM 服务可用) """ try: from .chat_services import get_chat_service get_chat_service() logger.info("LLM 降级重排服务可用") return True except Exception as e: logger.warning(f"LLM 降级重排服务不可用: {e}") return False def get_service(self) -> BaseRerankService: """ 获取 LLM 降级重排服务 """ if self._service_instance is None: self._service_instance = LLMFallbackRerankService(self._llm) return self._service_instance class LocalLlamaCppRerankProvider(BaseServiceProvider[BaseRerankService]): """ 本地 llama.cpp 重排服务提供者 """ def __init__(self, model: str = "bge-reranker-v2-m3"): super().__init__("local_llamacpp_rerank") self._model = model def is_available(self) -> bool: """ 检查本地 llama.cpp 重排服务是否可用 """ if not LLAMACPP_RERANKER_URL: logger.warning("LLAMACPP_RERANKER_URL 未配置") return False try: service = LocalLlamaCppRerankService( base_url=LLAMACPP_RERANKER_URL, api_key=LLAMACPP_API_KEY, model=self._model ) test_scores = service.compute_scores("test query", ["test document"]) logger.info(f"本地 llama.cpp 重排服务可用") return True except Exception as e: logger.warning(f"本地 llama.cpp 重排服务不可用: {e}") return False def get_service(self) -> BaseRerankService: """ 获取本地 llama.cpp 重排服务 """ if self._service_instance is None: self._service_instance = LocalLlamaCppRerankService( base_url=LLAMACPP_RERANKER_URL, api_key=LLAMACPP_API_KEY, model=self._model ) return self._service_instance class ZhipuRerankProvider(BaseServiceProvider[BaseRerankService]): """ 智谱 API 重排服务提供者 """ def __init__(self, model: str | None = None): super().__init__("zhipu_rerank") self._model = model or ZHIPU_RERANK_MODEL def is_available(self) -> bool: """ 检查智谱 API 重排服务是否可用 """ if not ZHIPUAI_API_KEY: logger.warning("ZHIPUAI_API_KEY 未配置") return False try: service = ZhipuRerankService(model=self._model) test_scores = service.compute_scores("test query", ["test document"]) logger.info(f"智谱重排服务可用") return True except ImportError: logger.warning("zhipuai 库未安装") return False except Exception as e: logger.warning(f"智谱重排服务不可用: {e}") return False def get_service(self) -> BaseRerankService: """ 获取智谱 API 重排服务 """ if self._service_instance is None: self._service_instance = ZhipuRerankService(model=self._model) return self._service_instance def get_rerank_service() -> BaseRerankService: """ 获取重排服务(带自动降级)- 纯服务层 降级链: Local llama.cpp -> Zhipu Rerank -> LLM Fallback Returns: BaseRerankService: 重排服务实例 """ def _create_chain(): primary = LocalLlamaCppRerankProvider() fallbacks = [ZhipuRerankProvider(), LLMFallbackRerankProvider()] return FallbackServiceChain(primary, fallbacks) chain = SingletonServiceManager.get_or_create("rerank_service_chain", _create_chain) return chain.get_available_service()