refactor: 重构快速路径流程,统一通过 llm_call 输出
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 5m31s
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 5m31s
- 重构 fast_paths.py,让 fast_chitchat 和 fast_rag 都进入 llm_call 而不是直接设置 final_result - 修改 check_fast_path_success 函数返回 'llm_call' 而不是 'success' - 更新 main_graph_builder.py 的条件边配置,支持路由到 llm_call - 在快速路径节点中添加清除 state.final_result 的逻辑,避免复用旧结果 - 重构 RAG 工具初始化方式,使用模块级变量管理 - 修改 finalize.py 让它返回 final_result - 更新 agent_service.py 的 RAG 工具注入方式 - 简化 hybrid_router.py 的代码结构 - 清理 rag_nodes.py 的全局变量相关代码 - 更新相关测试文件
This commit is contained in:
@@ -1,11 +1,11 @@
|
||||
"""
|
||||
RAG 检索节点模块
|
||||
包含 RAG 检索节点(带超时重试)
|
||||
使用模块级变量管理 RAG 工具
|
||||
"""
|
||||
|
||||
import time
|
||||
import asyncio
|
||||
from typing import Dict, Any, Optional
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
|
||||
from app.main_graph.state import MainGraphState, ErrorRecord, ErrorSeverity
|
||||
@@ -13,43 +13,15 @@ from app.main_graph.utils.retry_utils import RAG_RETRY_CONFIG
|
||||
from app.logger import info
|
||||
from ._utils import dispatch_custom_event, make_react_event
|
||||
|
||||
from app.rag.tools import create_rag_tool
|
||||
from app.rag.pipeline import RAGPipeline
|
||||
|
||||
|
||||
# ========== 全局 RAG 工具实例 ==========
|
||||
_GLOBAL_RAG_TOOL: Optional[Any] = None
|
||||
_GLOBAL_RAG_PIPELINE: Optional[RAGPipeline] = None
|
||||
|
||||
|
||||
def get_global_rag_tool() -> Optional[Any]:
|
||||
return _GLOBAL_RAG_TOOL
|
||||
|
||||
|
||||
def set_global_rag_tool(tool: Any) -> None:
|
||||
global _GLOBAL_RAG_TOOL
|
||||
_GLOBAL_RAG_TOOL = tool
|
||||
|
||||
|
||||
def set_global_rag_pipeline(pipeline: RAGPipeline) -> None:
|
||||
global _GLOBAL_RAG_PIPELINE
|
||||
_GLOBAL_RAG_PIPELINE = pipeline
|
||||
|
||||
|
||||
def get_rag_tool_from_state(state: MainGraphState) -> Optional[Any]:
|
||||
"""从状态或全局获取 RAG 工具"""
|
||||
return state.debug_info.get("rag_tool") or get_global_rag_tool()
|
||||
|
||||
|
||||
def inject_rag_tool_to_state(state: MainGraphState, rag_tool: Any) -> MainGraphState:
|
||||
"""将 RAG 工具注入到状态中"""
|
||||
state.debug_info["rag_tool"] = rag_tool
|
||||
state.debug_info["rag_tool_injected"] = datetime.now().isoformat()
|
||||
return state
|
||||
def _get_rag_tool() -> Optional[callable]:
|
||||
"""获取 RAG 工具"""
|
||||
from app.main_graph.utils.rag_initializer import get_rag_tool
|
||||
return get_rag_tool()
|
||||
|
||||
|
||||
# ========== RAG 检索核心逻辑 ==========
|
||||
async def _rag_retrieve_core(state: MainGraphState) -> MainGraphState:
|
||||
async def _rag_retrieve_core(state: MainGraphState, rag_tool: callable) -> MainGraphState:
|
||||
"""执行 RAG 检索的核心逻辑"""
|
||||
retrieval_query = state.user_query
|
||||
|
||||
@@ -60,55 +32,54 @@ async def _rag_retrieve_core(state: MainGraphState) -> MainGraphState:
|
||||
if cfg and cfg.retrieval_query:
|
||||
retrieval_query = cfg.retrieval_query
|
||||
|
||||
rag_tool = get_rag_tool_from_state(state)
|
||||
# 调用 RAG 工具
|
||||
rag_context = await rag_tool.ainvoke(retrieval_query)
|
||||
info(f"[RAG Core] 获取到 rag_context: {type(rag_context)}, 长度={len(rag_context) if rag_context else 0}")
|
||||
|
||||
if rag_tool:
|
||||
rag_context = await rag_tool.ainvoke(retrieval_query)
|
||||
state.rag_context = rag_context
|
||||
state.rag_docs = [{"source": "rag_retrieval", "content": rag_context}]
|
||||
state.rag_retrieved = True
|
||||
state.success = True
|
||||
state.debug_info["rag_source"] = "rag_tool"
|
||||
return state
|
||||
state.rag_context = rag_context
|
||||
state.rag_docs = [{"source": "rag_retrieval", "content": rag_context}]
|
||||
state.rag_retrieved = True
|
||||
state.success = True
|
||||
state.debug_info["rag_source"] = "tool"
|
||||
|
||||
if _GLOBAL_RAG_PIPELINE:
|
||||
documents = await _GLOBAL_RAG_PIPELINE.aretrieve(retrieval_query)
|
||||
if documents:
|
||||
rag_context = _GLOBAL_RAG_PIPELINE.format_context(documents)
|
||||
state.rag_context = rag_context
|
||||
state.rag_docs = [
|
||||
{"source": doc.metadata.get("source", "unknown"), "content": doc.page_content}
|
||||
for doc in documents
|
||||
]
|
||||
else:
|
||||
state.rag_context = f"未找到与 '{retrieval_query}' 相关的知识库信息。"
|
||||
state.rag_docs = []
|
||||
state.rag_retrieved = True
|
||||
state.success = True
|
||||
state.debug_info["rag_source"] = "rag_pipeline"
|
||||
return state
|
||||
|
||||
raise RuntimeError("RAG 工具未初始化,请先调用 set_global_rag_tool() 或 set_global_rag_pipeline()")
|
||||
info(f"[RAG Core] state.rag_docs 长度: {len(state.rag_docs)}")
|
||||
return state
|
||||
|
||||
|
||||
# ========== RAG 检索节点 ==========
|
||||
async def rag_retrieve_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState:
|
||||
async def rag_retrieve_node(state: MainGraphState, config: Optional[dict] = None) -> MainGraphState:
|
||||
"""RAG 检索节点:带超时和重试"""
|
||||
state.current_phase = "rag_retrieving"
|
||||
start_time = time.time()
|
||||
last_error = None
|
||||
|
||||
# 步骤1: 发送开始事件
|
||||
# 获取 RAG 工具
|
||||
rag_tool = _get_rag_tool()
|
||||
|
||||
if not rag_tool:
|
||||
error_record = ErrorRecord(
|
||||
error_type="RAGRetrievalError",
|
||||
error_message="RAG 工具未初始化",
|
||||
severity=ErrorSeverity.WARNING,
|
||||
source="rag_retrieve_node",
|
||||
timestamp=datetime.now().isoformat(),
|
||||
retry_count=0,
|
||||
max_retries=RAG_RETRY_CONFIG.max_retries,
|
||||
)
|
||||
state.errors.append(error_record)
|
||||
state.current_error = error_record
|
||||
state.current_phase = "error_handling"
|
||||
return state
|
||||
|
||||
await dispatch_custom_event(
|
||||
"react_reasoning",
|
||||
make_react_event(state.reasoning_step, "rag_retrieve_start", 1.0, "开始执行 RAG 检索..."),
|
||||
config
|
||||
)
|
||||
|
||||
# 步骤2: 执行检索(带重试)
|
||||
for attempt in range(RAG_RETRY_CONFIG.max_retries + 1):
|
||||
try:
|
||||
result = await _rag_retrieve_core(state)
|
||||
result = await _rag_retrieve_core(state, rag_tool)
|
||||
|
||||
info(f"[RAG] 检索成功,上下文长度: {len(result.rag_context)} 字符")
|
||||
|
||||
@@ -118,7 +89,6 @@ async def rag_retrieve_node(state: MainGraphState, config: Optional[Dict[str, An
|
||||
"time": time.time() - start_time
|
||||
}
|
||||
|
||||
# 记录成功到历史
|
||||
state.reasoning_history.append({
|
||||
"step": state.reasoning_step,
|
||||
"action": "RETRIEVE_RAG",
|
||||
@@ -127,7 +97,6 @@ async def rag_retrieve_node(state: MainGraphState, config: Optional[Dict[str, An
|
||||
"timestamp": datetime.now().isoformat()
|
||||
})
|
||||
|
||||
# 发送完成事件
|
||||
doc_count = len(result.rag_docs) if result.rag_docs else 0
|
||||
await dispatch_custom_event(
|
||||
"react_reasoning",
|
||||
@@ -144,7 +113,6 @@ async def rag_retrieve_node(state: MainGraphState, config: Optional[Dict[str, An
|
||||
if attempt >= RAG_RETRY_CONFIG.max_retries:
|
||||
break
|
||||
|
||||
# 发送重试事件
|
||||
await dispatch_custom_event(
|
||||
"react_reasoning",
|
||||
make_react_event(state.reasoning_step, "rag_retrieve_retry", 1.0,
|
||||
@@ -152,11 +120,10 @@ async def rag_retrieve_node(state: MainGraphState, config: Optional[Dict[str, An
|
||||
config
|
||||
)
|
||||
|
||||
# 指数退避
|
||||
delay = RAG_RETRY_CONFIG.base_delay * (2 ** attempt)
|
||||
await asyncio.sleep(min(delay, RAG_RETRY_CONFIG.max_delay))
|
||||
|
||||
# 步骤3: 所有重试失败,记录到历史(避免推理循环)
|
||||
# 失败记录
|
||||
state.reasoning_history.append({
|
||||
"step": state.reasoning_step,
|
||||
"action": "RETRIEVE_RAG",
|
||||
@@ -165,7 +132,6 @@ async def rag_retrieve_node(state: MainGraphState, config: Optional[Dict[str, An
|
||||
"timestamp": datetime.now().isoformat()
|
||||
})
|
||||
|
||||
# 步骤4: 记录错误
|
||||
error_record = ErrorRecord(
|
||||
error_type="RAGRetrievalError",
|
||||
error_message=str(last_error) if last_error else "RAG 检索超时",
|
||||
@@ -174,19 +140,12 @@ async def rag_retrieve_node(state: MainGraphState, config: Optional[Dict[str, An
|
||||
timestamp=datetime.now().isoformat(),
|
||||
retry_count=RAG_RETRY_CONFIG.max_retries,
|
||||
max_retries=RAG_RETRY_CONFIG.max_retries,
|
||||
context={
|
||||
"query": state.user_query,
|
||||
"total_time": time.time() - start_time,
|
||||
"has_rag_tool": get_global_rag_tool() is not None,
|
||||
"has_rag_pipeline": _GLOBAL_RAG_PIPELINE is not None
|
||||
}
|
||||
)
|
||||
|
||||
state.errors.append(error_record)
|
||||
state.current_error = error_record
|
||||
state.current_phase = "error_handling"
|
||||
|
||||
# 发送错误事件
|
||||
await dispatch_custom_event(
|
||||
"react_reasoning",
|
||||
make_react_event(state.reasoning_step, "rag_retrieve_error", 1.0,
|
||||
@@ -197,8 +156,7 @@ async def rag_retrieve_node(state: MainGraphState, config: Optional[Dict[str, An
|
||||
return state
|
||||
|
||||
|
||||
# ========== 重新检索节点 ==========
|
||||
async def rag_re_retrieve_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState:
|
||||
async def rag_re_retrieve_node(state: MainGraphState, config: Optional[dict] = None) -> MainGraphState:
|
||||
"""重新检索节点"""
|
||||
state.current_phase = "rag_re_retrieving"
|
||||
|
||||
@@ -214,9 +172,4 @@ async def rag_re_retrieve_node(state: MainGraphState, config: Optional[Dict[str,
|
||||
__all__ = [
|
||||
"rag_retrieve_node",
|
||||
"rag_re_retrieve_node",
|
||||
"inject_rag_tool_to_state",
|
||||
"get_rag_tool_from_state",
|
||||
"get_global_rag_tool",
|
||||
"set_global_rag_tool",
|
||||
"set_global_rag_pipeline",
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user