重构架构:恢复统一的 llm_call 节点,移除错误的 final_response 节点
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 5m50s

This commit is contained in:
2026-05-01 14:01:48 +08:00
parent 1e15a0e550
commit 4ee769a79f
4 changed files with 94 additions and 161 deletions

View File

@@ -9,54 +9,68 @@ from langchain_core.language_models import BaseLLM
from langchain_core.messages import AIMessage
# 本地模块
from app.main_graph.state import MessagesState
from app.main_graph.state import MainGraphState
from app.agent.prompts import create_system_prompt
from app.utils.logging import log_state_change
from app.logger import debug, info, error
def create_llm_call_node(llm: BaseLLM, tools: list):
def create_llm_call_node(llm, tools: list):
"""
工厂函数:创建 LLM 调用节点
Args:
llm: LangChain LLM 实例
tools: 工具列表
Returns:
异步节点函数
"""
# 构建调用链
prompt = create_system_prompt(tools)
llm_with_tools = llm.bind_tools(tools)
# 恢复带 RunnableLambda 的链,并在下方使用 astream 遍历
chain = prompt | llm_with_tools
from langchain_core.runnables.config import RunnableConfig
async def call_llm(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:
async def call_llm(state: MainGraphState, config: RunnableConfig) -> Dict[str, Any]:
"""
LLM 调用节点(异步方法)
Args:
state: 当前对话状态
config: LangChain/LangGraph 自动注入的配置,包含 callbacks 等信息
Returns:
更新后的状态字典
"""
log_state_change("llm_call", state, "进入")
memory_context = state.get("memory_context", "暂无用户信息")
memory_context = getattr(state, "memory_context", "暂无用户信息")
start_time = time.time()
try:
# 添加 RAG 上下文到消息
messages_with_context = list(state.messages)
if state.rag_context:
from langchain_core.messages import SystemMessage
rag_system_msg = SystemMessage(content=f"以下是检索到的相关信息:\n{state.rag_context}")
inserted = False
for i, msg in enumerate(messages_with_context):
if msg.type == "human":
messages_with_context.insert(i, rag_system_msg)
inserted = True
break
if not inserted:
messages_with_context.insert(0, rag_system_msg)
# 恢复为:手动进行 astream并将所有的 chunk 拼接成最终的 response 返回。
# LangGraph 会自动监听这期间产生的所有 token。
chunks = []
async for chunk in chain.astream(
{
"messages": state["messages"],
"messages": messages_with_context,
"memory_context": memory_context
},
config=config
@@ -70,14 +84,14 @@ def create_llm_call_node(llm: BaseLLM, tools: list):
response = response + chunk
else:
response = AIMessage(content="")
elapsed_time = time.time() - start_time
# 提取 token 用量(兼容不同 LLM 提供商的元数据格式)
token_usage = {}
input_tokens = 0
output_tokens = 0
# 尝试从 response_metadata 中提取
if hasattr(response, 'response_metadata') and response.response_metadata:
meta = response.response_metadata
@@ -85,18 +99,18 @@ def create_llm_call_node(llm: BaseLLM, tools: list):
token_usage = meta['token_usage']
elif 'usage' in meta:
token_usage = meta['usage']
# 尝试从 additional_kwargs 中提取
if not token_usage and hasattr(response, 'additional_kwargs'):
add_kwargs = response.additional_kwargs
if 'llm_output' in add_kwargs and 'token_usage' in add_kwargs['llm_output']:
token_usage = add_kwargs['llm_output']['token_usage']
# 提取具体的 token 数值
if token_usage:
input_tokens = token_usage.get('prompt_tokens', token_usage.get('input_tokens', 0))
output_tokens = token_usage.get('completion_tokens', token_usage.get('output_tokens', 0))
# 打印 LLM 的完整输出
debug("\n" + "="*80)
debug("📥 [LLM输出] 大模型返回的完整响应:")
@@ -111,18 +125,21 @@ def create_llm_call_node(llm: BaseLLM, tools: list):
if token_usage:
debug(f"📋 [LLM统计] 详细用量: {token_usage}")
debug("="*80 + "\n")
result = {
"messages": [response],
"llm_calls": state.get('llm_calls', 0) + 1,
"llm_calls": getattr(state, 'llm_calls', 0) + 1,
"last_token_usage": token_usage,
"last_elapsed_time": elapsed_time,
"turns_since_last_summary": state.get('turns_since_last_summary', 0) + 1 # 递增计数器
"turns_since_last_summary": getattr(state, 'turns_since_last_summary', 0) + 1,
"final_result": response.content,
"success": True,
"current_phase": "done"
}
log_state_change("llm_call", {**state, **result}, "离开")
return result
except Exception as e:
elapsed_time = time.time() - start_time
error(f"\n❌ [LLM错误] 调用失败 (耗时: {elapsed_time:.2f}秒)")
@@ -131,20 +148,23 @@ def create_llm_call_node(llm: BaseLLM, tools: list):
import traceback
traceback.print_exc()
debug("="*80 + "\n")
# 返回一个友好的错误消息
error_response = AIMessage(
content="抱歉,模型暂时无法响应,可能是网络超时或服务繁忙,请稍后再试。"
)
error_result = {
"messages": [error_response],
"llm_calls": state.get('llm_calls', 0),
"llm_calls": getattr(state, 'llm_calls', 0),
"last_token_usage": {},
"last_elapsed_time": elapsed_time,
"turns_since_last_summary": state.get('turns_since_last_summary', 0) + 1 # 即使出错也递增计数器
"turns_since_last_summary": getattr(state, 'turns_since_last_summary', 0) + 1,
"final_result": "抱歉,模型暂时无法响应,可能是网络超时或服务繁忙,请稍后再试。",
"success": False,
"current_phase": "done"
}
log_state_change("llm_call", state, "离开(异常)")
return error_result
return call_llm