修改rag,实现混合检索
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 5m42s

This commit is contained in:
2026-05-04 04:28:32 +08:00
parent d0590240f9
commit 82dde7113e
15 changed files with 536 additions and 65 deletions

View File

@@ -95,10 +95,10 @@ def inject_rag_tool_to_state(state: MainGraphState, rag_tool: Any) -> MainGraphS
return state
# ========== RAG 检索核心逻辑(真正利用已有代码)==========
def _rag_retrieve_core(state: MainGraphState) -> MainGraphState:
# ========== RAG 检索核心逻辑(真正利用已有代码) ==========
async def _rag_retrieve_core(state: MainGraphState) -> MainGraphState:
"""
RAG 检索核心逻辑(真正利用 rag/tools.py
RAG 检索核心逻辑(真正利用 rag/tools.py - 异步版本
Args:
state: 主图状态
@@ -119,10 +119,10 @@ def _rag_retrieve_core(state: MainGraphState) -> MainGraphState:
rag_tool = get_rag_tool_from_state(state)
if rag_tool:
# 使用真正的 RAG 工具(来自 rag/tools.py
# 使用真正的 RAG 工具(来自 rag/tools.py- 异步版本
try:
# 调用 LangChain Tool 的 invoke 方法
rag_context = rag_tool.invoke(retrieval_query)
# 直接 await 异步工具ainvoke 方法
rag_context = await rag_tool.ainvoke(retrieval_query)
state.rag_context = rag_context
state.rag_docs = [
{"source": "rag_retrieval", "content": rag_context}
@@ -134,9 +134,9 @@ def _rag_retrieve_core(state: MainGraphState) -> MainGraphState:
except Exception as e:
raise RuntimeError(f"RAG 工具调用失败: {str(e)}") from e
elif _GLOBAL_RAG_PIPELINE:
# 使用 RAG Pipeline 直接检索
# 使用 RAG Pipeline 直接检索 - 直接用异步方法
try:
documents = _GLOBAL_RAG_PIPELINE.retrieve(retrieval_query)
documents = await _GLOBAL_RAG_PIPELINE.aretrieve(retrieval_query)
if documents:
rag_context = _GLOBAL_RAG_PIPELINE.format_context(documents)
state.rag_context = rag_context
@@ -158,7 +158,7 @@ def _rag_retrieve_core(state: MainGraphState) -> MainGraphState:
raise RuntimeError("RAG 工具未初始化,请先调用 set_global_rag_tool() 或 set_global_rag_pipeline()")
# ========== RAG 检索节点(带超时和重试)==========
# ========== RAG 检索节点(带超时和重试) ==========
async def rag_retrieve_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState:
"""
RAG 检索节点:带超时和重试,真正利用已有 RAG 代码
@@ -196,8 +196,13 @@ async def rag_retrieve_node(state: MainGraphState, config: Optional[Dict[str, An
for attempt in range(RAG_RETRY_CONFIG.max_retries + 1):
try:
# 执行核心逻辑
result = _rag_retrieve_core(state)
# 执行核心逻辑 - 异步 await
result = await _rag_retrieve_core(state)
info(f"[rag_retrieve_node] RAG 检索成功,获取到上下文长度: {len(result.rag_context)} 字符")
if result.rag_docs:
for i, doc in enumerate(result.rag_docs[:3]): # 只显示前3条
info(f"[rag_retrieve_node] 文档 {i+1}: {doc.get('content', '')[:100]}...")
# 成功
state.debug_info["rag_retrieval"] = {
@@ -226,6 +231,15 @@ async def rag_retrieve_node(state: MainGraphState, config: Optional[Dict[str, An
except Exception as e:
info(f"[rag_retrieve_node] 无法发送完成事件: {e}")
# 关键修复:把 rag_retrieve 加到 reasoning_history 里,让下次推理知道
state.reasoning_history.append({
"step": state.reasoning_step,
"action": "rag_retrieve",
"confidence": 1.0,
"reasoning": "RAG 检索完成",
"timestamp": datetime.now().isoformat()
})
return result
except Exception as e:
@@ -255,7 +269,7 @@ async def rag_retrieve_node(state: MainGraphState, config: Optional[Dict[str, An
# 指数退避等待
delay = RAG_RETRY_CONFIG.base_delay * (2 ** attempt)
time.sleep(min(delay, RAG_RETRY_CONFIG.max_delay))
await asyncio.sleep(min(delay, RAG_RETRY_CONFIG.max_delay))
# 所有重试都失败,记录结构化错误
error_record = ErrorRecord(

View File

@@ -364,20 +364,27 @@ def route_by_reasoning(state: MainGraphState) -> str:
if "subgraph_completed" in previous_actions or state.final_result:
return "llm_call"
# 检查是否刚刚执行 rag 或 web search应该继续推理一次然后去 llm_call
# 但为了避免死循环,我们设置一个简单的规则
if len(previous_actions) > 3:
# 关键修复:如果已经执行 rag_retrieve 并且又执行过推理,直接去 LLM_CALL
# 这样的流程推理1 → RAG → 推理2判断 RAG 结果) → LLM_CALL
rag_count = previous_actions.count("rag_retrieve")
if rag_count >= 1 and len(previous_actions) >= rag_count + 1:
info(f"[route_by_reasoning] 已完成 RAG 检索和结果判断,直接去 llm_call")
return "llm_call"
# 关键修复:限制最多 3 次推理,避免无限循环
if len(previous_actions) >= 3:
info(f"[route_by_reasoning] 已达到最大推理次数 ({len(previous_actions)}),直接去 llm_call")
return "llm_call"
# 获取推理结果
reasoning_result: Optional[ReasoningResult] = state.debug_info.get("reasoning_result")
if not reasoning_result:
return "llm_call"
# 使用 intent.py 提供的路由函数
route = get_route_by_reasoning(reasoning_result)
# 映射到我们的节点名称
# 注意:这些名称必须与 main_graph_builder.py 中定义的节点名称一致
route_mapping = {
@@ -391,7 +398,8 @@ def route_by_reasoning(state: MainGraphState) -> str:
"dictionary": "dictionary_subgraph",
"news_analysis": "news_analysis_subgraph",
}
info(f"[route_by_reasoning] 推理结果={reasoning_result.action.name}, 路由={route_mapping.get(route, 'llm_call')}, 历史动作={previous_actions}")
return route_mapping.get(route, "llm_call")