重构架构:恢复统一的 llm_call 节点,移除错误的 final_response 节点
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 5m50s
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 5m50s
This commit is contained in:
@@ -46,7 +46,7 @@ class AIAgentService:
|
||||
for name, llm in chat_services.items():
|
||||
try:
|
||||
info(f"🔄 初始化模型 '{name}'...")
|
||||
graph = build_react_main_graph().compile(checkpointer=self.checkpointer)
|
||||
graph = build_react_main_graph(llm=llm, tools=self.tools).compile(checkpointer=self.checkpointer)
|
||||
self.graphs[name] = graph
|
||||
info(f"✅ 模型 '{name}' 初始化成功")
|
||||
except Exception as e:
|
||||
|
||||
@@ -9,54 +9,68 @@ from langchain_core.language_models import BaseLLM
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
# 本地模块
|
||||
from app.main_graph.state import MessagesState
|
||||
from app.main_graph.state import MainGraphState
|
||||
from app.agent.prompts import create_system_prompt
|
||||
from app.utils.logging import log_state_change
|
||||
from app.logger import debug, info, error
|
||||
|
||||
def create_llm_call_node(llm: BaseLLM, tools: list):
|
||||
def create_llm_call_node(llm, tools: list):
|
||||
"""
|
||||
工厂函数:创建 LLM 调用节点
|
||||
|
||||
|
||||
Args:
|
||||
llm: LangChain LLM 实例
|
||||
tools: 工具列表
|
||||
|
||||
|
||||
Returns:
|
||||
异步节点函数
|
||||
"""
|
||||
# 构建调用链
|
||||
prompt = create_system_prompt(tools)
|
||||
llm_with_tools = llm.bind_tools(tools)
|
||||
|
||||
|
||||
# 恢复带 RunnableLambda 的链,并在下方使用 astream 遍历
|
||||
chain = prompt | llm_with_tools
|
||||
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
async def call_llm(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:
|
||||
|
||||
async def call_llm(state: MainGraphState, config: RunnableConfig) -> Dict[str, Any]:
|
||||
"""
|
||||
LLM 调用节点(异步方法)
|
||||
|
||||
|
||||
Args:
|
||||
state: 当前对话状态
|
||||
config: LangChain/LangGraph 自动注入的配置,包含 callbacks 等信息
|
||||
|
||||
|
||||
Returns:
|
||||
更新后的状态字典
|
||||
"""
|
||||
log_state_change("llm_call", state, "进入")
|
||||
|
||||
memory_context = state.get("memory_context", "暂无用户信息")
|
||||
|
||||
memory_context = getattr(state, "memory_context", "暂无用户信息")
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
try:
|
||||
# 添加 RAG 上下文到消息
|
||||
messages_with_context = list(state.messages)
|
||||
if state.rag_context:
|
||||
from langchain_core.messages import SystemMessage
|
||||
rag_system_msg = SystemMessage(content=f"以下是检索到的相关信息:\n{state.rag_context}")
|
||||
inserted = False
|
||||
for i, msg in enumerate(messages_with_context):
|
||||
if msg.type == "human":
|
||||
messages_with_context.insert(i, rag_system_msg)
|
||||
inserted = True
|
||||
break
|
||||
if not inserted:
|
||||
messages_with_context.insert(0, rag_system_msg)
|
||||
|
||||
# 恢复为:手动进行 astream,并将所有的 chunk 拼接成最终的 response 返回。
|
||||
# LangGraph 会自动监听这期间产生的所有 token。
|
||||
chunks = []
|
||||
async for chunk in chain.astream(
|
||||
{
|
||||
"messages": state["messages"],
|
||||
"messages": messages_with_context,
|
||||
"memory_context": memory_context
|
||||
},
|
||||
config=config
|
||||
@@ -70,14 +84,14 @@ def create_llm_call_node(llm: BaseLLM, tools: list):
|
||||
response = response + chunk
|
||||
else:
|
||||
response = AIMessage(content="")
|
||||
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
|
||||
# 提取 token 用量(兼容不同 LLM 提供商的元数据格式)
|
||||
token_usage = {}
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
|
||||
|
||||
# 尝试从 response_metadata 中提取
|
||||
if hasattr(response, 'response_metadata') and response.response_metadata:
|
||||
meta = response.response_metadata
|
||||
@@ -85,18 +99,18 @@ def create_llm_call_node(llm: BaseLLM, tools: list):
|
||||
token_usage = meta['token_usage']
|
||||
elif 'usage' in meta:
|
||||
token_usage = meta['usage']
|
||||
|
||||
|
||||
# 尝试从 additional_kwargs 中提取
|
||||
if not token_usage and hasattr(response, 'additional_kwargs'):
|
||||
add_kwargs = response.additional_kwargs
|
||||
if 'llm_output' in add_kwargs and 'token_usage' in add_kwargs['llm_output']:
|
||||
token_usage = add_kwargs['llm_output']['token_usage']
|
||||
|
||||
|
||||
# 提取具体的 token 数值
|
||||
if token_usage:
|
||||
input_tokens = token_usage.get('prompt_tokens', token_usage.get('input_tokens', 0))
|
||||
output_tokens = token_usage.get('completion_tokens', token_usage.get('output_tokens', 0))
|
||||
|
||||
|
||||
# 打印 LLM 的完整输出
|
||||
debug("\n" + "="*80)
|
||||
debug("📥 [LLM输出] 大模型返回的完整响应:")
|
||||
@@ -111,18 +125,21 @@ def create_llm_call_node(llm: BaseLLM, tools: list):
|
||||
if token_usage:
|
||||
debug(f"📋 [LLM统计] 详细用量: {token_usage}")
|
||||
debug("="*80 + "\n")
|
||||
|
||||
|
||||
result = {
|
||||
"messages": [response],
|
||||
"llm_calls": state.get('llm_calls', 0) + 1,
|
||||
"llm_calls": getattr(state, 'llm_calls', 0) + 1,
|
||||
"last_token_usage": token_usage,
|
||||
"last_elapsed_time": elapsed_time,
|
||||
"turns_since_last_summary": state.get('turns_since_last_summary', 0) + 1 # 递增计数器
|
||||
"turns_since_last_summary": getattr(state, 'turns_since_last_summary', 0) + 1,
|
||||
"final_result": response.content,
|
||||
"success": True,
|
||||
"current_phase": "done"
|
||||
}
|
||||
|
||||
|
||||
log_state_change("llm_call", {**state, **result}, "离开")
|
||||
return result
|
||||
|
||||
|
||||
except Exception as e:
|
||||
elapsed_time = time.time() - start_time
|
||||
error(f"\n❌ [LLM错误] 调用失败 (耗时: {elapsed_time:.2f}秒)")
|
||||
@@ -131,20 +148,23 @@ def create_llm_call_node(llm: BaseLLM, tools: list):
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
debug("="*80 + "\n")
|
||||
|
||||
|
||||
# 返回一个友好的错误消息
|
||||
error_response = AIMessage(
|
||||
content="抱歉,模型暂时无法响应,可能是网络超时或服务繁忙,请稍后再试。"
|
||||
)
|
||||
error_result = {
|
||||
"messages": [error_response],
|
||||
"llm_calls": state.get('llm_calls', 0),
|
||||
"llm_calls": getattr(state, 'llm_calls', 0),
|
||||
"last_token_usage": {},
|
||||
"last_elapsed_time": elapsed_time,
|
||||
"turns_since_last_summary": state.get('turns_since_last_summary', 0) + 1 # 即使出错也递增计数器
|
||||
"turns_since_last_summary": getattr(state, 'turns_since_last_summary', 0) + 1,
|
||||
"final_result": "抱歉,模型暂时无法响应,可能是网络超时或服务繁忙,请稍后再试。",
|
||||
"success": False,
|
||||
"current_phase": "done"
|
||||
}
|
||||
|
||||
|
||||
log_state_change("llm_call", state, "离开(异常)")
|
||||
return error_result
|
||||
|
||||
|
||||
return call_llm
|
||||
|
||||
@@ -3,7 +3,6 @@ React 模式节点模块 - 带超时和重试功能
|
||||
包含:
|
||||
- react_reason_node: 使用 intent.py 进行推理
|
||||
- error_handling_node: 错误处理节点
|
||||
- final_response_node: 最终回答节点
|
||||
- init_state_node: 初始化状态节点
|
||||
|
||||
注意:为了兼容 LangGraph 的同步接口,我们保留了同步的 react_reason 调用
|
||||
@@ -233,98 +232,6 @@ def error_handling_node(state: MainGraphState) -> MainGraphState:
|
||||
return state
|
||||
|
||||
|
||||
# ========== 3. 最终回答节点 ==========
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
async def final_response_node(state: MainGraphState, config: RunnableConfig) -> MainGraphState:
|
||||
"""
|
||||
最终回答节点:调用 LLM 生成最终回答(支持流式输出)
|
||||
"""
|
||||
state.current_phase = "finalizing"
|
||||
|
||||
# 如果已经有 final_result 了,直接返回
|
||||
if state.final_result:
|
||||
state.current_phase = "done"
|
||||
return state
|
||||
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 构建 LLM 调用链
|
||||
from app.agent.prompts import create_system_prompt
|
||||
from app.model_services.chat_services import get_chat_service
|
||||
from app.logger import debug, info
|
||||
|
||||
llm = get_chat_service()
|
||||
prompt = create_system_prompt(tools=[])
|
||||
chain = prompt | llm
|
||||
|
||||
# 构建上下文
|
||||
memory_context = getattr(state, "memory_context", "暂无用户信息")
|
||||
|
||||
# 添加 RAG 上下文到消息
|
||||
messages_with_context = list(state.messages)
|
||||
if state.rag_context:
|
||||
# 把 RAG 上下文作为系统消息添加
|
||||
from langchain_core.messages import SystemMessage
|
||||
rag_system_msg = SystemMessage(content=f"以下是检索到的相关信息:\n{state.rag_context}")
|
||||
# 插入到第一个用户消息之前
|
||||
inserted = False
|
||||
for i, msg in enumerate(messages_with_context):
|
||||
if msg.type == "human":
|
||||
messages_with_context.insert(i, rag_system_msg)
|
||||
inserted = True
|
||||
break
|
||||
if not inserted:
|
||||
messages_with_context.insert(0, rag_system_msg)
|
||||
|
||||
# 调用 LLM(流式输出)
|
||||
chunks = []
|
||||
async for chunk in chain.astream(
|
||||
{
|
||||
"messages": messages_with_context,
|
||||
"memory_context": memory_context
|
||||
},
|
||||
config=config
|
||||
):
|
||||
chunks.append(chunk)
|
||||
|
||||
# 将所有 chunk 合并成最终的 AIMessage
|
||||
if chunks:
|
||||
response = chunks[0]
|
||||
for chunk in chunks[1:]:
|
||||
response = response + chunk
|
||||
else:
|
||||
response = AIMessage(content="")
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# 更新状态
|
||||
state.messages.append(response)
|
||||
state.final_result = response.content
|
||||
state.success = True
|
||||
state.current_phase = "done"
|
||||
state.end_time = datetime.now().isoformat()
|
||||
state.llm_calls = getattr(state, "llm_calls", 0) + 1
|
||||
|
||||
info(f"⏱️ [LLM统计] 调用耗时: {elapsed_time:.2f}秒")
|
||||
|
||||
except Exception as e:
|
||||
from app.logger import error
|
||||
import traceback
|
||||
error(f"❌ [LLM错误] 调用失败: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
state.final_result = "抱歉,模型暂时无法响应,请稍后再试。"
|
||||
state.success = False
|
||||
state.current_phase = "done"
|
||||
|
||||
return state
|
||||
|
||||
|
||||
# ========== 4. 初始化状态节点 ==========
|
||||
|
||||
def init_state_node(state: MainGraphState) -> MainGraphState:
|
||||
@@ -353,11 +260,11 @@ def route_by_reasoning(state: MainGraphState) -> str:
|
||||
"""
|
||||
# 先检查特殊情况
|
||||
if state.current_phase == "max_steps_exceeded":
|
||||
return "final_response"
|
||||
return "llm_call"
|
||||
if state.current_phase == "error_handling" or state.current_error:
|
||||
return "handle_error"
|
||||
if state.current_phase == "finalizing" or state.current_phase == "done":
|
||||
return "final_response"
|
||||
return "llm_call"
|
||||
if state.current_phase == "retrying":
|
||||
if state.retry_action and "rag" in state.retry_action.lower():
|
||||
return "rag_retrieve"
|
||||
@@ -367,7 +274,7 @@ def route_by_reasoning(state: MainGraphState) -> str:
|
||||
reasoning_result: Optional[ReasoningResult] = state.debug_info.get("reasoning_result")
|
||||
|
||||
if not reasoning_result:
|
||||
return "final_response"
|
||||
return "llm_call"
|
||||
|
||||
# 使用 intent.py 提供的路由函数
|
||||
route = get_route_by_reasoning(reasoning_result)
|
||||
@@ -375,18 +282,18 @@ def route_by_reasoning(state: MainGraphState) -> str:
|
||||
# 映射到我们的节点名称
|
||||
# 注意:这些名称必须与 main_graph_builder.py 中定义的节点名称一致
|
||||
route_mapping = {
|
||||
"direct_response": "final_response",
|
||||
"direct_response": "llm_call",
|
||||
"retrieve_rag": "rag_retrieve",
|
||||
"re_retrieve_rag": "rag_retrieve",
|
||||
"web_search": "web_search", # ⭐ 新增:联网搜索
|
||||
"clarify": "final_response",
|
||||
"call_tool": "final_response", # 暂时映射到 final_response,后续可以扩展
|
||||
"web_search": "web_search",
|
||||
"clarify": "llm_call",
|
||||
"call_tool": "llm_call",
|
||||
"contact": "contact_subgraph",
|
||||
"dictionary": "dictionary_subgraph",
|
||||
"news_analysis": "news_analysis_subgraph",
|
||||
}
|
||||
|
||||
return route_mapping.get(route, "final_response")
|
||||
return route_mapping.get(route, "llm_call")
|
||||
|
||||
|
||||
# ========== 导出 ==========
|
||||
@@ -394,8 +301,7 @@ def route_by_reasoning(state: MainGraphState) -> str:
|
||||
__all__ = [
|
||||
"init_state_node",
|
||||
"react_reason_node",
|
||||
"web_search_node", # ⭐ 新增
|
||||
"web_search_node",
|
||||
"error_handling_node",
|
||||
"final_response_node",
|
||||
"route_by_reasoning"
|
||||
]
|
||||
|
||||
@@ -10,11 +10,11 @@ from app.main_graph.state import MainGraphState, CurrentAction
|
||||
from app.main_graph.nodes.react_nodes import (
|
||||
init_state_node,
|
||||
react_reason_node,
|
||||
web_search_node, # ⭐ 新增
|
||||
web_search_node,
|
||||
error_handling_node,
|
||||
final_response_node,
|
||||
route_by_reasoning
|
||||
)
|
||||
from app.main_graph.nodes.llm_call import create_llm_call_node
|
||||
from app.main_graph.nodes.rag_nodes import rag_retrieve_node
|
||||
from app.subgraphs.contact import build_contact_subgraph
|
||||
from app.subgraphs.dictionary import build_dictionary_subgraph
|
||||
@@ -75,10 +75,10 @@ def wrap_subgraph_for_error_handling(subgraph, name: str):
|
||||
|
||||
|
||||
# ========== 主图构建 ==========
|
||||
def build_react_main_graph() -> StateGraph:
|
||||
def build_react_main_graph(llm=None, tools=None) -> StateGraph:
|
||||
"""
|
||||
构建完整的 React 模式主图
|
||||
|
||||
|
||||
流程:
|
||||
START
|
||||
↓
|
||||
@@ -87,18 +87,23 @@ def build_react_main_graph() -> StateGraph:
|
||||
react_reason (推理) ←──────────────┐
|
||||
↓ │
|
||||
条件路由 │
|
||||
├─→ rag_retrieve →───────────────┤
|
||||
├─→ contact_subgraph →───────────┤
|
||||
├─→ dictionary_subgraph →────────┤
|
||||
├─→ news_analysis_subgraph →─────┤
|
||||
├─→ handle_error → (重试或结束) ──┤
|
||||
└─→ final_response
|
||||
├─ rag_retrieve →───────────────┤
|
||||
├─ contact_subgraph →───────────┤
|
||||
├─ dictionary_subgraph →────────┤
|
||||
├─ news_analysis_subgraph →─────┤
|
||||
├─ handle_error → (重试或结束) ─┤
|
||||
└─ llm_call → END
|
||||
↓
|
||||
END
|
||||
"""
|
||||
# 创建图
|
||||
graph = StateGraph(MainGraphState)
|
||||
|
||||
# 创建 llm_call 节点
|
||||
llm_node = None
|
||||
if llm is not None:
|
||||
llm_node = create_llm_call_node(llm, tools or [])
|
||||
|
||||
# ========== 添加节点 ==========
|
||||
|
||||
# 1. 初始化节点
|
||||
@@ -110,14 +115,15 @@ def build_react_main_graph() -> StateGraph:
|
||||
# 3. RAG 检索节点
|
||||
graph.add_node("rag_retrieve", rag_retrieve_node)
|
||||
|
||||
# 4. 联网搜索节点 ⭐ 新增
|
||||
# 4. 联网搜索节点
|
||||
graph.add_node("web_search", web_search_node)
|
||||
|
||||
# 5. 错误处理节点
|
||||
graph.add_node("handle_error", error_handling_node)
|
||||
|
||||
# 6. 最终回答节点
|
||||
graph.add_node("final_response", final_response_node)
|
||||
# 6. LLM 调用节点(真正的大模型输出)
|
||||
if llm_node is not None:
|
||||
graph.add_node("llm_call", llm_node)
|
||||
|
||||
# ========== 添加子图节点 ==========
|
||||
|
||||
@@ -154,33 +160,34 @@ def build_react_main_graph() -> StateGraph:
|
||||
{
|
||||
# 检索分支 → 检索后回到推理
|
||||
"rag_retrieve": "rag_retrieve",
|
||||
|
||||
# 联网搜索分支 ⭐ 新增
|
||||
|
||||
# 联网搜索分支
|
||||
"web_search": "web_search",
|
||||
|
||||
|
||||
# 子图分支 → 子图后回到推理
|
||||
"contact_subgraph": "contact_subgraph",
|
||||
"dictionary_subgraph": "dictionary_subgraph",
|
||||
"news_analysis_subgraph": "news_analysis_subgraph",
|
||||
|
||||
|
||||
# 错误处理分支
|
||||
"handle_error": "handle_error",
|
||||
|
||||
# 最终回答分支
|
||||
"final_response": "final_response",
|
||||
|
||||
# LLM 调用分支 → 直接输出给用户
|
||||
"llm_call": "llm_call"
|
||||
}
|
||||
)
|
||||
|
||||
# 4. 循环边:检索/搜索/子图/错误处理 后 → 回到推理
|
||||
|
||||
# 4. 循环边:检索/搜索/子图/错误处理后 → 回到推理
|
||||
graph.add_edge("rag_retrieve", "react_reason")
|
||||
graph.add_edge("web_search", "react_reason") # ⭐ 新增
|
||||
graph.add_edge("web_search", "react_reason")
|
||||
graph.add_edge("contact_subgraph", "react_reason")
|
||||
graph.add_edge("dictionary_subgraph", "react_reason")
|
||||
graph.add_edge("news_analysis_subgraph", "react_reason")
|
||||
graph.add_edge("handle_error", "react_reason") # 错误处理后可能重试
|
||||
|
||||
# 5. 最终边:final_response → END
|
||||
graph.add_edge("final_response", END)
|
||||
graph.add_edge("handle_error", "react_reason")
|
||||
|
||||
# 5. 最终边:llm_call → END
|
||||
if llm_node is not None:
|
||||
graph.add_edge("llm_call", END)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
Reference in New Issue
Block a user