All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 12m9s
## 核心改动 ### 1. 单图方案重构 - 删除了多图(self.graphs),改为单图(self.graph) - 新增 MainGraphState.current_model 字段用于运行时注入模型 - llm_call 节点改为动态选择模型(create_dynamic_llm_call_node) ### 2. chat_services 优化 - 添加 _cached_services 缓存,避免重复初始化 - 新增 get_cached_chat_services() 函数,用于单图注入 - 新增 _check_http_service_available() 统一HTTP探测逻辑 - 减少重复代码,LocalVLLMChatProvider和LocalSmallModelProvider共用探测方法 ### 3. AIAgentService 重构 - initialize() 只构建一次图,传入 chat_services 字典 - 新增 _resolve_model() 模型回退逻辑 - 新增 _build_invocation() 统一构建调用参数 - process_message() 和 process_message_stream() 改为注入 current_model - 流式处理代码拆分,增加可读性 ### 4. 新增和删除文件 - 新增:backend/app/main_graph/main_graph_builder.py(图构建) - 新增:backend/app/main_graph/subgraph_wrapper.py(子图封装) - 新增:tools/test/test_tavily_search.py(测试) - 删除:backend/app/main_graph/graph.py(旧图) - 删除:backend/app/main_graph/utils/main_graph_builder.py(旧构建器) - 删除:backend/app/main_graph/utils/__init__.py ### 5. 其他更新 - README.md:新增模型服务使用情况详解章节 - backend/app/model_services/__init__.py:新增 get_cached_chat_services 导出 ## 方案优势 - 内存优化:N张图 → 1张图 - 灵活性:运行时动态选择模型,支持同会话不同模型 - 性能:模型服务缓存,初始化仅一次 - 可维护性:减少重复代码,统一HTTP探测逻辑
355 lines
12 KiB
Python
355 lines
12 KiB
Python
"""
|
||
生成式大模型服务模块
|
||
|
||
本模块提供统一的生成式大模型服务获取接口,支持多种模型:
|
||
1. Local VLLM 服务:本地 gemma-4-E4B-it 模型
|
||
2. Zhipu AI:智谱 glm-5.1 模型
|
||
3. DeepSeek:deepseek-v4-pro 模型
|
||
|
||
主要功能:
|
||
- LocalVLLMChatProvider:本地 VLLM 服务提供者
|
||
- ZhipuChatProvider:智谱 API 服务提供者
|
||
- DeepSeekChatProvider:DeepSeek API 服务提供者
|
||
- get_chat_service():获取默认服务(带自动降级)
|
||
- get_all_chat_services():获取所有可用模型服务(用于多模型切换)
|
||
"""
|
||
|
||
import logging
|
||
from typing import Dict, Callable
|
||
from langchain_core.language_models import BaseChatModel
|
||
|
||
from .base import (
|
||
BaseServiceProvider,
|
||
FallbackServiceChain,
|
||
SingletonServiceManager
|
||
)
|
||
from app.config import (
|
||
VLLM_BASE_URL,
|
||
LLM_API_KEY,
|
||
ZHIPUAI_API_KEY,
|
||
DEEPSEEK_API_KEY,
|
||
LOCAL_MODEL_NAME
|
||
)
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 缓存已初始化的模型字典
|
||
_cached_services: Dict[str, BaseChatModel] | None = None
|
||
|
||
|
||
def _check_http_service_available(base_url: str, api_key: str = "", timeout: float = 2.0) -> bool:
|
||
"""通过探测 /models 端点检查 HTTP API 是否可用(内部工具函数)"""
|
||
try:
|
||
import httpx
|
||
client = httpx.Client(base_url=base_url.rstrip('/'), timeout=timeout)
|
||
headers = {"Authorization": f"Bearer {api_key}"} if api_key else {}
|
||
resp = client.get("/models", headers=headers)
|
||
return resp.status_code == 200
|
||
except Exception:
|
||
return False
|
||
|
||
|
||
class LocalVLLMChatProvider(BaseServiceProvider[BaseChatModel]):
|
||
"""
|
||
本地 VLLM 生成式大模型服务提供者
|
||
"""
|
||
|
||
def __init__(self, model: str = None):
|
||
super().__init__("local_vllm_chat")
|
||
self._model = model or LOCAL_MODEL_NAME
|
||
|
||
def is_available(self) -> bool:
|
||
"""
|
||
检查本地 VLLM 服务是否可用
|
||
|
||
Returns:
|
||
bool: 服务是否可用
|
||
"""
|
||
if not VLLM_BASE_URL:
|
||
logger.warning("VLLM_BASE_URL 未配置")
|
||
return False
|
||
|
||
# 使用统一的 HTTP 探测方法
|
||
return _check_http_service_available(VLLM_BASE_URL, LLM_API_KEY, timeout=2.0)
|
||
|
||
def get_service(self) -> BaseChatModel:
|
||
"""
|
||
获取本地 VLLM 服务
|
||
|
||
Returns:
|
||
BaseChatModel: LangChain 兼容的 ChatModel 实例
|
||
"""
|
||
if self._service_instance is None:
|
||
from langchain_openai import ChatOpenAI
|
||
from pydantic import SecretStr
|
||
|
||
self._service_instance = ChatOpenAI(
|
||
base_url=VLLM_BASE_URL,
|
||
api_key=SecretStr(LLM_API_KEY) if LLM_API_KEY else SecretStr(""),
|
||
model=self._model,
|
||
timeout=60.0,
|
||
max_retries=2,
|
||
streaming=True,
|
||
)
|
||
return self._service_instance
|
||
|
||
|
||
class ZhipuChatProvider(BaseServiceProvider[BaseChatModel]):
|
||
"""
|
||
智谱 AI 生成式大模型服务提供者
|
||
"""
|
||
|
||
def __init__(self, model: str = "glm-5.1"):
|
||
super().__init__("zhipu_chat")
|
||
self._model = model
|
||
|
||
def is_available(self) -> bool:
|
||
"""
|
||
检查智谱 AI 服务是否可用
|
||
|
||
Returns:
|
||
bool: 服务是否可用
|
||
"""
|
||
if not ZHIPUAI_API_KEY:
|
||
logger.warning("ZHIPUAI_API_KEY 未配置")
|
||
return False
|
||
|
||
try:
|
||
logger.info(f"智谱 AI 服务配置正确,准备使用: {self._model}")
|
||
return True
|
||
except Exception as e:
|
||
logger.warning(f"智谱 AI 服务不可用: {e}")
|
||
return False
|
||
|
||
def get_service(self) -> BaseChatModel:
|
||
"""
|
||
获取智谱 AI 服务
|
||
|
||
Returns:
|
||
BaseChatModel: LangChain 兼容的 ChatModel 实例
|
||
"""
|
||
if self._service_instance is None:
|
||
from langchain_community.chat_models import ChatZhipuAI
|
||
|
||
self._service_instance = ChatZhipuAI(
|
||
model=self._model,
|
||
api_key=ZHIPUAI_API_KEY,
|
||
temperature=0.1,
|
||
max_tokens=4096,
|
||
timeout=120.0,
|
||
max_retries=3,
|
||
streaming=True,
|
||
)
|
||
return self._service_instance
|
||
|
||
|
||
class DeepSeekChatProvider(BaseServiceProvider[BaseChatModel]):
|
||
"""
|
||
DeepSeek 生成式大模型服务提供者
|
||
"""
|
||
|
||
def __init__(self, model: str = "deepseek-v4-pro"):
|
||
super().__init__("deepseek_chat")
|
||
self._model = model
|
||
|
||
def is_available(self) -> bool:
|
||
"""
|
||
检查 DeepSeek 服务是否可用
|
||
|
||
Returns:
|
||
bool: 服务是否可用
|
||
"""
|
||
if not DEEPSEEK_API_KEY:
|
||
logger.warning("DEEPSEEK_API_KEY 未配置")
|
||
return False
|
||
|
||
try:
|
||
logger.info(f"DeepSeek 服务配置正确,准备使用: {self._model}")
|
||
return True
|
||
except Exception as e:
|
||
logger.warning(f"DeepSeek 服务不可用: {e}")
|
||
return False
|
||
|
||
def get_service(self) -> BaseChatModel:
|
||
"""
|
||
获取 DeepSeek 服务
|
||
|
||
Returns:
|
||
BaseChatModel: LangChain 兼容的 ChatModel 实例
|
||
"""
|
||
if self._service_instance is None:
|
||
from langchain_openai import ChatOpenAI
|
||
from pydantic import SecretStr
|
||
|
||
self._service_instance = ChatOpenAI(
|
||
base_url="https://api.deepseek.com",
|
||
api_key=SecretStr(DEEPSEEK_API_KEY),
|
||
model=self._model,
|
||
temperature=0.1,
|
||
max_tokens=4096,
|
||
timeout=60.0,
|
||
max_retries=2,
|
||
streaming=True,
|
||
)
|
||
return self._service_instance
|
||
|
||
|
||
# ========== 轻量级模型 Provider ==========
|
||
|
||
class LocalSmallModelProvider(BaseServiceProvider[BaseChatModel]):
|
||
"""
|
||
本地轻量级模型服务提供者(用于查询改写、意图分类等简单任务)
|
||
使用小模型独立配置
|
||
"""
|
||
|
||
def __init__(self, model: str = None):
|
||
from app.config import SMALL_LOCAL_MODEL_NAME, SMALL_VLLM_BASE_URL, SMALL_LLM_API_KEY
|
||
super().__init__("local_small")
|
||
self._model = model or SMALL_LOCAL_MODEL_NAME
|
||
self._base_url = SMALL_VLLM_BASE_URL
|
||
self._api_key = SMALL_LLM_API_KEY
|
||
|
||
def is_available(self) -> bool:
|
||
"""检查本地小模型服务是否可用"""
|
||
if not self._base_url:
|
||
logger.warning("SMALL_VLLM_BASE_URL 未配置,本地小模型不可用")
|
||
return False
|
||
|
||
# 使用统一的 HTTP 探测方法
|
||
return _check_http_service_available(self._base_url, self._api_key, timeout=2.0)
|
||
|
||
def get_service(self) -> BaseChatModel:
|
||
"""获取本地小模型服务"""
|
||
if self._service_instance is None:
|
||
from langchain_openai import ChatOpenAI
|
||
from pydantic import SecretStr
|
||
|
||
self._service_instance = ChatOpenAI(
|
||
base_url=self._base_url,
|
||
api_key=SecretStr(self._api_key) if self._api_key else SecretStr(""),
|
||
model=self._model,
|
||
timeout=30.0,
|
||
max_retries=2,
|
||
streaming=False,
|
||
)
|
||
return self._service_instance
|
||
|
||
|
||
class DeepSeekSmallModelProvider(BaseServiceProvider[BaseChatModel]):
|
||
"""
|
||
DeepSeek 轻量级模型服务提供者(用于查询改写、意图分类等简单任务)
|
||
使用小模型独立配置
|
||
"""
|
||
|
||
def __init__(self, model: str = None):
|
||
from app.config import SMALL_DEEPSEEK_MODEL, SMALL_DEEPSEEK_API_KEY, SMALL_DEEPSEEK_API_BASE
|
||
super().__init__("deepseek_small")
|
||
self._model = model or SMALL_DEEPSEEK_MODEL
|
||
self._api_key = SMALL_DEEPSEEK_API_KEY
|
||
self._api_base = SMALL_DEEPSEEK_API_BASE
|
||
|
||
def is_available(self) -> bool:
|
||
if not self._api_key:
|
||
logger.warning("SMALL_DEEPSEEK_API_KEY 未配置")
|
||
return False
|
||
logger.info(f"DeepSeek 轻量模型配置正确: {self._model}")
|
||
return True
|
||
|
||
def get_service(self) -> BaseChatModel:
|
||
if self._service_instance is None:
|
||
from langchain_openai import ChatOpenAI
|
||
from pydantic import SecretStr
|
||
|
||
self._service_instance = ChatOpenAI(
|
||
base_url=self._api_base,
|
||
api_key=SecretStr(self._api_key),
|
||
model=self._model,
|
||
temperature=0.1,
|
||
max_tokens=2048,
|
||
timeout=30.0,
|
||
max_retries=2,
|
||
streaming=False,
|
||
)
|
||
return self._service_instance
|
||
|
||
|
||
# 全局服务映射表 - 名称 -> Provider
|
||
CHAT_PROVIDERS: Dict[str, Callable[[], BaseServiceProvider[BaseChatModel]]] = {
|
||
"local": lambda: LocalVLLMChatProvider(),
|
||
"zhipu": lambda: ZhipuChatProvider(),
|
||
"deepseek": lambda: DeepSeekChatProvider(),
|
||
}
|
||
|
||
|
||
def get_chat_service() -> BaseChatModel:
|
||
"""
|
||
获取默认的生成式大模型服务(带自动降级)
|
||
优先顺序: local -> zhipu -> deepseek
|
||
|
||
Returns:
|
||
BaseChatModel: LangChain 兼容的 ChatModel 实例
|
||
"""
|
||
def _create_chain():
|
||
primary = LocalVLLMChatProvider()
|
||
fallbacks = [ZhipuChatProvider(), DeepSeekChatProvider()]
|
||
return FallbackServiceChain(primary, fallbacks)
|
||
|
||
chain = SingletonServiceManager.get_or_create("chat_service_chain", _create_chain)
|
||
return chain.get_available_service()
|
||
|
||
|
||
def _init_chat_services() -> Dict[str, BaseChatModel]:
|
||
"""实际初始化所有可用模型(仅在首次调用)"""
|
||
services = {}
|
||
|
||
for name, provider_factory in CHAT_PROVIDERS.items():
|
||
try:
|
||
provider = provider_factory()
|
||
if provider.is_available():
|
||
services[name] = provider.get_service()
|
||
logger.info(f"已加载模型: {name}")
|
||
except Exception as e:
|
||
logger.warning(f"模型 {name} 初始化失败: {e}")
|
||
|
||
if not services:
|
||
raise RuntimeError(f"没有可用的生成式大模型,尝试了: {list(CHAT_PROVIDERS.keys())}")
|
||
|
||
return services
|
||
|
||
|
||
def get_cached_chat_services() -> Dict[str, BaseChatModel]:
|
||
"""获取缓存的可用模型字典(用于单图动态注入)"""
|
||
global _cached_services
|
||
if _cached_services is None:
|
||
_cached_services = _init_chat_services()
|
||
return _cached_services
|
||
|
||
|
||
def get_all_chat_services() -> Dict[str, BaseChatModel]:
|
||
"""
|
||
获取所有可用的生成式大模型服务(用于多模型切换,保留兼容性)
|
||
新代码请使用 get_cached_chat_services() 获取缓存版本
|
||
|
||
Returns:
|
||
Dict[str, BaseChatModel]: 模型名称 -> ChatModel 实例 的字典
|
||
"""
|
||
return get_cached_chat_services()
|
||
|
||
|
||
def get_small_llm_service() -> BaseChatModel:
|
||
"""
|
||
获取轻量级大模型服务(用于查询改写、意图分类等简单任务)
|
||
优先顺序: 本地模型 -> DeepSeek 小模型
|
||
⚠️ 注意:小模型任务不降级到大模型,避免不必要的 token 消耗!
|
||
|
||
Returns:
|
||
BaseChatModel: LangChain 兼容的 ChatModel 实例
|
||
"""
|
||
def _create_small_chain():
|
||
primary = LocalSmallModelProvider()
|
||
fallbacks = [DeepSeekSmallModelProvider()]
|
||
return FallbackServiceChain(primary, fallbacks)
|
||
|
||
chain = SingletonServiceManager.get_or_create("small_llm_chain", _create_small_chain)
|
||
return chain.get_available_service()
|