refactor: 单图方案重构 + 动态模型选择 + chat_services优化
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 12m9s
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 12m9s
## 核心改动 ### 1. 单图方案重构 - 删除了多图(self.graphs),改为单图(self.graph) - 新增 MainGraphState.current_model 字段用于运行时注入模型 - llm_call 节点改为动态选择模型(create_dynamic_llm_call_node) ### 2. chat_services 优化 - 添加 _cached_services 缓存,避免重复初始化 - 新增 get_cached_chat_services() 函数,用于单图注入 - 新增 _check_http_service_available() 统一HTTP探测逻辑 - 减少重复代码,LocalVLLMChatProvider和LocalSmallModelProvider共用探测方法 ### 3. AIAgentService 重构 - initialize() 只构建一次图,传入 chat_services 字典 - 新增 _resolve_model() 模型回退逻辑 - 新增 _build_invocation() 统一构建调用参数 - process_message() 和 process_message_stream() 改为注入 current_model - 流式处理代码拆分,增加可读性 ### 4. 新增和删除文件 - 新增:backend/app/main_graph/main_graph_builder.py(图构建) - 新增:backend/app/main_graph/subgraph_wrapper.py(子图封装) - 新增:tools/test/test_tavily_search.py(测试) - 删除:backend/app/main_graph/graph.py(旧图) - 删除:backend/app/main_graph/utils/main_graph_builder.py(旧构建器) - 删除:backend/app/main_graph/utils/__init__.py ### 5. 其他更新 - README.md:新增模型服务使用情况详解章节 - backend/app/model_services/__init__.py:新增 get_cached_chat_services 导出 ## 方案优势 - 内存优化:N张图 → 1张图 - 灵活性:运行时动态选择模型,支持同会话不同模型 - 性能:模型服务缓存,初始化仅一次 - 可维护性:减少重复代码,统一HTTP探测逻辑
This commit is contained in:
@@ -6,8 +6,8 @@
|
||||
from .reasoning import react_reason_node
|
||||
from .web_search import web_search_node
|
||||
from .error_handling import error_handling_node
|
||||
from .routing import init_state_node, route_by_reasoning
|
||||
from .llm_call import create_llm_call_node
|
||||
from .routing import init_state_node, route_by_reasoning, should_summarize
|
||||
from .llm_call import create_dynamic_llm_call_node
|
||||
from .rag_nodes import rag_retrieve_node, rag_re_retrieve_node
|
||||
|
||||
# 记忆节点
|
||||
@@ -38,7 +38,8 @@ __all__ = [
|
||||
"web_search_node",
|
||||
"error_handling_node",
|
||||
"route_by_reasoning",
|
||||
"create_llm_call_node",
|
||||
"should_summarize",
|
||||
"create_dynamic_llm_call_node",
|
||||
"rag_retrieve_node",
|
||||
"rag_re_retrieve_node",
|
||||
# 记忆节点
|
||||
|
||||
@@ -5,7 +5,7 @@ LLM 调用节点模块
|
||||
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
from langchain_core.language_models import BaseLLM
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
# 本地模块
|
||||
@@ -14,29 +14,34 @@ from app.agent.prompts import create_system_prompt
|
||||
from app.utils.logging import log_state_change
|
||||
from app.logger import debug, info, error
|
||||
|
||||
def create_llm_call_node(llm, tools: list):
|
||||
|
||||
def create_dynamic_llm_call_node(chat_services: Dict[str, BaseChatModel], tools: list):
|
||||
"""
|
||||
工厂函数:创建 LLM 调用节点
|
||||
|
||||
工厂函数:创建动态 LLM 调用节点(根据 state.current_model 选择模型)
|
||||
|
||||
Args:
|
||||
llm: LangChain LLM 实例
|
||||
chat_services: 模型名称 -> ChatModel 实例 的字典
|
||||
tools: 工具列表
|
||||
|
||||
|
||||
Returns:
|
||||
异步节点函数
|
||||
"""
|
||||
# 构建调用链
|
||||
# 预构建所有模型的 tools 绑定(避免每次调用都 bind)
|
||||
bound_models: Dict[str, Any] = {}
|
||||
for name, llm in chat_services.items():
|
||||
if tools:
|
||||
bound_models[name] = llm.bind_tools(tools)
|
||||
else:
|
||||
bound_models[name] = llm
|
||||
|
||||
# 预构建 prompt
|
||||
prompt = create_system_prompt(tools)
|
||||
llm_with_tools = llm.bind_tools(tools)
|
||||
|
||||
# 恢复带 RunnableLambda 的链,并在下方使用 astream 遍历
|
||||
chain = prompt | llm_with_tools
|
||||
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
|
||||
async def call_llm(state: MainGraphState, config: RunnableConfig) -> Dict[str, Any]:
|
||||
"""
|
||||
LLM 调用节点(异步方法)
|
||||
LLM 调用节点(动态选择模型)
|
||||
|
||||
Args:
|
||||
state: 当前对话状态
|
||||
@@ -46,7 +51,7 @@ def create_llm_call_node(llm, tools: list):
|
||||
更新后的状态字典
|
||||
"""
|
||||
log_state_change("llm_call", state, "进入")
|
||||
|
||||
|
||||
memory_context = getattr(state, "memory_context", "暂无用户信息")
|
||||
start_time = time.time()
|
||||
|
||||
@@ -62,9 +67,20 @@ def create_llm_call_node(llm, tools: list):
|
||||
"last_elapsed_time": elapsed_time,
|
||||
"turns_since_last_summary": getattr(state, 'turns_since_last_summary', 0) + 1,
|
||||
}
|
||||
|
||||
|
||||
# 动态选择模型
|
||||
model_name = getattr(state, "current_model", "")
|
||||
if not model_name or model_name not in bound_models:
|
||||
# 回退到第一个可用模型
|
||||
fallback_name = next(iter(bound_models.keys()))
|
||||
info(f"[llm_call] 模型 '{model_name}' 不可用,回退到 '{fallback_name}'")
|
||||
model_name = fallback_name
|
||||
|
||||
llm_with_tools = bound_models[model_name]
|
||||
info(f"[llm_call] 使用模型: {model_name}")
|
||||
|
||||
try:
|
||||
# 添加 RAG 上下文到消息
|
||||
# 添加上下文到消息
|
||||
messages_with_context = list(state.messages)
|
||||
if state.rag_context:
|
||||
from langchain_core.messages import SystemMessage
|
||||
@@ -77,9 +93,10 @@ def create_llm_call_node(llm, tools: list):
|
||||
break
|
||||
if not inserted:
|
||||
messages_with_context.insert(0, rag_system_msg)
|
||||
|
||||
|
||||
# 恢复为:手动进行 astream,并将所有的 chunk 拼接成最终的 response 返回。
|
||||
# LangGraph 会自动监听这期间产生的所有 token。
|
||||
chain = prompt | llm_with_tools
|
||||
chunks = []
|
||||
async for chunk in chain.astream(
|
||||
{
|
||||
@@ -89,7 +106,7 @@ def create_llm_call_node(llm, tools: list):
|
||||
config=config
|
||||
):
|
||||
chunks.append(chunk)
|
||||
|
||||
|
||||
# 将所有 chunk 合并成最终的 AIMessage
|
||||
if chunks:
|
||||
response = chunks[0]
|
||||
@@ -97,14 +114,14 @@ def create_llm_call_node(llm, tools: list):
|
||||
response = response + chunk
|
||||
else:
|
||||
response = AIMessage(content="")
|
||||
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
|
||||
# 提取 token 用量(兼容不同 LLM 提供商的元数据格式)
|
||||
token_usage = {}
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
|
||||
|
||||
# 尝试从 response_metadata 中提取
|
||||
if hasattr(response, 'response_metadata') and response.response_metadata:
|
||||
meta = response.response_metadata
|
||||
@@ -112,33 +129,33 @@ def create_llm_call_node(llm, tools: list):
|
||||
token_usage = meta['token_usage']
|
||||
elif 'usage' in meta:
|
||||
token_usage = meta['usage']
|
||||
|
||||
|
||||
# 尝试从 additional_kwargs 中提取
|
||||
if not token_usage and hasattr(response, 'additional_kwargs'):
|
||||
add_kwargs = response.additional_kwargs
|
||||
if 'llm_output' in add_kwargs and 'token_usage' in add_kwargs['llm_output']:
|
||||
token_usage = add_kwargs['llm_output']['token_usage']
|
||||
|
||||
|
||||
# 提取具体的 token 数值
|
||||
if token_usage:
|
||||
input_tokens = token_usage.get('prompt_tokens', token_usage.get('input_tokens', 0))
|
||||
output_tokens = token_usage.get('completion_tokens', token_usage.get('output_tokens', 0))
|
||||
|
||||
|
||||
# 打印 LLM 的完整输出
|
||||
debug("\n" + "="*80)
|
||||
debug("📥 [LLM输出] 大模型返回的完整响应:")
|
||||
debug(f"📥 [LLM输出] 模型: {model_name} 返回的完整响应:")
|
||||
debug(f" 消息类型: {response.type.upper()}")
|
||||
debug(f" 内容长度: {len(str(response.content))} 字符")
|
||||
debug("-"*80)
|
||||
debug(f"{response.content}")
|
||||
|
||||
|
||||
# 打印响应统计信息
|
||||
info(f"⏱️ [LLM统计] 调用耗时: {elapsed_time:.2f}秒")
|
||||
info(f"📊 [LLM统计] Token用量: 输入={input_tokens}, 输出={output_tokens}, 总计={input_tokens + output_tokens}")
|
||||
if token_usage:
|
||||
debug(f"📋 [LLM统计] 详细用量: {token_usage}")
|
||||
debug("="*80 + "\n")
|
||||
|
||||
|
||||
# 检查是否有工具调用
|
||||
has_tool_calls = hasattr(response, 'tool_calls') and len(response.tool_calls) > 0
|
||||
|
||||
@@ -151,21 +168,22 @@ def create_llm_call_node(llm, tools: list):
|
||||
"final_result": response.content,
|
||||
"success": True,
|
||||
"current_phase": "done",
|
||||
"has_tool_calls": has_tool_calls
|
||||
"has_tool_calls": has_tool_calls,
|
||||
"current_model": model_name # 记录实际使用的模型
|
||||
}
|
||||
|
||||
|
||||
log_state_change("llm_call", state, "离开")
|
||||
return result
|
||||
|
||||
|
||||
except Exception as e:
|
||||
elapsed_time = time.time() - start_time
|
||||
error(f"\n❌ [LLM错误] 调用失败 (耗时: {elapsed_time:.2f}秒)")
|
||||
error(f"\n❌ [LLM错误] 模型 {model_name} 调用失败 (耗时: {elapsed_time:.2f}秒)")
|
||||
error(f" 错误类型: {type(e).__name__}")
|
||||
error(f" 错误信息: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
error(f"📋 堆栈: {traceback.format_exc()}")
|
||||
debug("="*80 + "\n")
|
||||
|
||||
|
||||
# 返回一个友好的错误消息
|
||||
error_response = AIMessage(
|
||||
content="抱歉,模型暂时无法响应,可能是网络超时或服务繁忙,请稍后再试。"
|
||||
@@ -178,10 +196,11 @@ def create_llm_call_node(llm, tools: list):
|
||||
"turns_since_last_summary": getattr(state, 'turns_since_last_summary', 0) + 1,
|
||||
"final_result": "抱歉,模型暂时无法响应,可能是网络超时或服务繁忙,请稍后再试。",
|
||||
"success": False,
|
||||
"current_phase": "done"
|
||||
"current_phase": "done",
|
||||
"current_model": model_name
|
||||
}
|
||||
|
||||
|
||||
log_state_change("llm_call", state, "离开(异常)")
|
||||
return error_result
|
||||
|
||||
return call_llm
|
||||
|
||||
return call_llm
|
||||
|
||||
@@ -118,3 +118,21 @@ def route_by_reasoning(state: MainGraphState) -> str:
|
||||
|
||||
info(f"[条件路由] 动作={latest_action}, 目标={target}")
|
||||
return target
|
||||
|
||||
|
||||
# ========== 完成阶段条件路由函数 ==========
|
||||
|
||||
def should_summarize(state: MainGraphState) -> str:
|
||||
"""
|
||||
检查是否需要总结对话(对话足够长时)
|
||||
|
||||
Args:
|
||||
state: 当前图状态
|
||||
|
||||
Returns:
|
||||
"summarize" 或 "finalize"
|
||||
"""
|
||||
if state.turns_since_last_summary >= 5: # 每5轮对话总结一次
|
||||
return "summarize"
|
||||
else:
|
||||
return "finalize"
|
||||
|
||||
Reference in New Issue
Block a user