Files
ailine/backend/app/main_graph/nodes/react_nodes.py
root 3f6bbdec92
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 6m41s
给关键节点添加思考过程输出
- react_reason_node: 直接发送自定义推理事件
- web_search_node: 添加开始/完成/错误事件
- rag_retrieve_node: 添加开始/完成/重试/错误事件
- 子图包装器: 添加子图开始/完成/错误事件
2026-05-02 09:23:07 +08:00

407 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
React 模式节点模块 - 带超时和重试功能
包含:
- react_reason_node: 使用 intent.py 进行推理
- error_handling_node: 错误处理节点
- init_state_node: 初始化状态节点
注意:为了兼容 LangGraph 的同步接口,我们保留了同步的 react_reason 调用
但内部会根据情况使用规则推理或尝试异步调用
"""
import sys
from typing import Dict, Any, Optional
from datetime import datetime
# 导入我们的 intent.py
from app.core.intent import (
react_reason,
react_reason_async,
get_route_by_reasoning,
ReasoningAction,
ReasoningResult
)
from app.core.state_base import StateUtils
from app.main_graph.state import MainGraphState, ErrorRecord, ErrorSeverity
from app.main_graph.utils.retry_utils import (
RetryConfig,
SUBGRAPH_RETRY_CONFIG
)
from app.logger import info
# ========== 1. React 推理节点 ==========
async def react_reason_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState:
"""
React 模式推理节点:判断下一步做什么(异步版本)
Returns: 更新后的状态
"""
state.current_phase = "react_reasoning"
state.reasoning_step += 1
info(f"[react_reason] 第 {state.reasoning_step} 次推理开始")
# 检查是否超过最大步数
if state.reasoning_step > state.max_steps:
state.current_phase = "max_steps_exceeded"
state.final_result = (
f"❌ 推理步数超过限制(最大 {state.max_steps} 步),"
f"已执行 {state.reasoning_step - 1} 步。"
f"请简化您的问题或分批提问。"
)
state.success = False
return state
# 准备上下文
context = {
"retrieved_docs": state.rag_docs,
"previous_actions": [h.get("action") for h in state.reasoning_history],
"messages": state.messages,
"errors": state.errors
}
# 使用 intent.py 进行推理(现在直接用异步版本)
result: ReasoningResult = await react_reason_async(state.user_query, context)
info(f"[react_reason] 推理结果: action={result.action.name}, confidence={result.confidence}")
if result.reasoning:
info(f"[react_reason] 推理过程: {result.reasoning}")
# 关键修复:直接发送自定义事件给 agent_service而不是通过 state
if config:
try:
from langchain_core.callbacks.manager import adispatch_custom_event
callbacks = config.get("callbacks")
if callbacks:
info(f"[react_reason] 直接发送推理事件 #{state.reasoning_step}")
await adispatch_custom_event(
"react_reasoning",
{
"step": state.reasoning_step,
"action": result.action.name,
"confidence": result.confidence,
"reasoning": result.reasoning
},
callbacks=callbacks
)
except Exception as e:
info(f"[react_reason] 无法发送自定义事件: {e}")
# 记录推理历史
state.reasoning_history.append({
"step": state.reasoning_step,
"action": result.action.name,
"confidence": result.confidence,
"reasoning": result.reasoning,
"timestamp": datetime.now().isoformat()
})
# 更新状态
state.debug_info["last_reasoning"] = {
"action": result.action.name,
"confidence": result.confidence,
"reasoning": result.reasoning
}
# 保存推理结果到状态
state.debug_info["reasoning_result"] = result
# 确定下一步动作
state.last_action = result.action.name
# 关键修复:不再设置 latest_reasoning避免 agent_service 重复读取
if "latest_reasoning" in state.debug_info:
del state.debug_info["latest_reasoning"]
return state
# ========== 2. 联网搜索节点 ==========
async def web_search_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState:
"""
联网搜索节点:执行搜索并将结果保存到状态
"""
state.current_phase = "web_searching"
# 发送开始事件
if config:
try:
from langchain_core.callbacks.manager import adispatch_custom_event
callbacks = config.get("callbacks")
if callbacks:
await adispatch_custom_event(
"react_reasoning",
{
"step": state.reasoning_step,
"action": "web_search_start",
"confidence": 1.0,
"reasoning": "开始执行联网搜索..."
},
callbacks=callbacks
)
except Exception as e:
info(f"[web_search_node] 无法发送开始事件: {e}")
# 获取搜索查询
reasoning_result = state.debug_info.get("reasoning_result")
search_query = reasoning_result.metadata.get("search_query", state.user_query) if reasoning_result else state.user_query
try:
from app.core import web_search
print(f"[WebSearch] 搜索: {search_query}")
search_result = web_search(search_query, max_results=5)
# 保存搜索结果到状态
if not hasattr(state, "web_search_results"):
state.web_search_results = []
state.web_search_results.append(search_result)
# 将搜索结果添加到 rag_context供 LLM 使用
if state.rag_context:
state.rag_context = f"{state.rag_context}\n\n---\n\n## 🌐 联网搜索结果:\n{search_result}"
else:
state.rag_context = f"## 🌐 联网搜索结果:\n{search_result}"
state.success = True
print(f"[WebSearch] 搜索完成")
# 发送完成事件
if config:
try:
from langchain_core.callbacks.manager import adispatch_custom_event
callbacks = config.get("callbacks")
if callbacks:
await adispatch_custom_event(
"react_reasoning",
{
"step": state.reasoning_step,
"action": "web_search_complete",
"confidence": 1.0,
"reasoning": f"联网搜索完成,找到 {len(search_result) if isinstance(search_result, list) else 1} 条结果"
},
callbacks=callbacks
)
except Exception as e:
info(f"[web_search_node] 无法发送完成事件: {e}")
except Exception as e:
from app.main_graph.state import ErrorRecord, ErrorSeverity
from datetime import datetime
error_record = ErrorRecord(
error_type="WebSearchError",
error_message=str(e),
severity=ErrorSeverity.WARNING,
source="web_search_node",
timestamp=datetime.now().isoformat(),
retry_count=0,
max_retries=2,
context={"search_query": search_query}
)
state.errors.append(error_record)
state.current_error = error_record
state.current_phase = "error_handling"
state.success = False
# 发送错误事件
if config:
try:
from langchain_core.callbacks.manager import adispatch_custom_event
callbacks = config.get("callbacks")
if callbacks:
await adispatch_custom_event(
"react_reasoning",
{
"step": state.reasoning_step,
"action": "web_search_error",
"confidence": 1.0,
"reasoning": f"联网搜索失败: {str(e)}"
},
callbacks=callbacks
)
except Exception as e:
info(f"[web_search_node] 无法发送错误事件: {e}")
return state
# ========== 3. 错误处理节点 ==========
def error_handling_node(state: MainGraphState) -> MainGraphState:
"""
错误处理节点:处理子图/工具调用错误
返回结构化错误信息,格式如下:
{
"tool/node": "...",
"status": "failed",
"error": "...",
"retries_exceeded": true/false,
"suggestion": "..."
}
"""
state.current_phase = "error_handling"
if not state.current_error:
state.current_phase = "react_reasoning"
return state
error = state.current_error
# 更新错误状态
state.error_message = f"{error.error_type}: {error.error_message}"
# 记录结构化错误信息
structured_error = {
"tool": error.source,
"status": "failed",
"error": error.error_message,
"retries_exceeded": error.retry_count >= error.max_retries,
"retry_count": error.retry_count,
"max_retries": error.max_retries
}
# 根据错误类型添加建议
if "RAG" in error.error_type:
structured_error["suggestion"] = "尝试重新表述问题或直接询问"
elif "subgraph" in error.source or "contact" in error.source:
structured_error["suggestion"] = "子图执行失败,请尝试简化查询"
elif "timeout" in error.error_message.lower():
structured_error["suggestion"] = "请求超时,请稍后再试"
else:
structured_error["suggestion"] = "请尝试其他方式提问"
state.debug_info["structured_error"] = structured_error
# 策略1: 检查是否可以重试
can_retry = (
error.severity in [ErrorSeverity.WARNING, ErrorSeverity.ERROR]
and error.retry_count < error.max_retries
)
if can_retry:
error.retry_count += 1
state.retry_action = error.source
state.debug_info["retry_count"] = error.retry_count
if "RAG" in error.error_type:
state.last_action = "RE_RETRIEVE_RAG"
elif "subgraph" in error.source:
state.last_action = "DIRECT_RESPONSE"
else:
state.last_action = "REASON"
state.current_phase = "retrying"
return state
# 策略2: 无法重试,尝试降级方案
if error.severity != ErrorSeverity.FATAL:
state.final_result = (
f"⚠️ 遇到一些问题:\n"
f"```json\n{structured_error}\n```\n"
f"但我会尽力用现有信息回答您。"
)
state.success = True
state.current_phase = "finalizing"
return state
# 策略3: 致命错误
state.final_result = (
f"❌ 服务暂时不可用,请稍后再试。\n"
f"```json\n{structured_error}\n```"
)
state.success = False
state.current_phase = "finalizing"
return state
# ========== 4. 初始化状态节点 ==========
def init_state_node(state: MainGraphState) -> MainGraphState:
"""
初始化状态节点:在流程开始时设置初始值
"""
state.current_phase = "initializing"
state.reasoning_step = 0
state.start_time = datetime.now().isoformat()
# 从 messages 中提取用户查询
if not state.user_query and state.messages:
last_msg = state.messages[-1]
state.user_query = getattr(last_msg, "content", str(last_msg))
return state
# ========== 5. 条件路由函数 ==========
def route_by_reasoning(state: MainGraphState) -> str:
"""
根据推理结果决定下一步路由
Returns: 路由标识,对应 graph_builder.py 中的边
"""
# 先检查特殊情况
if state.current_phase == "max_steps_exceeded":
return "llm_call"
if state.current_phase == "error_handling" or state.current_error:
return "handle_error"
if state.current_phase == "finalizing" or state.current_phase == "done":
return "llm_call"
if state.current_phase == "retrying":
if state.retry_action and "rag" in state.retry_action.lower():
return "rag_retrieve"
return "react_reason"
# 关键修复:检查是否已经执行过子图,如果是,直接去 llm_call
previous_actions = [h.get("action") for h in state.reasoning_history]
if "subgraph_completed" in previous_actions or state.final_result:
return "llm_call"
# 检查是否刚刚执行完 rag 或 web search应该继续推理一次然后去 llm_call
# 但为了避免死循环,我们设置一个简单的规则
if len(previous_actions) > 3:
return "llm_call"
# 获取推理结果
reasoning_result: Optional[ReasoningResult] = state.debug_info.get("reasoning_result")
if not reasoning_result:
return "llm_call"
# 使用 intent.py 提供的路由函数
route = get_route_by_reasoning(reasoning_result)
# 映射到我们的节点名称
# 注意:这些名称必须与 main_graph_builder.py 中定义的节点名称一致
route_mapping = {
"direct_response": "llm_call",
"retrieve_rag": "rag_retrieve",
"re_retrieve_rag": "rag_retrieve",
"web_search": "web_search",
"clarify": "llm_call",
"call_tool": "llm_call",
"contact": "contact_subgraph",
"dictionary": "dictionary_subgraph",
"news_analysis": "news_analysis_subgraph",
}
return route_mapping.get(route, "llm_call")
# ========== 导出 ==========
__all__ = [
"init_state_node",
"react_reason_node",
"web_search_node",
"error_handling_node",
"route_by_reasoning"
]