""" 重排模型服务模块 本模块提供统一的重排模型服务获取接口,支持自动降级: 1. 优先使用本地 llama.cpp 重排服务 2. 本地服务不可用时,自动降级到智谱 API 重排服务 主要功能: - LocalLlamaCppRerankProvider:本地 llama.cpp 重排服务提供者 - ZhipuRerankProvider:智谱 API 重排服务提供者 - get_rerank_service():获取重排服务的统一接口 """ import logging from typing import List import requests from langchain_core.documents import Document from .base import ( BaseServiceProvider, FallbackServiceChain, SingletonServiceManager ) from ..config import ( LLAMACPP_RERANKER_URL, LLAMACPP_API_KEY, ZHIPUAI_API_KEY, ZHIPU_RERANK_MODEL, ZHIPU_API_BASE ) logger = logging.getLogger(__name__) class BaseReranker: """ 重排器基类,定义统一的接口 """ def compress_documents(self, documents: List[Document], query: str, top_n: int = 5) -> List[Document]: """ 对文档进行重排序 Args: documents: 待排序的文档列表 query: 查询字符串 top_n: 返回前 N 个结果 Returns: 排序后的文档列表 """ raise NotImplementedError class LocalLlamaCppReranker(BaseReranker): """ 使用远程 llama.cpp 服务对检索结果重排序 """ def __init__(self, base_url: str, api_key: str, model: str = "bge-reranker-v2-m3", timeout: int = 60): self.base_url = base_url self.api_key = api_key self.model = model self.timeout = timeout self.endpoint = f"{self.base_url}/rerank" def compress_documents(self, documents: List[Document], query: str, top_n: int = 5) -> List[Document]: """ 对文档进行重排序 """ if not documents: return [] # 准备请求体 payload = { "model": self.model, "query": query, "documents": [doc.page_content for doc in documents], "top_n": top_n } headers = { "Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}" } try: response = requests.post(self.endpoint, json=payload, headers=headers, timeout=self.timeout) response.raise_for_status() results = response.json() # 解析返回结果 sorted_indices = [item["index"] for item in results["results"]] sorted_docs = [documents[idx] for idx in sorted_indices] return sorted_docs except Exception as e: logger.warning(f"远程重排序过程出错,返回原始前 {top_n} 个结果: {e}") return documents[:top_n] class ZhipuReranker(BaseReranker): """ 使用智谱 API 对检索结果重排序 """ def __init__(self, model: str | None = None): self.model = model or ZHIPU_RERANK_MODEL self.api_key = ZHIPUAI_API_KEY def compress_documents(self, documents: List[Document], query: str, top_n: int = 5) -> List[Document]: """ 对文档进行重排序 """ if not documents: return [] try: from zhipuai import ZhipuAI client = ZhipuAI(api_key=self.api_key) response = client.rerank.create( model=self.model, query=query, documents=[doc.page_content for doc in documents], top_n=top_n ) sorted_indices = [item.index for item in response.results] sorted_docs = [documents[idx] for idx in sorted_indices] return sorted_docs except Exception as e: logger.warning(f"智谱重排序过程出错,返回原始前 {top_n} 个结果: {e}") return documents[:top_n] class LocalLlamaCppRerankProvider(BaseServiceProvider[BaseReranker]): """ 本地 llama.cpp 重排服务提供者 """ def __init__(self, model: str = "bge-reranker-v2-m3"): super().__init__("local_llamacpp_rerank") self._model = model def is_available(self) -> bool: """ 检查本地 llama.cpp 重排服务是否可用 """ if not LLAMACPP_RERANKER_URL: logger.warning("LLAMACPP_RERANKER_URL 未配置") return False try: # 测试重排服务 test_docs = [Document(page_content="test document 1"), Document(page_content="test document 2")] reranker = LocalLlamaCppReranker( base_url=LLAMACPP_RERANKER_URL, api_key=LLAMACPP_API_KEY, model=self._model ) result = reranker.compress_documents(test_docs, "test query", top_n=1) logger.info(f"本地 llama.cpp 重排服务可用") return True except Exception as e: logger.warning(f"本地 llama.cpp 重排服务不可用: {e}") return False def get_service(self) -> BaseReranker: """ 获取本地 llama.cpp 重排服务 """ if self._service_instance is None: self._service_instance = LocalLlamaCppReranker( base_url=LLAMACPP_RERANKER_URL, api_key=LLAMACPP_API_KEY, model=self._model ) return self._service_instance class ZhipuRerankProvider(BaseServiceProvider[BaseReranker]): """ 智谱 API 重排服务提供者 """ def __init__(self, model: str | None = None): super().__init__("zhipu_rerank") self._model = model or ZHIPU_RERANK_MODEL def is_available(self) -> bool: """ 检查智谱 API 重排服务是否可用 """ if not ZHIPUAI_API_KEY: logger.warning("ZHIPUAI_API_KEY 未配置") return False try: # 测试重排服务 test_docs = [Document(page_content="test document 1"), Document(page_content="test document 2")] reranker = ZhipuReranker(model=self._model) result = reranker.compress_documents(test_docs, "test query", top_n=1) logger.info(f"智谱重排服务可用") return True except ImportError: logger.warning("zhipuai 库未安装") return False except Exception as e: logger.warning(f"智谱重排服务不可用: {e}") return False def get_service(self) -> BaseReranker: """ 获取智谱 API 重排服务 """ if self._service_instance is None: self._service_instance = ZhipuReranker(model=self._model) return self._service_instance def get_rerank_service() -> BaseReranker: """ 获取重排服务(带自动降级) Returns: BaseReranker: 重排服务实例 """ def _create_chain(): primary = LocalLlamaCppRerankProvider() fallback = ZhipuRerankProvider() return FallbackServiceChain(primary, [fallback]) chain = SingletonServiceManager.get_or_create("rerank_service_chain", _create_chain) return chain.get_available_service()