This commit is contained in:
@@ -95,10 +95,10 @@ def inject_rag_tool_to_state(state: MainGraphState, rag_tool: Any) -> MainGraphS
|
||||
return state
|
||||
|
||||
|
||||
# ========== RAG 检索核心逻辑(真正利用已有代码)==========
|
||||
def _rag_retrieve_core(state: MainGraphState) -> MainGraphState:
|
||||
# ========== RAG 检索核心逻辑(真正利用已有代码) ==========
|
||||
async def _rag_retrieve_core(state: MainGraphState) -> MainGraphState:
|
||||
"""
|
||||
RAG 检索核心逻辑(真正利用 rag/tools.py)
|
||||
RAG 检索核心逻辑(真正利用 rag/tools.py) - 异步版本
|
||||
|
||||
Args:
|
||||
state: 主图状态
|
||||
@@ -119,10 +119,10 @@ def _rag_retrieve_core(state: MainGraphState) -> MainGraphState:
|
||||
rag_tool = get_rag_tool_from_state(state)
|
||||
|
||||
if rag_tool:
|
||||
# 使用真正的 RAG 工具(来自 rag/tools.py)
|
||||
# 使用真正的 RAG 工具(来自 rag/tools.py)- 异步版本
|
||||
try:
|
||||
# 调用 LangChain Tool 的 invoke 方法
|
||||
rag_context = rag_tool.invoke(retrieval_query)
|
||||
# 直接 await 异步工具的 ainvoke 方法
|
||||
rag_context = await rag_tool.ainvoke(retrieval_query)
|
||||
state.rag_context = rag_context
|
||||
state.rag_docs = [
|
||||
{"source": "rag_retrieval", "content": rag_context}
|
||||
@@ -134,9 +134,9 @@ def _rag_retrieve_core(state: MainGraphState) -> MainGraphState:
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"RAG 工具调用失败: {str(e)}") from e
|
||||
elif _GLOBAL_RAG_PIPELINE:
|
||||
# 使用 RAG Pipeline 直接检索
|
||||
# 使用 RAG Pipeline 直接检索 - 直接用异步方法
|
||||
try:
|
||||
documents = _GLOBAL_RAG_PIPELINE.retrieve(retrieval_query)
|
||||
documents = await _GLOBAL_RAG_PIPELINE.aretrieve(retrieval_query)
|
||||
if documents:
|
||||
rag_context = _GLOBAL_RAG_PIPELINE.format_context(documents)
|
||||
state.rag_context = rag_context
|
||||
@@ -158,7 +158,7 @@ def _rag_retrieve_core(state: MainGraphState) -> MainGraphState:
|
||||
raise RuntimeError("RAG 工具未初始化,请先调用 set_global_rag_tool() 或 set_global_rag_pipeline()")
|
||||
|
||||
|
||||
# ========== RAG 检索节点(带超时和重试)==========
|
||||
# ========== RAG 检索节点(带超时和重试) ==========
|
||||
async def rag_retrieve_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState:
|
||||
"""
|
||||
RAG 检索节点:带超时和重试,真正利用已有 RAG 代码
|
||||
@@ -196,8 +196,13 @@ async def rag_retrieve_node(state: MainGraphState, config: Optional[Dict[str, An
|
||||
|
||||
for attempt in range(RAG_RETRY_CONFIG.max_retries + 1):
|
||||
try:
|
||||
# 执行核心逻辑
|
||||
result = _rag_retrieve_core(state)
|
||||
# 执行核心逻辑 - 异步 await
|
||||
result = await _rag_retrieve_core(state)
|
||||
|
||||
info(f"[rag_retrieve_node] RAG 检索成功,获取到上下文长度: {len(result.rag_context)} 字符")
|
||||
if result.rag_docs:
|
||||
for i, doc in enumerate(result.rag_docs[:3]): # 只显示前3条
|
||||
info(f"[rag_retrieve_node] 文档 {i+1}: {doc.get('content', '')[:100]}...")
|
||||
|
||||
# 成功
|
||||
state.debug_info["rag_retrieval"] = {
|
||||
@@ -226,6 +231,15 @@ async def rag_retrieve_node(state: MainGraphState, config: Optional[Dict[str, An
|
||||
except Exception as e:
|
||||
info(f"[rag_retrieve_node] 无法发送完成事件: {e}")
|
||||
|
||||
# 关键修复:把 rag_retrieve 加到 reasoning_history 里,让下次推理知道
|
||||
state.reasoning_history.append({
|
||||
"step": state.reasoning_step,
|
||||
"action": "rag_retrieve",
|
||||
"confidence": 1.0,
|
||||
"reasoning": "RAG 检索完成",
|
||||
"timestamp": datetime.now().isoformat()
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
@@ -255,7 +269,7 @@ async def rag_retrieve_node(state: MainGraphState, config: Optional[Dict[str, An
|
||||
|
||||
# 指数退避等待
|
||||
delay = RAG_RETRY_CONFIG.base_delay * (2 ** attempt)
|
||||
time.sleep(min(delay, RAG_RETRY_CONFIG.max_delay))
|
||||
await asyncio.sleep(min(delay, RAG_RETRY_CONFIG.max_delay))
|
||||
|
||||
# 所有重试都失败,记录结构化错误
|
||||
error_record = ErrorRecord(
|
||||
|
||||
@@ -364,20 +364,27 @@ def route_by_reasoning(state: MainGraphState) -> str:
|
||||
if "subgraph_completed" in previous_actions or state.final_result:
|
||||
return "llm_call"
|
||||
|
||||
# 检查是否刚刚执行完 rag 或 web search,应该继续推理一次然后去 llm_call
|
||||
# 但为了避免死循环,我们设置一个简单的规则
|
||||
if len(previous_actions) > 3:
|
||||
# 关键修复:如果已经执行过 rag_retrieve 并且又执行过推理,直接去 LLM_CALL
|
||||
# 这样的流程:推理1 → RAG → 推理2(判断 RAG 结果) → LLM_CALL
|
||||
rag_count = previous_actions.count("rag_retrieve")
|
||||
if rag_count >= 1 and len(previous_actions) >= rag_count + 1:
|
||||
info(f"[route_by_reasoning] 已完成 RAG 检索和结果判断,直接去 llm_call")
|
||||
return "llm_call"
|
||||
|
||||
|
||||
# 关键修复:限制最多 3 次推理,避免无限循环
|
||||
if len(previous_actions) >= 3:
|
||||
info(f"[route_by_reasoning] 已达到最大推理次数 ({len(previous_actions)}),直接去 llm_call")
|
||||
return "llm_call"
|
||||
|
||||
# 获取推理结果
|
||||
reasoning_result: Optional[ReasoningResult] = state.debug_info.get("reasoning_result")
|
||||
|
||||
|
||||
if not reasoning_result:
|
||||
return "llm_call"
|
||||
|
||||
|
||||
# 使用 intent.py 提供的路由函数
|
||||
route = get_route_by_reasoning(reasoning_result)
|
||||
|
||||
|
||||
# 映射到我们的节点名称
|
||||
# 注意:这些名称必须与 main_graph_builder.py 中定义的节点名称一致
|
||||
route_mapping = {
|
||||
@@ -391,7 +398,8 @@ def route_by_reasoning(state: MainGraphState) -> str:
|
||||
"dictionary": "dictionary_subgraph",
|
||||
"news_analysis": "news_analysis_subgraph",
|
||||
}
|
||||
|
||||
|
||||
info(f"[route_by_reasoning] 推理结果={reasoning_result.action.name}, 路由={route_mapping.get(route, 'llm_call')}, 历史动作={previous_actions}")
|
||||
return route_mapping.get(route, "llm_call")
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user