""" 嵌入模型包装器,用于 llama.cpp 服务。 """ import os import httpx from typing import List, Optional from urllib.parse import urljoin from langchain_core.embeddings import Embeddings class LlamaCppEmbedder: """通过 OpenAI 兼容 API 封装 llama.cpp 嵌入服务。""" def __init__( self, base_url: Optional[str] = None, api_key: Optional[str] = None, model: str = "embeddinggemma-300M-Q8_0", ): self.base_url = base_url or os.getenv("LLAMACPP_EMBEDDING_URL", "http://127.0.0.1:8082") self.api_key = api_key or os.getenv("LLAMACPP_API_KEY", "") self.model = model def as_langchain_embeddings(self) -> Embeddings: """创建 LangChain 兼容的嵌入实例。""" return _LlamaCppLangchainAdapter(self) def embed_documents(self, texts: List[str]) -> List[List[float]]: """嵌入一批文档。""" return self._call_embedding_api(texts) def embed_query(self, text: str) -> List[float]: """嵌入单个查询。""" return self._call_embedding_api([text])[0] def get_embedding_dimension(self) -> int: """通过嵌入测试字符串获取嵌入维度。""" test_embedding = self.embed_query("test") return len(test_embedding) def _call_embedding_api(self, texts: List[str]) -> List[List[float]]: """直接调用 llama.cpp 嵌入 API。""" base = self.base_url.rstrip("/") if not base.endswith("/v1"): base = base + "/v1" headers = {"Content-Type": "application/json"} if self.api_key: headers["Authorization"] = f"Bearer {self.api_key}" payload = { "input": texts, "model": self.model, } with httpx.Client(timeout=120) as client: response = client.post( f"{base}/embeddings", headers=headers, json=payload, ) response.raise_for_status() data = response.json() # 处理不同响应格式 if isinstance(data, list): # llama.cpp 直接返回列表 return [item["embedding"] for item in data] elif isinstance(data, dict) and "data" in data: # OpenAI 标准格式 return [item["embedding"] for item in sorted(data["data"], key=lambda x: x["index"])] else: raise ValueError(f"未知的嵌入 API 响应格式: {data}") class _LlamaCppLangchainAdapter(Embeddings): """将 LlamaCppEmbedder 适配为 LangChain Embeddings 接口。""" def __init__(self, embedder: LlamaCppEmbedder): self._embedder = embedder def embed_documents(self, texts: List[str]) -> List[List[float]]: return self._embedder.embed_documents(texts) def embed_query(self, text: str) -> List[float]: return self._embedder.embed_query(text)