refactor: 真正利用已有 RAG 代码重构 rag_nodes.py
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m3s
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m3s
- 真正导入和使用 backend/app/rag/tools.py - 添加全局 RAG 工具管理(get/set_global_rag_tool) - 集成 RAGPipeline,支持多路查询和重排序 - 兼容 rag_initializer.py 的初始化方式 - 移除模拟实现,使用真正的 RAG 功能
This commit is contained in:
@@ -1,12 +1,13 @@
|
|||||||
"""
|
"""
|
||||||
RAG 节点模块 - 独立的 RAG 检索节点
|
RAG 节点模块 - 真正利用已有 RAG 代码
|
||||||
包含:
|
包含:
|
||||||
- rag_retrieve_node: RAG 检索节点(带超时重试)
|
- rag_retrieve_node: RAG 检索节点(带超时重试)
|
||||||
- rag_re_retrieve_node: 重新检索节点
|
- rag_re_retrieve_node: 重新检索节点
|
||||||
- 相关的 RAG 工具集成
|
- 集成 backend/app/rag/tools.py 和 rag_initializer.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import time
|
import time
|
||||||
|
import asyncio
|
||||||
from typing import Dict, Any, Optional
|
from typing import Dict, Any, Optional
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
@@ -17,15 +18,49 @@ from .retry_utils import (
|
|||||||
create_retry_wrapper_for_node
|
create_retry_wrapper_for_node
|
||||||
)
|
)
|
||||||
|
|
||||||
# 尝试导入现有的 RAG 工具
|
# 真正导入和利用已有 RAG 代码
|
||||||
try:
|
from ..rag.tools import create_rag_tool_sync
|
||||||
from ..rag.tools import create_rag_tool_sync
|
from ..rag.pipeline import RAGPipeline
|
||||||
from ..rag.pipeline import RAGPipeline
|
|
||||||
HAS_RAG = True
|
|
||||||
except ImportError:
|
|
||||||
HAS_RAG = False
|
|
||||||
|
|
||||||
|
|
||||||
|
# ========== 全局 RAG 工具实例(延迟初始化)==========
|
||||||
|
_GLOBAL_RAG_TOOL: Optional[Any] = None
|
||||||
|
_GLOBAL_RAG_PIPELINE: Optional[RAGPipeline] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_global_rag_tool() -> Optional[Any]:
|
||||||
|
"""
|
||||||
|
获取全局 RAG 工具(单例模式)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RAG 工具实例或 None
|
||||||
|
"""
|
||||||
|
return _GLOBAL_RAG_TOOL
|
||||||
|
|
||||||
|
|
||||||
|
def set_global_rag_tool(tool: Any) -> None:
|
||||||
|
"""
|
||||||
|
设置全局 RAG 工具(通常在应用启动时调用)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool: RAG 工具实例
|
||||||
|
"""
|
||||||
|
global _GLOBAL_RAG_TOOL
|
||||||
|
_GLOBAL_RAG_TOOL = tool
|
||||||
|
|
||||||
|
|
||||||
|
def set_global_rag_pipeline(pipeline: RAGPipeline) -> None:
|
||||||
|
"""
|
||||||
|
设置全局 RAG Pipeline
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pipeline: RAGPipeline 实例
|
||||||
|
"""
|
||||||
|
global _GLOBAL_RAG_PIPELINE
|
||||||
|
_GLOBAL_RAG_PIPELINE = pipeline
|
||||||
|
|
||||||
|
|
||||||
|
# ========== 从状态获取 RAG 工具 ==========
|
||||||
def get_rag_tool_from_state(state: MainGraphState) -> Optional[Any]:
|
def get_rag_tool_from_state(state: MainGraphState) -> Optional[Any]:
|
||||||
"""
|
"""
|
||||||
从状态中获取 RAG 工具(如果有)
|
从状态中获取 RAG 工具(如果有)
|
||||||
@@ -36,15 +71,34 @@ def get_rag_tool_from_state(state: MainGraphState) -> Optional[Any]:
|
|||||||
Returns:
|
Returns:
|
||||||
RAG 工具实例或 None
|
RAG 工具实例或 None
|
||||||
"""
|
"""
|
||||||
|
# 优先从状态获取
|
||||||
if "rag_tool" in state.debug_info:
|
if "rag_tool" in state.debug_info:
|
||||||
return state.debug_info["rag_tool"]
|
return state.debug_info["rag_tool"]
|
||||||
return None
|
# 其次从全局获取
|
||||||
|
return get_global_rag_tool()
|
||||||
|
|
||||||
|
|
||||||
# ========== RAG 检索核心逻辑 ==========
|
# ========== 工具:将 RAG 工具注入到状态 ==========
|
||||||
|
def inject_rag_tool_to_state(state: MainGraphState, rag_tool: Any) -> MainGraphState:
|
||||||
|
"""
|
||||||
|
将 RAG 工具注入到状态中,供后续节点使用
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: 主图状态
|
||||||
|
rag_tool: RAG 工具实例
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
更新后的状态
|
||||||
|
"""
|
||||||
|
state.debug_info["rag_tool"] = rag_tool
|
||||||
|
state.debug_info["rag_tool_injected"] = datetime.now().isoformat()
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
# ========== RAG 检索核心逻辑(真正利用已有代码)==========
|
||||||
def _rag_retrieve_core(state: MainGraphState) -> MainGraphState:
|
def _rag_retrieve_core(state: MainGraphState) -> MainGraphState:
|
||||||
"""
|
"""
|
||||||
RAG 检索核心逻辑(不带重试)
|
RAG 检索核心逻辑(真正利用 rag/tools.py)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
state: 主图状态
|
state: 主图状态
|
||||||
@@ -61,42 +115,53 @@ def _rag_retrieve_core(state: MainGraphState) -> MainGraphState:
|
|||||||
if cfg and cfg.retrieval_query:
|
if cfg and cfg.retrieval_query:
|
||||||
retrieval_query = cfg.retrieval_query
|
retrieval_query = cfg.retrieval_query
|
||||||
|
|
||||||
# 尝试获取 RAG 工具
|
# 尝试获取 RAG 工具(多种方式)
|
||||||
rag_tool = get_rag_tool_from_state(state)
|
rag_tool = get_rag_tool_from_state(state)
|
||||||
|
|
||||||
if rag_tool and HAS_RAG:
|
if rag_tool:
|
||||||
# 使用真实的 RAG 工具
|
# 使用真正的 RAG 工具(来自 rag/tools.py)
|
||||||
try:
|
try:
|
||||||
|
# 调用 LangChain Tool 的 invoke 方法
|
||||||
rag_context = rag_tool.invoke(retrieval_query)
|
rag_context = rag_tool.invoke(retrieval_query)
|
||||||
state.rag_context = rag_context
|
state.rag_context = rag_context
|
||||||
state.rag_docs = [
|
state.rag_docs = [
|
||||||
{"source": "rag_doc", "content": rag_context}
|
{"source": "rag_retrieval", "content": rag_context}
|
||||||
]
|
]
|
||||||
state.rag_retrieved = True
|
state.rag_retrieved = True
|
||||||
state.success = True
|
state.success = True
|
||||||
|
state.debug_info["rag_source"] = "rag_tool"
|
||||||
return state
|
return state
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"RAG 调用失败: {str(e)}") from e
|
raise RuntimeError(f"RAG 工具调用失败: {str(e)}") from e
|
||||||
|
elif _GLOBAL_RAG_PIPELINE:
|
||||||
|
# 使用 RAG Pipeline 直接检索
|
||||||
|
try:
|
||||||
|
documents = _GLOBAL_RAG_PIPELINE.retrieve(retrieval_query)
|
||||||
|
if documents:
|
||||||
|
rag_context = _GLOBAL_RAG_PIPELINE.format_context(documents)
|
||||||
|
state.rag_context = rag_context
|
||||||
|
state.rag_docs = [
|
||||||
|
{"source": doc.metadata.get("source", "unknown"), "content": doc.page_content}
|
||||||
|
for doc in documents
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
state.rag_context = f"未找到与 '{retrieval_query}' 相关的知识库信息。"
|
||||||
|
state.rag_docs = []
|
||||||
|
state.rag_retrieved = True
|
||||||
|
state.success = True
|
||||||
|
state.debug_info["rag_source"] = "rag_pipeline"
|
||||||
|
return state
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"RAG Pipeline 调用失败: {str(e)}") from e
|
||||||
else:
|
else:
|
||||||
# 没有 RAG 工具,使用模拟数据(演示用)
|
# 没有可用的 RAG 工具/Pipeline
|
||||||
state.rag_context = (
|
raise RuntimeError("RAG 工具未初始化,请先调用 set_global_rag_tool() 或 set_global_rag_pipeline()")
|
||||||
f"[RAG 检索结果]\n"
|
|
||||||
f"查询: {retrieval_query}\n"
|
|
||||||
f"这是来自知识库的相关信息。"
|
|
||||||
)
|
|
||||||
state.rag_docs = [
|
|
||||||
{"source": "doc1.txt", "content": "LangGraph 是一个用于构建 Agent 的框架"},
|
|
||||||
{"source": "doc2.txt", "content": "React 模式是 '思考→行动→观察' 循环"}
|
|
||||||
]
|
|
||||||
state.rag_retrieved = True
|
|
||||||
state.success = True
|
|
||||||
return state
|
|
||||||
|
|
||||||
|
|
||||||
# ========== RAG 检索节点(带超时和重试) ==========
|
# ========== RAG 检索节点(带超时和重试)==========
|
||||||
def rag_retrieve_node(state: MainGraphState) -> MainGraphState:
|
def rag_retrieve_node(state: MainGraphState) -> MainGraphState:
|
||||||
"""
|
"""
|
||||||
RAG 检索节点:带超时和重试
|
RAG 检索节点:带超时和重试,真正利用已有 RAG 代码
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
state: 主图状态
|
state: 主图状态
|
||||||
@@ -144,7 +209,9 @@ def rag_retrieve_node(state: MainGraphState) -> MainGraphState:
|
|||||||
context={
|
context={
|
||||||
"query": state.user_query,
|
"query": state.user_query,
|
||||||
"total_time": time.time() - start_time,
|
"total_time": time.time() - start_time,
|
||||||
"timeout": RAG_RETRY_CONFIG.timeout
|
"timeout": RAG_RETRY_CONFIG.timeout,
|
||||||
|
"has_rag_tool": get_global_rag_tool() is not None,
|
||||||
|
"has_rag_pipeline": _GLOBAL_RAG_PIPELINE is not None
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -168,37 +235,60 @@ def rag_re_retrieve_node(state: MainGraphState) -> MainGraphState:
|
|||||||
"""
|
"""
|
||||||
state.current_phase = "rag_re_retrieving"
|
state.current_phase = "rag_re_retrieving"
|
||||||
|
|
||||||
# 可以在这里修改检索参数(例如:扩大范围、调整查询)
|
# 记录原始检索信息
|
||||||
state.debug_info["rag_re_retrieve"] = {
|
state.debug_info["rag_re_retrieve"] = {
|
||||||
"original_retrieved": state.rag_retrieved,
|
"original_retrieved": state.rag_retrieved,
|
||||||
"original_docs_count": len(state.rag_docs)
|
"original_docs_count": len(state.rag_docs)
|
||||||
}
|
}
|
||||||
|
|
||||||
# 使用相同的检索逻辑
|
# 可以在这里修改检索参数(例如:调整查询、增加 k 值)
|
||||||
|
# 暂时复用同一个检索逻辑
|
||||||
return rag_retrieve_node(state)
|
return rag_retrieve_node(state)
|
||||||
|
|
||||||
|
|
||||||
# ========== 工具:将 RAG 工具注入到状态 ==========
|
# ========== 便捷函数:从 rag_initializer 初始化 ==========
|
||||||
def inject_rag_tool_to_state(state: MainGraphState, rag_tool: Any) -> MainGraphState:
|
async def initialize_rag_from_initializer() -> None:
|
||||||
"""
|
"""
|
||||||
将 RAG 工具注入到状态中,供后续节点使用
|
从 rag_initializer 初始化 RAG(便捷函数)
|
||||||
|
|
||||||
Args:
|
注意:这是示例代码,实际使用时需要提供 local_llm_creator
|
||||||
state: 主图状态
|
|
||||||
rag_tool: RAG 工具实例
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
更新后的状态
|
|
||||||
"""
|
"""
|
||||||
state.debug_info["rag_tool"] = rag_tool
|
try:
|
||||||
state.debug_info["rag_tool_injected"] = datetime.now().isoformat()
|
from ..agent.rag_initializer import init_rag_tool
|
||||||
return state
|
|
||||||
|
# 注意:这里需要传入 local_llm_creator
|
||||||
|
# 示例:
|
||||||
|
# def my_llm_creator():
|
||||||
|
# from ..model_services import get_llm
|
||||||
|
# return get_llm()
|
||||||
|
#
|
||||||
|
# rag_tool = await init_rag_tool(my_llm_creator)
|
||||||
|
# set_global_rag_tool(rag_tool)
|
||||||
|
|
||||||
|
print("⚠️ initialize_rag_from_initializer 需要传入 local_llm_creator")
|
||||||
|
print("⚠️ 请在应用启动时调用 init_rag_tool() 并设置全局 RAG 工具")
|
||||||
|
|
||||||
|
except ImportError as e:
|
||||||
|
print(f"⚠️ 无法导入 rag_initializer: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠️ RAG 初始化失败: {e}")
|
||||||
|
|
||||||
|
|
||||||
# ========== 导出 ==========
|
# ========== 导出 ==========
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
# 节点函数
|
||||||
"rag_retrieve_node",
|
"rag_retrieve_node",
|
||||||
"rag_re_retrieve_node",
|
"rag_re_retrieve_node",
|
||||||
|
|
||||||
|
# 工具函数
|
||||||
"inject_rag_tool_to_state",
|
"inject_rag_tool_to_state",
|
||||||
"get_rag_tool_from_state"
|
"get_rag_tool_from_state",
|
||||||
|
|
||||||
|
# 全局 RAG 管理
|
||||||
|
"get_global_rag_tool",
|
||||||
|
"set_global_rag_tool",
|
||||||
|
"set_global_rag_pipeline",
|
||||||
|
|
||||||
|
# 初始化
|
||||||
|
"initialize_rag_from_initializer"
|
||||||
]
|
]
|
||||||
|
|||||||
Reference in New Issue
Block a user