Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Has been cancelled
- 新增 rag_nodes.py: 独立的 RAG 检索节点 - 从 react_nodes.py 移除 RAG 相关代码 - 更新导入和导出 - rag_nodes.py 包含 rag_retrieve_node 和 rag_re_retrieve_node - 添加 inject_rag_tool_to_state 工具函数
205 lines
5.8 KiB
Python
205 lines
5.8 KiB
Python
"""
|
|
RAG 节点模块 - 独立的 RAG 检索节点
|
|
包含:
|
|
- rag_retrieve_node: RAG 检索节点(带超时重试)
|
|
- rag_re_retrieve_node: 重新检索节点
|
|
- 相关的 RAG 工具集成
|
|
"""
|
|
|
|
import time
|
|
from typing import Dict, Any, Optional
|
|
from datetime import datetime
|
|
|
|
from .state import MainGraphState, ErrorRecord, ErrorSeverity
|
|
from .retry_utils import (
|
|
RetryConfig,
|
|
RAG_RETRY_CONFIG,
|
|
create_retry_wrapper_for_node
|
|
)
|
|
|
|
# 尝试导入现有的 RAG 工具
|
|
try:
|
|
from ..rag.tools import create_rag_tool_sync
|
|
from ..rag.pipeline import RAGPipeline
|
|
HAS_RAG = True
|
|
except ImportError:
|
|
HAS_RAG = False
|
|
|
|
|
|
def get_rag_tool_from_state(state: MainGraphState) -> Optional[Any]:
|
|
"""
|
|
从状态中获取 RAG 工具(如果有)
|
|
|
|
Args:
|
|
state: 主图状态
|
|
|
|
Returns:
|
|
RAG 工具实例或 None
|
|
"""
|
|
if "rag_tool" in state.debug_info:
|
|
return state.debug_info["rag_tool"]
|
|
return None
|
|
|
|
|
|
# ========== RAG 检索核心逻辑 ==========
|
|
def _rag_retrieve_core(state: MainGraphState) -> MainGraphState:
|
|
"""
|
|
RAG 检索核心逻辑(不带重试)
|
|
|
|
Args:
|
|
state: 主图状态
|
|
|
|
Returns:
|
|
更新后的状态
|
|
"""
|
|
# 获取检索查询(优先使用推理结果中的优化查询)
|
|
retrieval_query = state.user_query
|
|
if "reasoning_result" in state.debug_info:
|
|
reasoning_result = state.debug_info["reasoning_result"]
|
|
if hasattr(reasoning_result, "retrieval_config"):
|
|
cfg = reasoning_result.retrieval_config
|
|
if cfg and cfg.retrieval_query:
|
|
retrieval_query = cfg.retrieval_query
|
|
|
|
# 尝试获取 RAG 工具
|
|
rag_tool = get_rag_tool_from_state(state)
|
|
|
|
if rag_tool and HAS_RAG:
|
|
# 使用真实的 RAG 工具
|
|
try:
|
|
rag_context = rag_tool.invoke(retrieval_query)
|
|
state.rag_context = rag_context
|
|
state.rag_docs = [
|
|
{"source": "rag_doc", "content": rag_context}
|
|
]
|
|
state.rag_retrieved = True
|
|
state.success = True
|
|
return state
|
|
except Exception as e:
|
|
raise RuntimeError(f"RAG 调用失败: {str(e)}") from e
|
|
else:
|
|
# 没有 RAG 工具,使用模拟数据(演示用)
|
|
state.rag_context = (
|
|
f"[RAG 检索结果]\n"
|
|
f"查询: {retrieval_query}\n"
|
|
f"这是来自知识库的相关信息。"
|
|
)
|
|
state.rag_docs = [
|
|
{"source": "doc1.txt", "content": "LangGraph 是一个用于构建 Agent 的框架"},
|
|
{"source": "doc2.txt", "content": "React 模式是 '思考→行动→观察' 循环"}
|
|
]
|
|
state.rag_retrieved = True
|
|
state.success = True
|
|
return state
|
|
|
|
|
|
# ========== RAG 检索节点(带超时和重试) ==========
|
|
def rag_retrieve_node(state: MainGraphState) -> MainGraphState:
|
|
"""
|
|
RAG 检索节点:带超时和重试
|
|
|
|
Args:
|
|
state: 主图状态
|
|
|
|
Returns:
|
|
更新后的状态
|
|
"""
|
|
state.current_phase = "rag_retrieving"
|
|
|
|
start_time = time.time()
|
|
last_error = None
|
|
|
|
for attempt in range(RAG_RETRY_CONFIG.max_retries + 1):
|
|
try:
|
|
# 执行核心逻辑
|
|
result = _rag_retrieve_core(state)
|
|
|
|
# 成功
|
|
state.debug_info["rag_retrieval"] = {
|
|
"attempt": attempt + 1,
|
|
"success": True,
|
|
"time": time.time() - start_time
|
|
}
|
|
return result
|
|
|
|
except Exception as e:
|
|
last_error = e
|
|
|
|
if attempt >= RAG_RETRY_CONFIG.max_retries:
|
|
break
|
|
|
|
# 指数退避等待
|
|
delay = RAG_RETRY_CONFIG.base_delay * (2 ** attempt)
|
|
time.sleep(min(delay, RAG_RETRY_CONFIG.max_delay))
|
|
|
|
# 所有重试都失败,记录结构化错误
|
|
error_record = ErrorRecord(
|
|
error_type="RAGRetrievalError",
|
|
error_message=str(last_error) if last_error else "RAG 检索超时",
|
|
severity=ErrorSeverity.WARNING,
|
|
source="rag_retrieve_node",
|
|
timestamp=datetime.now().isoformat(),
|
|
retry_count=RAG_RETRY_CONFIG.max_retries,
|
|
max_retries=RAG_RETRY_CONFIG.max_retries,
|
|
context={
|
|
"query": state.user_query,
|
|
"total_time": time.time() - start_time,
|
|
"timeout": RAG_RETRY_CONFIG.timeout
|
|
}
|
|
)
|
|
|
|
state.errors.append(error_record)
|
|
state.current_error = error_record
|
|
state.current_phase = "error_handling"
|
|
|
|
return state
|
|
|
|
|
|
# ========== 重新检索节点 ==========
|
|
def rag_re_retrieve_node(state: MainGraphState) -> MainGraphState:
|
|
"""
|
|
重新检索节点:用于第二次检索(不同的参数)
|
|
|
|
Args:
|
|
state: 主图状态
|
|
|
|
Returns:
|
|
更新后的状态
|
|
"""
|
|
state.current_phase = "rag_re_retrieving"
|
|
|
|
# 可以在这里修改检索参数(例如:扩大范围、调整查询)
|
|
state.debug_info["rag_re_retrieve"] = {
|
|
"original_retrieved": state.rag_retrieved,
|
|
"original_docs_count": len(state.rag_docs)
|
|
}
|
|
|
|
# 使用相同的检索逻辑
|
|
return rag_retrieve_node(state)
|
|
|
|
|
|
# ========== 工具:将 RAG 工具注入到状态 ==========
|
|
def inject_rag_tool_to_state(state: MainGraphState, rag_tool: Any) -> MainGraphState:
|
|
"""
|
|
将 RAG 工具注入到状态中,供后续节点使用
|
|
|
|
Args:
|
|
state: 主图状态
|
|
rag_tool: RAG 工具实例
|
|
|
|
Returns:
|
|
更新后的状态
|
|
"""
|
|
state.debug_info["rag_tool"] = rag_tool
|
|
state.debug_info["rag_tool_injected"] = datetime.now().isoformat()
|
|
return state
|
|
|
|
|
|
# ========== 导出 ==========
|
|
__all__ = [
|
|
"rag_retrieve_node",
|
|
"rag_re_retrieve_node",
|
|
"inject_rag_tool_to_state",
|
|
"get_rag_tool_from_state"
|
|
]
|