refactor: 将生成式大模型提取为服务层架构,移除 llm_factory
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m0s
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m0s
This commit is contained in:
@@ -1,57 +0,0 @@
|
||||
# app/llm_factory.py
|
||||
import os
|
||||
from ..config import ZHIPUAI_API_KEY, DEEPSEEK_API_KEY, VLLM_BASE_URL, LLAMACPP_API_KEY
|
||||
from langchain_community.chat_models import ChatZhipuAI
|
||||
from langchain_openai import ChatOpenAI
|
||||
from pydantic import SecretStr
|
||||
|
||||
class LLMFactory:
|
||||
@staticmethod
|
||||
def create_zhipu():
|
||||
api_key = ZHIPUAI_API_KEY
|
||||
if not api_key:
|
||||
raise ValueError("ZHIPUAI_API_KEY not set")
|
||||
return ChatZhipuAI(
|
||||
model="glm-4.7-flash",
|
||||
api_key=api_key,
|
||||
temperature=0.1,
|
||||
max_tokens=4096,
|
||||
timeout=120.0,
|
||||
max_retries=3,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_deepseek():
|
||||
api_key = DEEPSEEK_API_KEY
|
||||
if not api_key:
|
||||
raise ValueError("DEEPSEEK_API_KEY not set")
|
||||
return ChatOpenAI(
|
||||
base_url="https://api.deepseek.com",
|
||||
api_key=SecretStr(api_key),
|
||||
model="deepseek-reasoner",
|
||||
temperature=0.1,
|
||||
max_tokens=4096,
|
||||
timeout=60.0,
|
||||
max_retries=2,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_local():
|
||||
base_url = VLLM_BASE_URL
|
||||
return ChatOpenAI(
|
||||
base_url=base_url,
|
||||
api_key=SecretStr(LLAMACPP_API_KEY),
|
||||
model="gemma-4-E4B-it",
|
||||
timeout=60.0,
|
||||
max_retries=2,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
# 模型创建器映射
|
||||
CREATORS = {
|
||||
"zhipu": create_zhipu,
|
||||
"local": create_local,
|
||||
"deepseek": create_deepseek,
|
||||
}
|
||||
@@ -8,7 +8,7 @@ import json
|
||||
# 本地模块
|
||||
from ..graph.graph_builder import GraphBuilder, GraphContext
|
||||
from ..graph.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
|
||||
from .llm_factory import LLMFactory
|
||||
from ..model_services.chat_services import get_all_chat_services, LocalVLLMChatProvider
|
||||
from .rag_initializer import init_rag_tool
|
||||
from ..logger import info, warning
|
||||
|
||||
@@ -21,16 +21,19 @@ class AIAgentService:
|
||||
|
||||
async def initialize(self):
|
||||
# 1. 初始化 RAG 工具(如果需要)
|
||||
rag_tool = await init_rag_tool(LLMFactory.create_local)
|
||||
def create_local_llm():
|
||||
provider = LocalVLLMChatProvider()
|
||||
return provider.get_service()
|
||||
rag_tool = await init_rag_tool(create_local_llm)
|
||||
if rag_tool:
|
||||
self.tools.append(rag_tool)
|
||||
self.tools_by_name[rag_tool.name] = rag_tool
|
||||
|
||||
# 2. 构建各模型的 Graph
|
||||
for name, creator in LLMFactory.CREATORS.items():
|
||||
chat_services = get_all_chat_services()
|
||||
for name, llm in chat_services.items():
|
||||
try:
|
||||
info(f"🔄 初始化模型 '{name}'...")
|
||||
llm = creator()
|
||||
builder = GraphBuilder(llm, self.tools, self.tools_by_name).build()
|
||||
graph = builder.compile(checkpointer=self.checkpointer)
|
||||
self.graphs[name] = graph
|
||||
|
||||
246
backend/app/model_services/chat_services.py
Normal file
246
backend/app/model_services/chat_services.py
Normal file
@@ -0,0 +1,246 @@
|
||||
"""
|
||||
生成式大模型服务模块
|
||||
|
||||
本模块提供统一的生成式大模型服务获取接口,支持多种模型:
|
||||
1. Local VLLM 服务:本地 gemma-4-E4B-it 模型
|
||||
2. Zhipu AI:智谱 glm-4.7-flash 模型
|
||||
3. DeepSeek:deepseek-reasoner 模型
|
||||
|
||||
主要功能:
|
||||
- LocalVLLMChatProvider:本地 VLLM 服务提供者
|
||||
- ZhipuChatProvider:智谱 API 服务提供者
|
||||
- DeepSeekChatProvider:DeepSeek API 服务提供者
|
||||
- get_chat_service():获取默认服务(带自动降级)
|
||||
- get_all_chat_services():获取所有可用模型服务(用于多模型切换)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Callable
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
from .base import (
|
||||
BaseServiceProvider,
|
||||
FallbackServiceChain,
|
||||
SingletonServiceManager
|
||||
)
|
||||
from ..config import (
|
||||
VLLM_BASE_URL,
|
||||
LLM_API_KEY,
|
||||
ZHIPUAI_API_KEY,
|
||||
DEEPSEEK_API_KEY
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LocalVLLMChatProvider(BaseServiceProvider[BaseChatModel]):
|
||||
"""
|
||||
本地 VLLM 生成式大模型服务提供者
|
||||
"""
|
||||
|
||||
def __init__(self, model: str = "gemma-4-E4B-it"):
|
||||
super().__init__("local_vllm_chat")
|
||||
self._model = model
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""
|
||||
检查本地 VLLM 服务是否可用
|
||||
|
||||
Returns:
|
||||
bool: 服务是否可用
|
||||
"""
|
||||
if not VLLM_BASE_URL:
|
||||
logger.warning("VLLM_BASE_URL 未配置")
|
||||
return False
|
||||
|
||||
try:
|
||||
# 尝试创建一个简单的测试调用
|
||||
from langchain_openai import ChatOpenAI
|
||||
from pydantic import SecretStr
|
||||
|
||||
llm = ChatOpenAI(
|
||||
base_url=VLLM_BASE_URL,
|
||||
api_key=SecretStr(LLM_API_KEY) if LLM_API_KEY else SecretStr("dummy"),
|
||||
model=self._model,
|
||||
timeout=10.0,
|
||||
max_retries=1,
|
||||
)
|
||||
|
||||
# 简单的 ping 测试(不实际调用模型)
|
||||
logger.info(f"本地 VLLM 服务配置正确,准备使用: {self._model}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"本地 VLLM 服务不可用: {e}")
|
||||
return False
|
||||
|
||||
def get_service(self) -> BaseChatModel:
|
||||
"""
|
||||
获取本地 VLLM 服务
|
||||
|
||||
Returns:
|
||||
BaseChatModel: LangChain 兼容的 ChatModel 实例
|
||||
"""
|
||||
if self._service_instance is None:
|
||||
from langchain_openai import ChatOpenAI
|
||||
from pydantic import SecretStr
|
||||
|
||||
self._service_instance = ChatOpenAI(
|
||||
base_url=VLLM_BASE_URL,
|
||||
api_key=SecretStr(LLM_API_KEY) if LLM_API_KEY else SecretStr(""),
|
||||
model=self._model,
|
||||
timeout=60.0,
|
||||
max_retries=2,
|
||||
streaming=True,
|
||||
)
|
||||
return self._service_instance
|
||||
|
||||
|
||||
class ZhipuChatProvider(BaseServiceProvider[BaseChatModel]):
|
||||
"""
|
||||
智谱 AI 生成式大模型服务提供者
|
||||
"""
|
||||
|
||||
def __init__(self, model: str = "glm-4.7-flash"):
|
||||
super().__init__("zhipu_chat")
|
||||
self._model = model
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""
|
||||
检查智谱 AI 服务是否可用
|
||||
|
||||
Returns:
|
||||
bool: 服务是否可用
|
||||
"""
|
||||
if not ZHIPUAI_API_KEY:
|
||||
logger.warning("ZHIPUAI_API_KEY 未配置")
|
||||
return False
|
||||
|
||||
try:
|
||||
logger.info(f"智谱 AI 服务配置正确,准备使用: {self._model}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"智谱 AI 服务不可用: {e}")
|
||||
return False
|
||||
|
||||
def get_service(self) -> BaseChatModel:
|
||||
"""
|
||||
获取智谱 AI 服务
|
||||
|
||||
Returns:
|
||||
BaseChatModel: LangChain 兼容的 ChatModel 实例
|
||||
"""
|
||||
if self._service_instance is None:
|
||||
from langchain_community.chat_models import ChatZhipuAI
|
||||
|
||||
self._service_instance = ChatZhipuAI(
|
||||
model=self._model,
|
||||
api_key=ZHIPUAI_API_KEY,
|
||||
temperature=0.1,
|
||||
max_tokens=4096,
|
||||
timeout=120.0,
|
||||
max_retries=3,
|
||||
streaming=True,
|
||||
)
|
||||
return self._service_instance
|
||||
|
||||
|
||||
class DeepSeekChatProvider(BaseServiceProvider[BaseChatModel]):
|
||||
"""
|
||||
DeepSeek 生成式大模型服务提供者
|
||||
"""
|
||||
|
||||
def __init__(self, model: str = "deepseek-reasoner"):
|
||||
super().__init__("deepseek_chat")
|
||||
self._model = model
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""
|
||||
检查 DeepSeek 服务是否可用
|
||||
|
||||
Returns:
|
||||
bool: 服务是否可用
|
||||
"""
|
||||
if not DEEPSEEK_API_KEY:
|
||||
logger.warning("DEEPSEEK_API_KEY 未配置")
|
||||
return False
|
||||
|
||||
try:
|
||||
logger.info(f"DeepSeek 服务配置正确,准备使用: {self._model}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"DeepSeek 服务不可用: {e}")
|
||||
return False
|
||||
|
||||
def get_service(self) -> BaseChatModel:
|
||||
"""
|
||||
获取 DeepSeek 服务
|
||||
|
||||
Returns:
|
||||
BaseChatModel: LangChain 兼容的 ChatModel 实例
|
||||
"""
|
||||
if self._service_instance is None:
|
||||
from langchain_openai import ChatOpenAI
|
||||
from pydantic import SecretStr
|
||||
|
||||
self._service_instance = ChatOpenAI(
|
||||
base_url="https://api.deepseek.com",
|
||||
api_key=SecretStr(DEEPSEEK_API_KEY),
|
||||
model=self._model,
|
||||
temperature=0.1,
|
||||
max_tokens=4096,
|
||||
timeout=60.0,
|
||||
max_retries=2,
|
||||
streaming=True,
|
||||
)
|
||||
return self._service_instance
|
||||
|
||||
|
||||
# 全局服务映射表 - 名称 -> Provider
|
||||
CHAT_PROVIDERS: Dict[str, Callable[[], BaseServiceProvider[BaseChatModel]]] = {
|
||||
"local": lambda: LocalVLLMChatProvider(),
|
||||
"zhipu": lambda: ZhipuChatProvider(),
|
||||
"deepseek": lambda: DeepSeekChatProvider(),
|
||||
}
|
||||
|
||||
|
||||
def get_chat_service() -> BaseChatModel:
|
||||
"""
|
||||
获取默认的生成式大模型服务(带自动降级)
|
||||
优先顺序: local -> zhipu -> deepseek
|
||||
|
||||
Returns:
|
||||
BaseChatModel: LangChain 兼容的 ChatModel 实例
|
||||
"""
|
||||
def _create_chain():
|
||||
primary = LocalVLLMChatProvider()
|
||||
fallbacks = [ZhipuChatProvider(), DeepSeekChatProvider()]
|
||||
return FallbackServiceChain(primary, fallbacks)
|
||||
|
||||
chain = SingletonServiceManager.get_or_create("chat_service_chain", _create_chain)
|
||||
return chain.get_available_service()
|
||||
|
||||
|
||||
def get_all_chat_services() -> Dict[str, BaseChatModel]:
|
||||
"""
|
||||
获取所有可用的生成式大模型服务(用于多模型切换)
|
||||
|
||||
Returns:
|
||||
Dict[str, BaseChatModel]: 模型名称 -> ChatModel 实例 的字典
|
||||
"""
|
||||
services = {}
|
||||
|
||||
for name, provider_factory in CHAT_PROVIDERS.items():
|
||||
try:
|
||||
provider = provider_factory()
|
||||
if provider.is_available():
|
||||
logger.info(f"模型 '{name}' 可用")
|
||||
services[name] = provider.get_service()
|
||||
else:
|
||||
logger.warning(f"模型 '{name}' 不可用,跳过")
|
||||
except Exception as e:
|
||||
logger.warning(f"初始化模型 '{name}' 失败: {e}")
|
||||
|
||||
if not services:
|
||||
raise RuntimeError(f"没有可用的生成式大模型,尝试了: {list(CHAT_PROVIDERS.keys())}")
|
||||
|
||||
return services
|
||||
Reference in New Issue
Block a user