refactor: 将 RAG 节点拆分为独立模块
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Has been cancelled

- 新增 rag_nodes.py: 独立的 RAG 检索节点
- 从 react_nodes.py 移除 RAG 相关代码
- 更新导入和导出
- rag_nodes.py 包含 rag_retrieve_node 和 rag_re_retrieve_node
- 添加 inject_rag_tool_to_state 工具函数
This commit is contained in:
2026-04-26 11:23:12 +08:00
parent e3adb45454
commit aba261df35
4 changed files with 250 additions and 138 deletions

View File

@@ -7,11 +7,16 @@ from .subgraph_builder import build_main_graph, build_react_main_graph
from .react_nodes import ( from .react_nodes import (
init_state_node, init_state_node,
react_reason_node, react_reason_node,
rag_retrieve_node,
error_handling_node, error_handling_node,
final_response_node, final_response_node,
route_by_reasoning route_by_reasoning
) )
from .rag_nodes import (
rag_retrieve_node,
rag_re_retrieve_node,
inject_rag_tool_to_state,
get_rag_tool_from_state
)
from .state import ( from .state import (
MessagesState, MessagesState,
GraphContext, GraphContext,
@@ -44,13 +49,18 @@ __all__ = [
"build_react_main_graph", "build_react_main_graph",
"init_state_node", "init_state_node",
"react_reason_node", "react_reason_node",
"rag_retrieve_node",
"error_handling_node", "error_handling_node",
"final_response_node", "final_response_node",
"route_by_reasoning", "route_by_reasoning",
"ErrorRecord", "ErrorRecord",
"ErrorSeverity", "ErrorSeverity",
# RAG 节点(独立模块)
"rag_retrieve_node",
"rag_re_retrieve_node",
"inject_rag_tool_to_state",
"get_rag_tool_from_state",
# 超时和重试工具 # 超时和重试工具
"RetryConfig", "RetryConfig",
"RetryResult", "RetryResult",

View File

@@ -0,0 +1,204 @@
"""
RAG 节点模块 - 独立的 RAG 检索节点
包含:
- rag_retrieve_node: RAG 检索节点(带超时重试)
- rag_re_retrieve_node: 重新检索节点
- 相关的 RAG 工具集成
"""
import time
from typing import Dict, Any, Optional
from datetime import datetime
from .state import MainGraphState, ErrorRecord, ErrorSeverity
from .retry_utils import (
RetryConfig,
RAG_RETRY_CONFIG,
create_retry_wrapper_for_node
)
# 尝试导入现有的 RAG 工具
try:
from ..rag.tools import create_rag_tool_sync
from ..rag.pipeline import RAGPipeline
HAS_RAG = True
except ImportError:
HAS_RAG = False
def get_rag_tool_from_state(state: MainGraphState) -> Optional[Any]:
"""
从状态中获取 RAG 工具(如果有)
Args:
state: 主图状态
Returns:
RAG 工具实例或 None
"""
if "rag_tool" in state.debug_info:
return state.debug_info["rag_tool"]
return None
# ========== RAG 检索核心逻辑 ==========
def _rag_retrieve_core(state: MainGraphState) -> MainGraphState:
"""
RAG 检索核心逻辑(不带重试)
Args:
state: 主图状态
Returns:
更新后的状态
"""
# 获取检索查询(优先使用推理结果中的优化查询)
retrieval_query = state.user_query
if "reasoning_result" in state.debug_info:
reasoning_result = state.debug_info["reasoning_result"]
if hasattr(reasoning_result, "retrieval_config"):
cfg = reasoning_result.retrieval_config
if cfg and cfg.retrieval_query:
retrieval_query = cfg.retrieval_query
# 尝试获取 RAG 工具
rag_tool = get_rag_tool_from_state(state)
if rag_tool and HAS_RAG:
# 使用真实的 RAG 工具
try:
rag_context = rag_tool.invoke(retrieval_query)
state.rag_context = rag_context
state.rag_docs = [
{"source": "rag_doc", "content": rag_context}
]
state.rag_retrieved = True
state.success = True
return state
except Exception as e:
raise RuntimeError(f"RAG 调用失败: {str(e)}") from e
else:
# 没有 RAG 工具,使用模拟数据(演示用)
state.rag_context = (
f"[RAG 检索结果]\n"
f"查询: {retrieval_query}\n"
f"这是来自知识库的相关信息。"
)
state.rag_docs = [
{"source": "doc1.txt", "content": "LangGraph 是一个用于构建 Agent 的框架"},
{"source": "doc2.txt", "content": "React 模式是 '思考→行动→观察' 循环"}
]
state.rag_retrieved = True
state.success = True
return state
# ========== RAG 检索节点(带超时和重试) ==========
def rag_retrieve_node(state: MainGraphState) -> MainGraphState:
"""
RAG 检索节点:带超时和重试
Args:
state: 主图状态
Returns:
更新后的状态
"""
state.current_phase = "rag_retrieving"
start_time = time.time()
last_error = None
for attempt in range(RAG_RETRY_CONFIG.max_retries + 1):
try:
# 执行核心逻辑
result = _rag_retrieve_core(state)
# 成功
state.debug_info["rag_retrieval"] = {
"attempt": attempt + 1,
"success": True,
"time": time.time() - start_time
}
return result
except Exception as e:
last_error = e
if attempt >= RAG_RETRY_CONFIG.max_retries:
break
# 指数退避等待
delay = RAG_RETRY_CONFIG.base_delay * (2 ** attempt)
time.sleep(min(delay, RAG_RETRY_CONFIG.max_delay))
# 所有重试都失败,记录结构化错误
error_record = ErrorRecord(
error_type="RAGRetrievalError",
error_message=str(last_error) if last_error else "RAG 检索超时",
severity=ErrorSeverity.WARNING,
source="rag_retrieve_node",
timestamp=datetime.now().isoformat(),
retry_count=RAG_RETRY_CONFIG.max_retries,
max_retries=RAG_RETRY_CONFIG.max_retries,
context={
"query": state.user_query,
"total_time": time.time() - start_time,
"timeout": RAG_RETRY_CONFIG.timeout
}
)
state.errors.append(error_record)
state.current_error = error_record
state.current_phase = "error_handling"
return state
# ========== 重新检索节点 ==========
def rag_re_retrieve_node(state: MainGraphState) -> MainGraphState:
"""
重新检索节点:用于第二次检索(不同的参数)
Args:
state: 主图状态
Returns:
更新后的状态
"""
state.current_phase = "rag_re_retrieving"
# 可以在这里修改检索参数(例如:扩大范围、调整查询)
state.debug_info["rag_re_retrieve"] = {
"original_retrieved": state.rag_retrieved,
"original_docs_count": len(state.rag_docs)
}
# 使用相同的检索逻辑
return rag_retrieve_node(state)
# ========== 工具:将 RAG 工具注入到状态 ==========
def inject_rag_tool_to_state(state: MainGraphState, rag_tool: Any) -> MainGraphState:
"""
将 RAG 工具注入到状态中,供后续节点使用
Args:
state: 主图状态
rag_tool: RAG 工具实例
Returns:
更新后的状态
"""
state.debug_info["rag_tool"] = rag_tool
state.debug_info["rag_tool_injected"] = datetime.now().isoformat()
return state
# ========== 导出 ==========
__all__ = [
"rag_retrieve_node",
"rag_re_retrieve_node",
"inject_rag_tool_to_state",
"get_rag_tool_from_state"
]

View File

@@ -2,50 +2,30 @@
React 模式节点模块 - 带超时和重试功能 React 模式节点模块 - 带超时和重试功能
包含: 包含:
- react_reason_node: 使用 intent.py 进行推理 - react_reason_node: 使用 intent.py 进行推理
- rag_retrieve_node: RAG 检索节点(带重试)
- error_handling_node: 错误处理节点 - error_handling_node: 错误处理节点
- final_response_node: 最终回答节点 - final_response_node: 最终回答节点
- init_state_node: 初始化节点
""" """
import sys import sys
import time
from typing import Dict, Any, Optional from typing import Dict, Any, Optional
from datetime import datetime from datetime import datetime
from functools import wraps
# 导入我们的 intent.py # 导入我们的 intent.py
from ..agent_subgraphs.common.intent import ( from ..agent_subgraphs.common.intent import (
react_reason, react_reason,
get_route_by_reasoning, get_route_by_reasoning,
ReasoningAction, ReasoningAction,
RetrievalConfig,
ReasoningResult ReasoningResult
) )
from ..agent_subgraphs.common.state_base import StateUtils from ..agent_subgraphs.common.state_base import StateUtils
from .state import MainGraphState, ErrorRecord, ErrorSeverity from .state import MainGraphState, ErrorRecord, ErrorSeverity
from .retry_utils import ( from .retry_utils import (
RetryConfig, RetryConfig,
RetryResult,
with_retry,
create_retry_wrapper_for_node,
RAG_RETRY_CONFIG,
SUBGRAPH_RETRY_CONFIG SUBGRAPH_RETRY_CONFIG
) )
def get_rag_tool():
"""
获取 RAG 工具(延迟导入,避免循环依赖)
"""
try:
# 尝试导入现有的 RAG 工具
from ..rag.tools import create_rag_tool_sync
# 注意:这里简化处理,实际使用时应该从全局获取初始化好的工具
return None # 先返回 None后面通过注入方式
except Exception:
return None
# ========== 1. React 推理节点 ========== # ========== 1. React 推理节点 ==========
def react_reason_node(state: MainGraphState) -> MainGraphState: def react_reason_node(state: MainGraphState) -> MainGraphState:
""" """
@@ -94,7 +74,7 @@ def react_reason_node(state: MainGraphState) -> MainGraphState:
"reasoning": result.reasoning "reasoning": result.reasoning
} }
# 保存推理结果到状态(供条件路由使用) # 保存推理结果到状态
state.debug_info["reasoning_result"] = result state.debug_info["reasoning_result"] = result
# 确定下一步动作 # 确定下一步动作
@@ -103,97 +83,7 @@ def react_reason_node(state: MainGraphState) -> MainGraphState:
return state return state
# ========== 2. RAG 检索节点(带超时和重试) ========== # ========== 2. 错误处理节点 ==========
def _rag_retrieve_core(state: MainGraphState) -> MainGraphState:
"""
RAG 检索核心逻辑(不带重试)
"""
# 获取推理结果中的检索配置
reasoning_result: Optional[ReasoningResult] = state.debug_info.get("reasoning_result")
retrieval_query = state.user_query
if reasoning_result and reasoning_result.retrieval_config:
cfg: RetrievalConfig = reasoning_result.retrieval_config
if cfg.retrieval_query:
retrieval_query = cfg.retrieval_query
# 尝试获取 RAG 工具并调用
# 这里演示如何调用,实际使用时需要确保 RAG 已初始化
# 暂时用模拟数据
state.rag_context = (
f"[模拟RAG检索结果]\n"
f"查询: {retrieval_query}\n"
f"这是一个来自知识库的示例回答。"
)
state.rag_docs = [
{"source": "doc1.txt", "content": "示例内容1"},
{"source": "doc2.txt", "content": "示例内容2"}
]
state.rag_retrieved = True
state.success = True
return state
def rag_retrieve_node(state: MainGraphState) -> MainGraphState:
"""
RAG 检索节点:带超时和重试
Returns: 更新后的状态
"""
state.current_phase = "rag_retrieving"
# 使用重试包装器
start_time = time.time()
last_error = None
for attempt in range(RAG_RETRY_CONFIG.max_retries + 1):
try:
# 执行核心逻辑
result = _rag_retrieve_core(state)
# 成功
state.debug_info["rag_retrieval"] = {
"attempt": attempt + 1,
"success": True,
"time": time.time() - start_time
}
return result
except Exception as e:
last_error = e
if attempt >= RAG_RETRY_CONFIG.max_retries:
break
# 等待后重试(指数退避)
delay = RAG_RETRY_CONFIG.base_delay * (2 ** attempt)
time.sleep(min(delay, RAG_RETRY_CONFIG.max_delay))
# 所有重试都失败,记录结构化错误
error_record = ErrorRecord(
error_type="RAGRetrievalError",
error_message=str(last_error) if last_error else "RAG 检索超时",
severity=ErrorSeverity.WARNING,
source="rag_retrieve_node",
timestamp=datetime.now().isoformat(),
retry_count=RAG_RETRY_CONFIG.max_retries,
max_retries=RAG_RETRY_CONFIG.max_retries,
context={
"query": state.user_query,
"total_time": time.time() - start_time,
"timeout": RAG_RETRY_CONFIG.timeout
}
)
state.errors.append(error_record)
state.current_error = error_record
state.current_phase = "error_handling"
return state
# ========== 3. 错误处理节点 ==========
def error_handling_node(state: MainGraphState) -> MainGraphState: def error_handling_node(state: MainGraphState) -> MainGraphState:
""" """
错误处理节点:处理子图/工具调用错误 错误处理节点:处理子图/工具调用错误
@@ -210,7 +100,6 @@ def error_handling_node(state: MainGraphState) -> MainGraphState:
state.current_phase = "error_handling" state.current_phase = "error_handling"
if not state.current_error: if not state.current_error:
# 没有错误,直接返回
state.current_phase = "react_reasoning" state.current_phase = "react_reasoning"
return state return state
@@ -219,7 +108,7 @@ def error_handling_node(state: MainGraphState) -> MainGraphState:
# 更新错误状态 # 更新错误状态
state.error_message = f"{error.error_type}: {error.error_message}" state.error_message = f"{error.error_type}: {error.error_message}"
# 记录结构化错误信息(用于 LLM 决策) # 记录结构化错误信息
structured_error = { structured_error = {
"tool": error.source, "tool": error.source,
"status": "failed", "status": "failed",
@@ -231,11 +120,11 @@ def error_handling_node(state: MainGraphState) -> MainGraphState:
# 根据错误类型添加建议 # 根据错误类型添加建议
if "RAG" in error.error_type: if "RAG" in error.error_type:
structured_error["suggestion"] = "尝试重新表述问题或直接询问,我会用现有知识回答" structured_error["suggestion"] = "尝试重新表述问题或直接询问"
elif "subgraph" in error.source or "contact" in error.source: elif "subgraph" in error.source or "contact" in error.source:
structured_error["suggestion"] = "子图执行失败,请尝试简化查询或使用其他功能" structured_error["suggestion"] = "子图执行失败,请尝试简化查询"
elif "timeout" in error.error_message.lower(): elif "timeout" in error.error_message.lower():
structured_error["suggestion"] = "请求超时,请稍后再试或简化查询" structured_error["suggestion"] = "请求超时,请稍后再试"
else: else:
structured_error["suggestion"] = "请尝试其他方式提问" structured_error["suggestion"] = "请尝试其他方式提问"
@@ -248,7 +137,6 @@ def error_handling_node(state: MainGraphState) -> MainGraphState:
) )
if can_retry: if can_retry:
# 重试策略
error.retry_count += 1 error.retry_count += 1
state.retry_action = error.source state.retry_action = error.source
state.debug_info["retry_count"] = error.retry_count state.debug_info["retry_count"] = error.retry_count
@@ -265,7 +153,6 @@ def error_handling_node(state: MainGraphState) -> MainGraphState:
# 策略2: 无法重试,尝试降级方案 # 策略2: 无法重试,尝试降级方案
if error.severity != ErrorSeverity.FATAL: if error.severity != ErrorSeverity.FATAL:
# 降级到直接回答模式
state.final_result = ( state.final_result = (
f"⚠️ 遇到一些问题:\n" f"⚠️ 遇到一些问题:\n"
f"```json\n{structured_error}\n```\n" f"```json\n{structured_error}\n```\n"
@@ -275,7 +162,7 @@ def error_handling_node(state: MainGraphState) -> MainGraphState:
state.current_phase = "finalizing" state.current_phase = "finalizing"
return state return state
# 策略3: 致命错误,无法继续 # 策略3: 致命错误
state.final_result = ( state.final_result = (
f"❌ 服务暂时不可用,请稍后再试。\n" f"❌ 服务暂时不可用,请稍后再试。\n"
f"```json\n{structured_error}\n```" f"```json\n{structured_error}\n```"
@@ -286,7 +173,7 @@ def error_handling_node(state: MainGraphState) -> MainGraphState:
return state return state
# ========== 4. 最终回答节点 ========== # ========== 3. 最终回答节点 ==========
def final_response_node(state: MainGraphState) -> MainGraphState: def final_response_node(state: MainGraphState) -> MainGraphState:
""" """
最终回答节点:整理并生成最终回答 最终回答节点:整理并生成最终回答
@@ -307,12 +194,15 @@ def final_response_node(state: MainGraphState) -> MainGraphState:
parts.append("---") parts.append("---")
# 添加子图结果(如果有) # 添加子图结果(如果有)
if state.contact_result and state.contact_result.get("final_result"): if state.contact_result and hasattr(state.contact_result, "get"):
parts.append(state.contact_result["final_result"]) if state.contact_result.get("final_result"):
if state.dictionary_result and state.dictionary_result.get("final_result"): parts.append(state.contact_result["final_result"])
parts.append(state.dictionary_result["final_result"]) if state.dictionary_result and hasattr(state.dictionary_result, "get"):
if state.news_result and state.news_result.get("final_result"): if state.dictionary_result.get("final_result"):
parts.append(state.news_result["final_result"]) parts.append(state.dictionary_result["final_result"])
if state.news_result and hasattr(state.news_result, "get"):
if state.news_result.get("final_result"):
parts.append(state.news_result["final_result"])
# 如果都没有,用默认回答 # 如果都没有,用默认回答
if not parts: if not parts:
@@ -326,7 +216,7 @@ def final_response_node(state: MainGraphState) -> MainGraphState:
return state return state
# ========== 5. 初始化状态节点 ========== # ========== 4. 初始化状态节点 ==========
def init_state_node(state: MainGraphState) -> MainGraphState: def init_state_node(state: MainGraphState) -> MainGraphState:
""" """
初始化状态节点:在流程开始时设置初始值 初始化状态节点:在流程开始时设置初始值
@@ -335,7 +225,7 @@ def init_state_node(state: MainGraphState) -> MainGraphState:
state.reasoning_step = 0 state.reasoning_step = 0
state.start_time = datetime.now().isoformat() state.start_time = datetime.now().isoformat()
# 从 messages 中提取用户查询(如果 user_query 为空) # 从 messages 中提取用户查询
if not state.user_query and state.messages: if not state.user_query and state.messages:
last_msg = state.messages[-1] last_msg = state.messages[-1]
state.user_query = getattr(last_msg, "content", str(last_msg)) state.user_query = getattr(last_msg, "content", str(last_msg))
@@ -343,7 +233,7 @@ def init_state_node(state: MainGraphState) -> MainGraphState:
return state return state
# ========== 6. 条件路由函数 ========== # ========== 5. 条件路由函数 ==========
def route_by_reasoning(state: MainGraphState) -> str: def route_by_reasoning(state: MainGraphState) -> str:
""" """
根据推理结果决定下一步路由 根据推理结果决定下一步路由
@@ -358,7 +248,6 @@ def route_by_reasoning(state: MainGraphState) -> str:
if state.current_phase == "finalizing" or state.current_phase == "done": if state.current_phase == "finalizing" or state.current_phase == "done":
return "final_response" return "final_response"
if state.current_phase == "retrying": if state.current_phase == "retrying":
# 重试路由
if state.retry_action and "rag" in state.retry_action.lower(): if state.retry_action and "rag" in state.retry_action.lower():
return "rag_retrieve" return "rag_retrieve"
return "react_reason" return "react_reason"
@@ -367,7 +256,6 @@ def route_by_reasoning(state: MainGraphState) -> str:
reasoning_result: Optional[ReasoningResult] = state.debug_info.get("reasoning_result") reasoning_result: Optional[ReasoningResult] = state.debug_info.get("reasoning_result")
if not reasoning_result: if not reasoning_result:
# 没有推理结果,直接结束
return "final_response" return "final_response"
# 使用 intent.py 提供的路由函数 # 使用 intent.py 提供的路由函数
@@ -378,11 +266,21 @@ def route_by_reasoning(state: MainGraphState) -> str:
"direct_response": "final_response", "direct_response": "final_response",
"retrieve_rag": "rag_retrieve", "retrieve_rag": "rag_retrieve",
"re_retrieve_rag": "rag_retrieve", "re_retrieve_rag": "rag_retrieve",
"clarify": "final_response", # 简化:澄清直接回答让用户补充 "clarify": "final_response",
"call_tool": "final_response", # 简化:工具调用暂未实现 "call_tool": "final_response",
"contact": "contact_subgraph", "contact": "contact_subgraph",
"dictionary": "dictionary_subgraph", "dictionary": "dictionary_subgraph",
"news_analysis": "news_analysis_subgraph", "news_analysis": "news_analysis_subgraph",
} }
return route_mapping.get(route, "final_response") return route_mapping.get(route, "final_response")
# ========== 导出 ==========
__all__ = [
"init_state_node",
"react_reason_node",
"error_handling_node",
"final_response_node",
"route_by_reasoning"
]

View File

@@ -10,11 +10,11 @@ from .state import MainGraphState, CurrentAction
from .react_nodes import ( from .react_nodes import (
init_state_node, init_state_node,
react_reason_node, react_reason_node,
rag_retrieve_node,
error_handling_node, error_handling_node,
final_response_node, final_response_node,
route_by_reasoning route_by_reasoning
) )
from .rag_nodes import rag_retrieve_node
from ..agent_subgraphs.contact import build_contact_subgraph from ..agent_subgraphs.contact import build_contact_subgraph
from ..agent_subgraphs.dictionary import build_dictionary_subgraph from ..agent_subgraphs.dictionary import build_dictionary_subgraph
from ..agent_subgraphs.news_analysis import build_news_analysis_subgraph from ..agent_subgraphs.news_analysis import build_news_analysis_subgraph