Files
ailine/backend/app/graph/rag_nodes.py
root 5a67a77c95
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m3s
refactor: 真正利用已有 RAG 代码重构 rag_nodes.py
- 真正导入和使用 backend/app/rag/tools.py
- 添加全局 RAG 工具管理(get/set_global_rag_tool)
- 集成 RAGPipeline,支持多路查询和重排序
- 兼容 rag_initializer.py 的初始化方式
- 移除模拟实现,使用真正的 RAG 功能
2026-04-26 11:25:01 +08:00

295 lines
8.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
RAG 节点模块 - 真正利用已有 RAG 代码
包含:
- rag_retrieve_node: RAG 检索节点(带超时重试)
- rag_re_retrieve_node: 重新检索节点
- 集成 backend/app/rag/tools.py 和 rag_initializer.py
"""
import time
import asyncio
from typing import Dict, Any, Optional
from datetime import datetime
from .state import MainGraphState, ErrorRecord, ErrorSeverity
from .retry_utils import (
RetryConfig,
RAG_RETRY_CONFIG,
create_retry_wrapper_for_node
)
# 真正导入和利用已有 RAG 代码
from ..rag.tools import create_rag_tool_sync
from ..rag.pipeline import RAGPipeline
# ========== 全局 RAG 工具实例(延迟初始化)==========
_GLOBAL_RAG_TOOL: Optional[Any] = None
_GLOBAL_RAG_PIPELINE: Optional[RAGPipeline] = None
def get_global_rag_tool() -> Optional[Any]:
"""
获取全局 RAG 工具(单例模式)
Returns:
RAG 工具实例或 None
"""
return _GLOBAL_RAG_TOOL
def set_global_rag_tool(tool: Any) -> None:
"""
设置全局 RAG 工具(通常在应用启动时调用)
Args:
tool: RAG 工具实例
"""
global _GLOBAL_RAG_TOOL
_GLOBAL_RAG_TOOL = tool
def set_global_rag_pipeline(pipeline: RAGPipeline) -> None:
"""
设置全局 RAG Pipeline
Args:
pipeline: RAGPipeline 实例
"""
global _GLOBAL_RAG_PIPELINE
_GLOBAL_RAG_PIPELINE = pipeline
# ========== 从状态获取 RAG 工具 ==========
def get_rag_tool_from_state(state: MainGraphState) -> Optional[Any]:
"""
从状态中获取 RAG 工具(如果有)
Args:
state: 主图状态
Returns:
RAG 工具实例或 None
"""
# 优先从状态获取
if "rag_tool" in state.debug_info:
return state.debug_info["rag_tool"]
# 其次从全局获取
return get_global_rag_tool()
# ========== 工具:将 RAG 工具注入到状态 ==========
def inject_rag_tool_to_state(state: MainGraphState, rag_tool: Any) -> MainGraphState:
"""
将 RAG 工具注入到状态中,供后续节点使用
Args:
state: 主图状态
rag_tool: RAG 工具实例
Returns:
更新后的状态
"""
state.debug_info["rag_tool"] = rag_tool
state.debug_info["rag_tool_injected"] = datetime.now().isoformat()
return state
# ========== RAG 检索核心逻辑(真正利用已有代码)==========
def _rag_retrieve_core(state: MainGraphState) -> MainGraphState:
"""
RAG 检索核心逻辑(真正利用 rag/tools.py
Args:
state: 主图状态
Returns:
更新后的状态
"""
# 获取检索查询(优先使用推理结果中的优化查询)
retrieval_query = state.user_query
if "reasoning_result" in state.debug_info:
reasoning_result = state.debug_info["reasoning_result"]
if hasattr(reasoning_result, "retrieval_config"):
cfg = reasoning_result.retrieval_config
if cfg and cfg.retrieval_query:
retrieval_query = cfg.retrieval_query
# 尝试获取 RAG 工具(多种方式)
rag_tool = get_rag_tool_from_state(state)
if rag_tool:
# 使用真正的 RAG 工具(来自 rag/tools.py
try:
# 调用 LangChain Tool 的 invoke 方法
rag_context = rag_tool.invoke(retrieval_query)
state.rag_context = rag_context
state.rag_docs = [
{"source": "rag_retrieval", "content": rag_context}
]
state.rag_retrieved = True
state.success = True
state.debug_info["rag_source"] = "rag_tool"
return state
except Exception as e:
raise RuntimeError(f"RAG 工具调用失败: {str(e)}") from e
elif _GLOBAL_RAG_PIPELINE:
# 使用 RAG Pipeline 直接检索
try:
documents = _GLOBAL_RAG_PIPELINE.retrieve(retrieval_query)
if documents:
rag_context = _GLOBAL_RAG_PIPELINE.format_context(documents)
state.rag_context = rag_context
state.rag_docs = [
{"source": doc.metadata.get("source", "unknown"), "content": doc.page_content}
for doc in documents
]
else:
state.rag_context = f"未找到与 '{retrieval_query}' 相关的知识库信息。"
state.rag_docs = []
state.rag_retrieved = True
state.success = True
state.debug_info["rag_source"] = "rag_pipeline"
return state
except Exception as e:
raise RuntimeError(f"RAG Pipeline 调用失败: {str(e)}") from e
else:
# 没有可用的 RAG 工具/Pipeline
raise RuntimeError("RAG 工具未初始化,请先调用 set_global_rag_tool() 或 set_global_rag_pipeline()")
# ========== RAG 检索节点(带超时和重试)==========
def rag_retrieve_node(state: MainGraphState) -> MainGraphState:
"""
RAG 检索节点:带超时和重试,真正利用已有 RAG 代码
Args:
state: 主图状态
Returns:
更新后的状态
"""
state.current_phase = "rag_retrieving"
start_time = time.time()
last_error = None
for attempt in range(RAG_RETRY_CONFIG.max_retries + 1):
try:
# 执行核心逻辑
result = _rag_retrieve_core(state)
# 成功
state.debug_info["rag_retrieval"] = {
"attempt": attempt + 1,
"success": True,
"time": time.time() - start_time
}
return result
except Exception as e:
last_error = e
if attempt >= RAG_RETRY_CONFIG.max_retries:
break
# 指数退避等待
delay = RAG_RETRY_CONFIG.base_delay * (2 ** attempt)
time.sleep(min(delay, RAG_RETRY_CONFIG.max_delay))
# 所有重试都失败,记录结构化错误
error_record = ErrorRecord(
error_type="RAGRetrievalError",
error_message=str(last_error) if last_error else "RAG 检索超时",
severity=ErrorSeverity.WARNING,
source="rag_retrieve_node",
timestamp=datetime.now().isoformat(),
retry_count=RAG_RETRY_CONFIG.max_retries,
max_retries=RAG_RETRY_CONFIG.max_retries,
context={
"query": state.user_query,
"total_time": time.time() - start_time,
"timeout": RAG_RETRY_CONFIG.timeout,
"has_rag_tool": get_global_rag_tool() is not None,
"has_rag_pipeline": _GLOBAL_RAG_PIPELINE is not None
}
)
state.errors.append(error_record)
state.current_error = error_record
state.current_phase = "error_handling"
return state
# ========== 重新检索节点 ==========
def rag_re_retrieve_node(state: MainGraphState) -> MainGraphState:
"""
重新检索节点:用于第二次检索(不同的参数)
Args:
state: 主图状态
Returns:
更新后的状态
"""
state.current_phase = "rag_re_retrieving"
# 记录原始检索信息
state.debug_info["rag_re_retrieve"] = {
"original_retrieved": state.rag_retrieved,
"original_docs_count": len(state.rag_docs)
}
# 可以在这里修改检索参数(例如:调整查询、增加 k 值)
# 暂时复用同一个检索逻辑
return rag_retrieve_node(state)
# ========== 便捷函数:从 rag_initializer 初始化 ==========
async def initialize_rag_from_initializer() -> None:
"""
从 rag_initializer 初始化 RAG便捷函数
注意:这是示例代码,实际使用时需要提供 local_llm_creator
"""
try:
from ..agent.rag_initializer import init_rag_tool
# 注意:这里需要传入 local_llm_creator
# 示例:
# def my_llm_creator():
# from ..model_services import get_llm
# return get_llm()
#
# rag_tool = await init_rag_tool(my_llm_creator)
# set_global_rag_tool(rag_tool)
print("⚠️ initialize_rag_from_initializer 需要传入 local_llm_creator")
print("⚠️ 请在应用启动时调用 init_rag_tool() 并设置全局 RAG 工具")
except ImportError as e:
print(f"⚠️ 无法导入 rag_initializer: {e}")
except Exception as e:
print(f"⚠️ RAG 初始化失败: {e}")
# ========== 导出 ==========
__all__ = [
# 节点函数
"rag_retrieve_node",
"rag_re_retrieve_node",
# 工具函数
"inject_rag_tool_to_state",
"get_rag_tool_from_state",
# 全局 RAG 管理
"get_global_rag_tool",
"set_global_rag_tool",
"set_global_rag_pipeline",
# 初始化
"initialize_rag_from_initializer"
]