refactor: 整理文件夹结构,修复 create_serde 导入问题
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m50s

- 移动 main_graph/tools/ 到 deprecated/main_graph_tools/(旧架构工具)
- 移动 rag_initializer.py 和 retry_utils.py 到 core/
- 清理 main_graph/nodes/ 里的旧节点到 deprecated/
- 修复 backend.py 中 create_serde 导入问题
This commit is contained in:
2026-05-07 01:19:15 +08:00
parent 22fdb625a4
commit 2d62bf956b
15 changed files with 9 additions and 1 deletions

View File

@@ -1,56 +0,0 @@
"""
主图节点通用工具模块
包含事件发送、状态更新等通用功能
"""
from typing import Dict, Any, Optional
from langchain_core.runnables.config import RunnableConfig
async def dispatch_custom_event(
event_name: str,
data: Dict[str, Any],
config: Optional[RunnableConfig] = None,
) -> None:
"""
安全地发送自定义事件,忽略发送失败
Args:
event_name: 事件名称
data: 事件数据
config: LangChain 配置
"""
if not config:
return
try:
from langchain_core.callbacks.manager import adispatch_custom_event
await adispatch_custom_event(event_name, data, config=config)
except Exception:
# 事件发送失败不应中断主流程
pass
def make_react_event(
step: int,
action: str,
confidence: float = 1.0,
reasoning: str = ""
) -> Dict[str, Any]:
"""
构造标准推理事件数据
Args:
step: 当前步数
action: 动作名称
confidence: 置信度
reasoning: 推理过程
Returns:
事件数据字典
"""
return {
"step": step,
"action": action,
"confidence": confidence,
"reasoning": reasoning
}

View File

@@ -1,95 +0,0 @@
"""
错误处理节点 - 处理子图/工具调用错误
"""
from ...main_graph.state import MainGraphState, ErrorSeverity
from backend.app.logger import info
def error_handling_node(state: MainGraphState) -> MainGraphState:
"""
错误处理节点:处理子图/工具调用错误
返回结构化错误信息,格式如下:
{
"tool/node": "...",
"status": "failed",
"error": "...",
"retries_exceeded": true/false,
"suggestion": "..."
}
"""
state.current_phase = "error_handling"
if not state.current_error:
state.current_phase = "react_reasoning"
return state
error = state.current_error
# 更新错误状态
state.error_message = f"{error.error_type}: {error.error_message}"
# 记录结构化错误信息
structured_error = {
"tool": error.source,
"status": "failed",
"error": error.error_message,
"retries_exceeded": error.retry_count >= error.max_retries,
"retry_count": error.retry_count,
"max_retries": error.max_retries
}
# 根据错误类型添加建议
if "RAG" in error.error_type:
structured_error["suggestion"] = "尝试重新表述问题或直接询问"
elif "subgraph" in error.source or "contact" in error.source:
structured_error["suggestion"] = "子图执行失败,请尝试简化查询"
elif "timeout" in error.error_message.lower():
structured_error["suggestion"] = "请求超时,请稍后再试"
else:
structured_error["suggestion"] = "请尝试其他方式提问"
state.debug_info["structured_error"] = structured_error
# 策略1: 检查是否可以重试
can_retry = (
error.severity in [ErrorSeverity.WARNING, ErrorSeverity.ERROR]
and error.retry_count < error.max_retries
)
if can_retry:
error.retry_count += 1
state.retry_action = error.source
state.debug_info["retry_count"] = error.retry_count
if "RAG" in error.error_type:
state.last_action = "RE_RETRIEVE_RAG"
elif "subgraph" in error.source:
state.last_action = "DIRECT_RESPONSE"
else:
state.last_action = "REASON"
state.current_phase = "retrying"
return state
# 策略2: 无法重试,尝试降级方案
if error.severity != ErrorSeverity.FATAL:
state.final_result = (
f"⚠️ 遇到一些问题:\n"
f"```json\n{structured_error}\n```\n"
f"但我会尽力用现有信息回答您。"
)
state.success = True
state.current_phase = "finalizing"
return state
# 策略3: 致命错误
state.final_result = (
f"❌ 服务暂时不可用,请稍后再试。\n"
f"```json\n{structured_error}\n```"
)
state.success = False
state.current_phase = "finalizing"
return state

View File

@@ -1,59 +0,0 @@
"""
完成事件节点模块(新架构版本)
负责发送完成事件
"""
from typing import Any, Dict
from datetime import datetime
# 本地模块
from .state import AgentState
from backend.app.logger import info, warning
from langchain_core.runnables.config import RunnableConfig
async def finalize_node(state: AgentState, config: RunnableConfig) -> Dict[str, Any]:
"""
完成事件节点(新架构版本)
Args:
state: 当前对话状态
config: 运行时配置
Returns:
空(不修改状态)
"""
info("[Finalize] 进入完成节点")
try:
# 获取流式写入器并发送完成事件
from backend.app.main_graph.config import get_stream_writer
writer = get_stream_writer()
# 提取最后的回复
final_reply = ""
if state.messages:
last_msg = state.messages[-1]
final_reply = last_msg.content if hasattr(last_msg, 'content') else str(last_msg)
# 只在 writer 存在且不是 noop 时才发送
if writer and hasattr(writer, '__call__'):
try:
writer({
"type": "custom",
"data": {
"type": "done",
"token_usage": state.last_token_usage,
"elapsed_time": state.last_elapsed_time,
"final_result": final_reply
}
})
info("🏁 [完成事件] 已发送完成事件")
except Exception as e:
warning(f"⚠️ [完成事件] 发送完成事件失败 (非致命): {e}")
except Exception as e:
warning(f"⚠️ [完成事件] 处理失败 (非致命): {e}")
info("[Finalize] 离开完成节点")
return {}

View File

@@ -1,214 +0,0 @@
"""
LLM 调用节点模块
负责调用大语言模型并处理响应
"""
import time
from typing import Any, Dict
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage
# 本地模块
from ...main_graph.state import MainGraphState
from ...agent.prompts import create_system_prompt
from ...utils.logging import log_state_change
from backend.app.logger import debug, info, error
def create_dynamic_llm_call_node(chat_services: Dict[str, BaseChatModel], tools: list):
"""
工厂函数:创建动态 LLM 调用节点(根据 state.current_model 选择模型)
Args:
chat_services: 模型名称 -> ChatModel 实例 的字典
tools: 工具列表llm_call 不使用工具,只负责回答)
Returns:
异步节点函数
"""
# llm_call 节点不使用工具,只负责生成回答
# 直接使用原始模型,不绑定工具
models = chat_services
# 预构建 prompt不带工具描述
prompt = create_system_prompt()
from langchain_core.runnables.config import RunnableConfig
async def call_llm(state: MainGraphState, config: RunnableConfig) -> Dict[str, Any]:
"""
LLM 调用节点(动态选择模型)
Args:
state: 当前对话状态
config: LangChain/LangGraph 自动注入的配置,包含 callbacks 等信息
Returns:
更新后的状态字典
"""
log_state_change("llm_call", state, "进入")
memory_context = getattr(state, "memory_context", "暂无用户信息")
start_time = time.time()
# 关键修复:如果 state.final_result 已经存在(比如子图执行完),直接返回
if state.final_result:
info(f"[llm_call] 检测到已有最终结果,直接返回: {state.final_result[:100]}...")
elapsed_time = time.time() - start_time
return {
"final_result": state.final_result,
"success": True,
"current_phase": "done",
"llm_calls": getattr(state, 'llm_calls', 0) + 1,
"last_elapsed_time": elapsed_time,
"turns_since_last_summary": getattr(state, 'turns_since_last_summary', 0) + 1,
}
# 动态选择模型
model_name = getattr(state, "current_model", "")
if not model_name or model_name not in models:
# 回退到第一个可用模型
fallback_name = next(iter(models.keys()))
info(f"[llm_call] 模型 '{model_name}' 不可用,回退到 '{fallback_name}'")
model_name = fallback_name
llm = models[model_name]
info(f"[llm_call] 使用模型(无工具): {model_name}")
try:
# 添加上下文到消息
messages_with_context = list(state.messages)
info(f"[llm_call] 原始消息数量: {len(messages_with_context)}")
for i, msg in enumerate(messages_with_context):
msg_type = getattr(msg, 'type', 'unknown')
msg_content = getattr(msg, 'content', '')[:100] if hasattr(msg, 'content') else str(msg)[:100]
info(f"[llm_call] msg[{i}] type={msg_type}, content={repr(msg_content)}")
if state.rag_context:
from langchain_core.messages import SystemMessage
rag_system_msg = SystemMessage(content=f"以下是检索到的相关信息:\n{state.rag_context}")
inserted = False
for i, msg in enumerate(messages_with_context):
if msg.type == "human":
messages_with_context.insert(i, rag_system_msg)
inserted = True
break
if not inserted:
messages_with_context.insert(0, rag_system_msg)
info(f"[llm_call] RAG上下文已添加长度: {len(state.rag_context)}")
# 恢复为:手动进行 astream并将所有的 chunk 拼接成最终的 response 返回。
# LangGraph 会自动监听这期间产生的所有 token。
chain = prompt | llm
chunks = []
info(f"[llm_call] 开始调用 LLM astream...")
async for chunk in chain.astream(
{
"messages": messages_with_context,
"memory_context": memory_context
},
config=config
):
chunks.append(chunk)
info(f"[llm_call] LLM astream 完成,共收到 {len(chunks)} 个 chunks,info:{chunks[0].content[:50]}...{chunks[-1].content[:50]}")
# 将所有 chunk 合并成最终的 AIMessage
if chunks:
response = chunks[0].content
for chunk in chunks[1:]:
response = response + chunk.content
# 将所有 chunk 合并成最终的 AIMessage
if chunks:
response = chunks[0]
for chunk in chunks[1:]:
response = response + chunk
else:
response = AIMessage(content="")
info(f"[llm_call] ⚠️ 警告: 没有收到任何 chunks")
elapsed_time = time.time() - start_time
# 提取 token 用量(兼容不同 LLM 提供商的元数据格式)
token_usage = {}
input_tokens = 0
output_tokens = 0
# 尝试从 response_metadata 中提取
if hasattr(response, 'response_metadata') and response.response_metadata:
meta = response.response_metadata
if 'token_usage' in meta:
token_usage = meta['token_usage']
elif 'usage' in meta:
token_usage = meta['usage']
# 尝试从 additional_kwargs 中提取
if not token_usage and hasattr(response, 'additional_kwargs'):
add_kwargs = response.additional_kwargs
if 'llm_output' in add_kwargs and 'token_usage' in add_kwargs['llm_output']:
token_usage = add_kwargs['llm_output']['token_usage']
# 提取具体的 token 数值
if token_usage:
input_tokens = token_usage.get('prompt_tokens', token_usage.get('input_tokens', 0))
output_tokens = token_usage.get('completion_tokens', token_usage.get('output_tokens', 0))
# 打印 LLM 的完整输出
debug("\n" + "="*80)
debug(f"📥 [LLM输出] 模型: {model_name} 返回的完整响应:")
debug(f" 消息类型: {response.type.upper()}")
debug(f" 内容长度: {len(str(response.content))} 字符")
debug("-"*80)
debug(f"{response.content}")
# 打印响应统计信息
info(f"⏱️ [LLM统计] 调用耗时: {elapsed_time:.2f}")
info(f"📊 [LLM统计] Token用量: 输入={input_tokens}, 输出={output_tokens}, 总计={input_tokens + output_tokens}")
if token_usage:
debug(f"📋 [LLM统计] 详细用量: {token_usage}")
debug("="*80 + "\n")
result = {
"messages": [response],
"llm_calls": getattr(state, 'llm_calls', 0) + 1,
"last_token_usage": token_usage,
"last_elapsed_time": elapsed_time,
"turns_since_last_summary": getattr(state, 'turns_since_last_summary', 0) + 1,
"final_result": response.content,
"success": True,
"current_phase": "done",
"current_model": model_name # 记录实际使用的模型
}
log_state_change("llm_call", state, "离开")
return result
except Exception as e:
elapsed_time = time.time() - start_time
error(f"\n❌ [LLM错误] 模型 {model_name} 调用失败 (耗时: {elapsed_time:.2f}秒)")
error(f" 错误类型: {type(e).__name__}")
error(f" 错误信息: {str(e)}")
import traceback
error(f"📋 堆栈: {traceback.format_exc()}")
debug("="*80 + "\n")
# 返回一个友好的错误消息
error_response = AIMessage(
content="抱歉,模型暂时无法响应,可能是网络超时或服务繁忙,请稍后再试。"
)
error_result = {
"messages": [error_response],
"llm_calls": getattr(state, 'llm_calls', 0),
"last_token_usage": {},
"last_elapsed_time": elapsed_time,
"turns_since_last_summary": getattr(state, 'turns_since_last_summary', 0) + 1,
"final_result": "抱歉,模型暂时无法响应,可能是网络超时或服务繁忙,请稍后再试。",
"success": False,
"current_phase": "done",
"current_model": model_name
}
log_state_change("llm_call", state, "离开(异常)")
return error_result
return call_llm

View File

@@ -1,269 +0,0 @@
"""
RAG 检索节点模块
包含RAG 检索、置信度判断、重检索等节点
"""
import time
import asyncio
from typing import Optional
from datetime import datetime
from langchain_core.runnables.config import RunnableConfig
from ...main_graph.state import MainGraphState, ErrorRecord, ErrorSeverity
from ...main_graph.utils.retry_utils import RAG_RETRY_CONFIG
from backend.app.logger import info, debug
from ...model_services import get_small_llm_service
from ._utils import dispatch_custom_event, make_react_event
# 置信度阈值配置
RAG_CONFIDENCE_THRESHOLD = 0.6 # 低于此值认为检索不相关
# 全局 pipeline 实例
_rag_pipeline = None
def _get_rag_pipeline():
"""获取 RAG Pipeline 实例"""
global _rag_pipeline
if _rag_pipeline is None:
from backend.app.rag.pipeline import RAGPipeline
_rag_pipeline = RAGPipeline(
num_queries=3,
rerank_top_n=5,
use_rerank=True,
return_parent_docs=True,
)
return _rag_pipeline
def _get_rag_tool() -> Optional[callable]:
"""获取 RAG 工具"""
from backend.app.main_graph.utils.rag_initializer import get_rag_tool
return get_rag_tool()
# ========== RAG 检索核心逻辑 ==========
async def _rag_retrieve_core(state: MainGraphState, pipeline) -> MainGraphState:
info(f"[RAG Core] _rag_retrieve_core 开始")
retrieval_query = state.user_query
# 优先使用推理结果中的优化查询 - 从新的结构化字段获取
reasoning_result = state.react_reasoning.reasoning_result
if reasoning_result and hasattr(reasoning_result, "retrieval_config"):
cfg = reasoning_result.retrieval_config
if cfg and cfg.retrieval_query:
retrieval_query = cfg.retrieval_query
info(f"[RAG Core] 使用检索查询: {retrieval_query[:50]}...")
# 直接调用 pipeline 获取文档和上下文
info(f"[RAG Core] 调用 pipeline.aretrieve")
documents = await pipeline.aretrieve(retrieval_query)
info(f"[RAG Core] pipeline.aretrieve 返回,得到 {len(documents)} 个文档")
info(f"[RAG Core] 调用 pipeline.format_context")
rag_context = pipeline.format_context(documents)
info(f"[RAG Core] pipeline.format_context 返回")
info(f"[RAG Core] 获取到 rag_context: {type(rag_context)}, 长度={len(rag_context) if rag_context else 0}")
info(f"[RAG Core] 获取到 rag_docs: {len(documents)} 个文档")
# 更新状态
state.rag_context = rag_context
state.rag_docs = documents # 保存文档用于置信度评估
state.rag_retrieved = bool(documents) # 有文档才算检索成功
state.rag_attempts = getattr(state, 'rag_attempts', 0) + 1
# 移除对 debug_info 的依赖,不再保存 rag_scores
info(f"[RAG Core] _rag_retrieve_core 结束")
return state
# ========== RAG 检索节点 ==========
async def rag_retrieve_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState:
info(f"[RAG] rag_retrieve_node 开始")
state.current_phase = "rag_retrieving"
start_time = time.time()
info(f"[RAG] 调用 _get_rag_pipeline")
pipeline = _get_rag_pipeline()
await dispatch_custom_event(
"react_reasoning",
make_react_event(state.reasoning_step, "rag_retrieve_start", 1.0, "开始执行 RAG 检索..."),
config
)
try:
info(f"[RAG] 调用 _rag_retrieve_core")
state = await _rag_retrieve_core(state, pipeline)
info(f"[RAG] _rag_retrieve_core 返回")
# 评估置信度
info(f"[RAG] 调用 _evaluate_rag_confidence")
confidence = await _evaluate_rag_confidence(state)
state.rag_confidence = confidence
info(f"[RAG] 检索完成,置信度={confidence:.2f}RAG尝试次数={state.rag_attempts}")
state.reasoning_history.append({
"step": state.reasoning_step,
"action": "RETRIEVE_RAG",
"confidence": confidence,
"reasoning": f"RAG 检索完成,置信度={confidence:.2f}",
"timestamp": datetime.now().isoformat()
})
await dispatch_custom_event(
"react_reasoning",
make_react_event(state.reasoning_step, "rag_retrieve_complete", confidence,
f"RAG 检索完成,置信度={confidence:.2f}"),
config
)
except Exception as e:
info(f"[RAG] 检索失败: {e}", exc_info=True)
state.rag_confidence = 0.0
state.rag_retrieved = False
info(f"[RAG] rag_retrieve_node 结束")
return state
async def _evaluate_rag_confidence(state: MainGraphState) -> float:
"""评估 RAG 检索结果置信度(综合向量相似度 + 重排分数 + 小模型判断)"""
query = state.user_query or ""
rag_context = state.rag_context or ""
if not rag_context:
return 0.0
# 方式1: 向量相似度(从 rag_docs 中获取)
embedding_score = _get_embedding_similarity(state)
info(f"[RAG Confidence] 向量相似度={embedding_score:.3f}")
# 方式2: 重排序分数(从 rag_docs 中获取)
rerank_score = _get_rerank_score(state)
info(f"[RAG Confidence] 重排分数={rerank_score:.3f}")
# 方式3: 小模型判断
llm_score = await _get_llm_score(state)
info(f"[RAG Confidence] LLM评估={llm_score:.3f}")
# 综合得分(加权平均)
# 向量相似度权重 0.3,重排权重 0.3LLM 权重 0.4
final_score = embedding_score * 0.3 + rerank_score * 0.3 + llm_score * 0.4
info(f"[RAG Confidence] 综合置信度={final_score:.3f} (embedding={embedding_score:.3f}*0.3 + rerank={rerank_score:.3f}*0.3 + llm={llm_score:.3f}*0.4)")
return final_score
def _get_embedding_similarity(state: MainGraphState) -> float:
"""从 rag_docs 中获取向量相似度分数(不再从 debug_info 获取)"""
# 降级:从 rag_docs 中获取
rag_docs = getattr(state, "rag_docs", [])
scores = []
for doc in rag_docs:
if isinstance(doc, dict):
score = doc.get("score", 0.0)
elif hasattr(doc, "metadata"):
score = doc.metadata.get("embedding_score", doc.metadata.get("score", 0.0))
else:
continue
if score > 1.0:
score = min(score / 10.0, 1.0)
scores.append(score)
return max(scores) if scores else 0.0
def _get_rerank_score(state: MainGraphState) -> float:
"""从 rag_docs 中获取重排序分数(不再从 debug_info 获取)"""
# 降级:从 rag_docs 中获取
rag_docs = getattr(state, "rag_docs", [])
scores = []
for doc in rag_docs:
if isinstance(doc, dict):
score = doc.get("rerank_score", 0.0)
elif hasattr(doc, "metadata"):
score = doc.metadata.get("rerank_score", 0.0)
else:
continue
if score > 0:
scores.append(score)
return max(scores) if scores else 0.0
async def _get_llm_score(state: MainGraphState) -> float:
"""使用小模型评估检索结果相关性"""
query = state.user_query or ""
rag_context = state.rag_context or ""
try:
llm = get_small_llm_service()
prompt = f"""评估以下检索结果与用户问题的相关性,返回 0.0-1.0 的分数:
- 1.0 = 完全相关,能直接回答问题
- 0.5 = 部分相关,有一定参考价值
- 0.0 = 完全不相关,无法回答问题
用户问题:{query}
检索结果:{rag_context[:1500]}
只返回一个数字:"""
response = await llm.ainvoke(prompt)
content = response.content.strip()
import re
match = re.search(r'(\d+\.?\d*)', content)
if match:
score = float(match.group(1))
return max(0.0, min(1.0, score))
except Exception as e:
info(f"[RAG Confidence] LLM评估失败: {e}")
return 0.5 # 默认中等置信度
# ========== 置信度判断节点 ==========
def check_rag_confidence(state: MainGraphState) -> str:
"""
根据 RAG 置信度判断下一步
Returns:
"high_confidence" - 高置信度(>=0.6),可直接生成回答
"low_confidence" - 低置信度(<0.6),需要联网搜索
"no_rag" - 无检索结果,需要联网搜索
"""
rag_attempts = getattr(state, 'rag_attempts', 0)
rag_confidence = getattr(state, 'rag_confidence', 0.0)
info(f"[Confidence Check] rag_attempts={rag_attempts}, rag_confidence={rag_confidence:.2f}")
# 情况1: 没有检索结果
if not getattr(state, 'rag_retrieved', False) or not state.rag_context:
info("[Confidence Check] 无检索结果,走联网")
return "no_rag"
# 情况2: 置信度低于阈值
if rag_confidence < RAG_CONFIDENCE_THRESHOLD:
if rag_attempts >= 2:
info(f"[Confidence Check] 置信度={rag_confidence:.2f}<{RAG_CONFIDENCE_THRESHOLD}且RAG尝试{rag_attempts}次,走联网")
return "low_confidence"
else:
info(f"[Confidence Check] 置信度={rag_confidence:.2f}<{RAG_CONFIDENCE_THRESHOLD}可再尝试RAG一次")
return "retry_rag"
# 情况3: 高置信度
info(f"[Confidence Check] 高置信度={rag_confidence:.2f}>={RAG_CONFIDENCE_THRESHOLD},直接生成回答")
return "high_confidence"
# ========== 导出 ==========
__all__ = [
"rag_retrieve_node",
"check_rag_confidence",
"RAG_CONFIDENCE_THRESHOLD",
]

View File

@@ -1,220 +0,0 @@
"""
路由与初始化模块
包含状态初始化节点和条件路由函数
三层统一循环防护:
1. 全局步数硬上限reasoning_step > max_steps
2. 路由模式检测A→B→A→B 交替循环)
3. 状态停滞检测(连续相同动作)
"""
from datetime import datetime
from backend.app.core.intent import get_route_by_reasoning, ReasoningAction
from ...main_graph.state import (
MainGraphState,
CurrentAction,
ReactReasoningState,
HybridRouterState,
FastPathState
)
from backend.app.logger import info
# ========== 初始化状态节点 ==========
def init_state_node(state: MainGraphState) -> MainGraphState:
"""
初始化状态节点:在流程开始时设置初始值
重置策略:
- 持久化字段(如 messages、turns_since_last_summary不重置
- 临时字段(如 rag_context、final_result重置为初始值
"""
# 持久化字段保留原样
# - messages
# - turns_since_last_summary
# - user_id
# ========== 重置临时字段 ==========
# 主图控制字段
state.user_query = ""
state.current_action = CurrentAction.NONE
state.current_model = ""
state.intent_confidence = 0.0
# React 推理专用字段
state.reasoning_step = 0
state.last_action = ""
state.reasoning_history = []
# RAG 相关字段
state.rag_context = ""
state.rag_retrieved = False
state.rag_docs = []
state.rag_confidence = 0.0
state.rag_attempts = 0
# 联网搜索相关字段
state.web_search_results = []
# 错误处理字段
state.errors = []
state.current_error = None
state.retry_action = None
state.error_message = ""
# 子图结果字段
state.news_result = None
state.dictionary_result = None
state.contact_result = None
# 执行状态
state.current_phase = "initializing"
state.final_result = ""
state.success = False
# 元数据
state.start_time = None
state.end_time = None
# 结构化状态
state.react_reasoning = ReactReasoningState()
state.hybrid_router = HybridRouterState()
state.fast_path = FastPathState()
# 统计字段
state.llm_calls = 0
state.last_token_usage = {}
state.last_elapsed_time = 0.0
state.memory_context = ""
# 向后兼容字段
state.debug_info = {}
# 设置初始值
state.current_phase = "initializing"
state.reasoning_step = 0
state.start_time = datetime.now().isoformat()
# 从 messages 中提取 user_query如果没有的话
if not state.user_query and state.messages:
last_msg = state.messages[-1]
state.user_query = getattr(last_msg, "content", str(last_msg))
return state
# ========== 条件路由函数 ==========
def route_by_reasoning(state: MainGraphState) -> str:
"""
根据推理结果决定下一步路由,带三层统一循环防护
核心逻辑:
1. DIRECT_RESPONSE → 直接返回 llm_call
2. 子图完成/已有结果 → 直接返回 llm_call
3. 步数超限 → 直接返回 llm_call
4. 其他 → 正常路由
"""
# 获取历史动作
previous_actions = [h.get("action") for h in state.reasoning_history]
info(f"[条件路由] step={state.reasoning_step}, phase={state.current_phase}, history={previous_actions}")
# ========== 获取推理结果 - 从新的结构化字段获取 ==========
reasoning_result = state.react_reasoning.reasoning_result
latest_action = reasoning_result.action.name if reasoning_result else None
# ========== 核心检查DIRECT_RESPONSE 优先 ==========
# 从 reasoning_result 检查(最新)
if latest_action == "DIRECT_RESPONSE":
info(f"[条件路由] 推理结果为 DIRECT_RESPONSE直接去 llm_call")
return "llm_call"
# 备用:从历史记录检查
if previous_actions and previous_actions[-1] == "DIRECT_RESPONSE":
info(f"[条件路由] 历史记录最新动作为 DIRECT_RESPONSE直接去 llm_call")
return "llm_call"
# ========== 子图完成/已有结果 ==========
if "subgraph_completed" in previous_actions or state.final_result:
info("[条件路由] 子图已完成或已有结果,直接终止")
return "llm_call"
# ========== 步数超限 ==========
if state.reasoning_step > state.max_steps:
info(f"[条件路由] 步数超限 ({state.reasoning_step}/{state.max_steps}),强制终止")
return "llm_call"
# ========== 特殊阶段快速通道 ==========
if state.current_phase in ("max_steps_exceeded", "finalizing", "done"):
return "llm_call"
if state.current_phase == "error_handling" or state.current_error:
return "handle_error"
# ========== 无推理结果,默认终止 ==========
if not reasoning_result:
info("[条件路由] 无推理结果,默认去 llm_call")
return "llm_call"
# ========== 计算目标路由 ==========
route = get_route_by_reasoning(reasoning_result)
route_mapping = {
"direct_response": "llm_call",
"retrieve_rag": "rag_retrieve",
"re_retrieve_rag": "rag_retrieve",
"web_search": "web_search",
"clarify": "llm_call",
"call_tool": "llm_call",
"contact": "contact_subgraph",
"dictionary": "dictionary_subgraph",
"news_analysis": "news_analysis_subgraph",
}
target = route_mapping.get(route, "llm_call")
# ========== RAG 次数硬限制 ==========
rag_attempts = getattr(state, 'rag_attempts', 0)
if target == "rag_retrieve" and rag_attempts >= 2:
info(f"[条件路由] RAG已尝试{rag_attempts}次,强制走联网搜索")
target = "web_search"
# ========== 循环防护检测 ==========
# 1. 路由模式检测A→B→A→B 交替)
if len(previous_actions) >= 4:
if (previous_actions[-4] == previous_actions[-2]
and previous_actions[-3] == previous_actions[-1]
and previous_actions[-2] != previous_actions[-1]):
info(f"[条件路由] 检测到路由循环: {previous_actions[-4:]},强制终止")
return "llm_call"
# 2. 状态停滞检测(连续相同动作 TODO本来应该是2
if len(previous_actions) >= 3 and previous_actions[-1] == previous_actions[-2] and previous_actions[-2] == previous_actions[-3]:
info(f"[条件路由] 连续相同动作 '{previous_actions[-1]}',强制终止")
return "llm_call"
# ========== 智能优化 ==========
if target == "rag_retrieve" and (state.rag_docs or state.rag_context):
info("[条件路由] RAG 结果已存在,跳过检索")
return "llm_call"
info(f"[条件路由] 动作={latest_action}, 目标={target}")
return target
# ========== 完成阶段条件路由函数 ==========
def should_summarize(state: MainGraphState) -> str:
"""
检查是否需要总结对话(对话足够长时)
Args:
state: 当前图状态
Returns:
"summarize""finalize"
"""
if state.turns_since_last_summary >= 5: # 每5轮对话总结一次
return "summarize"
else:
return "finalize"

View File

@@ -1,100 +0,0 @@
"""
工具执行节点模块
负责执行 AI 调用的工具函数
"""
import asyncio
from typing import Any, Dict
from langchain_core.messages import AIMessage, ToolMessage
from ...main_graph.config import get_stream_writer
# 本地模块
from ...main_graph.state import MainGraphState
from ...utils.logging import log_state_change
from backend.app.logger import debug, info
def create_tool_call_node(tools_by_name: Dict[str, Any]):
"""
工厂函数:创建工具执行节点
Args:
tools_by_name: 名称到工具函数的映射字典
Returns:
异步节点函数
"""
from langchain_core.runnables.config import RunnableConfig
async def call_tools(state: MainGraphState, config: RunnableConfig) -> Dict[str, Any]:
"""
工具执行节点(异步方法)
Args:
state: 当前对话状态
config: 运行时配置
Returns:
包含 ToolMessage 的状态更新
"""
log_state_change("tool_node", state, "进入")
last_message = state.messages[-1]
if not isinstance(last_message, AIMessage) or not last_message.tool_calls:
log_state_change("tool_node", state, "离开(无工具调用)")
return {"messages": []}
results = []
loop = asyncio.get_event_loop()
info(f"🛠️ [工具调用] 准备执行 {len(last_message.tool_calls)} 个工具")
for tool_call in last_message.tool_calls:
tool_name = tool_call["name"]
tool_args = tool_call["args"]
tool_id = tool_call["id"]
tool_func = tools_by_name.get(tool_name)
debug(f" ├─ 调用工具: {tool_name} 参数: {tool_args}")
if tool_func is None:
err_msg = f"Tool {tool_name} not found"
debug(f" └─ ❌ {err_msg}")
results.append(ToolMessage(content=err_msg, tool_call_id=tool_id))
continue
# 获取流式写入器并发送工具调用开始事件
writer = get_stream_writer()
writer({"type": "custom", "data": {"type": "tool_start", "tool": tool_name}})
try:
# 修复闭包问题:将变量作为默认参数传入 lambda
# 如果工具支持异步 (ainvoke),优先使用异步调用
if hasattr(tool_func, 'ainvoke'):
observation = await tool_func.ainvoke(tool_args)
else:
observation = await loop.run_in_executor(
None,
lambda args=tool_args: tool_func.invoke(args)
)
result_preview = str(observation).replace("\n", " ")
debug(f" └─ ✅ 结果: {result_preview}")
results.append(ToolMessage(content=str(observation), tool_call_id=tool_id))
# 发送工具调用完成事件
writer({"type": "custom", "data": {"type": "tool_end", "tool": tool_name, "success": True}})
except Exception as e:
debug(f" └─ ❌ 异常: {e}")
results.append(ToolMessage(content=f"Error: {e}", tool_call_id=tool_id))
# 发送工具调用失败事件
writer({"type": "custom", "data": {"type": "tool_end", "tool": tool_name, "success": False, "error": str(e)}})
info(f"🛠️ [工具调用] 执行完成,返回 {len(results)} 条 ToolMessage")
result = {"messages": results}
log_state_change("tool_node", state, "离开")
return result
return call_tools

View File

@@ -1,116 +0,0 @@
"""
联网搜索节点 - 执行搜索并将结果保存到状态
"""
from typing import Optional
from datetime import datetime
from langchain_core.runnables.config import RunnableConfig
from ...main_graph.state import MainGraphState, ErrorRecord, ErrorSeverity
from backend.app.logger import info
async def web_search_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState:
"""
联网搜索节点:执行搜索并将结果保存到状态
"""
state.current_phase = "web_searching"
# 发送开始事件
if config:
try:
from langchain_core.callbacks.manager import adispatch_custom_event
callbacks = config.get("callbacks")
if callbacks:
await adispatch_custom_event(
"react_reasoning",
{
"step": state.reasoning_step,
"action": "web_search_start",
"confidence": 1.0,
"reasoning": "开始执行联网搜索..."
},
callbacks=callbacks
)
except Exception as e:
info(f"[web_search_node] 无法发送开始事件: {e}")
# 获取搜索查询 - 从新的结构化字段获取
reasoning_result = state.react_reasoning.reasoning_result
search_query = reasoning_result.metadata.get("search_query", state.user_query) if reasoning_result else state.user_query
try:
from backend.app.core import web_search
print(f"[WebSearch] 搜索: {search_query}")
search_result = web_search(search_query, max_results=5)
# 保存搜索结果到状态
if not hasattr(state, "web_search_results"):
state.web_search_results = []
state.web_search_results.append(search_result)
# 将搜索结果添加到 rag_context供 LLM 使用
if state.rag_context:
state.rag_context = f"{state.rag_context}\n\n---\n\n## 🌐 联网搜索结果:\n{search_result}"
else:
state.rag_context = f"## 🌐 联网搜索结果:\n{search_result}"
state.success = True
print(f"[WebSearch] 搜索完成")
# 发送完成事件
if config:
try:
from langchain_core.callbacks.manager import adispatch_custom_event
callbacks = config.get("callbacks")
if callbacks:
await adispatch_custom_event(
"react_reasoning",
{
"step": state.reasoning_step,
"action": "web_search_complete",
"confidence": 1.0,
"reasoning": f"联网搜索完成,找到 {len(search_result) if isinstance(search_result, list) else 1} 条结果"
},
callbacks=callbacks
)
except Exception as e:
info(f"[web_search_node] 无法发送完成事件: {e}")
except Exception as e:
error_record = ErrorRecord(
error_type="WebSearchError",
error_message=str(e),
severity=ErrorSeverity.WARNING,
source="web_search_node",
timestamp=datetime.now().isoformat(),
retry_count=0,
max_retries=2,
context={"search_query": search_query}
)
state.errors.append(error_record)
state.current_error = error_record
state.current_phase = "error_handling"
state.success = False
# 发送错误事件
if config:
try:
from langchain_core.callbacks.manager import adispatch_custom_event
callbacks = config.get("callbacks")
if callbacks:
await adispatch_custom_event(
"react_reasoning",
{
"step": state.reasoning_step,
"action": "web_search_error",
"confidence": 1.0,
"reasoning": f"联网搜索失败: {str(e)}"
},
callbacks=callbacks
)
except Exception as e:
info(f"[web_search_node] 无法发送错误事件: {e}")
return state

View File

@@ -1,10 +0,0 @@
"""主图工具"""
from .graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
from .subgraph_tools import SUBGRAPH_TOOLS, SUBGRAPH_TOOLS_BY_NAME
__all__ = [
"AVAILABLE_TOOLS",
"TOOLS_BY_NAME",
"SUBGRAPH_TOOLS",
"SUBGRAPH_TOOLS_BY_NAME",
]

View File

@@ -1,55 +0,0 @@
"""
公共工具模块 - 联网搜索、可视化图表等公共功能
Common Tools Module - Web search, visualization, etc.
"""
from langchain_core.tools import tool
from typing import Optional
@tool
def web_search_tool(query: str, max_results: int = 5) -> str:
"""
联网搜索工具 - 无需 API Key使用 DuckDuckGo 免费搜索
Args:
query: 搜索关键词或问题
max_results: 返回结果数量,默认 5 条
Returns:
格式化的搜索结果,包含引用溯源
"""
try:
from backend.app.core import web_search
return web_search(query, max_results)
except Exception as e:
return f"联网搜索出错:{str(e)}"
@tool
def generate_chart_tool(data_text: str, chart_type: str = "bar") -> str:
"""
可视化图表工具 - 生成 Mermaid 格式图表
Args:
data_text: 图表数据文本,格式:标题,标签1:值1,标签2:值2,...
示例:月度销售额,1月:100,2月:150,3月:200
chart_type: 图表类型可选bar柱状图、line折线图、pie饼图
Returns:
格式化的图表输出Mermaid 格式)
"""
try:
from backend.app.core import generate_chart
return generate_chart(data_text, chart_type)
except Exception as e:
return f"生成图表出错:{str(e)}\n\n请使用格式:标题,标签1:值1,标签2:值2,..."
# 公共工具列表
COMMON_TOOLS = [
web_search_tool,
generate_chart_tool
]
COMMON_TOOLS_BY_NAME = {tool.name: tool for tool in COMMON_TOOLS}

View File

@@ -1,25 +0,0 @@
"""
工具定义模块 - 子图工具 + RAG 工具 + 公共工具
Subgraph Tools + RAG Tools + Common Tools
"""
# 子图工具
from .subgraph_tools import (
SUBGRAPH_TOOLS,
SUBGRAPH_TOOLS_BY_NAME,
dictionary_tool,
news_analysis_tool,
contact_tool
)
# 公共工具
from .common_tools import (
COMMON_TOOLS,
COMMON_TOOLS_BY_NAME,
web_search_tool,
generate_chart_tool
)
# 工具列表和映射(全局常量)
AVAILABLE_TOOLS = SUBGRAPH_TOOLS.copy() + COMMON_TOOLS.copy()
TOOLS_BY_NAME = {**SUBGRAPH_TOOLS_BY_NAME, **COMMON_TOOLS_BY_NAME}

View File

@@ -1,193 +0,0 @@
"""
子图工具模块 - 将三个子图包装成 LangChain 工具
Subgraph Tools Module - Wrap three subgraphs as LangChain tools
"""
from langchain_core.tools import tool
from typing import Optional
# ============== 词典子图工具 ==============
@tool
def dictionary_tool(query: str, action: Optional[str] = None) -> str:
"""
词典/翻译工具 - 查询单词、翻译文本、提取术语、获取每日一词
Args:
query: 用户查询内容(单词、句子、文本等)
action: 可选,指定操作类型("query" 查询单词,"translate" 翻译,
"extract" 提取术语,"daily" 每日一词,不指定则自动识别)
Returns:
格式化的结果文本
"""
try:
from backend.app.subgraphs.dictionary import (
DictionaryState,
DictionaryAction,
parse_intent,
format_result
)
from backend.app.subgraphs.dictionary.nodes import (
query_word, translate_text, extract_terms, get_daily_word
)
# 创建初始状态
state = DictionaryState(user_query=query, user_id="default")
# 处理 action
if action == "query":
state.action = DictionaryAction.QUERY
state.action_params = {"word": query}
state = query_word(state)
elif action == "translate":
state.action = DictionaryAction.TRANSLATE
state.source_text = query
state = translate_text(state)
elif action == "daily":
state.action = DictionaryAction.DAILY_WORD
state = get_daily_word(state)
elif action == "extract":
state.action = DictionaryAction.EXTRACT
state.action_params = {"text": query}
state = extract_terms(state)
else:
# 自动解析意图
state = parse_intent(state)
if state.action == DictionaryAction.QUERY:
state = query_word(state)
elif state.action == DictionaryAction.TRANSLATE:
state = translate_text(state)
elif state.action == DictionaryAction.DAILY_WORD:
state = get_daily_word(state)
elif state.action == DictionaryAction.EXTRACT:
state = extract_terms(state)
# 格式化结果
state = format_result(state)
return state.final_result or "操作完成"
except Exception as e:
return f"词典工具执行出错:{str(e)}"
# ============== 资讯分析子图工具 ==============
@tool
def news_analysis_tool(query: str, action: Optional[str] = None) -> str:
"""
资讯分析工具 - 查询新闻、分析URL、提取关键词、生成报告
Args:
query: 用户查询内容关键词、URL、文本等
action: 可选,指定操作类型("query" 查询新闻,"analyze" 分析URL
"keywords" 提取关键词,"report" 生成报告,不指定则自动识别)
Returns:
格式化的结果文本
"""
try:
from backend.app.subgraphs.news_analysis import (
NewsAnalysisState,
NewsAction,
parse_intent,
format_result
)
from backend.app.subgraphs.news_analysis.nodes import (
query_news, analyze_url, extract_keywords, generate_report
)
# 创建初始状态
state = NewsAnalysisState(user_query=query, user_id="default")
# 处理 action
if action == "query":
state.action = NewsAction.QUERY_NEWS
state = query_news(state)
elif action == "analyze":
state.action = NewsAction.ANALYZE_URL
state.custom_urls = [query]
state = analyze_url(state)
elif action == "keywords":
state.action = NewsAction.EXTRACT_KEYWORDS
state = extract_keywords(state)
elif action == "report":
state.action = NewsAction.GENERATE_REPORT
state = generate_report(state)
else:
# 自动解析意图
state = parse_intent(state)
if state.action == NewsAction.QUERY_NEWS:
state = query_news(state)
elif state.action == NewsAction.ANALYZE_URL:
state.custom_urls = [query]
state = analyze_url(state)
elif state.action == NewsAction.EXTRACT_KEYWORDS:
state = extract_keywords(state)
elif state.action == NewsAction.GENERATE_REPORT:
state = generate_report(state)
# 格式化结果
state = format_result(state)
return state.final_result or "操作完成"
except Exception as e:
return f"资讯分析工具执行出错:{str(e)}"
# ============== 通讯录子图工具 ==============
@tool
def contact_tool(query: str, action: Optional[str] = None) -> str:
"""
通讯录工具 - 查询联系人、添加联系人、管理通讯录
Args:
query: 用户查询内容(姓名、电话、信息等)
action: 可选,指定操作类型(不指定则自动识别)
Returns:
格式化的结果文本
"""
try:
from backend.app.subgraphs.contact import (
ContactState,
ContactAction,
parse_intent,
format_result
)
from backend.app.subgraphs.contact.nodes import (
query_contact, add_contact, list_contacts
)
# 创建初始状态
state = ContactState(user_query=query, user_id="default")
# 自动解析意图
state = parse_intent(state)
# 根据解析结果执行操作
if state.action == ContactAction.QUERY:
state = query_contact(state)
elif state.action == ContactAction.ADD:
state = add_contact(state)
elif state.action == ContactAction.LIST:
state = list_contacts(state)
# 格式化结果
state = format_result(state)
return state.final_result or "操作完成"
except Exception as e:
return f"通讯录工具执行出错:{str(e)}"
# ============== 工具列表 ==============
SUBGRAPH_TOOLS = [
dictionary_tool,
news_analysis_tool,
contact_tool
]
SUBGRAPH_TOOLS_BY_NAME = {tool.name: tool for tool in SUBGRAPH_TOOLS}

View File

@@ -1,73 +0,0 @@
# app/rag_initializer.py
from ...rag.tools import create_rag_tool
from ...rag.retriever import create_parent_hybrid_retriever
from ...model_services import get_embedding_service
from backend.app.logger import info, warning
import sys
# 全局 RAG 工具
_rag_tool = None
_initialized = False
def get_rag_tool() -> callable:
"""获取全局 RAG 工具"""
return _rag_tool
def is_initialized() -> bool:
"""检查是否已初始化"""
return _initialized
async def init_rag_tool(force: bool = False):
"""
初始化 RAG 工具(注册到模块级变量,内部获取所需服务)
Args:
force: 是否强制重新初始化
Returns:
RAG 工具(@tool 装饰函数)或 None
"""
global _rag_tool, _initialized
# 防止重复初始化
if _initialized and not force:
info("[RAG] 已初始化,跳过")
return _rag_tool
try:
from backend.app.model_services.chat_services import get_chat_service
info("🔄 正在初始化 RAG 检索系统...")
embeddings = get_embedding_service()
retriever = create_parent_hybrid_retriever(
collection_name="rag_documents",
search_k=5,
embeddings=embeddings,
)
rewrite_llm = get_chat_service()
rag_tool = create_rag_tool(
retriever=retriever,
llm=rewrite_llm,
num_queries=3,
rerank_top_n=5,
)
_rag_tool = rag_tool
_initialized = True
info(f"✅ RAG 检索工具初始化成功 (id={id(rag_tool)})")
return rag_tool
except Exception as e:
warning(f"⚠️ RAG 检索工具初始化失败: {e}")
return None
def reset():
"""重置(用于测试)"""
global _rag_tool, _initialized
_rag_tool = None
_initialized = False

View File

@@ -1,332 +0,0 @@
"""
超时和重试工具模块
为 React 模式提供超时控制和重试机制
"""
import time
import asyncio
from functools import wraps
from typing import Callable, Any, Optional, Type, Tuple, Union
from dataclasses import dataclass, field
from enum import Enum, auto
class RetryStrategy(Enum):
"""重试策略"""
FIXED = auto() # 固定间隔
EXPONENTIAL = auto() # 指数退避
LINEAR = auto() # 线性增长
@dataclass
class RetryConfig:
"""重试配置"""
max_retries: int = 3 # 最大重试次数
base_delay: float = 1.0 # 基础延迟(秒)
max_delay: float = 10.0 # 最大延迟(秒)
strategy: RetryStrategy = RetryStrategy.EXPONENTIAL
timeout: Optional[float] = 30.0 # 单次调用超时(秒)
recoverable_exceptions: Tuple[Type[Exception], ...] = field(
default_factory=lambda: (Exception,)
)
unrecoverable_exceptions: Tuple[Type[Exception], ...] = field(
default_factory=tuple
)
@dataclass
class RetryResult:
"""重试结果"""
success: bool
result: Any = None
error: Optional[Exception] = None
retry_count: int = 0
total_time: float = 0.0
timed_out: bool = False
# ========== 同步重试装饰器 ==========
def with_retry(
config: Optional[RetryConfig] = None,
max_retries: int = 3,
timeout: Optional[float] = 30.0,
base_delay: float = 1.0,
on_retry: Optional[Callable[[int, Exception], None]] = None
):
"""
同步重试装饰器
Args:
config: 重试配置对象
max_retries: 最大重试次数(如果没有 config
timeout: 单次调用超时(秒)
base_delay: 基础延迟(秒)
on_retry: 重试回调函数(retry_count, exception)
"""
if config is None:
config = RetryConfig(
max_retries=max_retries,
timeout=timeout,
base_delay=base_delay
)
def decorator(func: Callable) -> Callable:
@wraps(func)
def wrapper(*args, **kwargs) -> RetryResult:
start_time = time.time()
last_error = None
for attempt in range(config.max_retries + 1):
try:
# 执行函数(带超时)
if config.timeout:
# 使用信号量或线程实现超时(简化版)
result = func(*args, **kwargs)
else:
result = func(*args, **kwargs)
# 成功
total_time = time.time() - start_time
return RetryResult(
success=True,
result=result,
retry_count=attempt,
total_time=total_time
)
except Exception as e:
last_error = e
# 检查是否是不可恢复的异常
if isinstance(e, config.unrecoverable_exceptions):
break
# 检查是否达到最大重试次数
if attempt >= config.max_retries:
break
# 计算延迟
delay = _calculate_delay(attempt, config)
# 回调通知
if on_retry:
on_retry(attempt + 1, e)
# 等待
time.sleep(delay)
# 所有重试都失败
total_time = time.time() - start_time
return RetryResult(
success=False,
error=last_error,
retry_count=config.max_retries,
total_time=total_time
)
return wrapper
return decorator
# ========== 异步重试装饰器 ==========
def with_async_retry(
config: Optional[RetryConfig] = None,
max_retries: int = 3,
timeout: Optional[float] = 30.0,
base_delay: float = 1.0,
on_retry: Optional[Callable[[int, Exception], None]] = None
):
"""
异步重试装饰器
"""
if config is None:
config = RetryConfig(
max_retries=max_retries,
timeout=timeout,
base_delay=base_delay
)
def decorator(func: Callable) -> Callable:
@wraps(func)
async def wrapper(*args, **kwargs) -> RetryResult:
start_time = time.time()
last_error = None
for attempt in range(config.max_retries + 1):
try:
# 执行函数(带超时)
if config.timeout:
result = await asyncio.wait_for(
func(*args, **kwargs),
timeout=config.timeout
)
else:
result = await func(*args, **kwargs)
# 成功
total_time = time.time() - start_time
return RetryResult(
success=True,
result=result,
retry_count=attempt,
total_time=total_time
)
except asyncio.TimeoutError as e:
last_error = e
timed_out = True
except Exception as e:
last_error = e
timed_out = False
# 检查是否是不可恢复的异常
if isinstance(e, config.unrecoverable_exceptions):
break
# 检查是否达到最大重试次数
if attempt >= config.max_retries:
break
# 计算延迟
delay = _calculate_delay(attempt, config)
# 回调通知
if on_retry:
on_retry(attempt + 1, last_error)
# 等待
await asyncio.sleep(delay)
# 所有重试都失败
total_time = time.time() - start_time
return RetryResult(
success=False,
error=last_error,
retry_count=config.max_retries,
total_time=total_time,
timed_out=isinstance(last_error, asyncio.TimeoutError)
)
return wrapper
return decorator
# ========== 辅助函数 ==========
def _calculate_delay(attempt: int, config: RetryConfig) -> float:
"""计算延迟时间"""
if config.strategy == RetryStrategy.FIXED:
delay = config.base_delay
elif config.strategy == RetryStrategy.LINEAR:
delay = config.base_delay * (attempt + 1)
elif config.strategy == RetryStrategy.EXPONENTIAL:
delay = config.base_delay * (2 ** attempt)
else:
delay = config.base_delay
# 不超过最大延迟
return min(delay, config.max_delay)
# ========== 为 React 节点设计的超时重试包装器 ==========
def create_retry_wrapper_for_node(
node_func: Callable,
node_name: str,
max_retries: int = 2,
timeout: float = 30.0
):
"""
为 React 节点创建带重试和超时的包装器
Args:
node_func: 原始节点函数
node_name: 节点名称(用于错误标识)
max_retries: 最大重试次数
timeout: 单次执行超时
Returns: 包装后的节点函数
"""
config = RetryConfig(
max_retries=max_retries,
timeout=timeout,
strategy=RetryStrategy.EXPONENTIAL
)
@wraps(node_func)
def wrapped_node(state):
# 记录开始时间
start_time = time.time()
# 重试循环
last_error = None
for attempt in range(config.max_retries + 1):
try:
# 执行节点
result = node_func(state)
# 检查节点是否报告了错误
if hasattr(state, "current_error") and state.current_error:
# 节点内部报告了错误,继续重试
last_error = Exception(state.current_error.error_message)
if attempt < config.max_retries:
delay = _calculate_delay(attempt, config)
time.sleep(delay)
continue
# 成功
return result
except Exception as e:
last_error = e
if attempt >= config.max_retries:
break
# 等待后重试
delay = _calculate_delay(attempt, config)
time.sleep(delay)
# 所有重试都失败,更新状态错误信息
from backend.app.main_graph.state import ErrorRecord, ErrorSeverity
error_record = ErrorRecord(
error_type=f"{node_name}TimeoutError",
error_message=str(last_error) if last_error else f"{node_name} 执行超时",
severity=ErrorSeverity.ERROR,
source=node_name,
retry_count=config.max_retries,
max_retries=config.max_retries,
context={
"timeout": timeout,
"total_time": time.time() - start_time
}
)
if hasattr(state, "errors"):
state.errors.append(error_record)
if hasattr(state, "current_error"):
state.current_error = error_record
if hasattr(state, "error_message"):
state.error_message = str(last_error)
if hasattr(state, "current_phase"):
state.current_phase = "error_handling"
return state
return wrapped_node
# ========== 预配置的 RAG 重试配置 ==========
RAG_RETRY_CONFIG = RetryConfig(
max_retries=2,
timeout=60.0, # RAG 可以容忍稍长的超时
base_delay=2.0,
strategy=RetryStrategy.EXPONENTIAL
)
# ========== 预配置的子图重试配置 ==========
SUBGRAPH_RETRY_CONFIG = RetryConfig(
max_retries=1, # 子图通常不适合多次重试
timeout=120.0, # 子图执行时间较长
base_delay=3.0
)