""" 嵌入模型包装器,用于 llama.cpp 服务。 """ import os from .config import LLAMACPP_EMBEDDING_URL, LLAMACPP_API_KEY import httpx from typing import List from langchain_core.embeddings import Embeddings class LlamaCppEmbedder: """通过 OpenAI 兼容 API 封装 llama.cpp 嵌入服务。""" def __init__(self, model: str = "Qwen3-Embedding-0.6B-Q8_0"): """ Args: model: 嵌入模型名称,默认 "Qwen3-Embedding-0.6B-Q8_0"。 """ self.base_url = LLAMACPP_EMBEDDING_URL self.api_key = LLAMACPP_API_KEY self.model = model print(f"初始化 base_url: { self.base_url}") def as_langchain_embeddings(self) -> Embeddings: """创建 LangChain 兼容的嵌入实例。""" return _LlamaCppLangchainAdapter(self) def embed_documents(self, texts: List[str]) -> List[List[float]]: """嵌入一批文档。""" return self._call_embedding_api(texts) def embed_query(self, text: str) -> List[List[float]]: """嵌入单个查询。""" return self._call_embedding_api([text])[0] def get_embedding_dimension(self) -> int: """通过嵌入测试字符串获取嵌入维度。""" test_embedding = self.embed_query("test") return len(test_embedding) def _call_embedding_api(self, texts: List[str]) -> List[List[float]]: """直接调用 llama.cpp 嵌入 API。""" headers = {"Content-Type": "application/json"} if self.api_key: headers["Authorization"] = f"Bearer {self.api_key}" base = self.base_url.rstrip("/") if not base.endswith("/v1"): base = base + "/v1" payload = { "input": texts, "model": self.model, } with httpx.Client(timeout=120) as client: response = client.post( f"{base}/embeddings", headers=headers, json=payload, ) response.raise_for_status() data = response.json() if isinstance(data, list): return [item["embedding"] for item in data] elif isinstance(data, dict) and "data" in data: return [item["embedding"] for item in sorted(data["data"], key=lambda x: x["index"])] else: raise ValueError(f"未知的嵌入 API 响应格式: {data}") class _LlamaCppLangchainAdapter(Embeddings): """将 LlamaCppEmbedder 适配为 LangChain Embeddings 接口。""" def __init__(self, embedder: LlamaCppEmbedder): self._embedder = embedder def embed_documents(self, texts: List[str]) -> List[List[float]]: return self._embedder.embed_documents(texts) def embed_query(self, text: str) -> List[List[float]]: return self._embedder.embed_query(text)