180 lines
6.0 KiB
Python
180 lines
6.0 KiB
Python
"""
|
|
RAG 检索节点模块
|
|
使用模块级变量管理 RAG 工具
|
|
"""
|
|
|
|
import time
|
|
import asyncio
|
|
from typing import Optional
|
|
from datetime import datetime
|
|
from langchain_core.runnables.config import RunnableConfig
|
|
|
|
from ...main_graph.state import MainGraphState, ErrorRecord, ErrorSeverity
|
|
from ...main_graph.utils.retry_utils import RAG_RETRY_CONFIG
|
|
from ...logger import info
|
|
from ._utils import dispatch_custom_event, make_react_event
|
|
|
|
|
|
def _get_rag_tool() -> Optional[callable]:
|
|
"""获取 RAG 工具"""
|
|
from backend.app.main_graph.utils.rag_initializer import get_rag_tool
|
|
return get_rag_tool()
|
|
|
|
|
|
# ========== RAG 检索核心逻辑 ==========
|
|
async def _rag_retrieve_core(state: MainGraphState, rag_tool: callable) -> MainGraphState:
|
|
"""执行 RAG 检索的核心逻辑"""
|
|
retrieval_query = state.user_query
|
|
|
|
# 优先使用推理结果中的优化查询
|
|
reasoning_result = state.debug_info.get("reasoning_result")
|
|
if reasoning_result and hasattr(reasoning_result, "retrieval_config"):
|
|
cfg = reasoning_result.retrieval_config
|
|
if cfg and cfg.retrieval_query:
|
|
retrieval_query = cfg.retrieval_query
|
|
|
|
# 调用 RAG 工具
|
|
rag_context = await rag_tool.ainvoke(retrieval_query)
|
|
info(f"[RAG Core] 获取到 rag_context: {type(rag_context)}, 长度={len(rag_context) if rag_context else 0}")
|
|
info(f"[RAG Core] ========== RAG 返回的知识内容 ==========")
|
|
info(f"{rag_context[:500]}..." if len(rag_context) > 500 else rag_context)
|
|
info(f"[RAG Core] ========================================")
|
|
|
|
state.rag_context = rag_context
|
|
state.rag_docs = [{"source": "rag_retrieval", "content": rag_context}]
|
|
state.rag_retrieved = True
|
|
state.success = True
|
|
state.debug_info["rag_source"] = "tool"
|
|
|
|
info(f"[RAG Core] state.rag_docs 长度: {len(state.rag_docs)}")
|
|
return state
|
|
|
|
|
|
# ========== RAG 检索节点 ==========
|
|
async def rag_retrieve_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState:
|
|
"""RAG 检索节点:带超时和重试"""
|
|
state.current_phase = "rag_retrieving"
|
|
start_time = time.time()
|
|
last_error = None
|
|
|
|
# 获取 RAG 工具
|
|
rag_tool = _get_rag_tool()
|
|
|
|
if not rag_tool:
|
|
error_record = ErrorRecord(
|
|
error_type="RAGRetrievalError",
|
|
error_message="RAG 工具未初始化",
|
|
severity=ErrorSeverity.WARNING,
|
|
source="rag_retrieve_node",
|
|
timestamp=datetime.now().isoformat(),
|
|
retry_count=0,
|
|
max_retries=RAG_RETRY_CONFIG.max_retries,
|
|
)
|
|
state.errors.append(error_record)
|
|
state.current_error = error_record
|
|
state.current_phase = "error_handling"
|
|
return state
|
|
|
|
await dispatch_custom_event(
|
|
"react_reasoning",
|
|
make_react_event(state.reasoning_step, "rag_retrieve_start", 1.0, "开始执行 RAG 检索..."),
|
|
config
|
|
)
|
|
|
|
for attempt in range(RAG_RETRY_CONFIG.max_retries + 1):
|
|
try:
|
|
result = await _rag_retrieve_core(state, rag_tool)
|
|
|
|
info(f"[RAG] 检索成功,上下文长度: {len(result.rag_context)} 字符")
|
|
|
|
state.debug_info["rag_retrieval"] = {
|
|
"attempt": attempt + 1,
|
|
"success": True,
|
|
"time": time.time() - start_time
|
|
}
|
|
|
|
state.reasoning_history.append({
|
|
"step": state.reasoning_step,
|
|
"action": "RETRIEVE_RAG",
|
|
"confidence": 1.0,
|
|
"reasoning": "RAG 检索完成",
|
|
"timestamp": datetime.now().isoformat()
|
|
})
|
|
|
|
doc_count = len(result.rag_docs) if result.rag_docs else 0
|
|
await dispatch_custom_event(
|
|
"react_reasoning",
|
|
make_react_event(state.reasoning_step, "rag_retrieve_complete", 1.0,
|
|
f"RAG 检索完成,找到 {doc_count} 条相关文档"),
|
|
config
|
|
)
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
last_error = e
|
|
|
|
if attempt >= RAG_RETRY_CONFIG.max_retries:
|
|
break
|
|
|
|
await dispatch_custom_event(
|
|
"react_reasoning",
|
|
make_react_event(state.reasoning_step, "rag_retrieve_retry", 1.0,
|
|
f"RAG 检索失败,第 {attempt + 1} 次重试..."),
|
|
config
|
|
)
|
|
|
|
delay = RAG_RETRY_CONFIG.base_delay * (2 ** attempt)
|
|
await asyncio.sleep(min(delay, RAG_RETRY_CONFIG.max_delay))
|
|
|
|
# 失败记录
|
|
state.reasoning_history.append({
|
|
"step": state.reasoning_step,
|
|
"action": "RETRIEVE_RAG",
|
|
"confidence": 0.0,
|
|
"reasoning": f"RAG 检索失败: {str(last_error) if last_error else '超时'}",
|
|
"timestamp": datetime.now().isoformat()
|
|
})
|
|
|
|
error_record = ErrorRecord(
|
|
error_type="RAGRetrievalError",
|
|
error_message=str(last_error) if last_error else "RAG 检索超时",
|
|
severity=ErrorSeverity.WARNING,
|
|
source="rag_retrieve_node",
|
|
timestamp=datetime.now().isoformat(),
|
|
retry_count=RAG_RETRY_CONFIG.max_retries,
|
|
max_retries=RAG_RETRY_CONFIG.max_retries,
|
|
)
|
|
|
|
state.errors.append(error_record)
|
|
state.current_error = error_record
|
|
state.current_phase = "error_handling"
|
|
|
|
await dispatch_custom_event(
|
|
"react_reasoning",
|
|
make_react_event(state.reasoning_step, "rag_retrieve_error", 1.0,
|
|
f"RAG 检索失败: {str(last_error)}"),
|
|
config
|
|
)
|
|
|
|
return state
|
|
|
|
|
|
async def rag_re_retrieve_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState:
|
|
"""重新检索节点"""
|
|
state.current_phase = "rag_re_retrieving"
|
|
|
|
state.debug_info["rag_re_retrieve"] = {
|
|
"original_retrieved": state.rag_retrieved,
|
|
"original_docs_count": len(state.rag_docs)
|
|
}
|
|
|
|
return await rag_retrieve_node(state, config)
|
|
|
|
|
|
# ========== 导出 ==========
|
|
__all__ = [
|
|
"rag_retrieve_node",
|
|
"rag_re_retrieve_node",
|
|
]
|