This commit is contained in:
68
rag_indexer/embedders.py
Normal file
68
rag_indexer/embedders.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""
|
||||
Embedding model wrapper for llama.cpp service.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import List, Optional
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
|
||||
|
||||
class LlamaCppEmbedder:
|
||||
"""Wrapper for llama.cpp embedding service via OpenAI-compatible API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
model: str = "embeddinggemma-300M-Q8_0",
|
||||
):
|
||||
self.base_url = base_url or os.getenv("LLAMACPP_EMBEDDING_URL", "http://127.0.0.1:8082")
|
||||
self.api_key = api_key or os.getenv("LLAMACPP_API_KEY", "")
|
||||
self.model = model
|
||||
|
||||
# Ensure URL ends with /v1
|
||||
self.base_url = urljoin(self.base_url.rstrip("/") + "/", "v1")
|
||||
|
||||
def as_langchain_embeddings(self) -> OpenAIEmbeddings:
|
||||
"""Create LangChain OpenAIEmbeddings instance."""
|
||||
return OpenAIEmbeddings(
|
||||
openai_api_base=self.base_url,
|
||||
openai_api_key=self.api_key,
|
||||
model=self.model,
|
||||
)
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed a list of documents."""
|
||||
emb = self.as_langchain_embeddings()
|
||||
return emb.embed_documents(texts)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Embed a single query."""
|
||||
emb = self.as_langchain_embeddings()
|
||||
return emb.embed_query(text)
|
||||
|
||||
def get_embedding_dimension(self) -> int:
|
||||
"""Get embedding dimension by embedding a test string."""
|
||||
test_embedding = self.embed_query("test")
|
||||
return len(test_embedding)
|
||||
|
||||
|
||||
class MockEmbedder:
|
||||
"""Mock embedder for testing without a real service."""
|
||||
|
||||
def __init__(self, dimension: int = 768):
|
||||
self.dimension = dimension
|
||||
|
||||
def as_langchain_embeddings(self) -> OpenAIEmbeddings:
|
||||
raise NotImplementedError("MockEmbedder cannot be used as LangChain embeddings")
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
return [[0.0] * self.dimension for _ in texts]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
return [0.0] * self.dimension
|
||||
|
||||
def get_embedding_dimension(self) -> int:
|
||||
return self.dimension
|
||||
Reference in New Issue
Block a user