Files
ailine/backend/app/model_services/chat_services.py
root 1260bef5cb
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m31s
添加rag置信度判断
2026-05-06 01:15:52 +08:00

355 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
生成式大模型服务模块
本模块提供统一的生成式大模型服务获取接口,支持多种模型:
1. Local VLLM 服务:本地 gemma-4-E4B-it 模型
2. Zhipu AI智谱 glm-5.1 模型
3. DeepSeekdeepseek-v4-pro 模型
主要功能:
- LocalVLLMChatProvider本地 VLLM 服务提供者
- ZhipuChatProvider智谱 API 服务提供者
- DeepSeekChatProviderDeepSeek API 服务提供者
- get_chat_service():获取默认服务(带自动降级)
- get_all_chat_services():获取所有可用模型服务(用于多模型切换)
"""
import logging
from typing import Dict, Callable
from langchain_core.language_models import BaseChatModel
from .base import (
BaseServiceProvider,
FallbackServiceChain,
SingletonServiceManager
)
from backend.app.config import (
VLLM_BASE_URL,
LLM_API_KEY,
ZHIPUAI_API_KEY,
DEEPSEEK_API_KEY,
LOCAL_MODEL_NAME
)
logger = logging.getLogger(__name__)
# 缓存已初始化的模型字典
_cached_services: Dict[str, BaseChatModel] | None = None
def _check_http_service_available(base_url: str, api_key: str = "", timeout: float = 2.0) -> bool:
"""通过探测 /models 端点检查 HTTP API 是否可用(内部工具函数)"""
try:
import httpx
client = httpx.Client(base_url=base_url.rstrip('/'), timeout=timeout)
headers = {"Authorization": f"Bearer {api_key}"} if api_key else {}
resp = client.get("/models", headers=headers)
return resp.status_code == 200
except Exception:
return False
class LocalVLLMChatProvider(BaseServiceProvider[BaseChatModel]):
"""
本地 VLLM 生成式大模型服务提供者
"""
def __init__(self, model: str = None):
super().__init__("local_vllm_chat")
self._model = model or LOCAL_MODEL_NAME
def is_available(self) -> bool:
"""
检查本地 VLLM 服务是否可用
Returns:
bool: 服务是否可用
"""
if not VLLM_BASE_URL:
logger.warning("VLLM_BASE_URL 未配置")
return False
# 使用统一的 HTTP 探测方法
return _check_http_service_available(VLLM_BASE_URL, LLM_API_KEY, timeout=2.0)
def get_service(self) -> BaseChatModel:
"""
获取本地 VLLM 服务
Returns:
BaseChatModel: LangChain 兼容的 ChatModel 实例
"""
if self._service_instance is None:
from langchain_openai import ChatOpenAI
from pydantic import SecretStr
self._service_instance = ChatOpenAI(
base_url=VLLM_BASE_URL,
api_key=SecretStr(LLM_API_KEY) if LLM_API_KEY else SecretStr(""),
model=self._model,
timeout=60.0,
max_retries=2,
streaming=True,
)
return self._service_instance
class ZhipuChatProvider(BaseServiceProvider[BaseChatModel]):
"""
智谱 AI 生成式大模型服务提供者
"""
def __init__(self, model: str = "glm-5.1"):
super().__init__("zhipu_chat")
self._model = model
def is_available(self) -> bool:
"""
检查智谱 AI 服务是否可用
Returns:
bool: 服务是否可用
"""
if not ZHIPUAI_API_KEY:
logger.warning("ZHIPUAI_API_KEY 未配置")
return False
try:
logger.info(f"智谱 AI 服务配置正确,准备使用: {self._model}")
return True
except Exception as e:
logger.warning(f"智谱 AI 服务不可用: {e}")
return False
def get_service(self) -> BaseChatModel:
"""
获取智谱 AI 服务
Returns:
BaseChatModel: LangChain 兼容的 ChatModel 实例
"""
if self._service_instance is None:
from langchain_community.chat_models import ChatZhipuAI
self._service_instance = ChatZhipuAI(
model=self._model,
api_key=ZHIPUAI_API_KEY,
temperature=0.1,
max_tokens=4096,
timeout=120.0,
max_retries=3,
streaming=True,
)
return self._service_instance
class DeepSeekChatProvider(BaseServiceProvider[BaseChatModel]):
"""
DeepSeek 生成式大模型服务提供者
"""
def __init__(self, model: str = "deepseek-v4-pro"):
super().__init__("deepseek_chat")
self._model = model
def is_available(self) -> bool:
"""
检查 DeepSeek 服务是否可用
Returns:
bool: 服务是否可用
"""
if not DEEPSEEK_API_KEY:
logger.warning("DEEPSEEK_API_KEY 未配置")
return False
try:
logger.info(f"DeepSeek 服务配置正确,准备使用: {self._model}")
return True
except Exception as e:
logger.warning(f"DeepSeek 服务不可用: {e}")
return False
def get_service(self) -> BaseChatModel:
"""
获取 DeepSeek 服务
Returns:
BaseChatModel: LangChain 兼容的 ChatModel 实例
"""
if self._service_instance is None:
from langchain_openai import ChatOpenAI
from pydantic import SecretStr
self._service_instance = ChatOpenAI(
base_url="https://api.deepseek.com",
api_key=SecretStr(DEEPSEEK_API_KEY),
model=self._model,
temperature=0.1,
max_tokens=4096,
timeout=60.0,
max_retries=2,
streaming=True,
)
return self._service_instance
# ========== 轻量级模型 Provider ==========
class LocalSmallModelProvider(BaseServiceProvider[BaseChatModel]):
"""
本地轻量级模型服务提供者(用于查询改写、意图分类等简单任务)
使用小模型独立配置
"""
def __init__(self, model: str = None):
from backend.app.config import SMALL_LOCAL_MODEL_NAME, SMALL_VLLM_BASE_URL, SMALL_LLM_API_KEY
super().__init__("local_small")
self._model = model or SMALL_LOCAL_MODEL_NAME
self._base_url = SMALL_VLLM_BASE_URL
self._api_key = SMALL_LLM_API_KEY
def is_available(self) -> bool:
"""检查本地小模型服务是否可用"""
if not self._base_url:
logger.warning("SMALL_VLLM_BASE_URL 未配置,本地小模型不可用")
return False
# 使用统一的 HTTP 探测方法
return _check_http_service_available(self._base_url, self._api_key, timeout=2.0)
def get_service(self) -> BaseChatModel:
"""获取本地小模型服务"""
if self._service_instance is None:
from langchain_openai import ChatOpenAI
from pydantic import SecretStr
self._service_instance = ChatOpenAI(
base_url=self._base_url,
api_key=SecretStr(self._api_key) if self._api_key else SecretStr(""),
model=self._model,
timeout=30.0,
max_retries=2,
streaming=False,
)
return self._service_instance
class DeepSeekSmallModelProvider(BaseServiceProvider[BaseChatModel]):
"""
DeepSeek 轻量级模型服务提供者(用于查询改写、意图分类等简单任务)
使用小模型独立配置
"""
def __init__(self, model: str = None):
from backend.app.config import SMALL_DEEPSEEK_MODEL, SMALL_DEEPSEEK_API_KEY, SMALL_DEEPSEEK_API_BASE
super().__init__("deepseek_small")
self._model = model or SMALL_DEEPSEEK_MODEL
self._api_key = SMALL_DEEPSEEK_API_KEY
self._api_base = SMALL_DEEPSEEK_API_BASE
def is_available(self) -> bool:
if not self._api_key:
logger.warning("SMALL_DEEPSEEK_API_KEY 未配置")
return False
logger.info(f"DeepSeek 轻量模型配置正确: {self._model}")
return True
def get_service(self) -> BaseChatModel:
if self._service_instance is None:
from langchain_openai import ChatOpenAI
from pydantic import SecretStr
self._service_instance = ChatOpenAI(
base_url=self._api_base,
api_key=SecretStr(self._api_key),
model=self._model,
temperature=0.1,
max_tokens=2048,
timeout=30.0,
max_retries=2,
streaming=False,
)
return self._service_instance
# 全局服务映射表 - 名称 -> Provider
CHAT_PROVIDERS: Dict[str, Callable[[], BaseServiceProvider[BaseChatModel]]] = {
"local": lambda: LocalVLLMChatProvider(),
"zhipu": lambda: ZhipuChatProvider(),
"deepseek": lambda: DeepSeekChatProvider(),
}
def get_chat_service() -> BaseChatModel:
"""
获取默认的生成式大模型服务(带自动降级)
优先顺序: local -> zhipu -> deepseek
Returns:
BaseChatModel: LangChain 兼容的 ChatModel 实例
"""
def _create_chain():
primary = LocalVLLMChatProvider()
fallbacks = [ZhipuChatProvider(), DeepSeekChatProvider()]
return FallbackServiceChain(primary, fallbacks)
chain = SingletonServiceManager.get_or_create("chat_service_chain", _create_chain)
return chain.get_available_service()
def _init_chat_services() -> Dict[str, BaseChatModel]:
"""实际初始化所有可用模型(仅在首次调用)"""
services = {}
for name, provider_factory in CHAT_PROVIDERS.items():
try:
provider = provider_factory()
if provider.is_available():
services[name] = provider.get_service()
logger.info(f"已加载模型: {name}")
except Exception as e:
logger.warning(f"模型 {name} 初始化失败: {e}")
if not services:
raise RuntimeError(f"没有可用的生成式大模型,尝试了: {list(CHAT_PROVIDERS.keys())}")
return services
def get_cached_chat_services() -> Dict[str, BaseChatModel]:
"""获取缓存的可用模型字典(用于单图动态注入)"""
global _cached_services
if _cached_services is None:
_cached_services = _init_chat_services()
return _cached_services
def get_all_chat_services() -> Dict[str, BaseChatModel]:
"""
获取所有可用的生成式大模型服务(用于多模型切换,保留兼容性)
新代码请使用 get_cached_chat_services() 获取缓存版本
Returns:
Dict[str, BaseChatModel]: 模型名称 -> ChatModel 实例 的字典
"""
return get_cached_chat_services()
def get_small_llm_service() -> BaseChatModel:
"""
获取轻量级大模型服务(用于查询改写、意图分类等简单任务)
优先顺序: 本地模型 -> DeepSeek 小模型
⚠️ 注意:小模型任务不降级到大模型,避免不必要的 token 消耗!
Returns:
BaseChatModel: LangChain 兼容的 ChatModel 实例
"""
def _create_small_chain():
primary = LocalSmallModelProvider()
fallbacks = [DeepSeekSmallModelProvider()]
return FallbackServiceChain(primary, fallbacks)
chain = SingletonServiceManager.get_or_create("small_llm_chain", _create_small_chain)
return chain.get_available_service()