""" 嵌入模型服务模块 本模块提供统一的嵌入模型服务获取接口,支持自动降级: 1. 优先使用本地 llama.cpp 嵌入服务 2. 本地服务不可用时,自动降级到智谱 API 嵌入服务 主要功能: - LocalLlamaCppEmbeddingProvider:本地 llama.cpp 嵌入服务提供者 - ZhipuEmbeddingProvider:智谱 API 嵌入服务提供者 - get_embedding_service():获取嵌入服务的统一接口 """ import logging from typing import List import httpx from langchain_core.embeddings import Embeddings from .base import ( BaseServiceProvider, FallbackServiceChain, SingletonServiceManager ) from ..config import ( LLAMACPP_EMBEDDING_URL, LLAMACPP_API_KEY, ZHIPUAI_API_KEY, ZHIPU_EMBEDDING_MODEL, ZHIPU_API_BASE ) logger = logging.getLogger(__name__) class LocalLlamaCppEmbeddingProvider(BaseServiceProvider[Embeddings]): """ 本地 llama.cpp 嵌入服务提供者 """ def __init__(self, model: str = "Qwen3-Embedding-0.6B-Q8_0"): super().__init__("local_llamacpp_embedding") self._model = model def is_available(self) -> bool: """ 检查本地 llama.cpp 嵌入服务是否可用 Returns: bool: 服务是否可用 """ if not LLAMACPP_EMBEDDING_URL: logger.warning("LLAMACPP_EMBEDDING_URL 未配置") return False try: # 尝试嵌入一个测试字符串 embedder = LocalLlamaCppEmbedder(model=self._model) test_embedding = embedder.embed_query("test") logger.info(f"本地 llama.cpp 嵌入服务可用,维度: {len(test_embedding)}") return True except Exception as e: logger.warning(f"本地 llama.cpp 嵌入服务不可用: {e}") return False def get_service(self) -> Embeddings: """ 获取本地 llama.cpp 嵌入服务 Returns: Embeddings: LangChain 兼容的嵌入实例 """ if self._service_instance is None: embedder = LocalLlamaCppEmbedder(model=self._model) self._service_instance = embedder.as_langchain_embeddings() return self._service_instance class ZhipuEmbeddingProvider(BaseServiceProvider[Embeddings]): """ 智谱 API 嵌入服务提供者 """ def __init__(self, model: str | None = None): super().__init__("zhipu_embedding") self._model = model or ZHIPU_EMBEDDING_MODEL def is_available(self) -> bool: """ 检查智谱 API 嵌入服务是否可用 Returns: bool: 服务是否可用 """ if not ZHIPUAI_API_KEY: logger.warning("ZHIPUAI_API_KEY 未配置") return False try: # 测试智谱 API 是否可用 from zhipuai import ZhipuAI client = ZhipuAI(api_key=ZHIPUAI_API_KEY) response = client.embeddings.create( model=self._model, input=["test"] ) logger.info(f"智谱嵌入服务可用,维度: {len(response.data[0].embedding)}") return True except ImportError: logger.warning("zhipuai 库未安装") return False except Exception as e: logger.warning(f"智谱嵌入服务不可用: {e}") return False def get_service(self) -> Embeddings: """ 获取智谱 API 嵌入服务 Returns: Embeddings: LangChain 兼容的嵌入实例 """ if self._service_instance is None: from langchain_zhipu import ZhipuAIEmbeddings self._service_instance = ZhipuAIEmbeddings( model=self._model, api_key=ZHIPUAI_API_KEY ) return self._service_instance class LocalLlamaCppEmbedder: """ 通过 OpenAI 兼容 API 封装 llama.cpp 嵌入服务 """ def __init__(self, model: str = "Qwen3-Embedding-0.6B-Q8_0"): self.base_url = LLAMACPP_EMBEDDING_URL self.api_key = LLAMACPP_API_KEY self.model = model def as_langchain_embeddings(self) -> Embeddings: """创建 LangChain 兼容的嵌入实例""" return _LlamaCppLangchainAdapter(self) def embed_documents(self, texts: List[str]) -> List[List[float]]: """嵌入一批文档""" return self._call_embedding_api(texts) def embed_query(self, text: str) -> List[float]: """嵌入单个查询""" return self._call_embedding_api([text])[0] def _call_embedding_api(self, texts: List[str]) -> List[List[float]]: """直接调用 llama.cpp 嵌入 API""" headers = {"Content-Type": "application/json"} if self.api_key: headers["Authorization"] = f"Bearer {self.api_key}" base = self.base_url.rstrip("/") if not base.endswith("/v1"): base = base + "/v1" payload = { "input": texts, "model": self.model, } with httpx.Client(timeout=120) as client: response = client.post( f"{base}/embeddings", headers=headers, json=payload, ) response.raise_for_status() data = response.json() if isinstance(data, list): return [item["embedding"] for item in data] elif isinstance(data, dict) and "data" in data: return [item["embedding"] for item in sorted(data["data"], key=lambda x: x["index"])] else: raise ValueError(f"未知的嵌入 API 响应格式: {data}") class _LlamaCppLangchainAdapter(Embeddings): """ 将 LlamaCppEmbedder 适配为 LangChain Embeddings 接口 """ def __init__(self, embedder: "LocalLlamaCppEmbedder"): self._embedder = embedder def embed_documents(self, texts: List[str]) -> List[List[float]]: return self._embedder.embed_documents(texts) def embed_query(self, text: str) -> List[float]: return self._embedder.embed_query(text) def get_embedding_service() -> Embeddings: """ 获取嵌入服务(带自动降级) Returns: Embeddings: LangChain 兼容的嵌入实例 """ def _create_chain(): primary = LocalLlamaCppEmbeddingProvider() fallback = ZhipuEmbeddingProvider() return FallbackServiceChain(primary, [fallback]) chain = SingletonServiceManager.get_or_create("embedding_service_chain", _create_chain) return chain.get_available_service()