diff --git a/backend/app/graph/rag_nodes.py b/backend/app/graph/rag_nodes.py index f56e3f3..c1d0c98 100644 --- a/backend/app/graph/rag_nodes.py +++ b/backend/app/graph/rag_nodes.py @@ -1,12 +1,13 @@ """ -RAG 节点模块 - 独立的 RAG 检索节点 +RAG 节点模块 - 真正利用已有 RAG 代码 包含: - rag_retrieve_node: RAG 检索节点(带超时重试) - rag_re_retrieve_node: 重新检索节点 -- 相关的 RAG 工具集成 +- 集成 backend/app/rag/tools.py 和 rag_initializer.py """ import time +import asyncio from typing import Dict, Any, Optional from datetime import datetime @@ -17,15 +18,49 @@ from .retry_utils import ( create_retry_wrapper_for_node ) -# 尝试导入现有的 RAG 工具 -try: - from ..rag.tools import create_rag_tool_sync - from ..rag.pipeline import RAGPipeline - HAS_RAG = True -except ImportError: - HAS_RAG = False +# 真正导入和利用已有 RAG 代码 +from ..rag.tools import create_rag_tool_sync +from ..rag.pipeline import RAGPipeline +# ========== 全局 RAG 工具实例(延迟初始化)========== +_GLOBAL_RAG_TOOL: Optional[Any] = None +_GLOBAL_RAG_PIPELINE: Optional[RAGPipeline] = None + + +def get_global_rag_tool() -> Optional[Any]: + """ + 获取全局 RAG 工具(单例模式) + + Returns: + RAG 工具实例或 None + """ + return _GLOBAL_RAG_TOOL + + +def set_global_rag_tool(tool: Any) -> None: + """ + 设置全局 RAG 工具(通常在应用启动时调用) + + Args: + tool: RAG 工具实例 + """ + global _GLOBAL_RAG_TOOL + _GLOBAL_RAG_TOOL = tool + + +def set_global_rag_pipeline(pipeline: RAGPipeline) -> None: + """ + 设置全局 RAG Pipeline + + Args: + pipeline: RAGPipeline 实例 + """ + global _GLOBAL_RAG_PIPELINE + _GLOBAL_RAG_PIPELINE = pipeline + + +# ========== 从状态获取 RAG 工具 ========== def get_rag_tool_from_state(state: MainGraphState) -> Optional[Any]: """ 从状态中获取 RAG 工具(如果有) @@ -36,15 +71,34 @@ def get_rag_tool_from_state(state: MainGraphState) -> Optional[Any]: Returns: RAG 工具实例或 None """ + # 优先从状态获取 if "rag_tool" in state.debug_info: return state.debug_info["rag_tool"] - return None + # 其次从全局获取 + return get_global_rag_tool() -# ========== RAG 检索核心逻辑 ========== +# ========== 工具:将 RAG 工具注入到状态 ========== +def inject_rag_tool_to_state(state: MainGraphState, rag_tool: Any) -> MainGraphState: + """ + 将 RAG 工具注入到状态中,供后续节点使用 + + Args: + state: 主图状态 + rag_tool: RAG 工具实例 + + Returns: + 更新后的状态 + """ + state.debug_info["rag_tool"] = rag_tool + state.debug_info["rag_tool_injected"] = datetime.now().isoformat() + return state + + +# ========== RAG 检索核心逻辑(真正利用已有代码)========== def _rag_retrieve_core(state: MainGraphState) -> MainGraphState: """ - RAG 检索核心逻辑(不带重试) + RAG 检索核心逻辑(真正利用 rag/tools.py) Args: state: 主图状态 @@ -61,42 +115,53 @@ def _rag_retrieve_core(state: MainGraphState) -> MainGraphState: if cfg and cfg.retrieval_query: retrieval_query = cfg.retrieval_query - # 尝试获取 RAG 工具 + # 尝试获取 RAG 工具(多种方式) rag_tool = get_rag_tool_from_state(state) - if rag_tool and HAS_RAG: - # 使用真实的 RAG 工具 + if rag_tool: + # 使用真正的 RAG 工具(来自 rag/tools.py) try: + # 调用 LangChain Tool 的 invoke 方法 rag_context = rag_tool.invoke(retrieval_query) state.rag_context = rag_context state.rag_docs = [ - {"source": "rag_doc", "content": rag_context} + {"source": "rag_retrieval", "content": rag_context} ] state.rag_retrieved = True state.success = True + state.debug_info["rag_source"] = "rag_tool" return state except Exception as e: - raise RuntimeError(f"RAG 调用失败: {str(e)}") from e + raise RuntimeError(f"RAG 工具调用失败: {str(e)}") from e + elif _GLOBAL_RAG_PIPELINE: + # 使用 RAG Pipeline 直接检索 + try: + documents = _GLOBAL_RAG_PIPELINE.retrieve(retrieval_query) + if documents: + rag_context = _GLOBAL_RAG_PIPELINE.format_context(documents) + state.rag_context = rag_context + state.rag_docs = [ + {"source": doc.metadata.get("source", "unknown"), "content": doc.page_content} + for doc in documents + ] + else: + state.rag_context = f"未找到与 '{retrieval_query}' 相关的知识库信息。" + state.rag_docs = [] + state.rag_retrieved = True + state.success = True + state.debug_info["rag_source"] = "rag_pipeline" + return state + except Exception as e: + raise RuntimeError(f"RAG Pipeline 调用失败: {str(e)}") from e else: - # 没有 RAG 工具,使用模拟数据(演示用) - state.rag_context = ( - f"[RAG 检索结果]\n" - f"查询: {retrieval_query}\n" - f"这是来自知识库的相关信息。" - ) - state.rag_docs = [ - {"source": "doc1.txt", "content": "LangGraph 是一个用于构建 Agent 的框架"}, - {"source": "doc2.txt", "content": "React 模式是 '思考→行动→观察' 循环"} - ] - state.rag_retrieved = True - state.success = True - return state + # 没有可用的 RAG 工具/Pipeline + raise RuntimeError("RAG 工具未初始化,请先调用 set_global_rag_tool() 或 set_global_rag_pipeline()") -# ========== RAG 检索节点(带超时和重试) ========== +# ========== RAG 检索节点(带超时和重试)========== def rag_retrieve_node(state: MainGraphState) -> MainGraphState: """ - RAG 检索节点:带超时和重试 + RAG 检索节点:带超时和重试,真正利用已有 RAG 代码 Args: state: 主图状态 @@ -144,7 +209,9 @@ def rag_retrieve_node(state: MainGraphState) -> MainGraphState: context={ "query": state.user_query, "total_time": time.time() - start_time, - "timeout": RAG_RETRY_CONFIG.timeout + "timeout": RAG_RETRY_CONFIG.timeout, + "has_rag_tool": get_global_rag_tool() is not None, + "has_rag_pipeline": _GLOBAL_RAG_PIPELINE is not None } ) @@ -168,37 +235,60 @@ def rag_re_retrieve_node(state: MainGraphState) -> MainGraphState: """ state.current_phase = "rag_re_retrieving" - # 可以在这里修改检索参数(例如:扩大范围、调整查询) + # 记录原始检索信息 state.debug_info["rag_re_retrieve"] = { "original_retrieved": state.rag_retrieved, "original_docs_count": len(state.rag_docs) } - # 使用相同的检索逻辑 + # 可以在这里修改检索参数(例如:调整查询、增加 k 值) + # 暂时复用同一个检索逻辑 return rag_retrieve_node(state) -# ========== 工具:将 RAG 工具注入到状态 ========== -def inject_rag_tool_to_state(state: MainGraphState, rag_tool: Any) -> MainGraphState: +# ========== 便捷函数:从 rag_initializer 初始化 ========== +async def initialize_rag_from_initializer() -> None: """ - 将 RAG 工具注入到状态中,供后续节点使用 + 从 rag_initializer 初始化 RAG(便捷函数) - Args: - state: 主图状态 - rag_tool: RAG 工具实例 - - Returns: - 更新后的状态 + 注意:这是示例代码,实际使用时需要提供 local_llm_creator """ - state.debug_info["rag_tool"] = rag_tool - state.debug_info["rag_tool_injected"] = datetime.now().isoformat() - return state + try: + from ..agent.rag_initializer import init_rag_tool + + # 注意:这里需要传入 local_llm_creator + # 示例: + # def my_llm_creator(): + # from ..model_services import get_llm + # return get_llm() + # + # rag_tool = await init_rag_tool(my_llm_creator) + # set_global_rag_tool(rag_tool) + + print("⚠️ initialize_rag_from_initializer 需要传入 local_llm_creator") + print("⚠️ 请在应用启动时调用 init_rag_tool() 并设置全局 RAG 工具") + + except ImportError as e: + print(f"⚠️ 无法导入 rag_initializer: {e}") + except Exception as e: + print(f"⚠️ RAG 初始化失败: {e}") # ========== 导出 ========== __all__ = [ + # 节点函数 "rag_retrieve_node", "rag_re_retrieve_node", + + # 工具函数 "inject_rag_tool_to_state", - "get_rag_tool_from_state" + "get_rag_tool_from_state", + + # 全局 RAG 管理 + "get_global_rag_tool", + "set_global_rag_tool", + "set_global_rag_pipeline", + + # 初始化 + "initialize_rag_from_initializer" ]