68 lines
2.1 KiB
Python
68 lines
2.1 KiB
Python
"""
|
|
Embedding model wrapper for llama.cpp service.
|
|
"""
|
|
|
|
import os
|
|
from typing import List, Optional
|
|
from urllib.parse import urljoin
|
|
|
|
from langchain_openai import OpenAIEmbeddings
|
|
|
|
|
|
class LlamaCppEmbedder:
|
|
"""Wrapper for llama.cpp embedding service via OpenAI-compatible API."""
|
|
|
|
def __init__(
|
|
self,
|
|
base_url: Optional[str] = None,
|
|
api_key: Optional[str] = None,
|
|
model: str = "embeddinggemma-300M-Q8_0",
|
|
):
|
|
self.base_url = base_url or os.getenv("LLAMACPP_EMBEDDING_URL", "http://127.0.0.1:8082")
|
|
self.api_key = api_key or os.getenv("LLAMACPP_API_KEY", "")
|
|
self.model = model
|
|
|
|
# Ensure URL ends with /v1
|
|
self.base_url = urljoin(self.base_url.rstrip("/") + "/", "v1")
|
|
|
|
def as_langchain_embeddings(self) -> OpenAIEmbeddings:
|
|
"""Create LangChain OpenAIEmbeddings instance."""
|
|
return OpenAIEmbeddings(
|
|
openai_api_base=self.base_url,
|
|
openai_api_key=self.api_key,
|
|
model=self.model,
|
|
)
|
|
|
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
"""Embed a list of documents."""
|
|
emb = self.as_langchain_embeddings()
|
|
return emb.embed_documents(texts)
|
|
|
|
def embed_query(self, text: str) -> List[float]:
|
|
"""Embed a single query."""
|
|
emb = self.as_langchain_embeddings()
|
|
return emb.embed_query(text)
|
|
|
|
def get_embedding_dimension(self) -> int:
|
|
"""Get embedding dimension by embedding a test string."""
|
|
test_embedding = self.embed_query("test")
|
|
return len(test_embedding)
|
|
|
|
|
|
class MockEmbedder:
|
|
"""Mock embedder for testing without a real service."""
|
|
|
|
def __init__(self, dimension: int = 768):
|
|
self.dimension = dimension
|
|
|
|
def as_langchain_embeddings(self) -> OpenAIEmbeddings:
|
|
raise NotImplementedError("MockEmbedder cannot be used as LangChain embeddings")
|
|
|
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
return [[0.0] * self.dimension for _ in texts]
|
|
|
|
def embed_query(self, text: str) -> List[float]:
|
|
return [0.0] * self.dimension
|
|
|
|
def get_embedding_dimension(self) -> int:
|
|
return self.dimension |