""" RAG 节点模块 - 独立的 RAG 检索节点 包含: - rag_retrieve_node: RAG 检索节点(带超时重试) - rag_re_retrieve_node: 重新检索节点 - 相关的 RAG 工具集成 """ import time from typing import Dict, Any, Optional from datetime import datetime from .state import MainGraphState, ErrorRecord, ErrorSeverity from .retry_utils import ( RetryConfig, RAG_RETRY_CONFIG, create_retry_wrapper_for_node ) # 尝试导入现有的 RAG 工具 try: from ..rag.tools import create_rag_tool_sync from ..rag.pipeline import RAGPipeline HAS_RAG = True except ImportError: HAS_RAG = False def get_rag_tool_from_state(state: MainGraphState) -> Optional[Any]: """ 从状态中获取 RAG 工具(如果有) Args: state: 主图状态 Returns: RAG 工具实例或 None """ if "rag_tool" in state.debug_info: return state.debug_info["rag_tool"] return None # ========== RAG 检索核心逻辑 ========== def _rag_retrieve_core(state: MainGraphState) -> MainGraphState: """ RAG 检索核心逻辑(不带重试) Args: state: 主图状态 Returns: 更新后的状态 """ # 获取检索查询(优先使用推理结果中的优化查询) retrieval_query = state.user_query if "reasoning_result" in state.debug_info: reasoning_result = state.debug_info["reasoning_result"] if hasattr(reasoning_result, "retrieval_config"): cfg = reasoning_result.retrieval_config if cfg and cfg.retrieval_query: retrieval_query = cfg.retrieval_query # 尝试获取 RAG 工具 rag_tool = get_rag_tool_from_state(state) if rag_tool and HAS_RAG: # 使用真实的 RAG 工具 try: rag_context = rag_tool.invoke(retrieval_query) state.rag_context = rag_context state.rag_docs = [ {"source": "rag_doc", "content": rag_context} ] state.rag_retrieved = True state.success = True return state except Exception as e: raise RuntimeError(f"RAG 调用失败: {str(e)}") from e else: # 没有 RAG 工具,使用模拟数据(演示用) state.rag_context = ( f"[RAG 检索结果]\n" f"查询: {retrieval_query}\n" f"这是来自知识库的相关信息。" ) state.rag_docs = [ {"source": "doc1.txt", "content": "LangGraph 是一个用于构建 Agent 的框架"}, {"source": "doc2.txt", "content": "React 模式是 '思考→行动→观察' 循环"} ] state.rag_retrieved = True state.success = True return state # ========== RAG 检索节点(带超时和重试) ========== def rag_retrieve_node(state: MainGraphState) -> MainGraphState: """ RAG 检索节点:带超时和重试 Args: state: 主图状态 Returns: 更新后的状态 """ state.current_phase = "rag_retrieving" start_time = time.time() last_error = None for attempt in range(RAG_RETRY_CONFIG.max_retries + 1): try: # 执行核心逻辑 result = _rag_retrieve_core(state) # 成功 state.debug_info["rag_retrieval"] = { "attempt": attempt + 1, "success": True, "time": time.time() - start_time } return result except Exception as e: last_error = e if attempt >= RAG_RETRY_CONFIG.max_retries: break # 指数退避等待 delay = RAG_RETRY_CONFIG.base_delay * (2 ** attempt) time.sleep(min(delay, RAG_RETRY_CONFIG.max_delay)) # 所有重试都失败,记录结构化错误 error_record = ErrorRecord( error_type="RAGRetrievalError", error_message=str(last_error) if last_error else "RAG 检索超时", severity=ErrorSeverity.WARNING, source="rag_retrieve_node", timestamp=datetime.now().isoformat(), retry_count=RAG_RETRY_CONFIG.max_retries, max_retries=RAG_RETRY_CONFIG.max_retries, context={ "query": state.user_query, "total_time": time.time() - start_time, "timeout": RAG_RETRY_CONFIG.timeout } ) state.errors.append(error_record) state.current_error = error_record state.current_phase = "error_handling" return state # ========== 重新检索节点 ========== def rag_re_retrieve_node(state: MainGraphState) -> MainGraphState: """ 重新检索节点:用于第二次检索(不同的参数) Args: state: 主图状态 Returns: 更新后的状态 """ state.current_phase = "rag_re_retrieving" # 可以在这里修改检索参数(例如:扩大范围、调整查询) state.debug_info["rag_re_retrieve"] = { "original_retrieved": state.rag_retrieved, "original_docs_count": len(state.rag_docs) } # 使用相同的检索逻辑 return rag_retrieve_node(state) # ========== 工具:将 RAG 工具注入到状态 ========== def inject_rag_tool_to_state(state: MainGraphState, rag_tool: Any) -> MainGraphState: """ 将 RAG 工具注入到状态中,供后续节点使用 Args: state: 主图状态 rag_tool: RAG 工具实例 Returns: 更新后的状态 """ state.debug_info["rag_tool"] = rag_tool state.debug_info["rag_tool_injected"] = datetime.now().isoformat() return state # ========== 导出 ========== __all__ = [ "rag_retrieve_node", "rag_re_retrieve_node", "inject_rag_tool_to_state", "get_rag_tool_from_state" ]