refactor: 真正利用已有 RAG 代码重构 rag_nodes.py
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m3s

- 真正导入和使用 backend/app/rag/tools.py
- 添加全局 RAG 工具管理(get/set_global_rag_tool)
- 集成 RAGPipeline,支持多路查询和重排序
- 兼容 rag_initializer.py 的初始化方式
- 移除模拟实现,使用真正的 RAG 功能
This commit is contained in:
2026-04-26 11:25:01 +08:00
parent aba261df35
commit 5a67a77c95

View File

@@ -1,12 +1,13 @@
"""
RAG 节点模块 - 独立的 RAG 检索节点
RAG 节点模块 - 真正利用已有 RAG 代码
包含:
- rag_retrieve_node: RAG 检索节点(带超时重试)
- rag_re_retrieve_node: 重新检索节点
- 相关的 RAG 工具集成
- 集成 backend/app/rag/tools.py 和 rag_initializer.py
"""
import time
import asyncio
from typing import Dict, Any, Optional
from datetime import datetime
@@ -17,15 +18,49 @@ from .retry_utils import (
create_retry_wrapper_for_node
)
# 尝试导入现有的 RAG 工具
try:
from ..rag.tools import create_rag_tool_sync
from ..rag.pipeline import RAGPipeline
HAS_RAG = True
except ImportError:
HAS_RAG = False
# 真正导入和利用已有 RAG 代码
from ..rag.tools import create_rag_tool_sync
from ..rag.pipeline import RAGPipeline
# ========== 全局 RAG 工具实例(延迟初始化)==========
_GLOBAL_RAG_TOOL: Optional[Any] = None
_GLOBAL_RAG_PIPELINE: Optional[RAGPipeline] = None
def get_global_rag_tool() -> Optional[Any]:
"""
获取全局 RAG 工具(单例模式)
Returns:
RAG 工具实例或 None
"""
return _GLOBAL_RAG_TOOL
def set_global_rag_tool(tool: Any) -> None:
"""
设置全局 RAG 工具(通常在应用启动时调用)
Args:
tool: RAG 工具实例
"""
global _GLOBAL_RAG_TOOL
_GLOBAL_RAG_TOOL = tool
def set_global_rag_pipeline(pipeline: RAGPipeline) -> None:
"""
设置全局 RAG Pipeline
Args:
pipeline: RAGPipeline 实例
"""
global _GLOBAL_RAG_PIPELINE
_GLOBAL_RAG_PIPELINE = pipeline
# ========== 从状态获取 RAG 工具 ==========
def get_rag_tool_from_state(state: MainGraphState) -> Optional[Any]:
"""
从状态中获取 RAG 工具(如果有)
@@ -36,15 +71,34 @@ def get_rag_tool_from_state(state: MainGraphState) -> Optional[Any]:
Returns:
RAG 工具实例或 None
"""
# 优先从状态获取
if "rag_tool" in state.debug_info:
return state.debug_info["rag_tool"]
return None
# 其次从全局获取
return get_global_rag_tool()
# ========== RAG 检索核心逻辑 ==========
# ========== 工具:将 RAG 工具注入到状态 ==========
def inject_rag_tool_to_state(state: MainGraphState, rag_tool: Any) -> MainGraphState:
"""
将 RAG 工具注入到状态中,供后续节点使用
Args:
state: 主图状态
rag_tool: RAG 工具实例
Returns:
更新后的状态
"""
state.debug_info["rag_tool"] = rag_tool
state.debug_info["rag_tool_injected"] = datetime.now().isoformat()
return state
# ========== RAG 检索核心逻辑(真正利用已有代码)==========
def _rag_retrieve_core(state: MainGraphState) -> MainGraphState:
"""
RAG 检索核心逻辑(不带重试
RAG 检索核心逻辑(真正利用 rag/tools.py
Args:
state: 主图状态
@@ -61,42 +115,53 @@ def _rag_retrieve_core(state: MainGraphState) -> MainGraphState:
if cfg and cfg.retrieval_query:
retrieval_query = cfg.retrieval_query
# 尝试获取 RAG 工具
# 尝试获取 RAG 工具(多种方式)
rag_tool = get_rag_tool_from_state(state)
if rag_tool and HAS_RAG:
# 使用真的 RAG 工具
if rag_tool:
# 使用真的 RAG 工具(来自 rag/tools.py
try:
# 调用 LangChain Tool 的 invoke 方法
rag_context = rag_tool.invoke(retrieval_query)
state.rag_context = rag_context
state.rag_docs = [
{"source": "rag_doc", "content": rag_context}
{"source": "rag_retrieval", "content": rag_context}
]
state.rag_retrieved = True
state.success = True
state.debug_info["rag_source"] = "rag_tool"
return state
except Exception as e:
raise RuntimeError(f"RAG 调用失败: {str(e)}") from e
else:
# 没有 RAG 工具,使用模拟数据(演示用)
state.rag_context = (
f"[RAG 检索结果]\n"
f"查询: {retrieval_query}\n"
f"这是来自知识库的相关信息。"
)
raise RuntimeError(f"RAG 工具调用失败: {str(e)}") from e
elif _GLOBAL_RAG_PIPELINE:
# 使用 RAG Pipeline 直接检索
try:
documents = _GLOBAL_RAG_PIPELINE.retrieve(retrieval_query)
if documents:
rag_context = _GLOBAL_RAG_PIPELINE.format_context(documents)
state.rag_context = rag_context
state.rag_docs = [
{"source": "doc1.txt", "content": "LangGraph 是一个用于构建 Agent 的框架"},
{"source": "doc2.txt", "content": "React 模式是 '思考→行动→观察' 循环"}
{"source": doc.metadata.get("source", "unknown"), "content": doc.page_content}
for doc in documents
]
else:
state.rag_context = f"未找到与 '{retrieval_query}' 相关的知识库信息。"
state.rag_docs = []
state.rag_retrieved = True
state.success = True
state.debug_info["rag_source"] = "rag_pipeline"
return state
except Exception as e:
raise RuntimeError(f"RAG Pipeline 调用失败: {str(e)}") from e
else:
# 没有可用的 RAG 工具/Pipeline
raise RuntimeError("RAG 工具未初始化,请先调用 set_global_rag_tool() 或 set_global_rag_pipeline()")
# ========== RAG 检索节点(带超时和重试) ==========
# ========== RAG 检索节点(带超时和重试)==========
def rag_retrieve_node(state: MainGraphState) -> MainGraphState:
"""
RAG 检索节点:带超时和重试
RAG 检索节点:带超时和重试,真正利用已有 RAG 代码
Args:
state: 主图状态
@@ -144,7 +209,9 @@ def rag_retrieve_node(state: MainGraphState) -> MainGraphState:
context={
"query": state.user_query,
"total_time": time.time() - start_time,
"timeout": RAG_RETRY_CONFIG.timeout
"timeout": RAG_RETRY_CONFIG.timeout,
"has_rag_tool": get_global_rag_tool() is not None,
"has_rag_pipeline": _GLOBAL_RAG_PIPELINE is not None
}
)
@@ -168,37 +235,60 @@ def rag_re_retrieve_node(state: MainGraphState) -> MainGraphState:
"""
state.current_phase = "rag_re_retrieving"
# 可以在这里修改检索参数(例如:扩大范围、调整查询)
# 记录原始检索信息
state.debug_info["rag_re_retrieve"] = {
"original_retrieved": state.rag_retrieved,
"original_docs_count": len(state.rag_docs)
}
# 使用相同的检索逻辑
# 可以在这里修改检索参数(例如:调整查询、增加 k 值)
# 暂时复用同一个检索逻辑
return rag_retrieve_node(state)
# ========== 工具:将 RAG 工具注入到状态 ==========
def inject_rag_tool_to_state(state: MainGraphState, rag_tool: Any) -> MainGraphState:
# ========== 便捷函数:从 rag_initializer 初始化 ==========
async def initialize_rag_from_initializer() -> None:
"""
将 RAG 工具注入到状态中,供后续节点使用
从 rag_initializer 初始化 RAG便捷函数
Args:
state: 主图状态
rag_tool: RAG 工具实例
Returns:
更新后的状态
注意:这是示例代码,实际使用时需要提供 local_llm_creator
"""
state.debug_info["rag_tool"] = rag_tool
state.debug_info["rag_tool_injected"] = datetime.now().isoformat()
return state
try:
from ..agent.rag_initializer import init_rag_tool
# 注意:这里需要传入 local_llm_creator
# 示例:
# def my_llm_creator():
# from ..model_services import get_llm
# return get_llm()
#
# rag_tool = await init_rag_tool(my_llm_creator)
# set_global_rag_tool(rag_tool)
print("⚠️ initialize_rag_from_initializer 需要传入 local_llm_creator")
print("⚠️ 请在应用启动时调用 init_rag_tool() 并设置全局 RAG 工具")
except ImportError as e:
print(f"⚠️ 无法导入 rag_initializer: {e}")
except Exception as e:
print(f"⚠️ RAG 初始化失败: {e}")
# ========== 导出 ==========
__all__ = [
# 节点函数
"rag_retrieve_node",
"rag_re_retrieve_node",
# 工具函数
"inject_rag_tool_to_state",
"get_rag_tool_from_state"
"get_rag_tool_from_state",
# 全局 RAG 管理
"get_global_rag_tool",
"set_global_rag_tool",
"set_global_rag_pipeline",
# 初始化
"initialize_rag_from_initializer"
]