""" RAG 检索节点模块 包含 RAG 检索节点(带超时重试) """ import time import asyncio from typing import Dict, Any, Optional from datetime import datetime from app.main_graph.state import MainGraphState, ErrorRecord, ErrorSeverity from app.main_graph.utils.retry_utils import RAG_RETRY_CONFIG from app.logger import info from ._utils import dispatch_custom_event, make_react_event from app.rag.tools import create_rag_tool from app.rag.pipeline import RAGPipeline # ========== 全局 RAG 工具实例 ========== _GLOBAL_RAG_TOOL: Optional[Any] = None _GLOBAL_RAG_PIPELINE: Optional[RAGPipeline] = None def get_global_rag_tool() -> Optional[Any]: return _GLOBAL_RAG_TOOL def set_global_rag_tool(tool: Any) -> None: global _GLOBAL_RAG_TOOL _GLOBAL_RAG_TOOL = tool def set_global_rag_pipeline(pipeline: RAGPipeline) -> None: global _GLOBAL_RAG_PIPELINE _GLOBAL_RAG_PIPELINE = pipeline def get_rag_tool_from_state(state: MainGraphState) -> Optional[Any]: """从状态或全局获取 RAG 工具""" return state.debug_info.get("rag_tool") or get_global_rag_tool() def inject_rag_tool_to_state(state: MainGraphState, rag_tool: Any) -> MainGraphState: """将 RAG 工具注入到状态中""" state.debug_info["rag_tool"] = rag_tool state.debug_info["rag_tool_injected"] = datetime.now().isoformat() return state # ========== RAG 检索核心逻辑 ========== async def _rag_retrieve_core(state: MainGraphState) -> MainGraphState: """执行 RAG 检索的核心逻辑""" retrieval_query = state.user_query # 优先使用推理结果中的优化查询 reasoning_result = state.debug_info.get("reasoning_result") if reasoning_result and hasattr(reasoning_result, "retrieval_config"): cfg = reasoning_result.retrieval_config if cfg and cfg.retrieval_query: retrieval_query = cfg.retrieval_query rag_tool = get_rag_tool_from_state(state) if rag_tool: rag_context = await rag_tool.ainvoke(retrieval_query) state.rag_context = rag_context state.rag_docs = [{"source": "rag_retrieval", "content": rag_context}] state.rag_retrieved = True state.success = True state.debug_info["rag_source"] = "rag_tool" return state if _GLOBAL_RAG_PIPELINE: documents = await _GLOBAL_RAG_PIPELINE.aretrieve(retrieval_query) if documents: rag_context = _GLOBAL_RAG_PIPELINE.format_context(documents) state.rag_context = rag_context state.rag_docs = [ {"source": doc.metadata.get("source", "unknown"), "content": doc.page_content} for doc in documents ] else: state.rag_context = f"未找到与 '{retrieval_query}' 相关的知识库信息。" state.rag_docs = [] state.rag_retrieved = True state.success = True state.debug_info["rag_source"] = "rag_pipeline" return state raise RuntimeError("RAG 工具未初始化,请先调用 set_global_rag_tool() 或 set_global_rag_pipeline()") # ========== RAG 检索节点 ========== async def rag_retrieve_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState: """RAG 检索节点:带超时和重试""" state.current_phase = "rag_retrieving" start_time = time.time() last_error = None # 步骤1: 发送开始事件 await dispatch_custom_event( "react_reasoning", make_react_event(state.reasoning_step, "rag_retrieve_start", 1.0, "开始执行 RAG 检索..."), config ) # 步骤2: 执行检索(带重试) for attempt in range(RAG_RETRY_CONFIG.max_retries + 1): try: result = await _rag_retrieve_core(state) info(f"[RAG] 检索成功,上下文长度: {len(result.rag_context)} 字符") state.debug_info["rag_retrieval"] = { "attempt": attempt + 1, "success": True, "time": time.time() - start_time } # 记录成功到历史 state.reasoning_history.append({ "step": state.reasoning_step, "action": "RETRIEVE_RAG", "confidence": 1.0, "reasoning": "RAG 检索完成", "timestamp": datetime.now().isoformat() }) # 发送完成事件 doc_count = len(result.rag_docs) if result.rag_docs else 0 await dispatch_custom_event( "react_reasoning", make_react_event(state.reasoning_step, "rag_retrieve_complete", 1.0, f"RAG 检索完成,找到 {doc_count} 条相关文档"), config ) return result except Exception as e: last_error = e if attempt >= RAG_RETRY_CONFIG.max_retries: break # 发送重试事件 await dispatch_custom_event( "react_reasoning", make_react_event(state.reasoning_step, "rag_retrieve_retry", 1.0, f"RAG 检索失败,第 {attempt + 1} 次重试..."), config ) # 指数退避 delay = RAG_RETRY_CONFIG.base_delay * (2 ** attempt) await asyncio.sleep(min(delay, RAG_RETRY_CONFIG.max_delay)) # 步骤3: 所有重试失败,记录到历史(避免推理循环) state.reasoning_history.append({ "step": state.reasoning_step, "action": "RETRIEVE_RAG", "confidence": 0.0, "reasoning": f"RAG 检索失败: {str(last_error) if last_error else '超时'}", "timestamp": datetime.now().isoformat() }) # 步骤4: 记录错误 error_record = ErrorRecord( error_type="RAGRetrievalError", error_message=str(last_error) if last_error else "RAG 检索超时", severity=ErrorSeverity.WARNING, source="rag_retrieve_node", timestamp=datetime.now().isoformat(), retry_count=RAG_RETRY_CONFIG.max_retries, max_retries=RAG_RETRY_CONFIG.max_retries, context={ "query": state.user_query, "total_time": time.time() - start_time, "has_rag_tool": get_global_rag_tool() is not None, "has_rag_pipeline": _GLOBAL_RAG_PIPELINE is not None } ) state.errors.append(error_record) state.current_error = error_record state.current_phase = "error_handling" # 发送错误事件 await dispatch_custom_event( "react_reasoning", make_react_event(state.reasoning_step, "rag_retrieve_error", 1.0, f"RAG 检索失败: {str(last_error)}"), config ) return state # ========== 重新检索节点 ========== async def rag_re_retrieve_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState: """重新检索节点""" state.current_phase = "rag_re_retrieving" state.debug_info["rag_re_retrieve"] = { "original_retrieved": state.rag_retrieved, "original_docs_count": len(state.rag_docs) } return await rag_retrieve_node(state, config) # ========== 导出 ========== __all__ = [ "rag_retrieve_node", "rag_re_retrieve_node", "inject_rag_tool_to_state", "get_rag_tool_from_state", "get_global_rag_tool", "set_global_rag_tool", "set_global_rag_pipeline", ]