Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m3s
- 真正导入和使用 backend/app/rag/tools.py - 添加全局 RAG 工具管理(get/set_global_rag_tool) - 集成 RAGPipeline,支持多路查询和重排序 - 兼容 rag_initializer.py 的初始化方式 - 移除模拟实现,使用真正的 RAG 功能
295 lines
8.8 KiB
Python
295 lines
8.8 KiB
Python
"""
|
||
RAG 节点模块 - 真正利用已有 RAG 代码
|
||
包含:
|
||
- rag_retrieve_node: RAG 检索节点(带超时重试)
|
||
- rag_re_retrieve_node: 重新检索节点
|
||
- 集成 backend/app/rag/tools.py 和 rag_initializer.py
|
||
"""
|
||
|
||
import time
|
||
import asyncio
|
||
from typing import Dict, Any, Optional
|
||
from datetime import datetime
|
||
|
||
from .state import MainGraphState, ErrorRecord, ErrorSeverity
|
||
from .retry_utils import (
|
||
RetryConfig,
|
||
RAG_RETRY_CONFIG,
|
||
create_retry_wrapper_for_node
|
||
)
|
||
|
||
# 真正导入和利用已有 RAG 代码
|
||
from ..rag.tools import create_rag_tool_sync
|
||
from ..rag.pipeline import RAGPipeline
|
||
|
||
|
||
# ========== 全局 RAG 工具实例(延迟初始化)==========
|
||
_GLOBAL_RAG_TOOL: Optional[Any] = None
|
||
_GLOBAL_RAG_PIPELINE: Optional[RAGPipeline] = None
|
||
|
||
|
||
def get_global_rag_tool() -> Optional[Any]:
|
||
"""
|
||
获取全局 RAG 工具(单例模式)
|
||
|
||
Returns:
|
||
RAG 工具实例或 None
|
||
"""
|
||
return _GLOBAL_RAG_TOOL
|
||
|
||
|
||
def set_global_rag_tool(tool: Any) -> None:
|
||
"""
|
||
设置全局 RAG 工具(通常在应用启动时调用)
|
||
|
||
Args:
|
||
tool: RAG 工具实例
|
||
"""
|
||
global _GLOBAL_RAG_TOOL
|
||
_GLOBAL_RAG_TOOL = tool
|
||
|
||
|
||
def set_global_rag_pipeline(pipeline: RAGPipeline) -> None:
|
||
"""
|
||
设置全局 RAG Pipeline
|
||
|
||
Args:
|
||
pipeline: RAGPipeline 实例
|
||
"""
|
||
global _GLOBAL_RAG_PIPELINE
|
||
_GLOBAL_RAG_PIPELINE = pipeline
|
||
|
||
|
||
# ========== 从状态获取 RAG 工具 ==========
|
||
def get_rag_tool_from_state(state: MainGraphState) -> Optional[Any]:
|
||
"""
|
||
从状态中获取 RAG 工具(如果有)
|
||
|
||
Args:
|
||
state: 主图状态
|
||
|
||
Returns:
|
||
RAG 工具实例或 None
|
||
"""
|
||
# 优先从状态获取
|
||
if "rag_tool" in state.debug_info:
|
||
return state.debug_info["rag_tool"]
|
||
# 其次从全局获取
|
||
return get_global_rag_tool()
|
||
|
||
|
||
# ========== 工具:将 RAG 工具注入到状态 ==========
|
||
def inject_rag_tool_to_state(state: MainGraphState, rag_tool: Any) -> MainGraphState:
|
||
"""
|
||
将 RAG 工具注入到状态中,供后续节点使用
|
||
|
||
Args:
|
||
state: 主图状态
|
||
rag_tool: RAG 工具实例
|
||
|
||
Returns:
|
||
更新后的状态
|
||
"""
|
||
state.debug_info["rag_tool"] = rag_tool
|
||
state.debug_info["rag_tool_injected"] = datetime.now().isoformat()
|
||
return state
|
||
|
||
|
||
# ========== RAG 检索核心逻辑(真正利用已有代码)==========
|
||
def _rag_retrieve_core(state: MainGraphState) -> MainGraphState:
|
||
"""
|
||
RAG 检索核心逻辑(真正利用 rag/tools.py)
|
||
|
||
Args:
|
||
state: 主图状态
|
||
|
||
Returns:
|
||
更新后的状态
|
||
"""
|
||
# 获取检索查询(优先使用推理结果中的优化查询)
|
||
retrieval_query = state.user_query
|
||
if "reasoning_result" in state.debug_info:
|
||
reasoning_result = state.debug_info["reasoning_result"]
|
||
if hasattr(reasoning_result, "retrieval_config"):
|
||
cfg = reasoning_result.retrieval_config
|
||
if cfg and cfg.retrieval_query:
|
||
retrieval_query = cfg.retrieval_query
|
||
|
||
# 尝试获取 RAG 工具(多种方式)
|
||
rag_tool = get_rag_tool_from_state(state)
|
||
|
||
if rag_tool:
|
||
# 使用真正的 RAG 工具(来自 rag/tools.py)
|
||
try:
|
||
# 调用 LangChain Tool 的 invoke 方法
|
||
rag_context = rag_tool.invoke(retrieval_query)
|
||
state.rag_context = rag_context
|
||
state.rag_docs = [
|
||
{"source": "rag_retrieval", "content": rag_context}
|
||
]
|
||
state.rag_retrieved = True
|
||
state.success = True
|
||
state.debug_info["rag_source"] = "rag_tool"
|
||
return state
|
||
except Exception as e:
|
||
raise RuntimeError(f"RAG 工具调用失败: {str(e)}") from e
|
||
elif _GLOBAL_RAG_PIPELINE:
|
||
# 使用 RAG Pipeline 直接检索
|
||
try:
|
||
documents = _GLOBAL_RAG_PIPELINE.retrieve(retrieval_query)
|
||
if documents:
|
||
rag_context = _GLOBAL_RAG_PIPELINE.format_context(documents)
|
||
state.rag_context = rag_context
|
||
state.rag_docs = [
|
||
{"source": doc.metadata.get("source", "unknown"), "content": doc.page_content}
|
||
for doc in documents
|
||
]
|
||
else:
|
||
state.rag_context = f"未找到与 '{retrieval_query}' 相关的知识库信息。"
|
||
state.rag_docs = []
|
||
state.rag_retrieved = True
|
||
state.success = True
|
||
state.debug_info["rag_source"] = "rag_pipeline"
|
||
return state
|
||
except Exception as e:
|
||
raise RuntimeError(f"RAG Pipeline 调用失败: {str(e)}") from e
|
||
else:
|
||
# 没有可用的 RAG 工具/Pipeline
|
||
raise RuntimeError("RAG 工具未初始化,请先调用 set_global_rag_tool() 或 set_global_rag_pipeline()")
|
||
|
||
|
||
# ========== RAG 检索节点(带超时和重试)==========
|
||
def rag_retrieve_node(state: MainGraphState) -> MainGraphState:
|
||
"""
|
||
RAG 检索节点:带超时和重试,真正利用已有 RAG 代码
|
||
|
||
Args:
|
||
state: 主图状态
|
||
|
||
Returns:
|
||
更新后的状态
|
||
"""
|
||
state.current_phase = "rag_retrieving"
|
||
|
||
start_time = time.time()
|
||
last_error = None
|
||
|
||
for attempt in range(RAG_RETRY_CONFIG.max_retries + 1):
|
||
try:
|
||
# 执行核心逻辑
|
||
result = _rag_retrieve_core(state)
|
||
|
||
# 成功
|
||
state.debug_info["rag_retrieval"] = {
|
||
"attempt": attempt + 1,
|
||
"success": True,
|
||
"time": time.time() - start_time
|
||
}
|
||
return result
|
||
|
||
except Exception as e:
|
||
last_error = e
|
||
|
||
if attempt >= RAG_RETRY_CONFIG.max_retries:
|
||
break
|
||
|
||
# 指数退避等待
|
||
delay = RAG_RETRY_CONFIG.base_delay * (2 ** attempt)
|
||
time.sleep(min(delay, RAG_RETRY_CONFIG.max_delay))
|
||
|
||
# 所有重试都失败,记录结构化错误
|
||
error_record = ErrorRecord(
|
||
error_type="RAGRetrievalError",
|
||
error_message=str(last_error) if last_error else "RAG 检索超时",
|
||
severity=ErrorSeverity.WARNING,
|
||
source="rag_retrieve_node",
|
||
timestamp=datetime.now().isoformat(),
|
||
retry_count=RAG_RETRY_CONFIG.max_retries,
|
||
max_retries=RAG_RETRY_CONFIG.max_retries,
|
||
context={
|
||
"query": state.user_query,
|
||
"total_time": time.time() - start_time,
|
||
"timeout": RAG_RETRY_CONFIG.timeout,
|
||
"has_rag_tool": get_global_rag_tool() is not None,
|
||
"has_rag_pipeline": _GLOBAL_RAG_PIPELINE is not None
|
||
}
|
||
)
|
||
|
||
state.errors.append(error_record)
|
||
state.current_error = error_record
|
||
state.current_phase = "error_handling"
|
||
|
||
return state
|
||
|
||
|
||
# ========== 重新检索节点 ==========
|
||
def rag_re_retrieve_node(state: MainGraphState) -> MainGraphState:
|
||
"""
|
||
重新检索节点:用于第二次检索(不同的参数)
|
||
|
||
Args:
|
||
state: 主图状态
|
||
|
||
Returns:
|
||
更新后的状态
|
||
"""
|
||
state.current_phase = "rag_re_retrieving"
|
||
|
||
# 记录原始检索信息
|
||
state.debug_info["rag_re_retrieve"] = {
|
||
"original_retrieved": state.rag_retrieved,
|
||
"original_docs_count": len(state.rag_docs)
|
||
}
|
||
|
||
# 可以在这里修改检索参数(例如:调整查询、增加 k 值)
|
||
# 暂时复用同一个检索逻辑
|
||
return rag_retrieve_node(state)
|
||
|
||
|
||
# ========== 便捷函数:从 rag_initializer 初始化 ==========
|
||
async def initialize_rag_from_initializer() -> None:
|
||
"""
|
||
从 rag_initializer 初始化 RAG(便捷函数)
|
||
|
||
注意:这是示例代码,实际使用时需要提供 local_llm_creator
|
||
"""
|
||
try:
|
||
from ..agent.rag_initializer import init_rag_tool
|
||
|
||
# 注意:这里需要传入 local_llm_creator
|
||
# 示例:
|
||
# def my_llm_creator():
|
||
# from ..model_services import get_llm
|
||
# return get_llm()
|
||
#
|
||
# rag_tool = await init_rag_tool(my_llm_creator)
|
||
# set_global_rag_tool(rag_tool)
|
||
|
||
print("⚠️ initialize_rag_from_initializer 需要传入 local_llm_creator")
|
||
print("⚠️ 请在应用启动时调用 init_rag_tool() 并设置全局 RAG 工具")
|
||
|
||
except ImportError as e:
|
||
print(f"⚠️ 无法导入 rag_initializer: {e}")
|
||
except Exception as e:
|
||
print(f"⚠️ RAG 初始化失败: {e}")
|
||
|
||
|
||
# ========== 导出 ==========
|
||
__all__ = [
|
||
# 节点函数
|
||
"rag_retrieve_node",
|
||
"rag_re_retrieve_node",
|
||
|
||
# 工具函数
|
||
"inject_rag_tool_to_state",
|
||
"get_rag_tool_from_state",
|
||
|
||
# 全局 RAG 管理
|
||
"get_global_rag_tool",
|
||
"set_global_rag_tool",
|
||
"set_global_rag_pipeline",
|
||
|
||
# 初始化
|
||
"initialize_rag_from_initializer"
|
||
]
|