Files
ailine/backend/app/main_graph/nodes/react_nodes.py
root 9ed946cbe3
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 6m36s
修复: final_response_node 调用 LLM 并支持流式输出
2026-05-01 13:42:12 +08:00

402 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
React 模式节点模块 - 带超时和重试功能
包含:
- react_reason_node: 使用 intent.py 进行推理
- error_handling_node: 错误处理节点
- final_response_node: 最终回答节点
- init_state_node: 初始化状态节点
注意:为了兼容 LangGraph 的同步接口,我们保留了同步的 react_reason 调用
但内部会根据情况使用规则推理或尝试异步调用
"""
import sys
from typing import Dict, Any, Optional
from datetime import datetime
# 导入我们的 intent.py
from app.core.intent import (
react_reason,
get_route_by_reasoning,
ReasoningAction,
ReasoningResult
)
from app.core.state_base import StateUtils
from app.main_graph.state import MainGraphState, ErrorRecord, ErrorSeverity
from app.main_graph.utils.retry_utils import (
RetryConfig,
SUBGRAPH_RETRY_CONFIG
)
# ========== 1. React 推理节点 ==========
def react_reason_node(state: MainGraphState) -> MainGraphState:
"""
React 模式推理节点:判断下一步做什么
Returns: 更新后的状态
"""
state.current_phase = "react_reasoning"
state.reasoning_step += 1
# 检查是否超过最大步数
if state.reasoning_step > state.max_steps:
state.current_phase = "max_steps_exceeded"
state.final_result = (
f"❌ 推理步数超过限制(最大 {state.max_steps} 步),"
f"已执行 {state.reasoning_step - 1} 步。"
f"请简化您的问题或分批提问。"
)
state.success = False
return state
# 准备上下文
context = {
"retrieved_docs": state.rag_docs,
"previous_actions": [h.get("action") for h in state.reasoning_history],
"messages": state.messages,
"errors": state.errors
}
# 使用 intent.py 进行推理
# 注意:这里使用同步版本,内部会根据情况处理
result: ReasoningResult = react_reason(state.user_query, context)
# 记录推理历史
state.reasoning_history.append({
"step": state.reasoning_step,
"action": result.action.name,
"confidence": result.confidence,
"reasoning": result.reasoning,
"timestamp": datetime.now().isoformat()
})
# 更新状态
state.debug_info["last_reasoning"] = {
"action": result.action.name,
"confidence": result.confidence,
"reasoning": result.reasoning
}
# 保存推理结果到状态
state.debug_info["reasoning_result"] = result
# 确定下一步动作
state.last_action = result.action.name
return state
# ========== 2. 联网搜索节点 ==========
def web_search_node(state: MainGraphState) -> MainGraphState:
"""
联网搜索节点:执行搜索并将结果保存到状态
"""
state.current_phase = "web_searching"
# 获取搜索查询
reasoning_result = state.debug_info.get("reasoning_result")
search_query = reasoning_result.metadata.get("search_query", state.user_query) if reasoning_result else state.user_query
try:
from app.core import web_search
print(f"[WebSearch] 搜索: {search_query}")
search_result = web_search(search_query, max_results=5)
# 保存搜索结果到状态
if not hasattr(state, "web_search_results"):
state.web_search_results = []
state.web_search_results.append(search_result)
# 将搜索结果添加到 rag_context供 LLM 使用
if state.rag_context:
state.rag_context = f"{state.rag_context}\n\n---\n\n## 🌐 联网搜索结果:\n{search_result}"
else:
state.rag_context = f"## 🌐 联网搜索结果:\n{search_result}"
state.success = True
print(f"[WebSearch] 搜索完成")
except Exception as e:
from app.main_graph.state import ErrorRecord, ErrorSeverity
from datetime import datetime
error_record = ErrorRecord(
error_type="WebSearchError",
error_message=str(e),
severity=ErrorSeverity.WARNING,
source="web_search_node",
timestamp=datetime.now().isoformat(),
retry_count=0,
max_retries=2,
context={"search_query": search_query}
)
state.errors.append(error_record)
state.current_error = error_record
state.current_phase = "error_handling"
state.success = False
return state
# ========== 3. 错误处理节点 ==========
def error_handling_node(state: MainGraphState) -> MainGraphState:
"""
错误处理节点:处理子图/工具调用错误
返回结构化错误信息,格式如下:
{
"tool/node": "...",
"status": "failed",
"error": "...",
"retries_exceeded": true/false,
"suggestion": "..."
}
"""
state.current_phase = "error_handling"
if not state.current_error:
state.current_phase = "react_reasoning"
return state
error = state.current_error
# 更新错误状态
state.error_message = f"{error.error_type}: {error.error_message}"
# 记录结构化错误信息
structured_error = {
"tool": error.source,
"status": "failed",
"error": error.error_message,
"retries_exceeded": error.retry_count >= error.max_retries,
"retry_count": error.retry_count,
"max_retries": error.max_retries
}
# 根据错误类型添加建议
if "RAG" in error.error_type:
structured_error["suggestion"] = "尝试重新表述问题或直接询问"
elif "subgraph" in error.source or "contact" in error.source:
structured_error["suggestion"] = "子图执行失败,请尝试简化查询"
elif "timeout" in error.error_message.lower():
structured_error["suggestion"] = "请求超时,请稍后再试"
else:
structured_error["suggestion"] = "请尝试其他方式提问"
state.debug_info["structured_error"] = structured_error
# 策略1: 检查是否可以重试
can_retry = (
error.severity in [ErrorSeverity.WARNING, ErrorSeverity.ERROR]
and error.retry_count < error.max_retries
)
if can_retry:
error.retry_count += 1
state.retry_action = error.source
state.debug_info["retry_count"] = error.retry_count
if "RAG" in error.error_type:
state.last_action = "RE_RETRIEVE_RAG"
elif "subgraph" in error.source:
state.last_action = "DIRECT_RESPONSE"
else:
state.last_action = "REASON"
state.current_phase = "retrying"
return state
# 策略2: 无法重试,尝试降级方案
if error.severity != ErrorSeverity.FATAL:
state.final_result = (
f"⚠️ 遇到一些问题:\n"
f"```json\n{structured_error}\n```\n"
f"但我会尽力用现有信息回答您。"
)
state.success = True
state.current_phase = "finalizing"
return state
# 策略3: 致命错误
state.final_result = (
f"❌ 服务暂时不可用,请稍后再试。\n"
f"```json\n{structured_error}\n```"
)
state.success = False
state.current_phase = "finalizing"
return state
# ========== 3. 最终回答节点 ==========
from langchain_core.runnables.config import RunnableConfig
from langchain_core.messages import AIMessage
async def final_response_node(state: MainGraphState, config: RunnableConfig) -> MainGraphState:
"""
最终回答节点:调用 LLM 生成最终回答(支持流式输出)
"""
state.current_phase = "finalizing"
# 如果已经有 final_result 了,直接返回
if state.final_result:
state.current_phase = "done"
return state
import time
start_time = time.time()
try:
# 构建 LLM 调用链
from app.agent.prompts import create_system_prompt
from app.model_services.chat_services import get_chat_service
from app.logger import debug, info
llm = get_chat_service()
prompt = create_system_prompt(tools=[])
chain = prompt | llm
# 构建上下文
memory_context = getattr(state, "memory_context", "暂无用户信息")
# 添加 RAG 上下文到消息
messages_with_context = list(state.messages)
if state.rag_context:
# 把 RAG 上下文作为系统消息添加
from langchain_core.messages import SystemMessage
rag_system_msg = SystemMessage(content=f"以下是检索到的相关信息:\n{state.rag_context}")
# 插入到第一个用户消息之前
inserted = False
for i, msg in enumerate(messages_with_context):
if msg.type == "human":
messages_with_context.insert(i, rag_system_msg)
inserted = True
break
if not inserted:
messages_with_context.insert(0, rag_system_msg)
# 调用 LLM流式输出
chunks = []
async for chunk in chain.astream(
{
"messages": messages_with_context,
"memory_context": memory_context
},
config=config
):
chunks.append(chunk)
# 将所有 chunk 合并成最终的 AIMessage
if chunks:
response = chunks[0]
for chunk in chunks[1:]:
response = response + chunk
else:
response = AIMessage(content="")
elapsed_time = time.time() - start_time
# 更新状态
state.messages.append(response)
state.final_result = response.content
state.success = True
state.current_phase = "done"
state.end_time = datetime.now().isoformat()
state.llm_calls = getattr(state, "llm_calls", 0) + 1
info(f"⏱️ [LLM统计] 调用耗时: {elapsed_time:.2f}")
except Exception as e:
from app.logger import error
import traceback
error(f"❌ [LLM错误] 调用失败: {e}")
traceback.print_exc()
state.final_result = "抱歉,模型暂时无法响应,请稍后再试。"
state.success = False
state.current_phase = "done"
return state
# ========== 4. 初始化状态节点 ==========
def init_state_node(state: MainGraphState) -> MainGraphState:
"""
初始化状态节点:在流程开始时设置初始值
"""
state.current_phase = "initializing"
state.reasoning_step = 0
state.start_time = datetime.now().isoformat()
# 从 messages 中提取用户查询
if not state.user_query and state.messages:
last_msg = state.messages[-1]
state.user_query = getattr(last_msg, "content", str(last_msg))
return state
# ========== 5. 条件路由函数 ==========
def route_by_reasoning(state: MainGraphState) -> str:
"""
根据推理结果决定下一步路由
Returns: 路由标识,对应 graph_builder.py 中的边
"""
# 先检查特殊情况
if state.current_phase == "max_steps_exceeded":
return "final_response"
if state.current_phase == "error_handling" or state.current_error:
return "handle_error"
if state.current_phase == "finalizing" or state.current_phase == "done":
return "final_response"
if state.current_phase == "retrying":
if state.retry_action and "rag" in state.retry_action.lower():
return "rag_retrieve"
return "react_reason"
# 获取推理结果
reasoning_result: Optional[ReasoningResult] = state.debug_info.get("reasoning_result")
if not reasoning_result:
return "final_response"
# 使用 intent.py 提供的路由函数
route = get_route_by_reasoning(reasoning_result)
# 映射到我们的节点名称
# 注意:这些名称必须与 main_graph_builder.py 中定义的节点名称一致
route_mapping = {
"direct_response": "final_response",
"retrieve_rag": "rag_retrieve",
"re_retrieve_rag": "rag_retrieve",
"web_search": "web_search", # ⭐ 新增:联网搜索
"clarify": "final_response",
"call_tool": "final_response", # 暂时映射到 final_response后续可以扩展
"contact": "contact_subgraph",
"dictionary": "dictionary_subgraph",
"news_analysis": "news_analysis_subgraph",
}
return route_mapping.get(route, "final_response")
# ========== 导出 ==========
__all__ = [
"init_state_node",
"react_reason_node",
"web_search_node", # ⭐ 新增
"error_handling_node",
"final_response_node",
"route_by_reasoning"
]