""" RAG 节点模块 - 真正利用已有 RAG 代码 包含: - rag_retrieve_node: RAG 检索节点(带超时重试) - rag_re_retrieve_node: 重新检索节点 - 集成 backend/app/rag/tools.py 和 rag_initializer.py """ import time import asyncio from typing import Dict, Any, Optional from datetime import datetime from app.main_graph.state import MainGraphState, ErrorRecord, ErrorSeverity from app.main_graph.utils.retry_utils import ( RetryConfig, RAG_RETRY_CONFIG, create_retry_wrapper_for_node ) # 真正导入和利用已有 RAG 代码 from app.rag.tools import create_rag_tool_sync from app.rag.pipeline import RAGPipeline # ========== 全局 RAG 工具实例(延迟初始化)========== _GLOBAL_RAG_TOOL: Optional[Any] = None _GLOBAL_RAG_PIPELINE: Optional[RAGPipeline] = None def get_global_rag_tool() -> Optional[Any]: """ 获取全局 RAG 工具(单例模式) Returns: RAG 工具实例或 None """ return _GLOBAL_RAG_TOOL def set_global_rag_tool(tool: Any) -> None: """ 设置全局 RAG 工具(通常在应用启动时调用) Args: tool: RAG 工具实例 """ global _GLOBAL_RAG_TOOL _GLOBAL_RAG_TOOL = tool def set_global_rag_pipeline(pipeline: RAGPipeline) -> None: """ 设置全局 RAG Pipeline Args: pipeline: RAGPipeline 实例 """ global _GLOBAL_RAG_PIPELINE _GLOBAL_RAG_PIPELINE = pipeline # ========== 从状态获取 RAG 工具 ========== def get_rag_tool_from_state(state: MainGraphState) -> Optional[Any]: """ 从状态中获取 RAG 工具(如果有) Args: state: 主图状态 Returns: RAG 工具实例或 None """ # 优先从状态获取 if "rag_tool" in state.debug_info: return state.debug_info["rag_tool"] # 其次从全局获取 return get_global_rag_tool() # ========== 工具:将 RAG 工具注入到状态 ========== def inject_rag_tool_to_state(state: MainGraphState, rag_tool: Any) -> MainGraphState: """ 将 RAG 工具注入到状态中,供后续节点使用 Args: state: 主图状态 rag_tool: RAG 工具实例 Returns: 更新后的状态 """ state.debug_info["rag_tool"] = rag_tool state.debug_info["rag_tool_injected"] = datetime.now().isoformat() return state # ========== RAG 检索核心逻辑(真正利用已有代码) ========== async def _rag_retrieve_core(state: MainGraphState) -> MainGraphState: """ RAG 检索核心逻辑(真正利用 rag/tools.py) - 异步版本 Args: state: 主图状态 Returns: 更新后的状态 """ # 获取检索查询(优先使用推理结果中的优化查询) retrieval_query = state.user_query if "reasoning_result" in state.debug_info: reasoning_result = state.debug_info["reasoning_result"] if hasattr(reasoning_result, "retrieval_config"): cfg = reasoning_result.retrieval_config if cfg and cfg.retrieval_query: retrieval_query = cfg.retrieval_query # 尝试获取 RAG 工具(多种方式) rag_tool = get_rag_tool_from_state(state) if rag_tool: # 使用真正的 RAG 工具(来自 rag/tools.py)- 异步版本 try: # 直接 await 异步工具的 ainvoke 方法 rag_context = await rag_tool.ainvoke(retrieval_query) state.rag_context = rag_context state.rag_docs = [ {"source": "rag_retrieval", "content": rag_context} ] state.rag_retrieved = True state.success = True state.debug_info["rag_source"] = "rag_tool" return state except Exception as e: raise RuntimeError(f"RAG 工具调用失败: {str(e)}") from e elif _GLOBAL_RAG_PIPELINE: # 使用 RAG Pipeline 直接检索 - 直接用异步方法 try: documents = await _GLOBAL_RAG_PIPELINE.aretrieve(retrieval_query) if documents: rag_context = _GLOBAL_RAG_PIPELINE.format_context(documents) state.rag_context = rag_context state.rag_docs = [ {"source": doc.metadata.get("source", "unknown"), "content": doc.page_content} for doc in documents ] else: state.rag_context = f"未找到与 '{retrieval_query}' 相关的知识库信息。" state.rag_docs = [] state.rag_retrieved = True state.success = True state.debug_info["rag_source"] = "rag_pipeline" return state except Exception as e: raise RuntimeError(f"RAG Pipeline 调用失败: {str(e)}") from e else: # 没有可用的 RAG 工具/Pipeline raise RuntimeError("RAG 工具未初始化,请先调用 set_global_rag_tool() 或 set_global_rag_pipeline()") # ========== RAG 检索节点(带超时和重试) ========== async def rag_retrieve_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState: """ RAG 检索节点:带超时和重试,真正利用已有 RAG 代码 Args: state: 主图状态 config: LangChain 配置 Returns: 更新后的状态 """ state.current_phase = "rag_retrieving" # 发送开始事件 if config: try: from langchain_core.callbacks.manager import adispatch_custom_event callbacks = config.get("callbacks") if callbacks: await adispatch_custom_event( "react_reasoning", { "step": state.reasoning_step, "action": "rag_retrieve_start", "confidence": 1.0, "reasoning": "开始执行 RAG 检索..." }, callbacks=callbacks ) except Exception as e: info(f"[rag_retrieve_node] 无法发送开始事件: {e}") start_time = time.time() last_error = None for attempt in range(RAG_RETRY_CONFIG.max_retries + 1): try: # 执行核心逻辑 - 异步 await result = await _rag_retrieve_core(state) info(f"[rag_retrieve_node] RAG 检索成功,获取到上下文长度: {len(result.rag_context)} 字符") if result.rag_docs: for i, doc in enumerate(result.rag_docs[:3]): # 只显示前3条 info(f"[rag_retrieve_node] 文档 {i+1}: {doc.get('content', '')[:100]}...") # 成功 state.debug_info["rag_retrieval"] = { "attempt": attempt + 1, "success": True, "time": time.time() - start_time } # 发送完成事件 if config: try: from langchain_core.callbacks.manager import adispatch_custom_event callbacks = config.get("callbacks") if callbacks: doc_count = len(result.rag_docs) if result.rag_docs else 0 await adispatch_custom_event( "react_reasoning", { "step": state.reasoning_step, "action": "rag_retrieve_complete", "confidence": 1.0, "reasoning": f"RAG 检索完成,找到 {doc_count} 条相关文档" }, callbacks=callbacks ) except Exception as e: info(f"[rag_retrieve_node] 无法发送完成事件: {e}") # 关键修复:把 rag_retrieve 加到 reasoning_history 里,让下次推理知道 state.reasoning_history.append({ "step": state.reasoning_step, "action": "rag_retrieve", "confidence": 1.0, "reasoning": "RAG 检索完成", "timestamp": datetime.now().isoformat() }) return result except Exception as e: last_error = e if attempt >= RAG_RETRY_CONFIG.max_retries: break # 发送重试事件 if config: try: from langchain_core.callbacks.manager import adispatch_custom_event callbacks = config.get("callbacks") if callbacks: await adispatch_custom_event( "react_reasoning", { "step": state.reasoning_step, "action": "rag_retrieve_retry", "confidence": 1.0, "reasoning": f"RAG 检索失败,第 {attempt + 1} 次重试..." }, callbacks=callbacks ) except Exception as e: info(f"[rag_retrieve_node] 无法发送重试事件: {e}") # 指数退避等待 delay = RAG_RETRY_CONFIG.base_delay * (2 ** attempt) await asyncio.sleep(min(delay, RAG_RETRY_CONFIG.max_delay)) # 所有重试都失败,记录结构化错误 error_record = ErrorRecord( error_type="RAGRetrievalError", error_message=str(last_error) if last_error else "RAG 检索超时", severity=ErrorSeverity.WARNING, source="rag_retrieve_node", timestamp=datetime.now().isoformat(), retry_count=RAG_RETRY_CONFIG.max_retries, max_retries=RAG_RETRY_CONFIG.max_retries, context={ "query": state.user_query, "total_time": time.time() - start_time, "timeout": RAG_RETRY_CONFIG.timeout, "has_rag_tool": get_global_rag_tool() is not None, "has_rag_pipeline": _GLOBAL_RAG_PIPELINE is not None } ) state.errors.append(error_record) state.current_error = error_record state.current_phase = "error_handling" # 发送错误事件 if config: try: from langchain_core.callbacks.manager import adispatch_custom_event callbacks = config.get("callbacks") if callbacks: await adispatch_custom_event( "react_reasoning", { "step": state.reasoning_step, "action": "rag_retrieve_error", "confidence": 1.0, "reasoning": f"RAG 检索失败: {str(last_error)}" }, callbacks=callbacks ) except Exception as e: info(f"[rag_retrieve_node] 无法发送错误事件: {e}") return state # ========== 重新检索节点 ========== def rag_re_retrieve_node(state: MainGraphState) -> MainGraphState: """ 重新检索节点:用于第二次检索(不同的参数) Args: state: 主图状态 Returns: 更新后的状态 """ state.current_phase = "rag_re_retrieving" # 记录原始检索信息 state.debug_info["rag_re_retrieve"] = { "original_retrieved": state.rag_retrieved, "original_docs_count": len(state.rag_docs) } # 可以在这里修改检索参数(例如:调整查询、增加 k 值) # 暂时复用同一个检索逻辑 return rag_retrieve_node(state) # ========== 便捷函数:从 rag_initializer 初始化 ========== async def initialize_rag_from_initializer() -> None: """ 从 rag_initializer 初始化 RAG(便捷函数) 注意:这是示例代码,实际使用时需要提供 local_llm_creator """ try: from app.main_graph.utils.rag_initializer import init_rag_tool # 注意:这里需要传入 local_llm_creator # 示例: # def my_llm_creator(): # from ..model_services import get_llm # return get_llm() # # rag_tool = await init_rag_tool(my_llm_creator) # set_global_rag_tool(rag_tool) print("⚠️ initialize_rag_from_initializer 需要传入 local_llm_creator") print("⚠️ 请在应用启动时调用 init_rag_tool() 并设置全局 RAG 工具") except ImportError as e: print(f"⚠️ 无法导入 rag_initializer: {e}") except Exception as e: print(f"⚠️ RAG 初始化失败: {e}") # ========== 导出 ========== __all__ = [ # 节点函数 "rag_retrieve_node", "rag_re_retrieve_node", # 工具函数 "inject_rag_tool_to_state", "get_rag_tool_from_state", # 全局 RAG 管理 "get_global_rag_tool", "set_global_rag_tool", "set_global_rag_pipeline", # 初始化 "initialize_rag_from_initializer" ]