Files
ailine/backend/app/main_graph/nodes/react_nodes.py
root c9bf21be0e
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Has been cancelled
fix: 修复 RAG 无限循环问题和导入错误
主要修复:
1. 修复 RAG 推理无限循环问题(大小写不匹配 + 缺少已检索结果检查)
2. 修复 intent_classifier.py 的绝对导入错误
3. 删除旧的 start.sh 脚本,添加新的启动脚本
4. 优化路由逻辑和状态管理
2026-05-04 18:59:15 +08:00

419 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
React 模式节点模块 - 带超时和重试功能
包含:
- react_reason_node: 使用 intent.py 进行推理
- error_handling_node: 错误处理节点
- init_state_node: 初始化状态节点
注意:为了兼容 LangGraph 的同步接口,我们保留了同步的 react_reason 调用
但内部会根据情况使用规则推理或尝试异步调用
"""
import sys
from typing import Dict, Any, Optional
from datetime import datetime
# 导入我们的 intent.py
from app.core.intent import (
react_reason,
react_reason_async,
get_route_by_reasoning,
ReasoningAction,
ReasoningResult
)
from app.core.state_base import StateUtils
from app.main_graph.state import MainGraphState, ErrorRecord, ErrorSeverity
from app.main_graph.utils.retry_utils import (
RetryConfig,
SUBGRAPH_RETRY_CONFIG
)
from app.logger import info
# ========== 1. React 推理节点 ==========
async def react_reason_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState:
"""
React 模式推理节点:判断下一步做什么(异步版本)
Returns: 更新后的状态
"""
state.current_phase = "react_reasoning"
state.reasoning_step += 1
info(f"[react_reason] 第 {state.reasoning_step} 次推理开始")
# 检查是否超过最大步数
if state.reasoning_step > state.max_steps:
state.current_phase = "max_steps_exceeded"
state.final_result = (
f"❌ 推理步数超过限制(最大 {state.max_steps} 步),"
f"已执行 {state.reasoning_step - 1} 步。"
f"请简化您的问题或分批提问。"
)
state.success = False
return state
# 准备上下文
context = {
"retrieved_docs": state.rag_docs,
"previous_actions": [h.get("action") for h in state.reasoning_history],
"messages": state.messages,
"errors": state.errors
}
# 使用 intent.py 进行推理(现在直接用异步版本)
result: ReasoningResult = await react_reason_async(state.user_query, context)
info(f"[react_reason] 推理结果: action={result.action.name}, confidence={result.confidence}")
if result.reasoning:
info(f"[react_reason] 推理过程: {result.reasoning}")
# 关键修复:直接发送自定义事件给 agent_service而不是通过 state
if config:
try:
from langchain_core.callbacks.manager import adispatch_custom_event
callbacks = config.get("callbacks")
if callbacks:
info(f"[react_reason] 直接发送推理事件 #{state.reasoning_step}")
await adispatch_custom_event(
"react_reasoning",
{
"step": state.reasoning_step,
"action": result.action.name,
"confidence": result.confidence,
"reasoning": result.reasoning
},
callbacks=callbacks
)
except Exception as e:
info(f"[react_reason] 无法发送自定义事件: {e}")
# 记录推理历史
state.reasoning_history.append({
"step": state.reasoning_step,
"action": result.action.name,
"confidence": result.confidence,
"reasoning": result.reasoning,
"timestamp": datetime.now().isoformat()
})
# 更新状态
state.debug_info["last_reasoning"] = {
"action": result.action.name,
"confidence": result.confidence,
"reasoning": result.reasoning
}
# 保存推理结果到状态
state.debug_info["reasoning_result"] = result
# 确定下一步动作
state.last_action = result.action.name
# 关键修复:不再设置 latest_reasoning避免 agent_service 重复读取
if "latest_reasoning" in state.debug_info:
del state.debug_info["latest_reasoning"]
return state
# ========== 2. 联网搜索节点 ==========
async def web_search_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState:
"""
联网搜索节点:执行搜索并将结果保存到状态
"""
state.current_phase = "web_searching"
# 发送开始事件
if config:
try:
from langchain_core.callbacks.manager import adispatch_custom_event
callbacks = config.get("callbacks")
if callbacks:
await adispatch_custom_event(
"react_reasoning",
{
"step": state.reasoning_step,
"action": "web_search_start",
"confidence": 1.0,
"reasoning": "开始执行联网搜索..."
},
callbacks=callbacks
)
except Exception as e:
info(f"[web_search_node] 无法发送开始事件: {e}")
# 获取搜索查询
reasoning_result = state.debug_info.get("reasoning_result")
search_query = reasoning_result.metadata.get("search_query", state.user_query) if reasoning_result else state.user_query
try:
from app.core import web_search
print(f"[WebSearch] 搜索: {search_query}")
search_result = web_search(search_query, max_results=5)
# 保存搜索结果到状态
if not hasattr(state, "web_search_results"):
state.web_search_results = []
state.web_search_results.append(search_result)
# 将搜索结果添加到 rag_context供 LLM 使用
if state.rag_context:
state.rag_context = f"{state.rag_context}\n\n---\n\n## 🌐 联网搜索结果:\n{search_result}"
else:
state.rag_context = f"## 🌐 联网搜索结果:\n{search_result}"
state.success = True
print(f"[WebSearch] 搜索完成")
# 发送完成事件
if config:
try:
from langchain_core.callbacks.manager import adispatch_custom_event
callbacks = config.get("callbacks")
if callbacks:
await adispatch_custom_event(
"react_reasoning",
{
"step": state.reasoning_step,
"action": "web_search_complete",
"confidence": 1.0,
"reasoning": f"联网搜索完成,找到 {len(search_result) if isinstance(search_result, list) else 1} 条结果"
},
callbacks=callbacks
)
except Exception as e:
info(f"[web_search_node] 无法发送完成事件: {e}")
except Exception as e:
from app.main_graph.state import ErrorRecord, ErrorSeverity
from datetime import datetime
error_record = ErrorRecord(
error_type="WebSearchError",
error_message=str(e),
severity=ErrorSeverity.WARNING,
source="web_search_node",
timestamp=datetime.now().isoformat(),
retry_count=0,
max_retries=2,
context={"search_query": search_query}
)
state.errors.append(error_record)
state.current_error = error_record
state.current_phase = "error_handling"
state.success = False
# 发送错误事件
if config:
try:
from langchain_core.callbacks.manager import adispatch_custom_event
callbacks = config.get("callbacks")
if callbacks:
await adispatch_custom_event(
"react_reasoning",
{
"step": state.reasoning_step,
"action": "web_search_error",
"confidence": 1.0,
"reasoning": f"联网搜索失败: {str(e)}"
},
callbacks=callbacks
)
except Exception as e:
info(f"[web_search_node] 无法发送错误事件: {e}")
return state
# ========== 3. 错误处理节点 ==========
def error_handling_node(state: MainGraphState) -> MainGraphState:
"""
错误处理节点:处理子图/工具调用错误
返回结构化错误信息,格式如下:
{
"tool/node": "...",
"status": "failed",
"error": "...",
"retries_exceeded": true/false,
"suggestion": "..."
}
"""
state.current_phase = "error_handling"
if not state.current_error:
state.current_phase = "react_reasoning"
return state
error = state.current_error
# 更新错误状态
state.error_message = f"{error.error_type}: {error.error_message}"
# 记录结构化错误信息
structured_error = {
"tool": error.source,
"status": "failed",
"error": error.error_message,
"retries_exceeded": error.retry_count >= error.max_retries,
"retry_count": error.retry_count,
"max_retries": error.max_retries
}
# 根据错误类型添加建议
if "RAG" in error.error_type:
structured_error["suggestion"] = "尝试重新表述问题或直接询问"
elif "subgraph" in error.source or "contact" in error.source:
structured_error["suggestion"] = "子图执行失败,请尝试简化查询"
elif "timeout" in error.error_message.lower():
structured_error["suggestion"] = "请求超时,请稍后再试"
else:
structured_error["suggestion"] = "请尝试其他方式提问"
state.debug_info["structured_error"] = structured_error
# 策略1: 检查是否可以重试
can_retry = (
error.severity in [ErrorSeverity.WARNING, ErrorSeverity.ERROR]
and error.retry_count < error.max_retries
)
if can_retry:
error.retry_count += 1
state.retry_action = error.source
state.debug_info["retry_count"] = error.retry_count
if "RAG" in error.error_type:
state.last_action = "RE_RETRIEVE_RAG"
elif "subgraph" in error.source:
state.last_action = "DIRECT_RESPONSE"
else:
state.last_action = "REASON"
state.current_phase = "retrying"
return state
# 策略2: 无法重试,尝试降级方案
if error.severity != ErrorSeverity.FATAL:
state.final_result = (
f"⚠️ 遇到一些问题:\n"
f"```json\n{structured_error}\n```\n"
f"但我会尽力用现有信息回答您。"
)
state.success = True
state.current_phase = "finalizing"
return state
# 策略3: 致命错误
state.final_result = (
f"❌ 服务暂时不可用,请稍后再试。\n"
f"```json\n{structured_error}\n```"
)
state.success = False
state.current_phase = "finalizing"
return state
# ========== 4. 初始化状态节点 ==========
def init_state_node(state: MainGraphState) -> MainGraphState:
"""
初始化状态节点:在流程开始时设置初始值
"""
state.current_phase = "initializing"
state.reasoning_step = 0
state.start_time = datetime.now().isoformat()
# 从 messages 中提取用户查询
if not state.user_query and state.messages:
last_msg = state.messages[-1]
state.user_query = getattr(last_msg, "content", str(last_msg))
return state
# ========== 5. 条件路由函数 ==========
def route_by_reasoning(state: MainGraphState) -> str:
"""
根据推理结果决定下一步路由
Returns: 路由标识,对应 graph_builder.py 中的边
"""
# 先检查特殊情况
if state.current_phase == "max_steps_exceeded":
return "llm_call"
if state.current_phase == "error_handling" or state.current_error:
return "handle_error"
if state.current_phase == "finalizing" or state.current_phase == "done":
return "llm_call"
if state.current_phase == "retrying":
if state.retry_action and "rag" in state.retry_action.lower():
return "rag_retrieve"
return "react_reason"
# 关键修复:检查是否已经执行过子图,如果是,直接去 llm_call
previous_actions = [h.get("action") for h in state.reasoning_history]
if "subgraph_completed" in previous_actions or state.final_result:
return "llm_call"
# 关键修复:检测 RAG 重复循环 - 如果发现"RETRIEVE_RAG"出现超过1次直接去 LLM
rag_count = previous_actions.count("RETRIEVE_RAG")
if rag_count >= 2:
info(f"[route_by_reasoning] 检测到 RAG 重复循环({rag_count}次),直接去 llm_call")
return "llm_call"
# 关键修复:如果已经有 rag_docs 或 rag_context说明已经检索过了直接去 LLM
if (state.rag_docs and len(state.rag_docs) > 0) or (state.rag_context and len(state.rag_context) > 0):
info(f"[route_by_reasoning] 检测到已存在 RAG 检索结果,直接去 llm_call")
return "llm_call"
# 关键修复:限制最多 3 次推理,避免无限循环
if len(previous_actions) >= 3:
info(f"[route_by_reasoning] 已达到最大推理次数 ({len(previous_actions)}),直接去 llm_call")
return "llm_call"
# 获取推理结果
reasoning_result: Optional[ReasoningResult] = state.debug_info.get("reasoning_result")
if not reasoning_result:
return "llm_call"
# 使用 intent.py 提供的路由函数
route = get_route_by_reasoning(reasoning_result)
# 映射到我们的节点名称
# 注意:这些名称必须与 main_graph_builder.py 中定义的节点名称一致
route_mapping = {
"direct_response": "llm_call",
"retrieve_rag": "rag_retrieve",
"re_retrieve_rag": "rag_retrieve",
"web_search": "web_search",
"clarify": "llm_call",
"call_tool": "llm_call",
"contact": "contact_subgraph",
"dictionary": "dictionary_subgraph",
"news_analysis": "news_analysis_subgraph",
}
info(f"[route_by_reasoning] 推理结果={reasoning_result.action.name}, 路由={route_mapping.get(route, 'llm_call')}, 历史动作={previous_actions}")
return route_mapping.get(route, "llm_call")
# ========== 导出 ==========
__all__ = [
"init_state_node",
"react_reason_node",
"web_search_node",
"error_handling_node",
"route_by_reasoning"
]