Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Has been cancelled
- 新增 rag_nodes.py: 独立的 RAG 检索节点 - 从 react_nodes.py 移除 RAG 相关代码 - 更新导入和导出 - rag_nodes.py 包含 rag_retrieve_node 和 rag_re_retrieve_node - 添加 inject_rag_tool_to_state 工具函数
287 lines
8.7 KiB
Python
287 lines
8.7 KiB
Python
"""
|
|
React 模式节点模块 - 带超时和重试功能
|
|
包含:
|
|
- react_reason_node: 使用 intent.py 进行推理
|
|
- error_handling_node: 错误处理节点
|
|
- final_response_node: 最终回答节点
|
|
- init_state_node: 初始化节点
|
|
"""
|
|
|
|
import sys
|
|
from typing import Dict, Any, Optional
|
|
from datetime import datetime
|
|
|
|
# 导入我们的 intent.py
|
|
from ..agent_subgraphs.common.intent import (
|
|
react_reason,
|
|
get_route_by_reasoning,
|
|
ReasoningAction,
|
|
ReasoningResult
|
|
)
|
|
from ..agent_subgraphs.common.state_base import StateUtils
|
|
from .state import MainGraphState, ErrorRecord, ErrorSeverity
|
|
from .retry_utils import (
|
|
RetryConfig,
|
|
SUBGRAPH_RETRY_CONFIG
|
|
)
|
|
|
|
|
|
# ========== 1. React 推理节点 ==========
|
|
def react_reason_node(state: MainGraphState) -> MainGraphState:
|
|
"""
|
|
React 模式推理节点:判断下一步做什么
|
|
|
|
Returns: 更新后的状态
|
|
"""
|
|
state.current_phase = "react_reasoning"
|
|
state.reasoning_step += 1
|
|
|
|
# 检查是否超过最大步数
|
|
if state.reasoning_step > state.max_steps:
|
|
state.current_phase = "max_steps_exceeded"
|
|
state.final_result = (
|
|
f"❌ 推理步数超过限制(最大 {state.max_steps} 步),"
|
|
f"已执行 {state.reasoning_step - 1} 步。"
|
|
f"请简化您的问题或分批提问。"
|
|
)
|
|
state.success = False
|
|
return state
|
|
|
|
# 准备上下文
|
|
context = {
|
|
"retrieved_docs": state.rag_docs,
|
|
"previous_actions": [h.get("action") for h in state.reasoning_history],
|
|
"messages": state.messages,
|
|
"errors": state.errors
|
|
}
|
|
|
|
# 使用 intent.py 进行推理
|
|
result: ReasoningResult = react_reason(state.user_query, context)
|
|
|
|
# 记录推理历史
|
|
state.reasoning_history.append({
|
|
"step": state.reasoning_step,
|
|
"action": result.action.name,
|
|
"confidence": result.confidence,
|
|
"reasoning": result.reasoning,
|
|
"timestamp": datetime.now().isoformat()
|
|
})
|
|
|
|
# 更新状态
|
|
state.debug_info["last_reasoning"] = {
|
|
"action": result.action.name,
|
|
"confidence": result.confidence,
|
|
"reasoning": result.reasoning
|
|
}
|
|
|
|
# 保存推理结果到状态
|
|
state.debug_info["reasoning_result"] = result
|
|
|
|
# 确定下一步动作
|
|
state.last_action = result.action.name
|
|
|
|
return state
|
|
|
|
|
|
# ========== 2. 错误处理节点 ==========
|
|
def error_handling_node(state: MainGraphState) -> MainGraphState:
|
|
"""
|
|
错误处理节点:处理子图/工具调用错误
|
|
|
|
返回结构化错误信息,格式如下:
|
|
{
|
|
"tool/node": "...",
|
|
"status": "failed",
|
|
"error": "...",
|
|
"retries_exhausted": true/false,
|
|
"suggestion": "..."
|
|
}
|
|
"""
|
|
state.current_phase = "error_handling"
|
|
|
|
if not state.current_error:
|
|
state.current_phase = "react_reasoning"
|
|
return state
|
|
|
|
error = state.current_error
|
|
|
|
# 更新错误状态
|
|
state.error_message = f"{error.error_type}: {error.error_message}"
|
|
|
|
# 记录结构化错误信息
|
|
structured_error = {
|
|
"tool": error.source,
|
|
"status": "failed",
|
|
"error": error.error_message,
|
|
"retries_exhausted": error.retry_count >= error.max_retries,
|
|
"retry_count": error.retry_count,
|
|
"max_retries": error.max_retries
|
|
}
|
|
|
|
# 根据错误类型添加建议
|
|
if "RAG" in error.error_type:
|
|
structured_error["suggestion"] = "尝试重新表述问题或直接询问"
|
|
elif "subgraph" in error.source or "contact" in error.source:
|
|
structured_error["suggestion"] = "子图执行失败,请尝试简化查询"
|
|
elif "timeout" in error.error_message.lower():
|
|
structured_error["suggestion"] = "请求超时,请稍后再试"
|
|
else:
|
|
structured_error["suggestion"] = "请尝试其他方式提问"
|
|
|
|
state.debug_info["structured_error"] = structured_error
|
|
|
|
# 策略1: 检查是否可以重试
|
|
can_retry = (
|
|
error.severity in [ErrorSeverity.WARNING, ErrorSeverity.ERROR]
|
|
and error.retry_count < error.max_retries
|
|
)
|
|
|
|
if can_retry:
|
|
error.retry_count += 1
|
|
state.retry_action = error.source
|
|
state.debug_info["retry_count"] = error.retry_count
|
|
|
|
if "RAG" in error.error_type:
|
|
state.last_action = "RE_RETRIEVE_RAG"
|
|
elif "subgraph" in error.source:
|
|
state.last_action = "DIRECT_RESPONSE"
|
|
else:
|
|
state.last_action = "REASON"
|
|
|
|
state.current_phase = "retrying"
|
|
return state
|
|
|
|
# 策略2: 无法重试,尝试降级方案
|
|
if error.severity != ErrorSeverity.FATAL:
|
|
state.final_result = (
|
|
f"⚠️ 遇到一些问题:\n"
|
|
f"```json\n{structured_error}\n```\n"
|
|
f"但我会尽力用现有信息回答您。"
|
|
)
|
|
state.success = True
|
|
state.current_phase = "finalizing"
|
|
return state
|
|
|
|
# 策略3: 致命错误
|
|
state.final_result = (
|
|
f"❌ 服务暂时不可用,请稍后再试。\n"
|
|
f"```json\n{structured_error}\n```"
|
|
)
|
|
state.success = False
|
|
state.current_phase = "finalizing"
|
|
|
|
return state
|
|
|
|
|
|
# ========== 3. 最终回答节点 ==========
|
|
def final_response_node(state: MainGraphState) -> MainGraphState:
|
|
"""
|
|
最终回答节点:整理并生成最终回答
|
|
"""
|
|
state.current_phase = "finalizing"
|
|
|
|
# 如果已经有 final_result 了,直接返回
|
|
if state.final_result:
|
|
state.current_phase = "done"
|
|
return state
|
|
|
|
# 构建最终回答
|
|
parts = []
|
|
|
|
# 添加 RAG 上下文(如果有)
|
|
if state.rag_context:
|
|
parts.append(state.rag_context)
|
|
parts.append("---")
|
|
|
|
# 添加子图结果(如果有)
|
|
if state.contact_result and hasattr(state.contact_result, "get"):
|
|
if state.contact_result.get("final_result"):
|
|
parts.append(state.contact_result["final_result"])
|
|
if state.dictionary_result and hasattr(state.dictionary_result, "get"):
|
|
if state.dictionary_result.get("final_result"):
|
|
parts.append(state.dictionary_result["final_result"])
|
|
if state.news_result and hasattr(state.news_result, "get"):
|
|
if state.news_result.get("final_result"):
|
|
parts.append(state.news_result["final_result"])
|
|
|
|
# 如果都没有,用默认回答
|
|
if not parts:
|
|
parts.append(f"我理解了您的问题:{state.user_query}")
|
|
|
|
state.final_result = "\n".join(parts)
|
|
state.success = True
|
|
state.current_phase = "done"
|
|
state.end_time = datetime.now().isoformat()
|
|
|
|
return state
|
|
|
|
|
|
# ========== 4. 初始化状态节点 ==========
|
|
def init_state_node(state: MainGraphState) -> MainGraphState:
|
|
"""
|
|
初始化状态节点:在流程开始时设置初始值
|
|
"""
|
|
state.current_phase = "initializing"
|
|
state.reasoning_step = 0
|
|
state.start_time = datetime.now().isoformat()
|
|
|
|
# 从 messages 中提取用户查询
|
|
if not state.user_query and state.messages:
|
|
last_msg = state.messages[-1]
|
|
state.user_query = getattr(last_msg, "content", str(last_msg))
|
|
|
|
return state
|
|
|
|
|
|
# ========== 5. 条件路由函数 ==========
|
|
def route_by_reasoning(state: MainGraphState) -> str:
|
|
"""
|
|
根据推理结果决定下一步路由
|
|
|
|
Returns: 路由字符串
|
|
"""
|
|
# 先检查特殊情况
|
|
if state.current_phase == "max_steps_exceeded":
|
|
return "final_response"
|
|
if state.current_phase == "error_handling" or state.current_error:
|
|
return "handle_error"
|
|
if state.current_phase == "finalizing" or state.current_phase == "done":
|
|
return "final_response"
|
|
if state.current_phase == "retrying":
|
|
if state.retry_action and "rag" in state.retry_action.lower():
|
|
return "rag_retrieve"
|
|
return "react_reason"
|
|
|
|
# 获取推理结果
|
|
reasoning_result: Optional[ReasoningResult] = state.debug_info.get("reasoning_result")
|
|
|
|
if not reasoning_result:
|
|
return "final_response"
|
|
|
|
# 使用 intent.py 提供的路由函数
|
|
route = get_route_by_reasoning(reasoning_result)
|
|
|
|
# 映射到我们的节点名称
|
|
route_mapping = {
|
|
"direct_response": "final_response",
|
|
"retrieve_rag": "rag_retrieve",
|
|
"re_retrieve_rag": "rag_retrieve",
|
|
"clarify": "final_response",
|
|
"call_tool": "final_response",
|
|
"contact": "contact_subgraph",
|
|
"dictionary": "dictionary_subgraph",
|
|
"news_analysis": "news_analysis_subgraph",
|
|
}
|
|
|
|
return route_mapping.get(route, "final_response")
|
|
|
|
|
|
# ========== 导出 ==========
|
|
__all__ = [
|
|
"init_state_node",
|
|
"react_reason_node",
|
|
"error_handling_node",
|
|
"final_response_node",
|
|
"route_by_reasoning"
|
|
]
|