Files
ailine/backend/app/main_graph/nodes/rag_nodes.py
2026-05-05 00:54:04 +08:00

223 lines
7.3 KiB
Python

"""
RAG 检索节点模块
包含 RAG 检索节点(带超时重试)
"""
import time
import asyncio
from typing import Dict, Any, Optional
from datetime import datetime
from app.main_graph.state import MainGraphState, ErrorRecord, ErrorSeverity
from app.main_graph.utils.retry_utils import RAG_RETRY_CONFIG
from app.logger import info
from ._utils import dispatch_custom_event, make_react_event
from app.rag.tools import create_rag_tool
from app.rag.pipeline import RAGPipeline
# ========== 全局 RAG 工具实例 ==========
_GLOBAL_RAG_TOOL: Optional[Any] = None
_GLOBAL_RAG_PIPELINE: Optional[RAGPipeline] = None
def get_global_rag_tool() -> Optional[Any]:
return _GLOBAL_RAG_TOOL
def set_global_rag_tool(tool: Any) -> None:
global _GLOBAL_RAG_TOOL
_GLOBAL_RAG_TOOL = tool
def set_global_rag_pipeline(pipeline: RAGPipeline) -> None:
global _GLOBAL_RAG_PIPELINE
_GLOBAL_RAG_PIPELINE = pipeline
def get_rag_tool_from_state(state: MainGraphState) -> Optional[Any]:
"""从状态或全局获取 RAG 工具"""
return state.debug_info.get("rag_tool") or get_global_rag_tool()
def inject_rag_tool_to_state(state: MainGraphState, rag_tool: Any) -> MainGraphState:
"""将 RAG 工具注入到状态中"""
state.debug_info["rag_tool"] = rag_tool
state.debug_info["rag_tool_injected"] = datetime.now().isoformat()
return state
# ========== RAG 检索核心逻辑 ==========
async def _rag_retrieve_core(state: MainGraphState) -> MainGraphState:
"""执行 RAG 检索的核心逻辑"""
retrieval_query = state.user_query
# 优先使用推理结果中的优化查询
reasoning_result = state.debug_info.get("reasoning_result")
if reasoning_result and hasattr(reasoning_result, "retrieval_config"):
cfg = reasoning_result.retrieval_config
if cfg and cfg.retrieval_query:
retrieval_query = cfg.retrieval_query
rag_tool = get_rag_tool_from_state(state)
if rag_tool:
rag_context = await rag_tool.ainvoke(retrieval_query)
state.rag_context = rag_context
state.rag_docs = [{"source": "rag_retrieval", "content": rag_context}]
state.rag_retrieved = True
state.success = True
state.debug_info["rag_source"] = "rag_tool"
return state
if _GLOBAL_RAG_PIPELINE:
documents = await _GLOBAL_RAG_PIPELINE.aretrieve(retrieval_query)
if documents:
rag_context = _GLOBAL_RAG_PIPELINE.format_context(documents)
state.rag_context = rag_context
state.rag_docs = [
{"source": doc.metadata.get("source", "unknown"), "content": doc.page_content}
for doc in documents
]
else:
state.rag_context = f"未找到与 '{retrieval_query}' 相关的知识库信息。"
state.rag_docs = []
state.rag_retrieved = True
state.success = True
state.debug_info["rag_source"] = "rag_pipeline"
return state
raise RuntimeError("RAG 工具未初始化,请先调用 set_global_rag_tool() 或 set_global_rag_pipeline()")
# ========== RAG 检索节点 ==========
async def rag_retrieve_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState:
"""RAG 检索节点:带超时和重试"""
state.current_phase = "rag_retrieving"
start_time = time.time()
last_error = None
# 步骤1: 发送开始事件
await dispatch_custom_event(
"react_reasoning",
make_react_event(state.reasoning_step, "rag_retrieve_start", 1.0, "开始执行 RAG 检索..."),
config
)
# 步骤2: 执行检索(带重试)
for attempt in range(RAG_RETRY_CONFIG.max_retries + 1):
try:
result = await _rag_retrieve_core(state)
info(f"[RAG] 检索成功,上下文长度: {len(result.rag_context)} 字符")
state.debug_info["rag_retrieval"] = {
"attempt": attempt + 1,
"success": True,
"time": time.time() - start_time
}
# 记录成功到历史
state.reasoning_history.append({
"step": state.reasoning_step,
"action": "RETRIEVE_RAG",
"confidence": 1.0,
"reasoning": "RAG 检索完成",
"timestamp": datetime.now().isoformat()
})
# 发送完成事件
doc_count = len(result.rag_docs) if result.rag_docs else 0
await dispatch_custom_event(
"react_reasoning",
make_react_event(state.reasoning_step, "rag_retrieve_complete", 1.0,
f"RAG 检索完成,找到 {doc_count} 条相关文档"),
config
)
return result
except Exception as e:
last_error = e
if attempt >= RAG_RETRY_CONFIG.max_retries:
break
# 发送重试事件
await dispatch_custom_event(
"react_reasoning",
make_react_event(state.reasoning_step, "rag_retrieve_retry", 1.0,
f"RAG 检索失败,第 {attempt + 1} 次重试..."),
config
)
# 指数退避
delay = RAG_RETRY_CONFIG.base_delay * (2 ** attempt)
await asyncio.sleep(min(delay, RAG_RETRY_CONFIG.max_delay))
# 步骤3: 所有重试失败,记录到历史(避免推理循环)
state.reasoning_history.append({
"step": state.reasoning_step,
"action": "RETRIEVE_RAG",
"confidence": 0.0,
"reasoning": f"RAG 检索失败: {str(last_error) if last_error else '超时'}",
"timestamp": datetime.now().isoformat()
})
# 步骤4: 记录错误
error_record = ErrorRecord(
error_type="RAGRetrievalError",
error_message=str(last_error) if last_error else "RAG 检索超时",
severity=ErrorSeverity.WARNING,
source="rag_retrieve_node",
timestamp=datetime.now().isoformat(),
retry_count=RAG_RETRY_CONFIG.max_retries,
max_retries=RAG_RETRY_CONFIG.max_retries,
context={
"query": state.user_query,
"total_time": time.time() - start_time,
"has_rag_tool": get_global_rag_tool() is not None,
"has_rag_pipeline": _GLOBAL_RAG_PIPELINE is not None
}
)
state.errors.append(error_record)
state.current_error = error_record
state.current_phase = "error_handling"
# 发送错误事件
await dispatch_custom_event(
"react_reasoning",
make_react_event(state.reasoning_step, "rag_retrieve_error", 1.0,
f"RAG 检索失败: {str(last_error)}"),
config
)
return state
# ========== 重新检索节点 ==========
async def rag_re_retrieve_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState:
"""重新检索节点"""
state.current_phase = "rag_re_retrieving"
state.debug_info["rag_re_retrieve"] = {
"original_retrieved": state.rag_retrieved,
"original_docs_count": len(state.rag_docs)
}
return await rag_retrieve_node(state, config)
# ========== 导出 ==========
__all__ = [
"rag_retrieve_node",
"rag_re_retrieve_node",
"inject_rag_tool_to_state",
"get_rag_tool_from_state",
"get_global_rag_tool",
"set_global_rag_tool",
"set_global_rag_pipeline",
]