Files
ailine/backend/app/main_graph/nodes/llm_call.py
root b5c15ef445
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 12m9s
refactor: 单图方案重构 + 动态模型选择 + chat_services优化
## 核心改动

### 1. 单图方案重构
- 删除了多图(self.graphs),改为单图(self.graph)
- 新增 MainGraphState.current_model 字段用于运行时注入模型
- llm_call 节点改为动态选择模型(create_dynamic_llm_call_node)

### 2. chat_services 优化
- 添加 _cached_services 缓存,避免重复初始化
- 新增 get_cached_chat_services() 函数,用于单图注入
- 新增 _check_http_service_available() 统一HTTP探测逻辑
- 减少重复代码,LocalVLLMChatProvider和LocalSmallModelProvider共用探测方法

### 3. AIAgentService 重构
- initialize() 只构建一次图,传入 chat_services 字典
- 新增 _resolve_model() 模型回退逻辑
- 新增 _build_invocation() 统一构建调用参数
- process_message() 和 process_message_stream() 改为注入 current_model
- 流式处理代码拆分,增加可读性

### 4. 新增和删除文件
- 新增:backend/app/main_graph/main_graph_builder.py(图构建)
- 新增:backend/app/main_graph/subgraph_wrapper.py(子图封装)
- 新增:tools/test/test_tavily_search.py(测试)
- 删除:backend/app/main_graph/graph.py(旧图)
- 删除:backend/app/main_graph/utils/main_graph_builder.py(旧构建器)
- 删除:backend/app/main_graph/utils/__init__.py

### 5. 其他更新
- README.md:新增模型服务使用情况详解章节
- backend/app/model_services/__init__.py:新增 get_cached_chat_services 导出

## 方案优势

- 内存优化:N张图 → 1张图
- 灵活性:运行时动态选择模型,支持同会话不同模型
- 性能:模型服务缓存,初始化仅一次
- 可维护性:减少重复代码,统一HTTP探测逻辑
2026-05-05 17:30:55 +08:00

207 lines
8.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
LLM 调用节点模块
负责调用大语言模型并处理响应
"""
import time
from typing import Any, Dict
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage
# 本地模块
from app.main_graph.state import MainGraphState
from app.agent.prompts import create_system_prompt
from app.utils.logging import log_state_change
from app.logger import debug, info, error
def create_dynamic_llm_call_node(chat_services: Dict[str, BaseChatModel], tools: list):
"""
工厂函数:创建动态 LLM 调用节点(根据 state.current_model 选择模型)
Args:
chat_services: 模型名称 -> ChatModel 实例 的字典
tools: 工具列表
Returns:
异步节点函数
"""
# 预构建所有模型的 tools 绑定(避免每次调用都 bind
bound_models: Dict[str, Any] = {}
for name, llm in chat_services.items():
if tools:
bound_models[name] = llm.bind_tools(tools)
else:
bound_models[name] = llm
# 预构建 prompt
prompt = create_system_prompt(tools)
from langchain_core.runnables.config import RunnableConfig
async def call_llm(state: MainGraphState, config: RunnableConfig) -> Dict[str, Any]:
"""
LLM 调用节点(动态选择模型)
Args:
state: 当前对话状态
config: LangChain/LangGraph 自动注入的配置,包含 callbacks 等信息
Returns:
更新后的状态字典
"""
log_state_change("llm_call", state, "进入")
memory_context = getattr(state, "memory_context", "暂无用户信息")
start_time = time.time()
# 关键修复:如果 state.final_result 已经存在(比如子图执行完),直接返回
if state.final_result:
info(f"[llm_call] 检测到已有最终结果,直接返回: {state.final_result[:100]}...")
elapsed_time = time.time() - start_time
return {
"final_result": state.final_result,
"success": True,
"current_phase": "done",
"llm_calls": getattr(state, 'llm_calls', 0) + 1,
"last_elapsed_time": elapsed_time,
"turns_since_last_summary": getattr(state, 'turns_since_last_summary', 0) + 1,
}
# 动态选择模型
model_name = getattr(state, "current_model", "")
if not model_name or model_name not in bound_models:
# 回退到第一个可用模型
fallback_name = next(iter(bound_models.keys()))
info(f"[llm_call] 模型 '{model_name}' 不可用,回退到 '{fallback_name}'")
model_name = fallback_name
llm_with_tools = bound_models[model_name]
info(f"[llm_call] 使用模型: {model_name}")
try:
# 添加上下文到消息
messages_with_context = list(state.messages)
if state.rag_context:
from langchain_core.messages import SystemMessage
rag_system_msg = SystemMessage(content=f"以下是检索到的相关信息:\n{state.rag_context}")
inserted = False
for i, msg in enumerate(messages_with_context):
if msg.type == "human":
messages_with_context.insert(i, rag_system_msg)
inserted = True
break
if not inserted:
messages_with_context.insert(0, rag_system_msg)
# 恢复为:手动进行 astream并将所有的 chunk 拼接成最终的 response 返回。
# LangGraph 会自动监听这期间产生的所有 token。
chain = prompt | llm_with_tools
chunks = []
async for chunk in chain.astream(
{
"messages": messages_with_context,
"memory_context": memory_context
},
config=config
):
chunks.append(chunk)
# 将所有 chunk 合并成最终的 AIMessage
if chunks:
response = chunks[0]
for chunk in chunks[1:]:
response = response + chunk
else:
response = AIMessage(content="")
elapsed_time = time.time() - start_time
# 提取 token 用量(兼容不同 LLM 提供商的元数据格式)
token_usage = {}
input_tokens = 0
output_tokens = 0
# 尝试从 response_metadata 中提取
if hasattr(response, 'response_metadata') and response.response_metadata:
meta = response.response_metadata
if 'token_usage' in meta:
token_usage = meta['token_usage']
elif 'usage' in meta:
token_usage = meta['usage']
# 尝试从 additional_kwargs 中提取
if not token_usage and hasattr(response, 'additional_kwargs'):
add_kwargs = response.additional_kwargs
if 'llm_output' in add_kwargs and 'token_usage' in add_kwargs['llm_output']:
token_usage = add_kwargs['llm_output']['token_usage']
# 提取具体的 token 数值
if token_usage:
input_tokens = token_usage.get('prompt_tokens', token_usage.get('input_tokens', 0))
output_tokens = token_usage.get('completion_tokens', token_usage.get('output_tokens', 0))
# 打印 LLM 的完整输出
debug("\n" + "="*80)
debug(f"📥 [LLM输出] 模型: {model_name} 返回的完整响应:")
debug(f" 消息类型: {response.type.upper()}")
debug(f" 内容长度: {len(str(response.content))} 字符")
debug("-"*80)
debug(f"{response.content}")
# 打印响应统计信息
info(f"⏱️ [LLM统计] 调用耗时: {elapsed_time:.2f}")
info(f"📊 [LLM统计] Token用量: 输入={input_tokens}, 输出={output_tokens}, 总计={input_tokens + output_tokens}")
if token_usage:
debug(f"📋 [LLM统计] 详细用量: {token_usage}")
debug("="*80 + "\n")
# 检查是否有工具调用
has_tool_calls = hasattr(response, 'tool_calls') and len(response.tool_calls) > 0
result = {
"messages": [response],
"llm_calls": getattr(state, 'llm_calls', 0) + 1,
"last_token_usage": token_usage,
"last_elapsed_time": elapsed_time,
"turns_since_last_summary": getattr(state, 'turns_since_last_summary', 0) + 1,
"final_result": response.content,
"success": True,
"current_phase": "done",
"has_tool_calls": has_tool_calls,
"current_model": model_name # 记录实际使用的模型
}
log_state_change("llm_call", state, "离开")
return result
except Exception as e:
elapsed_time = time.time() - start_time
error(f"\n❌ [LLM错误] 模型 {model_name} 调用失败 (耗时: {elapsed_time:.2f}秒)")
error(f" 错误类型: {type(e).__name__}")
error(f" 错误信息: {str(e)}")
import traceback
error(f"📋 堆栈: {traceback.format_exc()}")
debug("="*80 + "\n")
# 返回一个友好的错误消息
error_response = AIMessage(
content="抱歉,模型暂时无法响应,可能是网络超时或服务繁忙,请稍后再试。"
)
error_result = {
"messages": [error_response],
"llm_calls": getattr(state, 'llm_calls', 0),
"last_token_usage": {},
"last_elapsed_time": elapsed_time,
"turns_since_last_summary": getattr(state, 'turns_since_last_summary', 0) + 1,
"final_result": "抱歉,模型暂时无法响应,可能是网络超时或服务繁忙,请稍后再试。",
"success": False,
"current_phase": "done",
"current_model": model_name
}
log_state_change("llm_call", state, "离开(异常)")
return error_result
return call_llm