修复循环推理bug

This commit is contained in:
2026-05-05 00:54:04 +08:00
parent acc8d801f3
commit b64dade9e9
11 changed files with 605 additions and 766 deletions

View File

@@ -2,25 +2,30 @@
主图节点模块导出
"""
# 新的 React 模式节点
from .react_nodes import (
init_state_node,
react_reason_node,
web_search_node,
error_handling_node,
route_by_reasoning
)
# React 模式节点
from .reasoning import react_reason_node
from .web_search import web_search_node
from .error_handling import error_handling_node
from .routing import init_state_node, route_by_reasoning
from .llm_call import create_llm_call_node
from .rag_nodes import rag_retrieve_node
from .rag_nodes import rag_retrieve_node, rag_re_retrieve_node
# 记忆节点(已更新到 MainGraphState
# 记忆节点
from .retrieve_memory import create_retrieve_memory_node
from .memory_trigger import memory_trigger_node, set_mem0_client
from .summarize import create_summarize_node
from .finalize import finalize_node
# 路由(已更新到 MainGraphState
from .router import should_continue
# 混合路由节点
from .hybrid_router import (
hybrid_router_node,
fast_chitchat_node,
fast_rag_node,
fast_tool_node,
)
# 通用工具
from ._utils import dispatch_custom_event, make_react_event
__all__ = [
# React 模式节点
@@ -29,15 +34,21 @@ __all__ = [
"web_search_node",
"error_handling_node",
"route_by_reasoning",
# 通用节点
"create_llm_call_node",
"rag_retrieve_node",
"rag_re_retrieve_node",
# 记忆节点
"create_retrieve_memory_node",
"memory_trigger_node",
"set_mem0_client",
"create_summarize_node",
"finalize_node",
# 路由
"should_continue",
# 混合路由节点
"hybrid_router_node",
"fast_chitchat_node",
"fast_rag_node",
"fast_tool_node",
# 通用工具
"dispatch_custom_event",
"make_react_event",
]

View File

@@ -0,0 +1,57 @@
"""
主图节点通用工具模块
包含事件发送、状态更新等通用功能
"""
from typing import Dict, Any, Optional
async def dispatch_custom_event(
event_name: str,
data: Dict[str, Any],
config: Optional[Dict[str, Any]] = None,
) -> None:
"""
安全地发送自定义事件,忽略发送失败
Args:
event_name: 事件名称
data: 事件数据
config: LangChain 配置
"""
if not config:
return
try:
from langchain_core.callbacks.manager import adispatch_custom_event
callbacks = config.get("callbacks")
if callbacks:
await adispatch_custom_event(event_name, data, callbacks=callbacks)
except Exception:
# 事件发送失败不应中断主流程
pass
def make_react_event(
step: int,
action: str,
confidence: float = 1.0,
reasoning: str = ""
) -> Dict[str, Any]:
"""
构造标准推理事件数据
Args:
step: 当前步数
action: 动作名称
confidence: 置信度
reasoning: 推理过程
Returns:
事件数据字典
"""
return {
"step": step,
"action": action,
"confidence": confidence,
"reasoning": reasoning
}

View File

@@ -0,0 +1,95 @@
"""
错误处理节点 - 处理子图/工具调用错误
"""
from app.main_graph.state import MainGraphState, ErrorSeverity
from app.logger import info
def error_handling_node(state: MainGraphState) -> MainGraphState:
"""
错误处理节点:处理子图/工具调用错误
返回结构化错误信息,格式如下:
{
"tool/node": "...",
"status": "failed",
"error": "...",
"retries_exceeded": true/false,
"suggestion": "..."
}
"""
state.current_phase = "error_handling"
if not state.current_error:
state.current_phase = "react_reasoning"
return state
error = state.current_error
# 更新错误状态
state.error_message = f"{error.error_type}: {error.error_message}"
# 记录结构化错误信息
structured_error = {
"tool": error.source,
"status": "failed",
"error": error.error_message,
"retries_exceeded": error.retry_count >= error.max_retries,
"retry_count": error.retry_count,
"max_retries": error.max_retries
}
# 根据错误类型添加建议
if "RAG" in error.error_type:
structured_error["suggestion"] = "尝试重新表述问题或直接询问"
elif "subgraph" in error.source or "contact" in error.source:
structured_error["suggestion"] = "子图执行失败,请尝试简化查询"
elif "timeout" in error.error_message.lower():
structured_error["suggestion"] = "请求超时,请稍后再试"
else:
structured_error["suggestion"] = "请尝试其他方式提问"
state.debug_info["structured_error"] = structured_error
# 策略1: 检查是否可以重试
can_retry = (
error.severity in [ErrorSeverity.WARNING, ErrorSeverity.ERROR]
and error.retry_count < error.max_retries
)
if can_retry:
error.retry_count += 1
state.retry_action = error.source
state.debug_info["retry_count"] = error.retry_count
if "RAG" in error.error_type:
state.last_action = "RE_RETRIEVE_RAG"
elif "subgraph" in error.source:
state.last_action = "DIRECT_RESPONSE"
else:
state.last_action = "REASON"
state.current_phase = "retrying"
return state
# 策略2: 无法重试,尝试降级方案
if error.severity != ErrorSeverity.FATAL:
state.final_result = (
f"⚠️ 遇到一些问题:\n"
f"```json\n{structured_error}\n```\n"
f"但我会尽力用现有信息回答您。"
)
state.success = True
state.current_phase = "finalizing"
return state
# 策略3: 致命错误
state.final_result = (
f"❌ 服务暂时不可用,请稍后再试。\n"
f"```json\n{structured_error}\n```"
)
state.success = False
state.current_phase = "finalizing"
return state

View File

@@ -1,9 +1,6 @@
"""
RAG 节点模块 - 真正利用已有 RAG 代码
包含
- rag_retrieve_node: RAG 检索节点(带超时重试)
- rag_re_retrieve_node: 重新检索节点
- 集成 backend/app/rag/tools.py 和 rag_initializer.py
RAG 检索节点模块
包含 RAG 检索节点(带超时重试)
"""
import time
@@ -12,267 +9,163 @@ from typing import Dict, Any, Optional
from datetime import datetime
from app.main_graph.state import MainGraphState, ErrorRecord, ErrorSeverity
from app.main_graph.utils.retry_utils import (
RetryConfig,
RAG_RETRY_CONFIG,
create_retry_wrapper_for_node
)
from app.main_graph.utils.retry_utils import RAG_RETRY_CONFIG
from app.logger import info
from ._utils import dispatch_custom_event, make_react_event
# 真正导入和利用已有 RAG 代码
from app.rag.tools import create_rag_tool
from app.rag.pipeline import RAGPipeline
# ========== 全局 RAG 工具实例(延迟初始化)==========
# ========== 全局 RAG 工具实例 ==========
_GLOBAL_RAG_TOOL: Optional[Any] = None
_GLOBAL_RAG_PIPELINE: Optional[RAGPipeline] = None
def get_global_rag_tool() -> Optional[Any]:
"""
获取全局 RAG 工具(单例模式)
Returns:
RAG 工具实例或 None
"""
return _GLOBAL_RAG_TOOL
def set_global_rag_tool(tool: Any) -> None:
"""
设置全局 RAG 工具(通常在应用启动时调用)
Args:
tool: RAG 工具实例
"""
global _GLOBAL_RAG_TOOL
_GLOBAL_RAG_TOOL = tool
def set_global_rag_pipeline(pipeline: RAGPipeline) -> None:
"""
设置全局 RAG Pipeline
Args:
pipeline: RAGPipeline 实例
"""
global _GLOBAL_RAG_PIPELINE
_GLOBAL_RAG_PIPELINE = pipeline
# ========== 从状态获取 RAG 工具 ==========
def get_rag_tool_from_state(state: MainGraphState) -> Optional[Any]:
"""
从状态中获取 RAG 工具(如果有)
Args:
state: 主图状态
Returns:
RAG 工具实例或 None
"""
# 优先从状态获取
if "rag_tool" in state.debug_info:
return state.debug_info["rag_tool"]
# 其次从全局获取
return get_global_rag_tool()
"""从状态或全局获取 RAG 工具"""
return state.debug_info.get("rag_tool") or get_global_rag_tool()
# ========== 工具:将 RAG 工具注入到状态 ==========
def inject_rag_tool_to_state(state: MainGraphState, rag_tool: Any) -> MainGraphState:
"""
将 RAG 工具注入到状态中,供后续节点使用
Args:
state: 主图状态
rag_tool: RAG 工具实例
Returns:
更新后的状态
"""
"""将 RAG 工具注入到状态中"""
state.debug_info["rag_tool"] = rag_tool
state.debug_info["rag_tool_injected"] = datetime.now().isoformat()
return state
# ========== RAG 检索核心逻辑(真正利用已有代码)==========
# ========== RAG 检索核心逻辑 ==========
async def _rag_retrieve_core(state: MainGraphState) -> MainGraphState:
"""
RAG 检索核心逻辑(真正利用 rag/tools.py - 异步版本
Args:
state: 主图状态
Returns:
更新后的状态
"""
# 获取检索查询(优先使用推理结果中的优化查询)
"""执行 RAG 检索的核心逻辑"""
retrieval_query = state.user_query
if "reasoning_result" in state.debug_info:
reasoning_result = state.debug_info["reasoning_result"]
if hasattr(reasoning_result, "retrieval_config"):
cfg = reasoning_result.retrieval_config
if cfg and cfg.retrieval_query:
retrieval_query = cfg.retrieval_query
# 尝试获取 RAG 工具(多种方式)
# 优先使用推理结果中的优化查询
reasoning_result = state.debug_info.get("reasoning_result")
if reasoning_result and hasattr(reasoning_result, "retrieval_config"):
cfg = reasoning_result.retrieval_config
if cfg and cfg.retrieval_query:
retrieval_query = cfg.retrieval_query
rag_tool = get_rag_tool_from_state(state)
if rag_tool:
# 使用真正的 RAG 工具(来自 rag/tools.py- 异步版本
try:
# 直接 await 异步工具的 ainvoke 方法
rag_context = await rag_tool.ainvoke(retrieval_query)
rag_context = await rag_tool.ainvoke(retrieval_query)
state.rag_context = rag_context
state.rag_docs = [{"source": "rag_retrieval", "content": rag_context}]
state.rag_retrieved = True
state.success = True
state.debug_info["rag_source"] = "rag_tool"
return state
if _GLOBAL_RAG_PIPELINE:
documents = await _GLOBAL_RAG_PIPELINE.aretrieve(retrieval_query)
if documents:
rag_context = _GLOBAL_RAG_PIPELINE.format_context(documents)
state.rag_context = rag_context
state.rag_docs = [
{"source": "rag_retrieval", "content": rag_context}
{"source": doc.metadata.get("source", "unknown"), "content": doc.page_content}
for doc in documents
]
state.rag_retrieved = True
state.success = True
state.debug_info["rag_source"] = "rag_tool"
return state
except Exception as e:
raise RuntimeError(f"RAG 工具调用失败: {str(e)}") from e
elif _GLOBAL_RAG_PIPELINE:
# 使用 RAG Pipeline 直接检索 - 直接用异步方法
try:
documents = await _GLOBAL_RAG_PIPELINE.aretrieve(retrieval_query)
if documents:
rag_context = _GLOBAL_RAG_PIPELINE.format_context(documents)
state.rag_context = rag_context
state.rag_docs = [
{"source": doc.metadata.get("source", "unknown"), "content": doc.page_content}
for doc in documents
]
else:
state.rag_context = f"未找到与 '{retrieval_query}' 相关的知识库信息。"
state.rag_docs = []
state.rag_retrieved = True
state.success = True
state.debug_info["rag_source"] = "rag_pipeline"
return state
except Exception as e:
raise RuntimeError(f"RAG Pipeline 调用失败: {str(e)}") from e
else:
# 没有可用的 RAG 工具/Pipeline
raise RuntimeError("RAG 工具未初始化,请先调用 set_global_rag_tool() 或 set_global_rag_pipeline()")
else:
state.rag_context = f"未找到与 '{retrieval_query}' 相关的知识库信息。"
state.rag_docs = []
state.rag_retrieved = True
state.success = True
state.debug_info["rag_source"] = "rag_pipeline"
return state
raise RuntimeError("RAG 工具未初始化,请先调用 set_global_rag_tool() 或 set_global_rag_pipeline()")
# ========== RAG 检索节点(带超时和重试)==========
# ========== RAG 检索节点 ==========
async def rag_retrieve_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState:
"""
RAG 检索节点:带超时和重试,真正利用已有 RAG 代码
Args:
state: 主图状态
config: LangChain 配置
Returns:
更新后的状态
"""
"""RAG 检索节点:带超时和重试"""
state.current_phase = "rag_retrieving"
# 发送开始事件
if config:
try:
from langchain_core.callbacks.manager import adispatch_custom_event
callbacks = config.get("callbacks")
if callbacks:
await adispatch_custom_event(
"react_reasoning",
{
"step": state.reasoning_step,
"action": "rag_retrieve_start",
"confidence": 1.0,
"reasoning": "开始执行 RAG 检索..."
},
callbacks=callbacks
)
except Exception as e:
info(f"[rag_retrieve_node] 无法发送开始事件: {e}")
start_time = time.time()
last_error = None
# 步骤1: 发送开始事件
await dispatch_custom_event(
"react_reasoning",
make_react_event(state.reasoning_step, "rag_retrieve_start", 1.0, "开始执行 RAG 检索..."),
config
)
# 步骤2: 执行检索(带重试)
for attempt in range(RAG_RETRY_CONFIG.max_retries + 1):
try:
# 执行核心逻辑 - 异步 await
result = await _rag_retrieve_core(state)
info(f"[rag_retrieve_node] RAG 检索成功,获取到上下文长度: {len(result.rag_context)} 字符")
if result.rag_docs:
for i, doc in enumerate(result.rag_docs[:3]): # 只显示前3条
info(f"[rag_retrieve_node] 文档 {i+1}: {doc.get('content', '')[:100]}...")
# 成功
info(f"[RAG] 检索成功,上下文长度: {len(result.rag_context)} 字符")
state.debug_info["rag_retrieval"] = {
"attempt": attempt + 1,
"success": True,
"time": time.time() - start_time
}
# 发送完成事件
if config:
try:
from langchain_core.callbacks.manager import adispatch_custom_event
callbacks = config.get("callbacks")
if callbacks:
doc_count = len(result.rag_docs) if result.rag_docs else 0
await adispatch_custom_event(
"react_reasoning",
{
"step": state.reasoning_step,
"action": "rag_retrieve_complete",
"confidence": 1.0,
"reasoning": f"RAG 检索完成,找到 {doc_count} 条相关文档"
},
callbacks=callbacks
)
except Exception as e:
info(f"[rag_retrieve_node] 无法发送完成事件: {e}")
# 关键修复:把 rag_retrieve 加到 reasoning_history 里,让下次推理知道
# 记录成功到历史
state.reasoning_history.append({
"step": state.reasoning_step,
"action": "RETRIEVE_RAG", # 大写,和推理结果保持一致
"action": "RETRIEVE_RAG",
"confidence": 1.0,
"reasoning": "RAG 检索完成",
"timestamp": datetime.now().isoformat()
})
# 发送完成事件
doc_count = len(result.rag_docs) if result.rag_docs else 0
await dispatch_custom_event(
"react_reasoning",
make_react_event(state.reasoning_step, "rag_retrieve_complete", 1.0,
f"RAG 检索完成,找到 {doc_count} 条相关文档"),
config
)
return result
except Exception as e:
last_error = e
if attempt >= RAG_RETRY_CONFIG.max_retries:
break
# 发送重试事件
if config:
try:
from langchain_core.callbacks.manager import adispatch_custom_event
callbacks = config.get("callbacks")
if callbacks:
await adispatch_custom_event(
"react_reasoning",
{
"step": state.reasoning_step,
"action": "rag_retrieve_retry",
"confidence": 1.0,
"reasoning": f"RAG 检索失败,第 {attempt + 1} 次重试..."
},
callbacks=callbacks
)
except Exception as e:
info(f"[rag_retrieve_node] 无法发送重试事件: {e}")
# 指数退避等待
await dispatch_custom_event(
"react_reasoning",
make_react_event(state.reasoning_step, "rag_retrieve_retry", 1.0,
f"RAG 检索失败,第 {attempt + 1} 次重试..."),
config
)
# 指数退避
delay = RAG_RETRY_CONFIG.base_delay * (2 ** attempt)
await asyncio.sleep(min(delay, RAG_RETRY_CONFIG.max_delay))
# 所有重试失败,记录结构化错误
# 步骤3: 所有重试失败,记录到历史(避免推理循环)
state.reasoning_history.append({
"step": state.reasoning_step,
"action": "RETRIEVE_RAG",
"confidence": 0.0,
"reasoning": f"RAG 检索失败: {str(last_error) if last_error else '超时'}",
"timestamp": datetime.now().isoformat()
})
# 步骤4: 记录错误
error_record = ErrorRecord(
error_type="RAGRetrievalError",
error_message=str(last_error) if last_error else "RAG 检索超时",
@@ -284,105 +177,46 @@ async def rag_retrieve_node(state: MainGraphState, config: Optional[Dict[str, An
context={
"query": state.user_query,
"total_time": time.time() - start_time,
"timeout": RAG_RETRY_CONFIG.timeout,
"has_rag_tool": get_global_rag_tool() is not None,
"has_rag_pipeline": _GLOBAL_RAG_PIPELINE is not None
}
)
state.errors.append(error_record)
state.current_error = error_record
state.current_phase = "error_handling"
# 发送错误事件
if config:
try:
from langchain_core.callbacks.manager import adispatch_custom_event
callbacks = config.get("callbacks")
if callbacks:
await adispatch_custom_event(
"react_reasoning",
{
"step": state.reasoning_step,
"action": "rag_retrieve_error",
"confidence": 1.0,
"reasoning": f"RAG 检索失败: {str(last_error)}"
},
callbacks=callbacks
)
except Exception as e:
info(f"[rag_retrieve_node] 无法发送错误事件: {e}")
await dispatch_custom_event(
"react_reasoning",
make_react_event(state.reasoning_step, "rag_retrieve_error", 1.0,
f"RAG 检索失败: {str(last_error)}"),
config
)
return state
# ========== 重新检索节点 ==========
def rag_re_retrieve_node(state: MainGraphState) -> MainGraphState:
"""
重新检索节点:用于第二次检索(不同的参数)
Args:
state: 主图状态
Returns:
更新后的状态
"""
async def rag_re_retrieve_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState:
"""重新检索节点"""
state.current_phase = "rag_re_retrieving"
# 记录原始检索信息
state.debug_info["rag_re_retrieve"] = {
"original_retrieved": state.rag_retrieved,
"original_docs_count": len(state.rag_docs)
}
# 可以在这里修改检索参数(例如:调整查询、增加 k 值)
# 暂时复用同一个检索逻辑
return rag_retrieve_node(state)
# ========== 便捷函数:从 rag_initializer 初始化 ==========
async def initialize_rag_from_initializer() -> None:
"""
从 rag_initializer 初始化 RAG便捷函数
注意:这是示例代码,实际使用时需要提供 local_llm_creator
"""
try:
from app.main_graph.utils.rag_initializer import init_rag_tool
# 注意:这里需要传入 local_llm_creator
# 示例:
# def my_llm_creator():
# from ..model_services import get_llm
# return get_llm()
#
# rag_tool = await init_rag_tool(my_llm_creator)
# set_global_rag_tool(rag_tool)
print("⚠️ initialize_rag_from_initializer 需要传入 local_llm_creator")
print("⚠️ 请在应用启动时调用 init_rag_tool() 并设置全局 RAG 工具")
except ImportError as e:
print(f"⚠️ 无法导入 rag_initializer: {e}")
except Exception as e:
print(f"⚠️ RAG 初始化失败: {e}")
return await rag_retrieve_node(state, config)
# ========== 导出 ==========
__all__ = [
# 节点函数
"rag_retrieve_node",
"rag_re_retrieve_node",
# 工具函数
"inject_rag_tool_to_state",
"get_rag_tool_from_state",
# 全局 RAG 管理
"get_global_rag_tool",
"set_global_rag_tool",
"set_global_rag_pipeline",
# 初始化
"initialize_rag_from_initializer"
]

View File

@@ -1,424 +0,0 @@
"""
React 模式节点模块 - 带超时和重试功能
包含:
- react_reason_node: 使用 intent.py 进行推理
- error_handling_node: 错误处理节点
- init_state_node: 初始化状态节点
注意:为了兼容 LangGraph 的同步接口,我们保留了同步的 react_reason 调用
但内部会根据情况使用规则推理或尝试异步调用
"""
import sys
from typing import Dict, Any, Optional
from datetime import datetime
# 导入我们的 intent.py
from app.core.intent import (
react_reason,
react_reason_async,
get_route_by_reasoning,
ReasoningAction,
ReasoningResult
)
from app.core.state_base import StateUtils
from app.main_graph.state import MainGraphState, ErrorRecord, ErrorSeverity
from app.main_graph.utils.retry_utils import (
RetryConfig,
SUBGRAPH_RETRY_CONFIG
)
from app.logger import info
# ========== 1. React 推理节点 ==========
async def react_reason_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState:
"""
React 模式推理节点:判断下一步做什么(异步版本)
Returns: 更新后的状态
"""
state.current_phase = "react_reasoning"
state.reasoning_step += 1
info(f"[react_reason] 第 {state.reasoning_step} 次推理开始")
# 检查是否超过最大步数
if state.reasoning_step > state.max_steps:
state.current_phase = "max_steps_exceeded"
state.final_result = (
f"❌ 推理步数超过限制(最大 {state.max_steps} 步),"
f"已执行 {state.reasoning_step - 1} 步。"
f"请简化您的问题或分批提问。"
)
state.success = False
return state
# 准备上下文
context = {
"retrieved_docs": state.rag_docs,
"previous_actions": [h.get("action") for h in state.reasoning_history],
"messages": state.messages,
"errors": state.errors
}
# 使用 intent.py 进行推理(现在直接用异步版本)
result: ReasoningResult = await react_reason_async(state.user_query, context)
info(f"[react_reason] 推理结果: action={result.action.name}, confidence={result.confidence}")
if result.reasoning:
info(f"[react_reason] 推理过程: {result.reasoning}")
# 关键修复:直接发送自定义事件给 agent_service而不是通过 state
if config:
try:
from langchain_core.callbacks.manager import adispatch_custom_event
callbacks = config.get("callbacks")
if callbacks:
info(f"[react_reason] 直接发送推理事件 #{state.reasoning_step}")
await adispatch_custom_event(
"react_reasoning",
{
"step": state.reasoning_step,
"action": result.action.name,
"confidence": result.confidence,
"reasoning": result.reasoning
},
callbacks=callbacks
)
except Exception as e:
info(f"[react_reason] 无法发送自定义事件: {e}")
# 记录推理历史
state.reasoning_history.append({
"step": state.reasoning_step,
"action": result.action.name,
"confidence": result.confidence,
"reasoning": result.reasoning,
"timestamp": datetime.now().isoformat()
})
# 更新状态
state.debug_info["last_reasoning"] = {
"action": result.action.name,
"confidence": result.confidence,
"reasoning": result.reasoning
}
# 保存推理结果到状态
state.debug_info["reasoning_result"] = result
# 确定下一步动作
state.last_action = result.action.name
# 关键修复:不再设置 latest_reasoning避免 agent_service 重复读取
if "latest_reasoning" in state.debug_info:
del state.debug_info["latest_reasoning"]
return state
# ========== 2. 联网搜索节点 ==========
async def web_search_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState:
"""
联网搜索节点:执行搜索并将结果保存到状态
"""
state.current_phase = "web_searching"
# 发送开始事件
if config:
try:
from langchain_core.callbacks.manager import adispatch_custom_event
callbacks = config.get("callbacks")
if callbacks:
await adispatch_custom_event(
"react_reasoning",
{
"step": state.reasoning_step,
"action": "web_search_start",
"confidence": 1.0,
"reasoning": "开始执行联网搜索..."
},
callbacks=callbacks
)
except Exception as e:
info(f"[web_search_node] 无法发送开始事件: {e}")
# 获取搜索查询
reasoning_result = state.debug_info.get("reasoning_result")
search_query = reasoning_result.metadata.get("search_query", state.user_query) if reasoning_result else state.user_query
try:
from app.core import web_search
print(f"[WebSearch] 搜索: {search_query}")
search_result = web_search(search_query, max_results=5)
# 保存搜索结果到状态
if not hasattr(state, "web_search_results"):
state.web_search_results = []
state.web_search_results.append(search_result)
# 将搜索结果添加到 rag_context供 LLM 使用
if state.rag_context:
state.rag_context = f"{state.rag_context}\n\n---\n\n## 🌐 联网搜索结果:\n{search_result}"
else:
state.rag_context = f"## 🌐 联网搜索结果:\n{search_result}"
state.success = True
print(f"[WebSearch] 搜索完成")
# 发送完成事件
if config:
try:
from langchain_core.callbacks.manager import adispatch_custom_event
callbacks = config.get("callbacks")
if callbacks:
await adispatch_custom_event(
"react_reasoning",
{
"step": state.reasoning_step,
"action": "web_search_complete",
"confidence": 1.0,
"reasoning": f"联网搜索完成,找到 {len(search_result) if isinstance(search_result, list) else 1} 条结果"
},
callbacks=callbacks
)
except Exception as e:
info(f"[web_search_node] 无法发送完成事件: {e}")
except Exception as e:
from app.main_graph.state import ErrorRecord, ErrorSeverity
from datetime import datetime
error_record = ErrorRecord(
error_type="WebSearchError",
error_message=str(e),
severity=ErrorSeverity.WARNING,
source="web_search_node",
timestamp=datetime.now().isoformat(),
retry_count=0,
max_retries=2,
context={"search_query": search_query}
)
state.errors.append(error_record)
state.current_error = error_record
state.current_phase = "error_handling"
state.success = False
# 发送错误事件
if config:
try:
from langchain_core.callbacks.manager import adispatch_custom_event
callbacks = config.get("callbacks")
if callbacks:
await adispatch_custom_event(
"react_reasoning",
{
"step": state.reasoning_step,
"action": "web_search_error",
"confidence": 1.0,
"reasoning": f"联网搜索失败: {str(e)}"
},
callbacks=callbacks
)
except Exception as e:
info(f"[web_search_node] 无法发送错误事件: {e}")
return state
# ========== 3. 错误处理节点 ==========
def error_handling_node(state: MainGraphState) -> MainGraphState:
"""
错误处理节点:处理子图/工具调用错误
返回结构化错误信息,格式如下:
{
"tool/node": "...",
"status": "failed",
"error": "...",
"retries_exceeded": true/false,
"suggestion": "..."
}
"""
state.current_phase = "error_handling"
if not state.current_error:
state.current_phase = "react_reasoning"
return state
error = state.current_error
# 更新错误状态
state.error_message = f"{error.error_type}: {error.error_message}"
# 记录结构化错误信息
structured_error = {
"tool": error.source,
"status": "failed",
"error": error.error_message,
"retries_exceeded": error.retry_count >= error.max_retries,
"retry_count": error.retry_count,
"max_retries": error.max_retries
}
# 根据错误类型添加建议
if "RAG" in error.error_type:
structured_error["suggestion"] = "尝试重新表述问题或直接询问"
elif "subgraph" in error.source or "contact" in error.source:
structured_error["suggestion"] = "子图执行失败,请尝试简化查询"
elif "timeout" in error.error_message.lower():
structured_error["suggestion"] = "请求超时,请稍后再试"
else:
structured_error["suggestion"] = "请尝试其他方式提问"
state.debug_info["structured_error"] = structured_error
# 策略1: 检查是否可以重试
can_retry = (
error.severity in [ErrorSeverity.WARNING, ErrorSeverity.ERROR]
and error.retry_count < error.max_retries
)
if can_retry:
error.retry_count += 1
state.retry_action = error.source
state.debug_info["retry_count"] = error.retry_count
if "RAG" in error.error_type:
state.last_action = "RE_RETRIEVE_RAG"
elif "subgraph" in error.source:
state.last_action = "DIRECT_RESPONSE"
else:
state.last_action = "REASON"
state.current_phase = "retrying"
return state
# 策略2: 无法重试,尝试降级方案
if error.severity != ErrorSeverity.FATAL:
state.final_result = (
f"⚠️ 遇到一些问题:\n"
f"```json\n{structured_error}\n```\n"
f"但我会尽力用现有信息回答您。"
)
state.success = True
state.current_phase = "finalizing"
return state
# 策略3: 致命错误
state.final_result = (
f"❌ 服务暂时不可用,请稍后再试。\n"
f"```json\n{structured_error}\n```"
)
state.success = False
state.current_phase = "finalizing"
return state
# ========== 4. 初始化状态节点 ==========
def init_state_node(state: MainGraphState) -> MainGraphState:
"""
初始化状态节点:在流程开始时设置初始值
"""
state.current_phase = "initializing"
state.reasoning_step = 0
state.start_time = datetime.now().isoformat()
# 从 messages 中提取用户查询
if not state.user_query and state.messages:
last_msg = state.messages[-1]
state.user_query = getattr(last_msg, "content", str(last_msg))
return state
# ========== 5. 条件路由函数 ==========
def route_by_reasoning(state: MainGraphState) -> str:
"""
根据推理结果决定下一步路由
Returns: 路由标识,对应 graph_builder.py 中的边
"""
# 先检查特殊情况
if state.current_phase == "max_steps_exceeded":
return "llm_call"
if state.current_phase == "error_handling" or state.current_error:
return "handle_error"
if state.current_phase == "finalizing" or state.current_phase == "done":
return "llm_call"
if state.current_phase == "retrying":
if state.retry_action and "rag" in state.retry_action.lower():
return "rag_retrieve"
return "react_reason"
# ========== 关键修复:优先检查当前推理结果 ==========
reasoning_result: Optional[ReasoningResult] = state.debug_info.get("reasoning_result")
if reasoning_result and reasoning_result.action == ReasoningAction.DIRECT_RESPONSE:
info(f"[route_by_reasoning] 当前推理结果=DIRECT_RESPONSE直接去 llm_call")
return "llm_call"
# ========== 然后检查历史和状态 ==========
previous_actions = [h.get("action") for h in state.reasoning_history]
# 检查是否已经执行过子图
if "subgraph_completed" in previous_actions or state.final_result:
return "llm_call"
# 检测 RAG 重复循环 - 如果发现"RETRIEVE_RAG"出现超过1次直接去 LLM
rag_count = previous_actions.count("RETRIEVE_RAG")
if rag_count >= 2:
info(f"[route_by_reasoning] 检测到 RAG 重复循环({rag_count}次),直接去 llm_call")
return "llm_call"
# 如果已经有 rag_docs 或 rag_context说明已经检索过了直接去 LLM
if (state.rag_docs and len(state.rag_docs) > 0) or (state.rag_context and len(state.rag_context) > 0):
info(f"[route_by_reasoning] 检测到已存在 RAG 检索结果,直接去 llm_call")
return "llm_call"
# 限制最多 3 次推理,避免无限循环
if len(previous_actions) >= 3:
info(f"[route_by_reasoning] 已达到最大推理次数 ({len(previous_actions)}),直接去 llm_call")
return "llm_call"
# ========== 最后处理其他推理结果 ==========
if not reasoning_result:
return "llm_call"
# 使用 intent.py 提供的路由函数
route = get_route_by_reasoning(reasoning_result)
# 映射到我们的节点名称
# 注意:这些名称必须与 main_graph_builder.py 中定义的节点名称一致
route_mapping = {
"direct_response": "llm_call",
"retrieve_rag": "rag_retrieve",
"re_retrieve_rag": "rag_retrieve",
"web_search": "web_search",
"clarify": "llm_call",
"call_tool": "llm_call",
"contact": "contact_subgraph",
"dictionary": "dictionary_subgraph",
"news_analysis": "news_analysis_subgraph",
}
info(f"[route_by_reasoning] 推理结果={reasoning_result.action.name}, 路由={route_mapping.get(route, 'llm_call')}, 历史动作={previous_actions}")
return route_mapping.get(route, "llm_call")
# ========== 导出 ==========
__all__ = [
"init_state_node",
"react_reason_node",
"web_search_node",
"error_handling_node",
"route_by_reasoning"
]

View File

@@ -0,0 +1,68 @@
"""
React 推理节点
使用 intent.py 进行意图推理
"""
from typing import Dict, Any, Optional
from datetime import datetime
from app.core.intent import react_reason_async, ReasoningResult
from app.main_graph.state import MainGraphState
from app.logger import info
from ._utils import dispatch_custom_event, make_react_event
async def react_reason_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState:
"""React 模式推理节点:判断下一步做什么"""
state.current_phase = "react_reasoning"
state.reasoning_step += 1
info(f"[推理] 第 {state.reasoning_step} 次推理开始")
# 步骤1: 准备上下文
context = {
"retrieved_docs": state.rag_docs,
"previous_actions": [h.get("action") for h in state.reasoning_history],
"reasoning_history": state.reasoning_history,
"messages": state.messages,
"errors": state.errors
}
# 步骤2: 执行推理
result: ReasoningResult = await react_reason_async(state.user_query, context)
info(f"[推理] 推理结果: action={result.action.name}, confidence={result.confidence}")
if result.reasoning:
info(f"[推理] 推理过程: {result.reasoning}")
# 步骤3: 记录推理历史
state.reasoning_history.append({
"step": state.reasoning_step,
"action": result.action.name,
"confidence": result.confidence,
"reasoning": result.reasoning,
"timestamp": datetime.now().isoformat()
})
# 步骤4: 更新调试信息
state.debug_info["last_reasoning"] = {
"action": result.action.name,
"confidence": result.confidence,
"reasoning": result.reasoning
}
state.debug_info["reasoning_result"] = result
state.last_action = result.action.name
# 步骤5: 发送推理事件
await dispatch_custom_event(
"react_reasoning",
make_react_event(
state.reasoning_step,
result.action.name,
result.confidence,
result.reasoning
),
config
)
return state

View File

@@ -1,48 +0,0 @@
"""
路由决策节点
根据当前状态决定下一步走向
"""
from typing import Literal
from langchain_core.messages import AIMessage
# 本地模块
from app.config import ENABLE_GRAPH_TRACE, MEMORY_SUMMARIZE_INTERVAL
from app.main_graph.state import MainGraphState
from app.logger import info
def should_continue(state: MainGraphState) -> Literal['tool_node', 'summarize', 'finalize']:
"""
决定下一步:工具调用、生成摘要还是结束
Args:
state: 当前对话状态
Returns:
下一个节点名称
"""
last_message = state.messages[-1]
# 1. 如果需要调用工具,优先进入工具节点
if isinstance(last_message, AIMessage) and last_message.tool_calls:
if ENABLE_GRAPH_TRACE:
info(f"🔀 [路由决策] 检测到 {len(last_message.tool_calls)} 个工具调用 → 转向 'tool_node'")
return 'tool_node'
# 2. 如果是 AI 的最终回复,判断是否达到摘要生成阈值
if isinstance(last_message, AIMessage):
turns = state.turns_since_last_summary
if turns >= MEMORY_SUMMARIZE_INTERVAL:
if ENABLE_GRAPH_TRACE:
info(f"🔀 [路由决策] 收到 AI 最终回复,已达摘要阈值({turns}/{MEMORY_SUMMARIZE_INTERVAL}) → 转向 'summarize'")
return 'summarize'
else:
if ENABLE_GRAPH_TRACE:
info(f"🔀 [路由决策] 收到 AI 最终回复,未达摘要阈值({turns}/{MEMORY_SUMMARIZE_INTERVAL}) → 结束流程")
return 'finalize'
# 3. 其他情况(如只有用户消息)直接结束
if ENABLE_GRAPH_TRACE:
info(f"🔀 [路由决策] 非 AI 消息(如纯用户消息) → 结束流程")
return 'finalize'

View File

@@ -0,0 +1,120 @@
"""
路由与初始化模块
包含状态初始化节点和条件路由函数
三层统一循环防护:
1. 全局步数硬上限reasoning_step > max_steps
2. 路由模式检测A→B→A→B 交替循环)
3. 状态停滞检测(连续相同动作)
"""
from datetime import datetime
from app.core.intent import get_route_by_reasoning, ReasoningAction
from app.main_graph.state import MainGraphState
from app.logger import info
# ========== 初始化状态节点 ==========
def init_state_node(state: MainGraphState) -> MainGraphState:
"""初始化状态节点:在流程开始时设置初始值"""
state.current_phase = "initializing"
state.reasoning_step = 0
state.start_time = datetime.now().isoformat()
if not state.user_query and state.messages:
last_msg = state.messages[-1]
state.user_query = getattr(last_msg, "content", str(last_msg))
return state
# ========== 条件路由函数 ==========
def route_by_reasoning(state: MainGraphState) -> str:
"""
根据推理结果决定下一步路由,带三层统一循环防护
核心逻辑:
1. DIRECT_RESPONSE → 直接返回 llm_call
2. 子图完成/已有结果 → 直接返回 llm_call
3. 步数超限 → 直接返回 llm_call
4. 其他 → 正常路由
"""
# 获取历史动作
previous_actions = [h.get("action") for h in state.reasoning_history]
info(f"[条件路由] step={state.reasoning_step}, phase={state.current_phase}, history={previous_actions}")
# ========== 获取推理结果 ==========
reasoning_result = state.debug_info.get("reasoning_result")
latest_action = reasoning_result.action.name if reasoning_result else None
# ========== 核心检查DIRECT_RESPONSE 优先 ==========
# 从 reasoning_result 检查(最新)
if latest_action == "DIRECT_RESPONSE":
info(f"[条件路由] 推理结果为 DIRECT_RESPONSE直接去 llm_call")
return "llm_call"
# 备用:从历史记录检查
if previous_actions and previous_actions[-1] == "DIRECT_RESPONSE":
info(f"[条件路由] 历史记录最新动作为 DIRECT_RESPONSE直接去 llm_call")
return "llm_call"
# ========== 子图完成/已有结果 ==========
if "subgraph_completed" in previous_actions or state.final_result:
info("[条件路由] 子图已完成或已有结果,直接终止")
return "llm_call"
# ========== 步数超限 ==========
if state.reasoning_step > state.max_steps:
info(f"[条件路由] 步数超限 ({state.reasoning_step}/{state.max_steps}),强制终止")
return "llm_call"
# ========== 特殊阶段快速通道 ==========
if state.current_phase in ("max_steps_exceeded", "finalizing", "done"):
return "llm_call"
if state.current_phase == "error_handling" or state.current_error:
return "handle_error"
# ========== 无推理结果,默认终止 ==========
if not reasoning_result:
info("[条件路由] 无推理结果,默认去 llm_call")
return "llm_call"
# ========== 计算目标路由 ==========
route = get_route_by_reasoning(reasoning_result)
route_mapping = {
"direct_response": "llm_call",
"retrieve_rag": "rag_retrieve",
"re_retrieve_rag": "rag_retrieve",
"web_search": "web_search",
"clarify": "llm_call",
"call_tool": "llm_call",
"contact": "contact_subgraph",
"dictionary": "dictionary_subgraph",
"news_analysis": "news_analysis_subgraph",
}
target = route_mapping.get(route, "llm_call")
# ========== 循环防护检测 ==========
# 1. 路由模式检测A→B→A→B 交替)
if len(previous_actions) >= 4:
if (previous_actions[-4] == previous_actions[-2]
and previous_actions[-3] == previous_actions[-1]
and previous_actions[-2] != previous_actions[-1]):
info(f"[条件路由] 检测到路由循环: {previous_actions[-4:]},强制终止")
return "llm_call"
# 2. 状态停滞检测(连续相同动作)
if len(previous_actions) >= 2 and previous_actions[-1] == previous_actions[-2]:
info(f"[条件路由] 连续相同动作 '{previous_actions[-1]}',强制终止")
return "llm_call"
# ========== 智能优化 ==========
if target == "rag_retrieve" and (state.rag_docs or state.rag_context):
info("[条件路由] RAG 结果已存在,跳过检索")
return "llm_call"
info(f"[条件路由] 动作={latest_action}, 目标={target}")
return target

View File

@@ -0,0 +1,115 @@
"""
联网搜索节点 - 执行搜索并将结果保存到状态
"""
from typing import Dict, Any, Optional
from datetime import datetime
from app.main_graph.state import MainGraphState, ErrorRecord, ErrorSeverity
from app.logger import info
async def web_search_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState:
"""
联网搜索节点:执行搜索并将结果保存到状态
"""
state.current_phase = "web_searching"
# 发送开始事件
if config:
try:
from langchain_core.callbacks.manager import adispatch_custom_event
callbacks = config.get("callbacks")
if callbacks:
await adispatch_custom_event(
"react_reasoning",
{
"step": state.reasoning_step,
"action": "web_search_start",
"confidence": 1.0,
"reasoning": "开始执行联网搜索..."
},
callbacks=callbacks
)
except Exception as e:
info(f"[web_search_node] 无法发送开始事件: {e}")
# 获取搜索查询
reasoning_result = state.debug_info.get("reasoning_result")
search_query = reasoning_result.metadata.get("search_query", state.user_query) if reasoning_result else state.user_query
try:
from app.core import web_search
print(f"[WebSearch] 搜索: {search_query}")
search_result = web_search(search_query, max_results=5)
# 保存搜索结果到状态
if not hasattr(state, "web_search_results"):
state.web_search_results = []
state.web_search_results.append(search_result)
# 将搜索结果添加到 rag_context供 LLM 使用
if state.rag_context:
state.rag_context = f"{state.rag_context}\n\n---\n\n## 🌐 联网搜索结果:\n{search_result}"
else:
state.rag_context = f"## 🌐 联网搜索结果:\n{search_result}"
state.success = True
print(f"[WebSearch] 搜索完成")
# 发送完成事件
if config:
try:
from langchain_core.callbacks.manager import adispatch_custom_event
callbacks = config.get("callbacks")
if callbacks:
await adispatch_custom_event(
"react_reasoning",
{
"step": state.reasoning_step,
"action": "web_search_complete",
"confidence": 1.0,
"reasoning": f"联网搜索完成,找到 {len(search_result) if isinstance(search_result, list) else 1} 条结果"
},
callbacks=callbacks
)
except Exception as e:
info(f"[web_search_node] 无法发送完成事件: {e}")
except Exception as e:
error_record = ErrorRecord(
error_type="WebSearchError",
error_message=str(e),
severity=ErrorSeverity.WARNING,
source="web_search_node",
timestamp=datetime.now().isoformat(),
retry_count=0,
max_retries=2,
context={"search_query": search_query}
)
state.errors.append(error_record)
state.current_error = error_record
state.current_phase = "error_handling"
state.success = False
# 发送错误事件
if config:
try:
from langchain_core.callbacks.manager import adispatch_custom_event
callbacks = config.get("callbacks")
if callbacks:
await adispatch_custom_event(
"react_reasoning",
{
"step": state.reasoning_step,
"action": "web_search_error",
"confidence": 1.0,
"reasoning": f"联网搜索失败: {str(e)}"
},
callbacks=callbacks
)
except Exception as e:
info(f"[web_search_node] 无法发送错误事件: {e}")
return state