修复循环推理bug
This commit is contained in:
@@ -151,6 +151,20 @@ class ReactIntentReasoner:
|
||||
result.reasoning = "已获取足够信息,直接回答"
|
||||
return result
|
||||
|
||||
# 检查 RAG 是否多次失败(reasoning_history 中有失败的 RAG 记录)
|
||||
# 失败的 RAG 记录特征:confidence = 0.0
|
||||
rag_history = context.get("reasoning_history", [])
|
||||
rag_fail_count = sum(
|
||||
1 for h in rag_history
|
||||
if h.get("action") in ("RETRIEVE_RAG", "RE_RETRIEVE_RAG") and h.get("confidence", 1.0) == 0.0
|
||||
)
|
||||
if rag_fail_count >= 2:
|
||||
# RAG 多次失败,应该直接回答而不是继续重试
|
||||
result.action = ReasoningAction.DIRECT_RESPONSE
|
||||
result.confidence = 0.7
|
||||
result.reasoning = f"RAG 已尝试 {rag_fail_count} 次均失败,知识库无相关内容,直接基于常识回答"
|
||||
return result
|
||||
|
||||
# 策略1:尝试使用 LLM 推理
|
||||
try:
|
||||
llm_result = await self._reason_with_llm(query, context)
|
||||
@@ -347,7 +361,7 @@ _reasoner: Optional[ReactIntentReasoner] = None
|
||||
_small_reasoner: Optional[ReactIntentReasoner] = None
|
||||
|
||||
|
||||
def _get_reasoner(use_small_llm: bool = False) -> ReactIntentReasoner:
|
||||
def _get_reasoner(use_small_llm: bool = True) -> ReactIntentReasoner:
|
||||
"""
|
||||
获取推理器实例
|
||||
|
||||
@@ -371,7 +385,7 @@ def _get_reasoner(use_small_llm: bool = False) -> ReactIntentReasoner:
|
||||
async def react_reason_async(
|
||||
query: str,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
use_small_llm: bool = False
|
||||
use_small_llm: bool = True
|
||||
) -> ReasoningResult:
|
||||
"""
|
||||
便捷函数:异步 React 推理(推荐使用)
|
||||
|
||||
@@ -2,25 +2,30 @@
|
||||
主图节点模块导出
|
||||
"""
|
||||
|
||||
# 新的 React 模式节点
|
||||
from .react_nodes import (
|
||||
init_state_node,
|
||||
react_reason_node,
|
||||
web_search_node,
|
||||
error_handling_node,
|
||||
route_by_reasoning
|
||||
)
|
||||
# React 模式节点
|
||||
from .reasoning import react_reason_node
|
||||
from .web_search import web_search_node
|
||||
from .error_handling import error_handling_node
|
||||
from .routing import init_state_node, route_by_reasoning
|
||||
from .llm_call import create_llm_call_node
|
||||
from .rag_nodes import rag_retrieve_node
|
||||
from .rag_nodes import rag_retrieve_node, rag_re_retrieve_node
|
||||
|
||||
# 记忆节点(已更新到 MainGraphState)
|
||||
# 记忆节点
|
||||
from .retrieve_memory import create_retrieve_memory_node
|
||||
from .memory_trigger import memory_trigger_node, set_mem0_client
|
||||
from .summarize import create_summarize_node
|
||||
from .finalize import finalize_node
|
||||
|
||||
# 路由(已更新到 MainGraphState)
|
||||
from .router import should_continue
|
||||
# 混合路由节点
|
||||
from .hybrid_router import (
|
||||
hybrid_router_node,
|
||||
fast_chitchat_node,
|
||||
fast_rag_node,
|
||||
fast_tool_node,
|
||||
)
|
||||
|
||||
# 通用工具
|
||||
from ._utils import dispatch_custom_event, make_react_event
|
||||
|
||||
__all__ = [
|
||||
# React 模式节点
|
||||
@@ -29,15 +34,21 @@ __all__ = [
|
||||
"web_search_node",
|
||||
"error_handling_node",
|
||||
"route_by_reasoning",
|
||||
# 通用节点
|
||||
"create_llm_call_node",
|
||||
"rag_retrieve_node",
|
||||
"rag_re_retrieve_node",
|
||||
# 记忆节点
|
||||
"create_retrieve_memory_node",
|
||||
"memory_trigger_node",
|
||||
"set_mem0_client",
|
||||
"create_summarize_node",
|
||||
"finalize_node",
|
||||
# 路由
|
||||
"should_continue",
|
||||
# 混合路由节点
|
||||
"hybrid_router_node",
|
||||
"fast_chitchat_node",
|
||||
"fast_rag_node",
|
||||
"fast_tool_node",
|
||||
# 通用工具
|
||||
"dispatch_custom_event",
|
||||
"make_react_event",
|
||||
]
|
||||
|
||||
57
backend/app/main_graph/nodes/_utils.py
Normal file
57
backend/app/main_graph/nodes/_utils.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""
|
||||
主图节点通用工具模块
|
||||
包含事件发送、状态更新等通用功能
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
|
||||
async def dispatch_custom_event(
|
||||
event_name: str,
|
||||
data: Dict[str, Any],
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
安全地发送自定义事件,忽略发送失败
|
||||
|
||||
Args:
|
||||
event_name: 事件名称
|
||||
data: 事件数据
|
||||
config: LangChain 配置
|
||||
"""
|
||||
if not config:
|
||||
return
|
||||
try:
|
||||
from langchain_core.callbacks.manager import adispatch_custom_event
|
||||
callbacks = config.get("callbacks")
|
||||
if callbacks:
|
||||
await adispatch_custom_event(event_name, data, callbacks=callbacks)
|
||||
except Exception:
|
||||
# 事件发送失败不应中断主流程
|
||||
pass
|
||||
|
||||
|
||||
def make_react_event(
|
||||
step: int,
|
||||
action: str,
|
||||
confidence: float = 1.0,
|
||||
reasoning: str = ""
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
构造标准推理事件数据
|
||||
|
||||
Args:
|
||||
step: 当前步数
|
||||
action: 动作名称
|
||||
confidence: 置信度
|
||||
reasoning: 推理过程
|
||||
|
||||
Returns:
|
||||
事件数据字典
|
||||
"""
|
||||
return {
|
||||
"step": step,
|
||||
"action": action,
|
||||
"confidence": confidence,
|
||||
"reasoning": reasoning
|
||||
}
|
||||
95
backend/app/main_graph/nodes/error_handling.py
Normal file
95
backend/app/main_graph/nodes/error_handling.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""
|
||||
错误处理节点 - 处理子图/工具调用错误
|
||||
"""
|
||||
|
||||
from app.main_graph.state import MainGraphState, ErrorSeverity
|
||||
from app.logger import info
|
||||
|
||||
|
||||
def error_handling_node(state: MainGraphState) -> MainGraphState:
|
||||
"""
|
||||
错误处理节点:处理子图/工具调用错误
|
||||
|
||||
返回结构化错误信息,格式如下:
|
||||
{
|
||||
"tool/node": "...",
|
||||
"status": "failed",
|
||||
"error": "...",
|
||||
"retries_exceeded": true/false,
|
||||
"suggestion": "..."
|
||||
}
|
||||
"""
|
||||
state.current_phase = "error_handling"
|
||||
|
||||
if not state.current_error:
|
||||
state.current_phase = "react_reasoning"
|
||||
return state
|
||||
|
||||
error = state.current_error
|
||||
|
||||
# 更新错误状态
|
||||
state.error_message = f"{error.error_type}: {error.error_message}"
|
||||
|
||||
# 记录结构化错误信息
|
||||
structured_error = {
|
||||
"tool": error.source,
|
||||
"status": "failed",
|
||||
"error": error.error_message,
|
||||
"retries_exceeded": error.retry_count >= error.max_retries,
|
||||
"retry_count": error.retry_count,
|
||||
"max_retries": error.max_retries
|
||||
}
|
||||
|
||||
# 根据错误类型添加建议
|
||||
if "RAG" in error.error_type:
|
||||
structured_error["suggestion"] = "尝试重新表述问题或直接询问"
|
||||
elif "subgraph" in error.source or "contact" in error.source:
|
||||
structured_error["suggestion"] = "子图执行失败,请尝试简化查询"
|
||||
elif "timeout" in error.error_message.lower():
|
||||
structured_error["suggestion"] = "请求超时,请稍后再试"
|
||||
else:
|
||||
structured_error["suggestion"] = "请尝试其他方式提问"
|
||||
|
||||
state.debug_info["structured_error"] = structured_error
|
||||
|
||||
# 策略1: 检查是否可以重试
|
||||
can_retry = (
|
||||
error.severity in [ErrorSeverity.WARNING, ErrorSeverity.ERROR]
|
||||
and error.retry_count < error.max_retries
|
||||
)
|
||||
|
||||
if can_retry:
|
||||
error.retry_count += 1
|
||||
state.retry_action = error.source
|
||||
state.debug_info["retry_count"] = error.retry_count
|
||||
|
||||
if "RAG" in error.error_type:
|
||||
state.last_action = "RE_RETRIEVE_RAG"
|
||||
elif "subgraph" in error.source:
|
||||
state.last_action = "DIRECT_RESPONSE"
|
||||
else:
|
||||
state.last_action = "REASON"
|
||||
|
||||
state.current_phase = "retrying"
|
||||
return state
|
||||
|
||||
# 策略2: 无法重试,尝试降级方案
|
||||
if error.severity != ErrorSeverity.FATAL:
|
||||
state.final_result = (
|
||||
f"⚠️ 遇到一些问题:\n"
|
||||
f"```json\n{structured_error}\n```\n"
|
||||
f"但我会尽力用现有信息回答您。"
|
||||
)
|
||||
state.success = True
|
||||
state.current_phase = "finalizing"
|
||||
return state
|
||||
|
||||
# 策略3: 致命错误
|
||||
state.final_result = (
|
||||
f"❌ 服务暂时不可用,请稍后再试。\n"
|
||||
f"```json\n{structured_error}\n```"
|
||||
)
|
||||
state.success = False
|
||||
state.current_phase = "finalizing"
|
||||
|
||||
return state
|
||||
@@ -1,9 +1,6 @@
|
||||
"""
|
||||
RAG 节点模块 - 真正利用已有 RAG 代码
|
||||
包含:
|
||||
- rag_retrieve_node: RAG 检索节点(带超时重试)
|
||||
- rag_re_retrieve_node: 重新检索节点
|
||||
- 集成 backend/app/rag/tools.py 和 rag_initializer.py
|
||||
RAG 检索节点模块
|
||||
包含 RAG 检索节点(带超时重试)
|
||||
"""
|
||||
|
||||
import time
|
||||
@@ -12,267 +9,163 @@ from typing import Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from app.main_graph.state import MainGraphState, ErrorRecord, ErrorSeverity
|
||||
from app.main_graph.utils.retry_utils import (
|
||||
RetryConfig,
|
||||
RAG_RETRY_CONFIG,
|
||||
create_retry_wrapper_for_node
|
||||
)
|
||||
from app.main_graph.utils.retry_utils import RAG_RETRY_CONFIG
|
||||
from app.logger import info
|
||||
from ._utils import dispatch_custom_event, make_react_event
|
||||
|
||||
# 真正导入和利用已有 RAG 代码
|
||||
from app.rag.tools import create_rag_tool
|
||||
from app.rag.pipeline import RAGPipeline
|
||||
|
||||
|
||||
# ========== 全局 RAG 工具实例(延迟初始化)==========
|
||||
# ========== 全局 RAG 工具实例 ==========
|
||||
_GLOBAL_RAG_TOOL: Optional[Any] = None
|
||||
_GLOBAL_RAG_PIPELINE: Optional[RAGPipeline] = None
|
||||
|
||||
|
||||
def get_global_rag_tool() -> Optional[Any]:
|
||||
"""
|
||||
获取全局 RAG 工具(单例模式)
|
||||
|
||||
Returns:
|
||||
RAG 工具实例或 None
|
||||
"""
|
||||
return _GLOBAL_RAG_TOOL
|
||||
|
||||
|
||||
def set_global_rag_tool(tool: Any) -> None:
|
||||
"""
|
||||
设置全局 RAG 工具(通常在应用启动时调用)
|
||||
|
||||
Args:
|
||||
tool: RAG 工具实例
|
||||
"""
|
||||
global _GLOBAL_RAG_TOOL
|
||||
_GLOBAL_RAG_TOOL = tool
|
||||
|
||||
|
||||
def set_global_rag_pipeline(pipeline: RAGPipeline) -> None:
|
||||
"""
|
||||
设置全局 RAG Pipeline
|
||||
|
||||
Args:
|
||||
pipeline: RAGPipeline 实例
|
||||
"""
|
||||
global _GLOBAL_RAG_PIPELINE
|
||||
_GLOBAL_RAG_PIPELINE = pipeline
|
||||
|
||||
|
||||
# ========== 从状态获取 RAG 工具 ==========
|
||||
def get_rag_tool_from_state(state: MainGraphState) -> Optional[Any]:
|
||||
"""
|
||||
从状态中获取 RAG 工具(如果有)
|
||||
|
||||
Args:
|
||||
state: 主图状态
|
||||
|
||||
Returns:
|
||||
RAG 工具实例或 None
|
||||
"""
|
||||
# 优先从状态获取
|
||||
if "rag_tool" in state.debug_info:
|
||||
return state.debug_info["rag_tool"]
|
||||
# 其次从全局获取
|
||||
return get_global_rag_tool()
|
||||
"""从状态或全局获取 RAG 工具"""
|
||||
return state.debug_info.get("rag_tool") or get_global_rag_tool()
|
||||
|
||||
|
||||
# ========== 工具:将 RAG 工具注入到状态 ==========
|
||||
def inject_rag_tool_to_state(state: MainGraphState, rag_tool: Any) -> MainGraphState:
|
||||
"""
|
||||
将 RAG 工具注入到状态中,供后续节点使用
|
||||
|
||||
Args:
|
||||
state: 主图状态
|
||||
rag_tool: RAG 工具实例
|
||||
|
||||
Returns:
|
||||
更新后的状态
|
||||
"""
|
||||
"""将 RAG 工具注入到状态中"""
|
||||
state.debug_info["rag_tool"] = rag_tool
|
||||
state.debug_info["rag_tool_injected"] = datetime.now().isoformat()
|
||||
return state
|
||||
|
||||
|
||||
# ========== RAG 检索核心逻辑(真正利用已有代码)==========
|
||||
# ========== RAG 检索核心逻辑 ==========
|
||||
async def _rag_retrieve_core(state: MainGraphState) -> MainGraphState:
|
||||
"""
|
||||
RAG 检索核心逻辑(真正利用 rag/tools.py) - 异步版本
|
||||
|
||||
Args:
|
||||
state: 主图状态
|
||||
|
||||
Returns:
|
||||
更新后的状态
|
||||
"""
|
||||
# 获取检索查询(优先使用推理结果中的优化查询)
|
||||
"""执行 RAG 检索的核心逻辑"""
|
||||
retrieval_query = state.user_query
|
||||
if "reasoning_result" in state.debug_info:
|
||||
reasoning_result = state.debug_info["reasoning_result"]
|
||||
if hasattr(reasoning_result, "retrieval_config"):
|
||||
cfg = reasoning_result.retrieval_config
|
||||
if cfg and cfg.retrieval_query:
|
||||
retrieval_query = cfg.retrieval_query
|
||||
|
||||
# 尝试获取 RAG 工具(多种方式)
|
||||
|
||||
# 优先使用推理结果中的优化查询
|
||||
reasoning_result = state.debug_info.get("reasoning_result")
|
||||
if reasoning_result and hasattr(reasoning_result, "retrieval_config"):
|
||||
cfg = reasoning_result.retrieval_config
|
||||
if cfg and cfg.retrieval_query:
|
||||
retrieval_query = cfg.retrieval_query
|
||||
|
||||
rag_tool = get_rag_tool_from_state(state)
|
||||
|
||||
|
||||
if rag_tool:
|
||||
# 使用真正的 RAG 工具(来自 rag/tools.py)- 异步版本
|
||||
try:
|
||||
# 直接 await 异步工具的 ainvoke 方法
|
||||
rag_context = await rag_tool.ainvoke(retrieval_query)
|
||||
rag_context = await rag_tool.ainvoke(retrieval_query)
|
||||
state.rag_context = rag_context
|
||||
state.rag_docs = [{"source": "rag_retrieval", "content": rag_context}]
|
||||
state.rag_retrieved = True
|
||||
state.success = True
|
||||
state.debug_info["rag_source"] = "rag_tool"
|
||||
return state
|
||||
|
||||
if _GLOBAL_RAG_PIPELINE:
|
||||
documents = await _GLOBAL_RAG_PIPELINE.aretrieve(retrieval_query)
|
||||
if documents:
|
||||
rag_context = _GLOBAL_RAG_PIPELINE.format_context(documents)
|
||||
state.rag_context = rag_context
|
||||
state.rag_docs = [
|
||||
{"source": "rag_retrieval", "content": rag_context}
|
||||
{"source": doc.metadata.get("source", "unknown"), "content": doc.page_content}
|
||||
for doc in documents
|
||||
]
|
||||
state.rag_retrieved = True
|
||||
state.success = True
|
||||
state.debug_info["rag_source"] = "rag_tool"
|
||||
return state
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"RAG 工具调用失败: {str(e)}") from e
|
||||
elif _GLOBAL_RAG_PIPELINE:
|
||||
# 使用 RAG Pipeline 直接检索 - 直接用异步方法
|
||||
try:
|
||||
documents = await _GLOBAL_RAG_PIPELINE.aretrieve(retrieval_query)
|
||||
if documents:
|
||||
rag_context = _GLOBAL_RAG_PIPELINE.format_context(documents)
|
||||
state.rag_context = rag_context
|
||||
state.rag_docs = [
|
||||
{"source": doc.metadata.get("source", "unknown"), "content": doc.page_content}
|
||||
for doc in documents
|
||||
]
|
||||
else:
|
||||
state.rag_context = f"未找到与 '{retrieval_query}' 相关的知识库信息。"
|
||||
state.rag_docs = []
|
||||
state.rag_retrieved = True
|
||||
state.success = True
|
||||
state.debug_info["rag_source"] = "rag_pipeline"
|
||||
return state
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"RAG Pipeline 调用失败: {str(e)}") from e
|
||||
else:
|
||||
# 没有可用的 RAG 工具/Pipeline
|
||||
raise RuntimeError("RAG 工具未初始化,请先调用 set_global_rag_tool() 或 set_global_rag_pipeline()")
|
||||
else:
|
||||
state.rag_context = f"未找到与 '{retrieval_query}' 相关的知识库信息。"
|
||||
state.rag_docs = []
|
||||
state.rag_retrieved = True
|
||||
state.success = True
|
||||
state.debug_info["rag_source"] = "rag_pipeline"
|
||||
return state
|
||||
|
||||
raise RuntimeError("RAG 工具未初始化,请先调用 set_global_rag_tool() 或 set_global_rag_pipeline()")
|
||||
|
||||
|
||||
# ========== RAG 检索节点(带超时和重试)==========
|
||||
# ========== RAG 检索节点 ==========
|
||||
async def rag_retrieve_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState:
|
||||
"""
|
||||
RAG 检索节点:带超时和重试,真正利用已有 RAG 代码
|
||||
|
||||
Args:
|
||||
state: 主图状态
|
||||
config: LangChain 配置
|
||||
|
||||
Returns:
|
||||
更新后的状态
|
||||
"""
|
||||
"""RAG 检索节点:带超时和重试"""
|
||||
state.current_phase = "rag_retrieving"
|
||||
|
||||
# 发送开始事件
|
||||
if config:
|
||||
try:
|
||||
from langchain_core.callbacks.manager import adispatch_custom_event
|
||||
callbacks = config.get("callbacks")
|
||||
if callbacks:
|
||||
await adispatch_custom_event(
|
||||
"react_reasoning",
|
||||
{
|
||||
"step": state.reasoning_step,
|
||||
"action": "rag_retrieve_start",
|
||||
"confidence": 1.0,
|
||||
"reasoning": "开始执行 RAG 检索..."
|
||||
},
|
||||
callbacks=callbacks
|
||||
)
|
||||
except Exception as e:
|
||||
info(f"[rag_retrieve_node] 无法发送开始事件: {e}")
|
||||
|
||||
start_time = time.time()
|
||||
last_error = None
|
||||
|
||||
|
||||
# 步骤1: 发送开始事件
|
||||
await dispatch_custom_event(
|
||||
"react_reasoning",
|
||||
make_react_event(state.reasoning_step, "rag_retrieve_start", 1.0, "开始执行 RAG 检索..."),
|
||||
config
|
||||
)
|
||||
|
||||
# 步骤2: 执行检索(带重试)
|
||||
for attempt in range(RAG_RETRY_CONFIG.max_retries + 1):
|
||||
try:
|
||||
# 执行核心逻辑 - 异步 await
|
||||
result = await _rag_retrieve_core(state)
|
||||
|
||||
info(f"[rag_retrieve_node] RAG 检索成功,获取到上下文长度: {len(result.rag_context)} 字符")
|
||||
if result.rag_docs:
|
||||
for i, doc in enumerate(result.rag_docs[:3]): # 只显示前3条
|
||||
info(f"[rag_retrieve_node] 文档 {i+1}: {doc.get('content', '')[:100]}...")
|
||||
|
||||
# 成功
|
||||
|
||||
info(f"[RAG] 检索成功,上下文长度: {len(result.rag_context)} 字符")
|
||||
|
||||
state.debug_info["rag_retrieval"] = {
|
||||
"attempt": attempt + 1,
|
||||
"success": True,
|
||||
"time": time.time() - start_time
|
||||
}
|
||||
|
||||
# 发送完成事件
|
||||
if config:
|
||||
try:
|
||||
from langchain_core.callbacks.manager import adispatch_custom_event
|
||||
callbacks = config.get("callbacks")
|
||||
if callbacks:
|
||||
doc_count = len(result.rag_docs) if result.rag_docs else 0
|
||||
await adispatch_custom_event(
|
||||
"react_reasoning",
|
||||
{
|
||||
"step": state.reasoning_step,
|
||||
"action": "rag_retrieve_complete",
|
||||
"confidence": 1.0,
|
||||
"reasoning": f"RAG 检索完成,找到 {doc_count} 条相关文档"
|
||||
},
|
||||
callbacks=callbacks
|
||||
)
|
||||
except Exception as e:
|
||||
info(f"[rag_retrieve_node] 无法发送完成事件: {e}")
|
||||
|
||||
# 关键修复:把 rag_retrieve 加到 reasoning_history 里,让下次推理知道
|
||||
|
||||
# 记录成功到历史
|
||||
state.reasoning_history.append({
|
||||
"step": state.reasoning_step,
|
||||
"action": "RETRIEVE_RAG", # 大写,和推理结果保持一致
|
||||
"action": "RETRIEVE_RAG",
|
||||
"confidence": 1.0,
|
||||
"reasoning": "RAG 检索完成",
|
||||
"timestamp": datetime.now().isoformat()
|
||||
})
|
||||
|
||||
|
||||
# 发送完成事件
|
||||
doc_count = len(result.rag_docs) if result.rag_docs else 0
|
||||
await dispatch_custom_event(
|
||||
"react_reasoning",
|
||||
make_react_event(state.reasoning_step, "rag_retrieve_complete", 1.0,
|
||||
f"RAG 检索完成,找到 {doc_count} 条相关文档"),
|
||||
config
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
|
||||
|
||||
if attempt >= RAG_RETRY_CONFIG.max_retries:
|
||||
break
|
||||
|
||||
|
||||
# 发送重试事件
|
||||
if config:
|
||||
try:
|
||||
from langchain_core.callbacks.manager import adispatch_custom_event
|
||||
callbacks = config.get("callbacks")
|
||||
if callbacks:
|
||||
await adispatch_custom_event(
|
||||
"react_reasoning",
|
||||
{
|
||||
"step": state.reasoning_step,
|
||||
"action": "rag_retrieve_retry",
|
||||
"confidence": 1.0,
|
||||
"reasoning": f"RAG 检索失败,第 {attempt + 1} 次重试..."
|
||||
},
|
||||
callbacks=callbacks
|
||||
)
|
||||
except Exception as e:
|
||||
info(f"[rag_retrieve_node] 无法发送重试事件: {e}")
|
||||
|
||||
# 指数退避等待
|
||||
await dispatch_custom_event(
|
||||
"react_reasoning",
|
||||
make_react_event(state.reasoning_step, "rag_retrieve_retry", 1.0,
|
||||
f"RAG 检索失败,第 {attempt + 1} 次重试..."),
|
||||
config
|
||||
)
|
||||
|
||||
# 指数退避
|
||||
delay = RAG_RETRY_CONFIG.base_delay * (2 ** attempt)
|
||||
await asyncio.sleep(min(delay, RAG_RETRY_CONFIG.max_delay))
|
||||
|
||||
# 所有重试都失败,记录结构化错误
|
||||
|
||||
# 步骤3: 所有重试失败,记录到历史(避免推理循环)
|
||||
state.reasoning_history.append({
|
||||
"step": state.reasoning_step,
|
||||
"action": "RETRIEVE_RAG",
|
||||
"confidence": 0.0,
|
||||
"reasoning": f"RAG 检索失败: {str(last_error) if last_error else '超时'}",
|
||||
"timestamp": datetime.now().isoformat()
|
||||
})
|
||||
|
||||
# 步骤4: 记录错误
|
||||
error_record = ErrorRecord(
|
||||
error_type="RAGRetrievalError",
|
||||
error_message=str(last_error) if last_error else "RAG 检索超时",
|
||||
@@ -284,105 +177,46 @@ async def rag_retrieve_node(state: MainGraphState, config: Optional[Dict[str, An
|
||||
context={
|
||||
"query": state.user_query,
|
||||
"total_time": time.time() - start_time,
|
||||
"timeout": RAG_RETRY_CONFIG.timeout,
|
||||
"has_rag_tool": get_global_rag_tool() is not None,
|
||||
"has_rag_pipeline": _GLOBAL_RAG_PIPELINE is not None
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
state.errors.append(error_record)
|
||||
state.current_error = error_record
|
||||
state.current_phase = "error_handling"
|
||||
|
||||
|
||||
# 发送错误事件
|
||||
if config:
|
||||
try:
|
||||
from langchain_core.callbacks.manager import adispatch_custom_event
|
||||
callbacks = config.get("callbacks")
|
||||
if callbacks:
|
||||
await adispatch_custom_event(
|
||||
"react_reasoning",
|
||||
{
|
||||
"step": state.reasoning_step,
|
||||
"action": "rag_retrieve_error",
|
||||
"confidence": 1.0,
|
||||
"reasoning": f"RAG 检索失败: {str(last_error)}"
|
||||
},
|
||||
callbacks=callbacks
|
||||
)
|
||||
except Exception as e:
|
||||
info(f"[rag_retrieve_node] 无法发送错误事件: {e}")
|
||||
|
||||
await dispatch_custom_event(
|
||||
"react_reasoning",
|
||||
make_react_event(state.reasoning_step, "rag_retrieve_error", 1.0,
|
||||
f"RAG 检索失败: {str(last_error)}"),
|
||||
config
|
||||
)
|
||||
|
||||
return state
|
||||
|
||||
|
||||
# ========== 重新检索节点 ==========
|
||||
def rag_re_retrieve_node(state: MainGraphState) -> MainGraphState:
|
||||
"""
|
||||
重新检索节点:用于第二次检索(不同的参数)
|
||||
|
||||
Args:
|
||||
state: 主图状态
|
||||
|
||||
Returns:
|
||||
更新后的状态
|
||||
"""
|
||||
async def rag_re_retrieve_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState:
|
||||
"""重新检索节点"""
|
||||
state.current_phase = "rag_re_retrieving"
|
||||
|
||||
# 记录原始检索信息
|
||||
|
||||
state.debug_info["rag_re_retrieve"] = {
|
||||
"original_retrieved": state.rag_retrieved,
|
||||
"original_docs_count": len(state.rag_docs)
|
||||
}
|
||||
|
||||
# 可以在这里修改检索参数(例如:调整查询、增加 k 值)
|
||||
# 暂时复用同一个检索逻辑
|
||||
return rag_retrieve_node(state)
|
||||
|
||||
|
||||
# ========== 便捷函数:从 rag_initializer 初始化 ==========
|
||||
async def initialize_rag_from_initializer() -> None:
|
||||
"""
|
||||
从 rag_initializer 初始化 RAG(便捷函数)
|
||||
|
||||
注意:这是示例代码,实际使用时需要提供 local_llm_creator
|
||||
"""
|
||||
try:
|
||||
from app.main_graph.utils.rag_initializer import init_rag_tool
|
||||
|
||||
# 注意:这里需要传入 local_llm_creator
|
||||
# 示例:
|
||||
# def my_llm_creator():
|
||||
# from ..model_services import get_llm
|
||||
# return get_llm()
|
||||
#
|
||||
# rag_tool = await init_rag_tool(my_llm_creator)
|
||||
# set_global_rag_tool(rag_tool)
|
||||
|
||||
print("⚠️ initialize_rag_from_initializer 需要传入 local_llm_creator")
|
||||
print("⚠️ 请在应用启动时调用 init_rag_tool() 并设置全局 RAG 工具")
|
||||
|
||||
except ImportError as e:
|
||||
print(f"⚠️ 无法导入 rag_initializer: {e}")
|
||||
except Exception as e:
|
||||
print(f"⚠️ RAG 初始化失败: {e}")
|
||||
return await rag_retrieve_node(state, config)
|
||||
|
||||
|
||||
# ========== 导出 ==========
|
||||
__all__ = [
|
||||
# 节点函数
|
||||
"rag_retrieve_node",
|
||||
"rag_re_retrieve_node",
|
||||
|
||||
# 工具函数
|
||||
"inject_rag_tool_to_state",
|
||||
"get_rag_tool_from_state",
|
||||
|
||||
# 全局 RAG 管理
|
||||
"get_global_rag_tool",
|
||||
"set_global_rag_tool",
|
||||
"set_global_rag_pipeline",
|
||||
|
||||
# 初始化
|
||||
"initialize_rag_from_initializer"
|
||||
]
|
||||
|
||||
@@ -1,424 +0,0 @@
|
||||
"""
|
||||
React 模式节点模块 - 带超时和重试功能
|
||||
包含:
|
||||
- react_reason_node: 使用 intent.py 进行推理
|
||||
- error_handling_node: 错误处理节点
|
||||
- init_state_node: 初始化状态节点
|
||||
|
||||
注意:为了兼容 LangGraph 的同步接口,我们保留了同步的 react_reason 调用
|
||||
但内部会根据情况使用规则推理或尝试异步调用
|
||||
"""
|
||||
|
||||
import sys
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
|
||||
# 导入我们的 intent.py
|
||||
from app.core.intent import (
|
||||
react_reason,
|
||||
react_reason_async,
|
||||
get_route_by_reasoning,
|
||||
ReasoningAction,
|
||||
ReasoningResult
|
||||
)
|
||||
from app.core.state_base import StateUtils
|
||||
from app.main_graph.state import MainGraphState, ErrorRecord, ErrorSeverity
|
||||
from app.main_graph.utils.retry_utils import (
|
||||
RetryConfig,
|
||||
SUBGRAPH_RETRY_CONFIG
|
||||
)
|
||||
from app.logger import info
|
||||
|
||||
|
||||
# ========== 1. React 推理节点 ==========
|
||||
|
||||
async def react_reason_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState:
|
||||
"""
|
||||
React 模式推理节点:判断下一步做什么(异步版本)
|
||||
|
||||
Returns: 更新后的状态
|
||||
"""
|
||||
state.current_phase = "react_reasoning"
|
||||
state.reasoning_step += 1
|
||||
|
||||
info(f"[react_reason] 第 {state.reasoning_step} 次推理开始")
|
||||
|
||||
# 检查是否超过最大步数
|
||||
if state.reasoning_step > state.max_steps:
|
||||
state.current_phase = "max_steps_exceeded"
|
||||
state.final_result = (
|
||||
f"❌ 推理步数超过限制(最大 {state.max_steps} 步),"
|
||||
f"已执行 {state.reasoning_step - 1} 步。"
|
||||
f"请简化您的问题或分批提问。"
|
||||
)
|
||||
state.success = False
|
||||
return state
|
||||
|
||||
# 准备上下文
|
||||
context = {
|
||||
"retrieved_docs": state.rag_docs,
|
||||
"previous_actions": [h.get("action") for h in state.reasoning_history],
|
||||
"messages": state.messages,
|
||||
"errors": state.errors
|
||||
}
|
||||
|
||||
# 使用 intent.py 进行推理(现在直接用异步版本)
|
||||
result: ReasoningResult = await react_reason_async(state.user_query, context)
|
||||
|
||||
info(f"[react_reason] 推理结果: action={result.action.name}, confidence={result.confidence}")
|
||||
if result.reasoning:
|
||||
info(f"[react_reason] 推理过程: {result.reasoning}")
|
||||
|
||||
# 关键修复:直接发送自定义事件给 agent_service,而不是通过 state
|
||||
if config:
|
||||
try:
|
||||
from langchain_core.callbacks.manager import adispatch_custom_event
|
||||
|
||||
callbacks = config.get("callbacks")
|
||||
if callbacks:
|
||||
info(f"[react_reason] 直接发送推理事件 #{state.reasoning_step}")
|
||||
await adispatch_custom_event(
|
||||
"react_reasoning",
|
||||
{
|
||||
"step": state.reasoning_step,
|
||||
"action": result.action.name,
|
||||
"confidence": result.confidence,
|
||||
"reasoning": result.reasoning
|
||||
},
|
||||
callbacks=callbacks
|
||||
)
|
||||
except Exception as e:
|
||||
info(f"[react_reason] 无法发送自定义事件: {e}")
|
||||
|
||||
# 记录推理历史
|
||||
state.reasoning_history.append({
|
||||
"step": state.reasoning_step,
|
||||
"action": result.action.name,
|
||||
"confidence": result.confidence,
|
||||
"reasoning": result.reasoning,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
})
|
||||
|
||||
# 更新状态
|
||||
state.debug_info["last_reasoning"] = {
|
||||
"action": result.action.name,
|
||||
"confidence": result.confidence,
|
||||
"reasoning": result.reasoning
|
||||
}
|
||||
|
||||
# 保存推理结果到状态
|
||||
state.debug_info["reasoning_result"] = result
|
||||
|
||||
# 确定下一步动作
|
||||
state.last_action = result.action.name
|
||||
|
||||
# 关键修复:不再设置 latest_reasoning,避免 agent_service 重复读取
|
||||
if "latest_reasoning" in state.debug_info:
|
||||
del state.debug_info["latest_reasoning"]
|
||||
|
||||
return state
|
||||
|
||||
|
||||
# ========== 2. 联网搜索节点 ==========
|
||||
|
||||
async def web_search_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState:
|
||||
"""
|
||||
联网搜索节点:执行搜索并将结果保存到状态
|
||||
"""
|
||||
state.current_phase = "web_searching"
|
||||
|
||||
# 发送开始事件
|
||||
if config:
|
||||
try:
|
||||
from langchain_core.callbacks.manager import adispatch_custom_event
|
||||
callbacks = config.get("callbacks")
|
||||
if callbacks:
|
||||
await adispatch_custom_event(
|
||||
"react_reasoning",
|
||||
{
|
||||
"step": state.reasoning_step,
|
||||
"action": "web_search_start",
|
||||
"confidence": 1.0,
|
||||
"reasoning": "开始执行联网搜索..."
|
||||
},
|
||||
callbacks=callbacks
|
||||
)
|
||||
except Exception as e:
|
||||
info(f"[web_search_node] 无法发送开始事件: {e}")
|
||||
|
||||
# 获取搜索查询
|
||||
reasoning_result = state.debug_info.get("reasoning_result")
|
||||
search_query = reasoning_result.metadata.get("search_query", state.user_query) if reasoning_result else state.user_query
|
||||
|
||||
try:
|
||||
from app.core import web_search
|
||||
|
||||
print(f"[WebSearch] 搜索: {search_query}")
|
||||
search_result = web_search(search_query, max_results=5)
|
||||
|
||||
# 保存搜索结果到状态
|
||||
if not hasattr(state, "web_search_results"):
|
||||
state.web_search_results = []
|
||||
state.web_search_results.append(search_result)
|
||||
|
||||
# 将搜索结果添加到 rag_context,供 LLM 使用
|
||||
if state.rag_context:
|
||||
state.rag_context = f"{state.rag_context}\n\n---\n\n## 🌐 联网搜索结果:\n{search_result}"
|
||||
else:
|
||||
state.rag_context = f"## 🌐 联网搜索结果:\n{search_result}"
|
||||
|
||||
state.success = True
|
||||
print(f"[WebSearch] 搜索完成")
|
||||
|
||||
# 发送完成事件
|
||||
if config:
|
||||
try:
|
||||
from langchain_core.callbacks.manager import adispatch_custom_event
|
||||
callbacks = config.get("callbacks")
|
||||
if callbacks:
|
||||
await adispatch_custom_event(
|
||||
"react_reasoning",
|
||||
{
|
||||
"step": state.reasoning_step,
|
||||
"action": "web_search_complete",
|
||||
"confidence": 1.0,
|
||||
"reasoning": f"联网搜索完成,找到 {len(search_result) if isinstance(search_result, list) else 1} 条结果"
|
||||
},
|
||||
callbacks=callbacks
|
||||
)
|
||||
except Exception as e:
|
||||
info(f"[web_search_node] 无法发送完成事件: {e}")
|
||||
|
||||
except Exception as e:
|
||||
from app.main_graph.state import ErrorRecord, ErrorSeverity
|
||||
from datetime import datetime
|
||||
|
||||
error_record = ErrorRecord(
|
||||
error_type="WebSearchError",
|
||||
error_message=str(e),
|
||||
severity=ErrorSeverity.WARNING,
|
||||
source="web_search_node",
|
||||
timestamp=datetime.now().isoformat(),
|
||||
retry_count=0,
|
||||
max_retries=2,
|
||||
context={"search_query": search_query}
|
||||
)
|
||||
state.errors.append(error_record)
|
||||
state.current_error = error_record
|
||||
state.current_phase = "error_handling"
|
||||
state.success = False
|
||||
|
||||
# 发送错误事件
|
||||
if config:
|
||||
try:
|
||||
from langchain_core.callbacks.manager import adispatch_custom_event
|
||||
callbacks = config.get("callbacks")
|
||||
if callbacks:
|
||||
await adispatch_custom_event(
|
||||
"react_reasoning",
|
||||
{
|
||||
"step": state.reasoning_step,
|
||||
"action": "web_search_error",
|
||||
"confidence": 1.0,
|
||||
"reasoning": f"联网搜索失败: {str(e)}"
|
||||
},
|
||||
callbacks=callbacks
|
||||
)
|
||||
except Exception as e:
|
||||
info(f"[web_search_node] 无法发送错误事件: {e}")
|
||||
|
||||
return state
|
||||
|
||||
|
||||
# ========== 3. 错误处理节点 ==========
|
||||
|
||||
def error_handling_node(state: MainGraphState) -> MainGraphState:
|
||||
"""
|
||||
错误处理节点:处理子图/工具调用错误
|
||||
|
||||
返回结构化错误信息,格式如下:
|
||||
{
|
||||
"tool/node": "...",
|
||||
"status": "failed",
|
||||
"error": "...",
|
||||
"retries_exceeded": true/false,
|
||||
"suggestion": "..."
|
||||
}
|
||||
"""
|
||||
state.current_phase = "error_handling"
|
||||
|
||||
if not state.current_error:
|
||||
state.current_phase = "react_reasoning"
|
||||
return state
|
||||
|
||||
error = state.current_error
|
||||
|
||||
# 更新错误状态
|
||||
state.error_message = f"{error.error_type}: {error.error_message}"
|
||||
|
||||
# 记录结构化错误信息
|
||||
structured_error = {
|
||||
"tool": error.source,
|
||||
"status": "failed",
|
||||
"error": error.error_message,
|
||||
"retries_exceeded": error.retry_count >= error.max_retries,
|
||||
"retry_count": error.retry_count,
|
||||
"max_retries": error.max_retries
|
||||
}
|
||||
|
||||
# 根据错误类型添加建议
|
||||
if "RAG" in error.error_type:
|
||||
structured_error["suggestion"] = "尝试重新表述问题或直接询问"
|
||||
elif "subgraph" in error.source or "contact" in error.source:
|
||||
structured_error["suggestion"] = "子图执行失败,请尝试简化查询"
|
||||
elif "timeout" in error.error_message.lower():
|
||||
structured_error["suggestion"] = "请求超时,请稍后再试"
|
||||
else:
|
||||
structured_error["suggestion"] = "请尝试其他方式提问"
|
||||
|
||||
state.debug_info["structured_error"] = structured_error
|
||||
|
||||
# 策略1: 检查是否可以重试
|
||||
can_retry = (
|
||||
error.severity in [ErrorSeverity.WARNING, ErrorSeverity.ERROR]
|
||||
and error.retry_count < error.max_retries
|
||||
)
|
||||
|
||||
if can_retry:
|
||||
error.retry_count += 1
|
||||
state.retry_action = error.source
|
||||
state.debug_info["retry_count"] = error.retry_count
|
||||
|
||||
if "RAG" in error.error_type:
|
||||
state.last_action = "RE_RETRIEVE_RAG"
|
||||
elif "subgraph" in error.source:
|
||||
state.last_action = "DIRECT_RESPONSE"
|
||||
else:
|
||||
state.last_action = "REASON"
|
||||
|
||||
state.current_phase = "retrying"
|
||||
return state
|
||||
|
||||
# 策略2: 无法重试,尝试降级方案
|
||||
if error.severity != ErrorSeverity.FATAL:
|
||||
state.final_result = (
|
||||
f"⚠️ 遇到一些问题:\n"
|
||||
f"```json\n{structured_error}\n```\n"
|
||||
f"但我会尽力用现有信息回答您。"
|
||||
)
|
||||
state.success = True
|
||||
state.current_phase = "finalizing"
|
||||
return state
|
||||
|
||||
# 策略3: 致命错误
|
||||
state.final_result = (
|
||||
f"❌ 服务暂时不可用,请稍后再试。\n"
|
||||
f"```json\n{structured_error}\n```"
|
||||
)
|
||||
state.success = False
|
||||
state.current_phase = "finalizing"
|
||||
|
||||
return state
|
||||
|
||||
|
||||
# ========== 4. 初始化状态节点 ==========
|
||||
|
||||
def init_state_node(state: MainGraphState) -> MainGraphState:
|
||||
"""
|
||||
初始化状态节点:在流程开始时设置初始值
|
||||
"""
|
||||
state.current_phase = "initializing"
|
||||
state.reasoning_step = 0
|
||||
state.start_time = datetime.now().isoformat()
|
||||
|
||||
# 从 messages 中提取用户查询
|
||||
if not state.user_query and state.messages:
|
||||
last_msg = state.messages[-1]
|
||||
state.user_query = getattr(last_msg, "content", str(last_msg))
|
||||
|
||||
return state
|
||||
|
||||
|
||||
# ========== 5. 条件路由函数 ==========
|
||||
|
||||
def route_by_reasoning(state: MainGraphState) -> str:
|
||||
"""
|
||||
根据推理结果决定下一步路由
|
||||
|
||||
Returns: 路由标识,对应 graph_builder.py 中的边
|
||||
"""
|
||||
# 先检查特殊情况
|
||||
if state.current_phase == "max_steps_exceeded":
|
||||
return "llm_call"
|
||||
if state.current_phase == "error_handling" or state.current_error:
|
||||
return "handle_error"
|
||||
if state.current_phase == "finalizing" or state.current_phase == "done":
|
||||
return "llm_call"
|
||||
if state.current_phase == "retrying":
|
||||
if state.retry_action and "rag" in state.retry_action.lower():
|
||||
return "rag_retrieve"
|
||||
return "react_reason"
|
||||
|
||||
# ========== 关键修复:优先检查当前推理结果 ==========
|
||||
reasoning_result: Optional[ReasoningResult] = state.debug_info.get("reasoning_result")
|
||||
if reasoning_result and reasoning_result.action == ReasoningAction.DIRECT_RESPONSE:
|
||||
info(f"[route_by_reasoning] 当前推理结果=DIRECT_RESPONSE,直接去 llm_call")
|
||||
return "llm_call"
|
||||
|
||||
# ========== 然后检查历史和状态 ==========
|
||||
previous_actions = [h.get("action") for h in state.reasoning_history]
|
||||
|
||||
# 检查是否已经执行过子图
|
||||
if "subgraph_completed" in previous_actions or state.final_result:
|
||||
return "llm_call"
|
||||
|
||||
# 检测 RAG 重复循环 - 如果发现"RETRIEVE_RAG"出现超过1次,直接去 LLM
|
||||
rag_count = previous_actions.count("RETRIEVE_RAG")
|
||||
if rag_count >= 2:
|
||||
info(f"[route_by_reasoning] 检测到 RAG 重复循环({rag_count}次),直接去 llm_call")
|
||||
return "llm_call"
|
||||
|
||||
# 如果已经有 rag_docs 或 rag_context,说明已经检索过了,直接去 LLM
|
||||
if (state.rag_docs and len(state.rag_docs) > 0) or (state.rag_context and len(state.rag_context) > 0):
|
||||
info(f"[route_by_reasoning] 检测到已存在 RAG 检索结果,直接去 llm_call")
|
||||
return "llm_call"
|
||||
|
||||
# 限制最多 3 次推理,避免无限循环
|
||||
if len(previous_actions) >= 3:
|
||||
info(f"[route_by_reasoning] 已达到最大推理次数 ({len(previous_actions)}),直接去 llm_call")
|
||||
return "llm_call"
|
||||
|
||||
# ========== 最后处理其他推理结果 ==========
|
||||
if not reasoning_result:
|
||||
return "llm_call"
|
||||
|
||||
# 使用 intent.py 提供的路由函数
|
||||
route = get_route_by_reasoning(reasoning_result)
|
||||
|
||||
# 映射到我们的节点名称
|
||||
# 注意:这些名称必须与 main_graph_builder.py 中定义的节点名称一致
|
||||
route_mapping = {
|
||||
"direct_response": "llm_call",
|
||||
"retrieve_rag": "rag_retrieve",
|
||||
"re_retrieve_rag": "rag_retrieve",
|
||||
"web_search": "web_search",
|
||||
"clarify": "llm_call",
|
||||
"call_tool": "llm_call",
|
||||
"contact": "contact_subgraph",
|
||||
"dictionary": "dictionary_subgraph",
|
||||
"news_analysis": "news_analysis_subgraph",
|
||||
}
|
||||
|
||||
info(f"[route_by_reasoning] 推理结果={reasoning_result.action.name}, 路由={route_mapping.get(route, 'llm_call')}, 历史动作={previous_actions}")
|
||||
return route_mapping.get(route, "llm_call")
|
||||
|
||||
|
||||
# ========== 导出 ==========
|
||||
|
||||
__all__ = [
|
||||
"init_state_node",
|
||||
"react_reason_node",
|
||||
"web_search_node",
|
||||
"error_handling_node",
|
||||
"route_by_reasoning"
|
||||
]
|
||||
68
backend/app/main_graph/nodes/reasoning.py
Normal file
68
backend/app/main_graph/nodes/reasoning.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""
|
||||
React 推理节点
|
||||
使用 intent.py 进行意图推理
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from app.core.intent import react_reason_async, ReasoningResult
|
||||
from app.main_graph.state import MainGraphState
|
||||
from app.logger import info
|
||||
from ._utils import dispatch_custom_event, make_react_event
|
||||
|
||||
|
||||
async def react_reason_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState:
|
||||
"""React 模式推理节点:判断下一步做什么"""
|
||||
state.current_phase = "react_reasoning"
|
||||
state.reasoning_step += 1
|
||||
|
||||
info(f"[推理] 第 {state.reasoning_step} 次推理开始")
|
||||
|
||||
# 步骤1: 准备上下文
|
||||
context = {
|
||||
"retrieved_docs": state.rag_docs,
|
||||
"previous_actions": [h.get("action") for h in state.reasoning_history],
|
||||
"reasoning_history": state.reasoning_history,
|
||||
"messages": state.messages,
|
||||
"errors": state.errors
|
||||
}
|
||||
|
||||
# 步骤2: 执行推理
|
||||
result: ReasoningResult = await react_reason_async(state.user_query, context)
|
||||
|
||||
info(f"[推理] 推理结果: action={result.action.name}, confidence={result.confidence}")
|
||||
if result.reasoning:
|
||||
info(f"[推理] 推理过程: {result.reasoning}")
|
||||
|
||||
# 步骤3: 记录推理历史
|
||||
state.reasoning_history.append({
|
||||
"step": state.reasoning_step,
|
||||
"action": result.action.name,
|
||||
"confidence": result.confidence,
|
||||
"reasoning": result.reasoning,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
})
|
||||
|
||||
# 步骤4: 更新调试信息
|
||||
state.debug_info["last_reasoning"] = {
|
||||
"action": result.action.name,
|
||||
"confidence": result.confidence,
|
||||
"reasoning": result.reasoning
|
||||
}
|
||||
state.debug_info["reasoning_result"] = result
|
||||
state.last_action = result.action.name
|
||||
|
||||
# 步骤5: 发送推理事件
|
||||
await dispatch_custom_event(
|
||||
"react_reasoning",
|
||||
make_react_event(
|
||||
state.reasoning_step,
|
||||
result.action.name,
|
||||
result.confidence,
|
||||
result.reasoning
|
||||
),
|
||||
config
|
||||
)
|
||||
|
||||
return state
|
||||
@@ -1,48 +0,0 @@
|
||||
"""
|
||||
路由决策节点
|
||||
根据当前状态决定下一步走向
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
# 本地模块
|
||||
from app.config import ENABLE_GRAPH_TRACE, MEMORY_SUMMARIZE_INTERVAL
|
||||
from app.main_graph.state import MainGraphState
|
||||
from app.logger import info
|
||||
|
||||
|
||||
def should_continue(state: MainGraphState) -> Literal['tool_node', 'summarize', 'finalize']:
|
||||
"""
|
||||
决定下一步:工具调用、生成摘要还是结束
|
||||
|
||||
Args:
|
||||
state: 当前对话状态
|
||||
|
||||
Returns:
|
||||
下一个节点名称
|
||||
"""
|
||||
last_message = state.messages[-1]
|
||||
|
||||
# 1. 如果需要调用工具,优先进入工具节点
|
||||
if isinstance(last_message, AIMessage) and last_message.tool_calls:
|
||||
if ENABLE_GRAPH_TRACE:
|
||||
info(f"🔀 [路由决策] 检测到 {len(last_message.tool_calls)} 个工具调用 → 转向 'tool_node'")
|
||||
return 'tool_node'
|
||||
|
||||
# 2. 如果是 AI 的最终回复,判断是否达到摘要生成阈值
|
||||
if isinstance(last_message, AIMessage):
|
||||
turns = state.turns_since_last_summary
|
||||
if turns >= MEMORY_SUMMARIZE_INTERVAL:
|
||||
if ENABLE_GRAPH_TRACE:
|
||||
info(f"🔀 [路由决策] 收到 AI 最终回复,已达摘要阈值({turns}/{MEMORY_SUMMARIZE_INTERVAL}) → 转向 'summarize'")
|
||||
return 'summarize'
|
||||
else:
|
||||
if ENABLE_GRAPH_TRACE:
|
||||
info(f"🔀 [路由决策] 收到 AI 最终回复,未达摘要阈值({turns}/{MEMORY_SUMMARIZE_INTERVAL}) → 结束流程")
|
||||
return 'finalize'
|
||||
|
||||
# 3. 其他情况(如只有用户消息)直接结束
|
||||
if ENABLE_GRAPH_TRACE:
|
||||
info(f"🔀 [路由决策] 非 AI 消息(如纯用户消息) → 结束流程")
|
||||
return 'finalize'
|
||||
120
backend/app/main_graph/nodes/routing.py
Normal file
120
backend/app/main_graph/nodes/routing.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""
|
||||
路由与初始化模块
|
||||
包含状态初始化节点和条件路由函数
|
||||
|
||||
三层统一循环防护:
|
||||
1. 全局步数硬上限(reasoning_step > max_steps)
|
||||
2. 路由模式检测(A→B→A→B 交替循环)
|
||||
3. 状态停滞检测(连续相同动作)
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from app.core.intent import get_route_by_reasoning, ReasoningAction
|
||||
from app.main_graph.state import MainGraphState
|
||||
from app.logger import info
|
||||
|
||||
|
||||
# ========== 初始化状态节点 ==========
|
||||
def init_state_node(state: MainGraphState) -> MainGraphState:
|
||||
"""初始化状态节点:在流程开始时设置初始值"""
|
||||
state.current_phase = "initializing"
|
||||
state.reasoning_step = 0
|
||||
state.start_time = datetime.now().isoformat()
|
||||
|
||||
if not state.user_query and state.messages:
|
||||
last_msg = state.messages[-1]
|
||||
state.user_query = getattr(last_msg, "content", str(last_msg))
|
||||
|
||||
return state
|
||||
|
||||
|
||||
# ========== 条件路由函数 ==========
|
||||
def route_by_reasoning(state: MainGraphState) -> str:
|
||||
"""
|
||||
根据推理结果决定下一步路由,带三层统一循环防护
|
||||
|
||||
核心逻辑:
|
||||
1. DIRECT_RESPONSE → 直接返回 llm_call
|
||||
2. 子图完成/已有结果 → 直接返回 llm_call
|
||||
3. 步数超限 → 直接返回 llm_call
|
||||
4. 其他 → 正常路由
|
||||
"""
|
||||
# 获取历史动作
|
||||
previous_actions = [h.get("action") for h in state.reasoning_history]
|
||||
|
||||
info(f"[条件路由] step={state.reasoning_step}, phase={state.current_phase}, history={previous_actions}")
|
||||
|
||||
# ========== 获取推理结果 ==========
|
||||
reasoning_result = state.debug_info.get("reasoning_result")
|
||||
latest_action = reasoning_result.action.name if reasoning_result else None
|
||||
|
||||
# ========== 核心检查:DIRECT_RESPONSE 优先 ==========
|
||||
# 从 reasoning_result 检查(最新)
|
||||
if latest_action == "DIRECT_RESPONSE":
|
||||
info(f"[条件路由] 推理结果为 DIRECT_RESPONSE,直接去 llm_call")
|
||||
return "llm_call"
|
||||
|
||||
# 备用:从历史记录检查
|
||||
if previous_actions and previous_actions[-1] == "DIRECT_RESPONSE":
|
||||
info(f"[条件路由] 历史记录最新动作为 DIRECT_RESPONSE,直接去 llm_call")
|
||||
return "llm_call"
|
||||
|
||||
# ========== 子图完成/已有结果 ==========
|
||||
if "subgraph_completed" in previous_actions or state.final_result:
|
||||
info("[条件路由] 子图已完成或已有结果,直接终止")
|
||||
return "llm_call"
|
||||
|
||||
# ========== 步数超限 ==========
|
||||
if state.reasoning_step > state.max_steps:
|
||||
info(f"[条件路由] 步数超限 ({state.reasoning_step}/{state.max_steps}),强制终止")
|
||||
return "llm_call"
|
||||
|
||||
# ========== 特殊阶段快速通道 ==========
|
||||
if state.current_phase in ("max_steps_exceeded", "finalizing", "done"):
|
||||
return "llm_call"
|
||||
if state.current_phase == "error_handling" or state.current_error:
|
||||
return "handle_error"
|
||||
|
||||
# ========== 无推理结果,默认终止 ==========
|
||||
if not reasoning_result:
|
||||
info("[条件路由] 无推理结果,默认去 llm_call")
|
||||
return "llm_call"
|
||||
|
||||
# ========== 计算目标路由 ==========
|
||||
route = get_route_by_reasoning(reasoning_result)
|
||||
|
||||
route_mapping = {
|
||||
"direct_response": "llm_call",
|
||||
"retrieve_rag": "rag_retrieve",
|
||||
"re_retrieve_rag": "rag_retrieve",
|
||||
"web_search": "web_search",
|
||||
"clarify": "llm_call",
|
||||
"call_tool": "llm_call",
|
||||
"contact": "contact_subgraph",
|
||||
"dictionary": "dictionary_subgraph",
|
||||
"news_analysis": "news_analysis_subgraph",
|
||||
}
|
||||
target = route_mapping.get(route, "llm_call")
|
||||
|
||||
# ========== 循环防护检测 ==========
|
||||
# 1. 路由模式检测(A→B→A→B 交替)
|
||||
if len(previous_actions) >= 4:
|
||||
if (previous_actions[-4] == previous_actions[-2]
|
||||
and previous_actions[-3] == previous_actions[-1]
|
||||
and previous_actions[-2] != previous_actions[-1]):
|
||||
info(f"[条件路由] 检测到路由循环: {previous_actions[-4:]},强制终止")
|
||||
return "llm_call"
|
||||
|
||||
# 2. 状态停滞检测(连续相同动作)
|
||||
if len(previous_actions) >= 2 and previous_actions[-1] == previous_actions[-2]:
|
||||
info(f"[条件路由] 连续相同动作 '{previous_actions[-1]}',强制终止")
|
||||
return "llm_call"
|
||||
|
||||
# ========== 智能优化 ==========
|
||||
if target == "rag_retrieve" and (state.rag_docs or state.rag_context):
|
||||
info("[条件路由] RAG 结果已存在,跳过检索")
|
||||
return "llm_call"
|
||||
|
||||
info(f"[条件路由] 动作={latest_action}, 目标={target}")
|
||||
return target
|
||||
115
backend/app/main_graph/nodes/web_search.py
Normal file
115
backend/app/main_graph/nodes/web_search.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""
|
||||
联网搜索节点 - 执行搜索并将结果保存到状态
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from app.main_graph.state import MainGraphState, ErrorRecord, ErrorSeverity
|
||||
from app.logger import info
|
||||
|
||||
|
||||
async def web_search_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState:
|
||||
"""
|
||||
联网搜索节点:执行搜索并将结果保存到状态
|
||||
"""
|
||||
state.current_phase = "web_searching"
|
||||
|
||||
# 发送开始事件
|
||||
if config:
|
||||
try:
|
||||
from langchain_core.callbacks.manager import adispatch_custom_event
|
||||
callbacks = config.get("callbacks")
|
||||
if callbacks:
|
||||
await adispatch_custom_event(
|
||||
"react_reasoning",
|
||||
{
|
||||
"step": state.reasoning_step,
|
||||
"action": "web_search_start",
|
||||
"confidence": 1.0,
|
||||
"reasoning": "开始执行联网搜索..."
|
||||
},
|
||||
callbacks=callbacks
|
||||
)
|
||||
except Exception as e:
|
||||
info(f"[web_search_node] 无法发送开始事件: {e}")
|
||||
|
||||
# 获取搜索查询
|
||||
reasoning_result = state.debug_info.get("reasoning_result")
|
||||
search_query = reasoning_result.metadata.get("search_query", state.user_query) if reasoning_result else state.user_query
|
||||
|
||||
try:
|
||||
from app.core import web_search
|
||||
|
||||
print(f"[WebSearch] 搜索: {search_query}")
|
||||
search_result = web_search(search_query, max_results=5)
|
||||
|
||||
# 保存搜索结果到状态
|
||||
if not hasattr(state, "web_search_results"):
|
||||
state.web_search_results = []
|
||||
state.web_search_results.append(search_result)
|
||||
|
||||
# 将搜索结果添加到 rag_context,供 LLM 使用
|
||||
if state.rag_context:
|
||||
state.rag_context = f"{state.rag_context}\n\n---\n\n## 🌐 联网搜索结果:\n{search_result}"
|
||||
else:
|
||||
state.rag_context = f"## 🌐 联网搜索结果:\n{search_result}"
|
||||
|
||||
state.success = True
|
||||
print(f"[WebSearch] 搜索完成")
|
||||
|
||||
# 发送完成事件
|
||||
if config:
|
||||
try:
|
||||
from langchain_core.callbacks.manager import adispatch_custom_event
|
||||
callbacks = config.get("callbacks")
|
||||
if callbacks:
|
||||
await adispatch_custom_event(
|
||||
"react_reasoning",
|
||||
{
|
||||
"step": state.reasoning_step,
|
||||
"action": "web_search_complete",
|
||||
"confidence": 1.0,
|
||||
"reasoning": f"联网搜索完成,找到 {len(search_result) if isinstance(search_result, list) else 1} 条结果"
|
||||
},
|
||||
callbacks=callbacks
|
||||
)
|
||||
except Exception as e:
|
||||
info(f"[web_search_node] 无法发送完成事件: {e}")
|
||||
|
||||
except Exception as e:
|
||||
error_record = ErrorRecord(
|
||||
error_type="WebSearchError",
|
||||
error_message=str(e),
|
||||
severity=ErrorSeverity.WARNING,
|
||||
source="web_search_node",
|
||||
timestamp=datetime.now().isoformat(),
|
||||
retry_count=0,
|
||||
max_retries=2,
|
||||
context={"search_query": search_query}
|
||||
)
|
||||
state.errors.append(error_record)
|
||||
state.current_error = error_record
|
||||
state.current_phase = "error_handling"
|
||||
state.success = False
|
||||
|
||||
# 发送错误事件
|
||||
if config:
|
||||
try:
|
||||
from langchain_core.callbacks.manager import adispatch_custom_event
|
||||
callbacks = config.get("callbacks")
|
||||
if callbacks:
|
||||
await adispatch_custom_event(
|
||||
"react_reasoning",
|
||||
{
|
||||
"step": state.reasoning_step,
|
||||
"action": "web_search_error",
|
||||
"confidence": 1.0,
|
||||
"reasoning": f"联网搜索失败: {str(e)}"
|
||||
},
|
||||
callbacks=callbacks
|
||||
)
|
||||
except Exception as e:
|
||||
info(f"[web_search_node] 无法发送错误事件: {e}")
|
||||
|
||||
return state
|
||||
@@ -7,13 +7,10 @@ from typing import Dict, Any, Optional
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from ..state import MainGraphState
|
||||
from ..nodes.react_nodes import (
|
||||
init_state_node,
|
||||
react_reason_node,
|
||||
web_search_node,
|
||||
error_handling_node,
|
||||
route_by_reasoning
|
||||
)
|
||||
from ..nodes.reasoning import react_reason_node
|
||||
from ..nodes.web_search import web_search_node
|
||||
from ..nodes.error_handling import error_handling_node
|
||||
from ..nodes.routing import init_state_node, route_by_reasoning
|
||||
from ..nodes.hybrid_router import (
|
||||
hybrid_router_node,
|
||||
fast_chitchat_node,
|
||||
|
||||
Reference in New Issue
Block a user