Files
ailine/backend/app/model_services/embedding_services.py
root 8db63e7a8d 重构:添加模型服务模块,支持嵌入和重排服务的自动降级
新增功能:
- 创建 app/model_services 模块,提供统一的模型服务获取接口
- 实现 BaseServiceProvider 基类和 FallbackServiceChain 降级链
- 实现 get_embedding_service():优先本地 llama.cpp,降级到智谱 API
- 实现 get_rerank_service():优先本地 llama.cpp,降级到智谱 API
- 支持单例管理,确保全局只有一个服务实例

修改内容:
- 更新 app/config.py,添加智谱 API 相关配置
- 修改 rag_core/vector_store.py:支持接受外部传入的 embeddings
- 修改 rag_core/retriever_factory.py:支持接受外部传入的 embeddings
- 修改 app/agent/rag_initializer.py:使用 get_embedding_service()
- 修改 app/rag/pipeline.py:使用 get_rerank_service()
- 修改 app/memory/mem0_client.py:智能判断可用服务配置 mem0
- 修改 rag_indexer/index_builder.py:支持使用新服务,保持向后兼容
- 修改 rag_indexer/config.py:添加智谱配置

环境变量:
- ZHIPUAI_API_KEY:智谱 API 密钥(必选)
- ZHIPU_EMBEDDING_MODEL:可选,默认 embedding-3
- ZHIPU_RERANK_MODEL:可选,默认 rerank-2
- ZHIPU_API_BASE:可选,默认 https://open.bigmodel.cn/api/paas/v4
2026-04-24 22:52:36 +08:00

214 lines
6.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
嵌入模型服务模块
本模块提供统一的嵌入模型服务获取接口,支持自动降级:
1. 优先使用本地 llama.cpp 嵌入服务
2. 本地服务不可用时,自动降级到智谱 API 嵌入服务
主要功能:
- LocalLlamaCppEmbeddingProvider本地 llama.cpp 嵌入服务提供者
- ZhipuEmbeddingProvider智谱 API 嵌入服务提供者
- get_embedding_service():获取嵌入服务的统一接口
"""
import logging
from typing import List
import httpx
from langchain_core.embeddings import Embeddings
from .base import (
BaseServiceProvider,
FallbackServiceChain,
SingletonServiceManager
)
from ..config import (
LLAMACPP_EMBEDDING_URL,
LLAMACPP_API_KEY,
ZHIPUAI_API_KEY,
ZHIPU_EMBEDDING_MODEL,
ZHIPU_API_BASE
)
logger = logging.getLogger(__name__)
class LocalLlamaCppEmbeddingProvider(BaseServiceProvider[Embeddings]):
"""
本地 llama.cpp 嵌入服务提供者
"""
def __init__(self, model: str = "Qwen3-Embedding-0.6B-Q8_0"):
super().__init__("local_llamacpp_embedding")
self._model = model
def is_available(self) -> bool:
"""
检查本地 llama.cpp 嵌入服务是否可用
Returns:
bool: 服务是否可用
"""
if not LLAMACPP_EMBEDDING_URL:
logger.warning("LLAMACPP_EMBEDDING_URL 未配置")
return False
try:
# 尝试嵌入一个测试字符串
embedder = LocalLlamaCppEmbedder(model=self._model)
test_embedding = embedder.embed_query("test")
logger.info(f"本地 llama.cpp 嵌入服务可用,维度: {len(test_embedding)}")
return True
except Exception as e:
logger.warning(f"本地 llama.cpp 嵌入服务不可用: {e}")
return False
def get_service(self) -> Embeddings:
"""
获取本地 llama.cpp 嵌入服务
Returns:
Embeddings: LangChain 兼容的嵌入实例
"""
if self._service_instance is None:
embedder = LocalLlamaCppEmbedder(model=self._model)
self._service_instance = embedder.as_langchain_embeddings()
return self._service_instance
class ZhipuEmbeddingProvider(BaseServiceProvider[Embeddings]):
"""
智谱 API 嵌入服务提供者
"""
def __init__(self, model: str | None = None):
super().__init__("zhipu_embedding")
self._model = model or ZHIPU_EMBEDDING_MODEL
def is_available(self) -> bool:
"""
检查智谱 API 嵌入服务是否可用
Returns:
bool: 服务是否可用
"""
if not ZHIPUAI_API_KEY:
logger.warning("ZHIPUAI_API_KEY 未配置")
return False
try:
# 测试智谱 API 是否可用
from zhipuai import ZhipuAI
client = ZhipuAI(api_key=ZHIPUAI_API_KEY)
response = client.embeddings.create(
model=self._model,
input=["test"]
)
logger.info(f"智谱嵌入服务可用,维度: {len(response.data[0].embedding)}")
return True
except ImportError:
logger.warning("zhipuai 库未安装")
return False
except Exception as e:
logger.warning(f"智谱嵌入服务不可用: {e}")
return False
def get_service(self) -> Embeddings:
"""
获取智谱 API 嵌入服务
Returns:
Embeddings: LangChain 兼容的嵌入实例
"""
if self._service_instance is None:
from langchain_zhipu import ZhipuAIEmbeddings
self._service_instance = ZhipuAIEmbeddings(
model=self._model,
api_key=ZHIPUAI_API_KEY
)
return self._service_instance
class LocalLlamaCppEmbedder:
"""
通过 OpenAI 兼容 API 封装 llama.cpp 嵌入服务
"""
def __init__(self, model: str = "Qwen3-Embedding-0.6B-Q8_0"):
self.base_url = LLAMACPP_EMBEDDING_URL
self.api_key = LLAMACPP_API_KEY
self.model = model
def as_langchain_embeddings(self) -> Embeddings:
"""创建 LangChain 兼容的嵌入实例"""
return _LlamaCppLangchainAdapter(self)
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""嵌入一批文档"""
return self._call_embedding_api(texts)
def embed_query(self, text: str) -> List[float]:
"""嵌入单个查询"""
return self._call_embedding_api([text])[0]
def _call_embedding_api(self, texts: List[str]) -> List[List[float]]:
"""直接调用 llama.cpp 嵌入 API"""
headers = {"Content-Type": "application/json"}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
base = self.base_url.rstrip("/")
if not base.endswith("/v1"):
base = base + "/v1"
payload = {
"input": texts,
"model": self.model,
}
with httpx.Client(timeout=120) as client:
response = client.post(
f"{base}/embeddings",
headers=headers,
json=payload,
)
response.raise_for_status()
data = response.json()
if isinstance(data, list):
return [item["embedding"] for item in data]
elif isinstance(data, dict) and "data" in data:
return [item["embedding"] for item in sorted(data["data"], key=lambda x: x["index"])]
else:
raise ValueError(f"未知的嵌入 API 响应格式: {data}")
class _LlamaCppLangchainAdapter(Embeddings):
"""
将 LlamaCppEmbedder 适配为 LangChain Embeddings 接口
"""
def __init__(self, embedder: "LocalLlamaCppEmbedder"):
self._embedder = embedder
def embed_documents(self, texts: List[str]) -> List[List[float]]:
return self._embedder.embed_documents(texts)
def embed_query(self, text: str) -> List[float]:
return self._embedder.embed_query(text)
def get_embedding_service() -> Embeddings:
"""
获取嵌入服务(带自动降级)
Returns:
Embeddings: LangChain 兼容的嵌入实例
"""
def _create_chain():
primary = LocalLlamaCppEmbeddingProvider()
fallback = ZhipuEmbeddingProvider()
return FallbackServiceChain(primary, [fallback])
chain = SingletonServiceManager.get_or_create("embedding_service_chain", _create_chain)
return chain.get_available_service()