Files
ailine/backend/app/graph/rag_nodes.py

295 lines
8.8 KiB
Python
Raw Normal View History

"""
RAG 节点模块 - 真正利用已有 RAG 代码
包含
- rag_retrieve_node: RAG 检索节点带超时重试
- rag_re_retrieve_node: 重新检索节点
- 集成 backend/app/rag/tools.py rag_initializer.py
"""
import time
import asyncio
from typing import Dict, Any, Optional
from datetime import datetime
from .state import MainGraphState, ErrorRecord, ErrorSeverity
from .retry_utils import (
RetryConfig,
RAG_RETRY_CONFIG,
create_retry_wrapper_for_node
)
# 真正导入和利用已有 RAG 代码
from ..rag.tools import create_rag_tool_sync
from ..rag.pipeline import RAGPipeline
# ========== 全局 RAG 工具实例(延迟初始化)==========
_GLOBAL_RAG_TOOL: Optional[Any] = None
_GLOBAL_RAG_PIPELINE: Optional[RAGPipeline] = None
def get_global_rag_tool() -> Optional[Any]:
"""
获取全局 RAG 工具单例模式
Returns:
RAG 工具实例或 None
"""
return _GLOBAL_RAG_TOOL
def set_global_rag_tool(tool: Any) -> None:
"""
设置全局 RAG 工具通常在应用启动时调用
Args:
tool: RAG 工具实例
"""
global _GLOBAL_RAG_TOOL
_GLOBAL_RAG_TOOL = tool
def set_global_rag_pipeline(pipeline: RAGPipeline) -> None:
"""
设置全局 RAG Pipeline
Args:
pipeline: RAGPipeline 实例
"""
global _GLOBAL_RAG_PIPELINE
_GLOBAL_RAG_PIPELINE = pipeline
# ========== 从状态获取 RAG 工具 ==========
def get_rag_tool_from_state(state: MainGraphState) -> Optional[Any]:
"""
从状态中获取 RAG 工具如果有
Args:
state: 主图状态
Returns:
RAG 工具实例或 None
"""
# 优先从状态获取
if "rag_tool" in state.debug_info:
return state.debug_info["rag_tool"]
# 其次从全局获取
return get_global_rag_tool()
# ========== 工具:将 RAG 工具注入到状态 ==========
def inject_rag_tool_to_state(state: MainGraphState, rag_tool: Any) -> MainGraphState:
"""
RAG 工具注入到状态中供后续节点使用
Args:
state: 主图状态
rag_tool: RAG 工具实例
Returns:
更新后的状态
"""
state.debug_info["rag_tool"] = rag_tool
state.debug_info["rag_tool_injected"] = datetime.now().isoformat()
return state
# ========== RAG 检索核心逻辑(真正利用已有代码)==========
def _rag_retrieve_core(state: MainGraphState) -> MainGraphState:
"""
RAG 检索核心逻辑真正利用 rag/tools.py
Args:
state: 主图状态
Returns:
更新后的状态
"""
# 获取检索查询(优先使用推理结果中的优化查询)
retrieval_query = state.user_query
if "reasoning_result" in state.debug_info:
reasoning_result = state.debug_info["reasoning_result"]
if hasattr(reasoning_result, "retrieval_config"):
cfg = reasoning_result.retrieval_config
if cfg and cfg.retrieval_query:
retrieval_query = cfg.retrieval_query
# 尝试获取 RAG 工具(多种方式)
rag_tool = get_rag_tool_from_state(state)
if rag_tool:
# 使用真正的 RAG 工具(来自 rag/tools.py
try:
# 调用 LangChain Tool 的 invoke 方法
rag_context = rag_tool.invoke(retrieval_query)
state.rag_context = rag_context
state.rag_docs = [
{"source": "rag_retrieval", "content": rag_context}
]
state.rag_retrieved = True
state.success = True
state.debug_info["rag_source"] = "rag_tool"
return state
except Exception as e:
raise RuntimeError(f"RAG 工具调用失败: {str(e)}") from e
elif _GLOBAL_RAG_PIPELINE:
# 使用 RAG Pipeline 直接检索
try:
documents = _GLOBAL_RAG_PIPELINE.retrieve(retrieval_query)
if documents:
rag_context = _GLOBAL_RAG_PIPELINE.format_context(documents)
state.rag_context = rag_context
state.rag_docs = [
{"source": doc.metadata.get("source", "unknown"), "content": doc.page_content}
for doc in documents
]
else:
state.rag_context = f"未找到与 '{retrieval_query}' 相关的知识库信息。"
state.rag_docs = []
state.rag_retrieved = True
state.success = True
state.debug_info["rag_source"] = "rag_pipeline"
return state
except Exception as e:
raise RuntimeError(f"RAG Pipeline 调用失败: {str(e)}") from e
else:
# 没有可用的 RAG 工具/Pipeline
raise RuntimeError("RAG 工具未初始化,请先调用 set_global_rag_tool() 或 set_global_rag_pipeline()")
# ========== RAG 检索节点(带超时和重试)==========
def rag_retrieve_node(state: MainGraphState) -> MainGraphState:
"""
RAG 检索节点带超时和重试真正利用已有 RAG 代码
Args:
state: 主图状态
Returns:
更新后的状态
"""
state.current_phase = "rag_retrieving"
start_time = time.time()
last_error = None
for attempt in range(RAG_RETRY_CONFIG.max_retries + 1):
try:
# 执行核心逻辑
result = _rag_retrieve_core(state)
# 成功
state.debug_info["rag_retrieval"] = {
"attempt": attempt + 1,
"success": True,
"time": time.time() - start_time
}
return result
except Exception as e:
last_error = e
if attempt >= RAG_RETRY_CONFIG.max_retries:
break
# 指数退避等待
delay = RAG_RETRY_CONFIG.base_delay * (2 ** attempt)
time.sleep(min(delay, RAG_RETRY_CONFIG.max_delay))
# 所有重试都失败,记录结构化错误
error_record = ErrorRecord(
error_type="RAGRetrievalError",
error_message=str(last_error) if last_error else "RAG 检索超时",
severity=ErrorSeverity.WARNING,
source="rag_retrieve_node",
timestamp=datetime.now().isoformat(),
retry_count=RAG_RETRY_CONFIG.max_retries,
max_retries=RAG_RETRY_CONFIG.max_retries,
context={
"query": state.user_query,
"total_time": time.time() - start_time,
"timeout": RAG_RETRY_CONFIG.timeout,
"has_rag_tool": get_global_rag_tool() is not None,
"has_rag_pipeline": _GLOBAL_RAG_PIPELINE is not None
}
)
state.errors.append(error_record)
state.current_error = error_record
state.current_phase = "error_handling"
return state
# ========== 重新检索节点 ==========
def rag_re_retrieve_node(state: MainGraphState) -> MainGraphState:
"""
重新检索节点用于第二次检索不同的参数
Args:
state: 主图状态
Returns:
更新后的状态
"""
state.current_phase = "rag_re_retrieving"
# 记录原始检索信息
state.debug_info["rag_re_retrieve"] = {
"original_retrieved": state.rag_retrieved,
"original_docs_count": len(state.rag_docs)
}
# 可以在这里修改检索参数(例如:调整查询、增加 k 值)
# 暂时复用同一个检索逻辑
return rag_retrieve_node(state)
# ========== 便捷函数:从 rag_initializer 初始化 ==========
async def initialize_rag_from_initializer() -> None:
"""
rag_initializer 初始化 RAG便捷函数
注意这是示例代码实际使用时需要提供 local_llm_creator
"""
try:
from ..agent.rag_initializer import init_rag_tool
# 注意:这里需要传入 local_llm_creator
# 示例:
# def my_llm_creator():
# from ..model_services import get_llm
# return get_llm()
#
# rag_tool = await init_rag_tool(my_llm_creator)
# set_global_rag_tool(rag_tool)
print("⚠️ initialize_rag_from_initializer 需要传入 local_llm_creator")
print("⚠️ 请在应用启动时调用 init_rag_tool() 并设置全局 RAG 工具")
except ImportError as e:
print(f"⚠️ 无法导入 rag_initializer: {e}")
except Exception as e:
print(f"⚠️ RAG 初始化失败: {e}")
# ========== 导出 ==========
__all__ = [
# 节点函数
"rag_retrieve_node",
"rag_re_retrieve_node",
# 工具函数
"inject_rag_tool_to_state",
"get_rag_tool_from_state",
# 全局 RAG 管理
"get_global_rag_tool",
"set_global_rag_tool",
"set_global_rag_pipeline",
# 初始化
"initialize_rag_from_initializer"
]