151 lines
5.7 KiB
Python
151 lines
5.7 KiB
Python
"""
|
||
LLM 调用节点模块
|
||
负责调用大语言模型并处理响应
|
||
"""
|
||
|
||
import time
|
||
from typing import Any, Dict
|
||
from langchain_core.language_models import BaseLLM
|
||
from langchain_core.messages import AIMessage
|
||
|
||
# 本地模块
|
||
from app.graph.state import MessagesState
|
||
from app.agent.prompts import create_system_prompt
|
||
from app.utils.logging import log_state_change
|
||
from app.logger import debug, info, error
|
||
|
||
def create_llm_call_node(llm: BaseLLM, tools: list):
|
||
"""
|
||
工厂函数:创建 LLM 调用节点
|
||
|
||
Args:
|
||
llm: LangChain LLM 实例
|
||
tools: 工具列表
|
||
|
||
Returns:
|
||
异步节点函数
|
||
"""
|
||
# 构建调用链
|
||
prompt = create_system_prompt(tools)
|
||
llm_with_tools = llm.bind_tools(tools)
|
||
|
||
# 恢复带 RunnableLambda 的链,并在下方使用 astream 遍历
|
||
chain = prompt | llm_with_tools
|
||
|
||
from langchain_core.runnables.config import RunnableConfig
|
||
|
||
async def call_llm(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:
|
||
"""
|
||
LLM 调用节点(异步方法)
|
||
|
||
Args:
|
||
state: 当前对话状态
|
||
config: LangChain/LangGraph 自动注入的配置,包含 callbacks 等信息
|
||
|
||
Returns:
|
||
更新后的状态字典
|
||
"""
|
||
log_state_change("llm_call", state, "进入")
|
||
|
||
memory_context = state.get("memory_context", "暂无用户信息")
|
||
start_time = time.time()
|
||
|
||
try:
|
||
# 恢复为:手动进行 astream,并将所有的 chunk 拼接成最终的 response 返回。
|
||
# LangGraph 会自动监听这期间产生的所有 token。
|
||
chunks = []
|
||
async for chunk in chain.astream(
|
||
{
|
||
"messages": state["messages"],
|
||
"memory_context": memory_context
|
||
},
|
||
config=config
|
||
):
|
||
chunks.append(chunk)
|
||
|
||
# 将所有 chunk 合并成最终的 AIMessage
|
||
if chunks:
|
||
response = chunks[0]
|
||
for chunk in chunks[1:]:
|
||
response = response + chunk
|
||
else:
|
||
response = AIMessage(content="")
|
||
|
||
elapsed_time = time.time() - start_time
|
||
|
||
# 提取 token 用量(兼容不同 LLM 提供商的元数据格式)
|
||
token_usage = {}
|
||
input_tokens = 0
|
||
output_tokens = 0
|
||
|
||
# 尝试从 response_metadata 中提取
|
||
if hasattr(response, 'response_metadata') and response.response_metadata:
|
||
meta = response.response_metadata
|
||
if 'token_usage' in meta:
|
||
token_usage = meta['token_usage']
|
||
elif 'usage' in meta:
|
||
token_usage = meta['usage']
|
||
|
||
# 尝试从 additional_kwargs 中提取
|
||
if not token_usage and hasattr(response, 'additional_kwargs'):
|
||
add_kwargs = response.additional_kwargs
|
||
if 'llm_output' in add_kwargs and 'token_usage' in add_kwargs['llm_output']:
|
||
token_usage = add_kwargs['llm_output']['token_usage']
|
||
|
||
# 提取具体的 token 数值
|
||
if token_usage:
|
||
input_tokens = token_usage.get('prompt_tokens', token_usage.get('input_tokens', 0))
|
||
output_tokens = token_usage.get('completion_tokens', token_usage.get('output_tokens', 0))
|
||
|
||
# 打印 LLM 的完整输出
|
||
debug("\n" + "="*80)
|
||
debug("📥 [LLM输出] 大模型返回的完整响应:")
|
||
debug(f" 消息类型: {response.type.upper()}")
|
||
debug(f" 内容长度: {len(str(response.content))} 字符")
|
||
debug("-"*80)
|
||
debug(f"{response.content}")
|
||
|
||
# 打印响应统计信息
|
||
info(f"⏱️ [LLM统计] 调用耗时: {elapsed_time:.2f}秒")
|
||
info(f"📊 [LLM统计] Token用量: 输入={input_tokens}, 输出={output_tokens}, 总计={input_tokens + output_tokens}")
|
||
if token_usage:
|
||
debug(f"📋 [LLM统计] 详细用量: {token_usage}")
|
||
debug("="*80 + "\n")
|
||
|
||
result = {
|
||
"messages": [response],
|
||
"llm_calls": state.get('llm_calls', 0) + 1,
|
||
"last_token_usage": token_usage,
|
||
"last_elapsed_time": elapsed_time,
|
||
"turns_since_last_summary": state.get('turns_since_last_summary', 0) + 1 # 递增计数器
|
||
}
|
||
|
||
log_state_change("llm_call", {**state, **result}, "离开")
|
||
return result
|
||
|
||
except Exception as e:
|
||
elapsed_time = time.time() - start_time
|
||
error(f"\n❌ [LLM错误] 调用失败 (耗时: {elapsed_time:.2f}秒)")
|
||
error(f" 错误类型: {type(e).__name__}")
|
||
error(f" 错误信息: {str(e)}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
debug("="*80 + "\n")
|
||
|
||
# 返回一个友好的错误消息
|
||
error_response = AIMessage(
|
||
content="抱歉,模型暂时无法响应,可能是网络超时或服务繁忙,请稍后再试。"
|
||
)
|
||
error_result = {
|
||
"messages": [error_response],
|
||
"llm_calls": state.get('llm_calls', 0),
|
||
"last_token_usage": {},
|
||
"last_elapsed_time": elapsed_time,
|
||
"turns_since_last_summary": state.get('turns_since_last_summary', 0) + 1 # 即使出错也递增计数器
|
||
}
|
||
|
||
log_state_change("llm_call", state, "离开(异常)")
|
||
return error_result
|
||
|
||
return call_llm
|