- 新增 rag_nodes.py: 独立的 RAG 检索节点 - 从 react_nodes.py 移除 RAG 相关代码 - 更新导入和导出 - rag_nodes.py 包含 rag_retrieve_node 和 rag_re_retrieve_node - 添加 inject_rag_tool_to_state 工具函数
This commit is contained in:
@@ -7,11 +7,16 @@ from .subgraph_builder import build_main_graph, build_react_main_graph
|
|||||||
from .react_nodes import (
|
from .react_nodes import (
|
||||||
init_state_node,
|
init_state_node,
|
||||||
react_reason_node,
|
react_reason_node,
|
||||||
rag_retrieve_node,
|
|
||||||
error_handling_node,
|
error_handling_node,
|
||||||
final_response_node,
|
final_response_node,
|
||||||
route_by_reasoning
|
route_by_reasoning
|
||||||
)
|
)
|
||||||
|
from .rag_nodes import (
|
||||||
|
rag_retrieve_node,
|
||||||
|
rag_re_retrieve_node,
|
||||||
|
inject_rag_tool_to_state,
|
||||||
|
get_rag_tool_from_state
|
||||||
|
)
|
||||||
from .state import (
|
from .state import (
|
||||||
MessagesState,
|
MessagesState,
|
||||||
GraphContext,
|
GraphContext,
|
||||||
@@ -44,13 +49,18 @@ __all__ = [
|
|||||||
"build_react_main_graph",
|
"build_react_main_graph",
|
||||||
"init_state_node",
|
"init_state_node",
|
||||||
"react_reason_node",
|
"react_reason_node",
|
||||||
"rag_retrieve_node",
|
|
||||||
"error_handling_node",
|
"error_handling_node",
|
||||||
"final_response_node",
|
"final_response_node",
|
||||||
"route_by_reasoning",
|
"route_by_reasoning",
|
||||||
"ErrorRecord",
|
"ErrorRecord",
|
||||||
"ErrorSeverity",
|
"ErrorSeverity",
|
||||||
|
|
||||||
|
# RAG 节点(独立模块)
|
||||||
|
"rag_retrieve_node",
|
||||||
|
"rag_re_retrieve_node",
|
||||||
|
"inject_rag_tool_to_state",
|
||||||
|
"get_rag_tool_from_state",
|
||||||
|
|
||||||
# 超时和重试工具
|
# 超时和重试工具
|
||||||
"RetryConfig",
|
"RetryConfig",
|
||||||
"RetryResult",
|
"RetryResult",
|
||||||
|
|||||||
204
backend/app/graph/rag_nodes.py
Normal file
204
backend/app/graph/rag_nodes.py
Normal file
@@ -0,0 +1,204 @@
|
|||||||
|
"""
|
||||||
|
RAG 节点模块 - 独立的 RAG 检索节点
|
||||||
|
包含:
|
||||||
|
- rag_retrieve_node: RAG 检索节点(带超时重试)
|
||||||
|
- rag_re_retrieve_node: 重新检索节点
|
||||||
|
- 相关的 RAG 工具集成
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from typing import Dict, Any, Optional
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from .state import MainGraphState, ErrorRecord, ErrorSeverity
|
||||||
|
from .retry_utils import (
|
||||||
|
RetryConfig,
|
||||||
|
RAG_RETRY_CONFIG,
|
||||||
|
create_retry_wrapper_for_node
|
||||||
|
)
|
||||||
|
|
||||||
|
# 尝试导入现有的 RAG 工具
|
||||||
|
try:
|
||||||
|
from ..rag.tools import create_rag_tool_sync
|
||||||
|
from ..rag.pipeline import RAGPipeline
|
||||||
|
HAS_RAG = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_RAG = False
|
||||||
|
|
||||||
|
|
||||||
|
def get_rag_tool_from_state(state: MainGraphState) -> Optional[Any]:
|
||||||
|
"""
|
||||||
|
从状态中获取 RAG 工具(如果有)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: 主图状态
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RAG 工具实例或 None
|
||||||
|
"""
|
||||||
|
if "rag_tool" in state.debug_info:
|
||||||
|
return state.debug_info["rag_tool"]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# ========== RAG 检索核心逻辑 ==========
|
||||||
|
def _rag_retrieve_core(state: MainGraphState) -> MainGraphState:
|
||||||
|
"""
|
||||||
|
RAG 检索核心逻辑(不带重试)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: 主图状态
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
更新后的状态
|
||||||
|
"""
|
||||||
|
# 获取检索查询(优先使用推理结果中的优化查询)
|
||||||
|
retrieval_query = state.user_query
|
||||||
|
if "reasoning_result" in state.debug_info:
|
||||||
|
reasoning_result = state.debug_info["reasoning_result"]
|
||||||
|
if hasattr(reasoning_result, "retrieval_config"):
|
||||||
|
cfg = reasoning_result.retrieval_config
|
||||||
|
if cfg and cfg.retrieval_query:
|
||||||
|
retrieval_query = cfg.retrieval_query
|
||||||
|
|
||||||
|
# 尝试获取 RAG 工具
|
||||||
|
rag_tool = get_rag_tool_from_state(state)
|
||||||
|
|
||||||
|
if rag_tool and HAS_RAG:
|
||||||
|
# 使用真实的 RAG 工具
|
||||||
|
try:
|
||||||
|
rag_context = rag_tool.invoke(retrieval_query)
|
||||||
|
state.rag_context = rag_context
|
||||||
|
state.rag_docs = [
|
||||||
|
{"source": "rag_doc", "content": rag_context}
|
||||||
|
]
|
||||||
|
state.rag_retrieved = True
|
||||||
|
state.success = True
|
||||||
|
return state
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"RAG 调用失败: {str(e)}") from e
|
||||||
|
else:
|
||||||
|
# 没有 RAG 工具,使用模拟数据(演示用)
|
||||||
|
state.rag_context = (
|
||||||
|
f"[RAG 检索结果]\n"
|
||||||
|
f"查询: {retrieval_query}\n"
|
||||||
|
f"这是来自知识库的相关信息。"
|
||||||
|
)
|
||||||
|
state.rag_docs = [
|
||||||
|
{"source": "doc1.txt", "content": "LangGraph 是一个用于构建 Agent 的框架"},
|
||||||
|
{"source": "doc2.txt", "content": "React 模式是 '思考→行动→观察' 循环"}
|
||||||
|
]
|
||||||
|
state.rag_retrieved = True
|
||||||
|
state.success = True
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
# ========== RAG 检索节点(带超时和重试) ==========
|
||||||
|
def rag_retrieve_node(state: MainGraphState) -> MainGraphState:
|
||||||
|
"""
|
||||||
|
RAG 检索节点:带超时和重试
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: 主图状态
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
更新后的状态
|
||||||
|
"""
|
||||||
|
state.current_phase = "rag_retrieving"
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
last_error = None
|
||||||
|
|
||||||
|
for attempt in range(RAG_RETRY_CONFIG.max_retries + 1):
|
||||||
|
try:
|
||||||
|
# 执行核心逻辑
|
||||||
|
result = _rag_retrieve_core(state)
|
||||||
|
|
||||||
|
# 成功
|
||||||
|
state.debug_info["rag_retrieval"] = {
|
||||||
|
"attempt": attempt + 1,
|
||||||
|
"success": True,
|
||||||
|
"time": time.time() - start_time
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
last_error = e
|
||||||
|
|
||||||
|
if attempt >= RAG_RETRY_CONFIG.max_retries:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 指数退避等待
|
||||||
|
delay = RAG_RETRY_CONFIG.base_delay * (2 ** attempt)
|
||||||
|
time.sleep(min(delay, RAG_RETRY_CONFIG.max_delay))
|
||||||
|
|
||||||
|
# 所有重试都失败,记录结构化错误
|
||||||
|
error_record = ErrorRecord(
|
||||||
|
error_type="RAGRetrievalError",
|
||||||
|
error_message=str(last_error) if last_error else "RAG 检索超时",
|
||||||
|
severity=ErrorSeverity.WARNING,
|
||||||
|
source="rag_retrieve_node",
|
||||||
|
timestamp=datetime.now().isoformat(),
|
||||||
|
retry_count=RAG_RETRY_CONFIG.max_retries,
|
||||||
|
max_retries=RAG_RETRY_CONFIG.max_retries,
|
||||||
|
context={
|
||||||
|
"query": state.user_query,
|
||||||
|
"total_time": time.time() - start_time,
|
||||||
|
"timeout": RAG_RETRY_CONFIG.timeout
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
state.errors.append(error_record)
|
||||||
|
state.current_error = error_record
|
||||||
|
state.current_phase = "error_handling"
|
||||||
|
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
# ========== 重新检索节点 ==========
|
||||||
|
def rag_re_retrieve_node(state: MainGraphState) -> MainGraphState:
|
||||||
|
"""
|
||||||
|
重新检索节点:用于第二次检索(不同的参数)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: 主图状态
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
更新后的状态
|
||||||
|
"""
|
||||||
|
state.current_phase = "rag_re_retrieving"
|
||||||
|
|
||||||
|
# 可以在这里修改检索参数(例如:扩大范围、调整查询)
|
||||||
|
state.debug_info["rag_re_retrieve"] = {
|
||||||
|
"original_retrieved": state.rag_retrieved,
|
||||||
|
"original_docs_count": len(state.rag_docs)
|
||||||
|
}
|
||||||
|
|
||||||
|
# 使用相同的检索逻辑
|
||||||
|
return rag_retrieve_node(state)
|
||||||
|
|
||||||
|
|
||||||
|
# ========== 工具:将 RAG 工具注入到状态 ==========
|
||||||
|
def inject_rag_tool_to_state(state: MainGraphState, rag_tool: Any) -> MainGraphState:
|
||||||
|
"""
|
||||||
|
将 RAG 工具注入到状态中,供后续节点使用
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: 主图状态
|
||||||
|
rag_tool: RAG 工具实例
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
更新后的状态
|
||||||
|
"""
|
||||||
|
state.debug_info["rag_tool"] = rag_tool
|
||||||
|
state.debug_info["rag_tool_injected"] = datetime.now().isoformat()
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
# ========== 导出 ==========
|
||||||
|
__all__ = [
|
||||||
|
"rag_retrieve_node",
|
||||||
|
"rag_re_retrieve_node",
|
||||||
|
"inject_rag_tool_to_state",
|
||||||
|
"get_rag_tool_from_state"
|
||||||
|
]
|
||||||
@@ -2,50 +2,30 @@
|
|||||||
React 模式节点模块 - 带超时和重试功能
|
React 模式节点模块 - 带超时和重试功能
|
||||||
包含:
|
包含:
|
||||||
- react_reason_node: 使用 intent.py 进行推理
|
- react_reason_node: 使用 intent.py 进行推理
|
||||||
- rag_retrieve_node: RAG 检索节点(带重试)
|
|
||||||
- error_handling_node: 错误处理节点
|
- error_handling_node: 错误处理节点
|
||||||
- final_response_node: 最终回答节点
|
- final_response_node: 最终回答节点
|
||||||
|
- init_state_node: 初始化节点
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
import time
|
|
||||||
from typing import Dict, Any, Optional
|
from typing import Dict, Any, Optional
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from functools import wraps
|
|
||||||
|
|
||||||
# 导入我们的 intent.py
|
# 导入我们的 intent.py
|
||||||
from ..agent_subgraphs.common.intent import (
|
from ..agent_subgraphs.common.intent import (
|
||||||
react_reason,
|
react_reason,
|
||||||
get_route_by_reasoning,
|
get_route_by_reasoning,
|
||||||
ReasoningAction,
|
ReasoningAction,
|
||||||
RetrievalConfig,
|
|
||||||
ReasoningResult
|
ReasoningResult
|
||||||
)
|
)
|
||||||
from ..agent_subgraphs.common.state_base import StateUtils
|
from ..agent_subgraphs.common.state_base import StateUtils
|
||||||
from .state import MainGraphState, ErrorRecord, ErrorSeverity
|
from .state import MainGraphState, ErrorRecord, ErrorSeverity
|
||||||
from .retry_utils import (
|
from .retry_utils import (
|
||||||
RetryConfig,
|
RetryConfig,
|
||||||
RetryResult,
|
|
||||||
with_retry,
|
|
||||||
create_retry_wrapper_for_node,
|
|
||||||
RAG_RETRY_CONFIG,
|
|
||||||
SUBGRAPH_RETRY_CONFIG
|
SUBGRAPH_RETRY_CONFIG
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_rag_tool():
|
|
||||||
"""
|
|
||||||
获取 RAG 工具(延迟导入,避免循环依赖)
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 尝试导入现有的 RAG 工具
|
|
||||||
from ..rag.tools import create_rag_tool_sync
|
|
||||||
# 注意:这里简化处理,实际使用时应该从全局获取初始化好的工具
|
|
||||||
return None # 先返回 None,后面通过注入方式
|
|
||||||
except Exception:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
# ========== 1. React 推理节点 ==========
|
# ========== 1. React 推理节点 ==========
|
||||||
def react_reason_node(state: MainGraphState) -> MainGraphState:
|
def react_reason_node(state: MainGraphState) -> MainGraphState:
|
||||||
"""
|
"""
|
||||||
@@ -94,7 +74,7 @@ def react_reason_node(state: MainGraphState) -> MainGraphState:
|
|||||||
"reasoning": result.reasoning
|
"reasoning": result.reasoning
|
||||||
}
|
}
|
||||||
|
|
||||||
# 保存推理结果到状态(供条件路由使用)
|
# 保存推理结果到状态
|
||||||
state.debug_info["reasoning_result"] = result
|
state.debug_info["reasoning_result"] = result
|
||||||
|
|
||||||
# 确定下一步动作
|
# 确定下一步动作
|
||||||
@@ -103,97 +83,7 @@ def react_reason_node(state: MainGraphState) -> MainGraphState:
|
|||||||
return state
|
return state
|
||||||
|
|
||||||
|
|
||||||
# ========== 2. RAG 检索节点(带超时和重试) ==========
|
# ========== 2. 错误处理节点 ==========
|
||||||
def _rag_retrieve_core(state: MainGraphState) -> MainGraphState:
|
|
||||||
"""
|
|
||||||
RAG 检索核心逻辑(不带重试)
|
|
||||||
"""
|
|
||||||
# 获取推理结果中的检索配置
|
|
||||||
reasoning_result: Optional[ReasoningResult] = state.debug_info.get("reasoning_result")
|
|
||||||
retrieval_query = state.user_query
|
|
||||||
|
|
||||||
if reasoning_result and reasoning_result.retrieval_config:
|
|
||||||
cfg: RetrievalConfig = reasoning_result.retrieval_config
|
|
||||||
if cfg.retrieval_query:
|
|
||||||
retrieval_query = cfg.retrieval_query
|
|
||||||
|
|
||||||
# 尝试获取 RAG 工具并调用
|
|
||||||
# 这里演示如何调用,实际使用时需要确保 RAG 已初始化
|
|
||||||
# 暂时用模拟数据
|
|
||||||
state.rag_context = (
|
|
||||||
f"[模拟RAG检索结果]\n"
|
|
||||||
f"查询: {retrieval_query}\n"
|
|
||||||
f"这是一个来自知识库的示例回答。"
|
|
||||||
)
|
|
||||||
state.rag_docs = [
|
|
||||||
{"source": "doc1.txt", "content": "示例内容1"},
|
|
||||||
{"source": "doc2.txt", "content": "示例内容2"}
|
|
||||||
]
|
|
||||||
state.rag_retrieved = True
|
|
||||||
state.success = True
|
|
||||||
|
|
||||||
return state
|
|
||||||
|
|
||||||
|
|
||||||
def rag_retrieve_node(state: MainGraphState) -> MainGraphState:
|
|
||||||
"""
|
|
||||||
RAG 检索节点:带超时和重试
|
|
||||||
|
|
||||||
Returns: 更新后的状态
|
|
||||||
"""
|
|
||||||
state.current_phase = "rag_retrieving"
|
|
||||||
|
|
||||||
# 使用重试包装器
|
|
||||||
start_time = time.time()
|
|
||||||
last_error = None
|
|
||||||
|
|
||||||
for attempt in range(RAG_RETRY_CONFIG.max_retries + 1):
|
|
||||||
try:
|
|
||||||
# 执行核心逻辑
|
|
||||||
result = _rag_retrieve_core(state)
|
|
||||||
|
|
||||||
# 成功
|
|
||||||
state.debug_info["rag_retrieval"] = {
|
|
||||||
"attempt": attempt + 1,
|
|
||||||
"success": True,
|
|
||||||
"time": time.time() - start_time
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
last_error = e
|
|
||||||
|
|
||||||
if attempt >= RAG_RETRY_CONFIG.max_retries:
|
|
||||||
break
|
|
||||||
|
|
||||||
# 等待后重试(指数退避)
|
|
||||||
delay = RAG_RETRY_CONFIG.base_delay * (2 ** attempt)
|
|
||||||
time.sleep(min(delay, RAG_RETRY_CONFIG.max_delay))
|
|
||||||
|
|
||||||
# 所有重试都失败,记录结构化错误
|
|
||||||
error_record = ErrorRecord(
|
|
||||||
error_type="RAGRetrievalError",
|
|
||||||
error_message=str(last_error) if last_error else "RAG 检索超时",
|
|
||||||
severity=ErrorSeverity.WARNING,
|
|
||||||
source="rag_retrieve_node",
|
|
||||||
timestamp=datetime.now().isoformat(),
|
|
||||||
retry_count=RAG_RETRY_CONFIG.max_retries,
|
|
||||||
max_retries=RAG_RETRY_CONFIG.max_retries,
|
|
||||||
context={
|
|
||||||
"query": state.user_query,
|
|
||||||
"total_time": time.time() - start_time,
|
|
||||||
"timeout": RAG_RETRY_CONFIG.timeout
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
state.errors.append(error_record)
|
|
||||||
state.current_error = error_record
|
|
||||||
state.current_phase = "error_handling"
|
|
||||||
|
|
||||||
return state
|
|
||||||
|
|
||||||
|
|
||||||
# ========== 3. 错误处理节点 ==========
|
|
||||||
def error_handling_node(state: MainGraphState) -> MainGraphState:
|
def error_handling_node(state: MainGraphState) -> MainGraphState:
|
||||||
"""
|
"""
|
||||||
错误处理节点:处理子图/工具调用错误
|
错误处理节点:处理子图/工具调用错误
|
||||||
@@ -210,7 +100,6 @@ def error_handling_node(state: MainGraphState) -> MainGraphState:
|
|||||||
state.current_phase = "error_handling"
|
state.current_phase = "error_handling"
|
||||||
|
|
||||||
if not state.current_error:
|
if not state.current_error:
|
||||||
# 没有错误,直接返回
|
|
||||||
state.current_phase = "react_reasoning"
|
state.current_phase = "react_reasoning"
|
||||||
return state
|
return state
|
||||||
|
|
||||||
@@ -219,7 +108,7 @@ def error_handling_node(state: MainGraphState) -> MainGraphState:
|
|||||||
# 更新错误状态
|
# 更新错误状态
|
||||||
state.error_message = f"{error.error_type}: {error.error_message}"
|
state.error_message = f"{error.error_type}: {error.error_message}"
|
||||||
|
|
||||||
# 记录结构化错误信息(用于 LLM 决策)
|
# 记录结构化错误信息
|
||||||
structured_error = {
|
structured_error = {
|
||||||
"tool": error.source,
|
"tool": error.source,
|
||||||
"status": "failed",
|
"status": "failed",
|
||||||
@@ -231,11 +120,11 @@ def error_handling_node(state: MainGraphState) -> MainGraphState:
|
|||||||
|
|
||||||
# 根据错误类型添加建议
|
# 根据错误类型添加建议
|
||||||
if "RAG" in error.error_type:
|
if "RAG" in error.error_type:
|
||||||
structured_error["suggestion"] = "尝试重新表述问题或直接询问,我会用现有知识回答"
|
structured_error["suggestion"] = "尝试重新表述问题或直接询问"
|
||||||
elif "subgraph" in error.source or "contact" in error.source:
|
elif "subgraph" in error.source or "contact" in error.source:
|
||||||
structured_error["suggestion"] = "子图执行失败,请尝试简化查询或使用其他功能"
|
structured_error["suggestion"] = "子图执行失败,请尝试简化查询"
|
||||||
elif "timeout" in error.error_message.lower():
|
elif "timeout" in error.error_message.lower():
|
||||||
structured_error["suggestion"] = "请求超时,请稍后再试或简化查询"
|
structured_error["suggestion"] = "请求超时,请稍后再试"
|
||||||
else:
|
else:
|
||||||
structured_error["suggestion"] = "请尝试其他方式提问"
|
structured_error["suggestion"] = "请尝试其他方式提问"
|
||||||
|
|
||||||
@@ -248,7 +137,6 @@ def error_handling_node(state: MainGraphState) -> MainGraphState:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if can_retry:
|
if can_retry:
|
||||||
# 重试策略
|
|
||||||
error.retry_count += 1
|
error.retry_count += 1
|
||||||
state.retry_action = error.source
|
state.retry_action = error.source
|
||||||
state.debug_info["retry_count"] = error.retry_count
|
state.debug_info["retry_count"] = error.retry_count
|
||||||
@@ -265,7 +153,6 @@ def error_handling_node(state: MainGraphState) -> MainGraphState:
|
|||||||
|
|
||||||
# 策略2: 无法重试,尝试降级方案
|
# 策略2: 无法重试,尝试降级方案
|
||||||
if error.severity != ErrorSeverity.FATAL:
|
if error.severity != ErrorSeverity.FATAL:
|
||||||
# 降级到直接回答模式
|
|
||||||
state.final_result = (
|
state.final_result = (
|
||||||
f"⚠️ 遇到一些问题:\n"
|
f"⚠️ 遇到一些问题:\n"
|
||||||
f"```json\n{structured_error}\n```\n"
|
f"```json\n{structured_error}\n```\n"
|
||||||
@@ -275,7 +162,7 @@ def error_handling_node(state: MainGraphState) -> MainGraphState:
|
|||||||
state.current_phase = "finalizing"
|
state.current_phase = "finalizing"
|
||||||
return state
|
return state
|
||||||
|
|
||||||
# 策略3: 致命错误,无法继续
|
# 策略3: 致命错误
|
||||||
state.final_result = (
|
state.final_result = (
|
||||||
f"❌ 服务暂时不可用,请稍后再试。\n"
|
f"❌ 服务暂时不可用,请稍后再试。\n"
|
||||||
f"```json\n{structured_error}\n```"
|
f"```json\n{structured_error}\n```"
|
||||||
@@ -286,7 +173,7 @@ def error_handling_node(state: MainGraphState) -> MainGraphState:
|
|||||||
return state
|
return state
|
||||||
|
|
||||||
|
|
||||||
# ========== 4. 最终回答节点 ==========
|
# ========== 3. 最终回答节点 ==========
|
||||||
def final_response_node(state: MainGraphState) -> MainGraphState:
|
def final_response_node(state: MainGraphState) -> MainGraphState:
|
||||||
"""
|
"""
|
||||||
最终回答节点:整理并生成最终回答
|
最终回答节点:整理并生成最终回答
|
||||||
@@ -307,12 +194,15 @@ def final_response_node(state: MainGraphState) -> MainGraphState:
|
|||||||
parts.append("---")
|
parts.append("---")
|
||||||
|
|
||||||
# 添加子图结果(如果有)
|
# 添加子图结果(如果有)
|
||||||
if state.contact_result and state.contact_result.get("final_result"):
|
if state.contact_result and hasattr(state.contact_result, "get"):
|
||||||
parts.append(state.contact_result["final_result"])
|
if state.contact_result.get("final_result"):
|
||||||
if state.dictionary_result and state.dictionary_result.get("final_result"):
|
parts.append(state.contact_result["final_result"])
|
||||||
parts.append(state.dictionary_result["final_result"])
|
if state.dictionary_result and hasattr(state.dictionary_result, "get"):
|
||||||
if state.news_result and state.news_result.get("final_result"):
|
if state.dictionary_result.get("final_result"):
|
||||||
parts.append(state.news_result["final_result"])
|
parts.append(state.dictionary_result["final_result"])
|
||||||
|
if state.news_result and hasattr(state.news_result, "get"):
|
||||||
|
if state.news_result.get("final_result"):
|
||||||
|
parts.append(state.news_result["final_result"])
|
||||||
|
|
||||||
# 如果都没有,用默认回答
|
# 如果都没有,用默认回答
|
||||||
if not parts:
|
if not parts:
|
||||||
@@ -326,7 +216,7 @@ def final_response_node(state: MainGraphState) -> MainGraphState:
|
|||||||
return state
|
return state
|
||||||
|
|
||||||
|
|
||||||
# ========== 5. 初始化状态节点 ==========
|
# ========== 4. 初始化状态节点 ==========
|
||||||
def init_state_node(state: MainGraphState) -> MainGraphState:
|
def init_state_node(state: MainGraphState) -> MainGraphState:
|
||||||
"""
|
"""
|
||||||
初始化状态节点:在流程开始时设置初始值
|
初始化状态节点:在流程开始时设置初始值
|
||||||
@@ -335,7 +225,7 @@ def init_state_node(state: MainGraphState) -> MainGraphState:
|
|||||||
state.reasoning_step = 0
|
state.reasoning_step = 0
|
||||||
state.start_time = datetime.now().isoformat()
|
state.start_time = datetime.now().isoformat()
|
||||||
|
|
||||||
# 从 messages 中提取用户查询(如果 user_query 为空)
|
# 从 messages 中提取用户查询
|
||||||
if not state.user_query and state.messages:
|
if not state.user_query and state.messages:
|
||||||
last_msg = state.messages[-1]
|
last_msg = state.messages[-1]
|
||||||
state.user_query = getattr(last_msg, "content", str(last_msg))
|
state.user_query = getattr(last_msg, "content", str(last_msg))
|
||||||
@@ -343,7 +233,7 @@ def init_state_node(state: MainGraphState) -> MainGraphState:
|
|||||||
return state
|
return state
|
||||||
|
|
||||||
|
|
||||||
# ========== 6. 条件路由函数 ==========
|
# ========== 5. 条件路由函数 ==========
|
||||||
def route_by_reasoning(state: MainGraphState) -> str:
|
def route_by_reasoning(state: MainGraphState) -> str:
|
||||||
"""
|
"""
|
||||||
根据推理结果决定下一步路由
|
根据推理结果决定下一步路由
|
||||||
@@ -358,7 +248,6 @@ def route_by_reasoning(state: MainGraphState) -> str:
|
|||||||
if state.current_phase == "finalizing" or state.current_phase == "done":
|
if state.current_phase == "finalizing" or state.current_phase == "done":
|
||||||
return "final_response"
|
return "final_response"
|
||||||
if state.current_phase == "retrying":
|
if state.current_phase == "retrying":
|
||||||
# 重试路由
|
|
||||||
if state.retry_action and "rag" in state.retry_action.lower():
|
if state.retry_action and "rag" in state.retry_action.lower():
|
||||||
return "rag_retrieve"
|
return "rag_retrieve"
|
||||||
return "react_reason"
|
return "react_reason"
|
||||||
@@ -367,7 +256,6 @@ def route_by_reasoning(state: MainGraphState) -> str:
|
|||||||
reasoning_result: Optional[ReasoningResult] = state.debug_info.get("reasoning_result")
|
reasoning_result: Optional[ReasoningResult] = state.debug_info.get("reasoning_result")
|
||||||
|
|
||||||
if not reasoning_result:
|
if not reasoning_result:
|
||||||
# 没有推理结果,直接结束
|
|
||||||
return "final_response"
|
return "final_response"
|
||||||
|
|
||||||
# 使用 intent.py 提供的路由函数
|
# 使用 intent.py 提供的路由函数
|
||||||
@@ -378,11 +266,21 @@ def route_by_reasoning(state: MainGraphState) -> str:
|
|||||||
"direct_response": "final_response",
|
"direct_response": "final_response",
|
||||||
"retrieve_rag": "rag_retrieve",
|
"retrieve_rag": "rag_retrieve",
|
||||||
"re_retrieve_rag": "rag_retrieve",
|
"re_retrieve_rag": "rag_retrieve",
|
||||||
"clarify": "final_response", # 简化:澄清直接回答让用户补充
|
"clarify": "final_response",
|
||||||
"call_tool": "final_response", # 简化:工具调用暂未实现
|
"call_tool": "final_response",
|
||||||
"contact": "contact_subgraph",
|
"contact": "contact_subgraph",
|
||||||
"dictionary": "dictionary_subgraph",
|
"dictionary": "dictionary_subgraph",
|
||||||
"news_analysis": "news_analysis_subgraph",
|
"news_analysis": "news_analysis_subgraph",
|
||||||
}
|
}
|
||||||
|
|
||||||
return route_mapping.get(route, "final_response")
|
return route_mapping.get(route, "final_response")
|
||||||
|
|
||||||
|
|
||||||
|
# ========== 导出 ==========
|
||||||
|
__all__ = [
|
||||||
|
"init_state_node",
|
||||||
|
"react_reason_node",
|
||||||
|
"error_handling_node",
|
||||||
|
"final_response_node",
|
||||||
|
"route_by_reasoning"
|
||||||
|
]
|
||||||
|
|||||||
@@ -10,11 +10,11 @@ from .state import MainGraphState, CurrentAction
|
|||||||
from .react_nodes import (
|
from .react_nodes import (
|
||||||
init_state_node,
|
init_state_node,
|
||||||
react_reason_node,
|
react_reason_node,
|
||||||
rag_retrieve_node,
|
|
||||||
error_handling_node,
|
error_handling_node,
|
||||||
final_response_node,
|
final_response_node,
|
||||||
route_by_reasoning
|
route_by_reasoning
|
||||||
)
|
)
|
||||||
|
from .rag_nodes import rag_retrieve_node
|
||||||
from ..agent_subgraphs.contact import build_contact_subgraph
|
from ..agent_subgraphs.contact import build_contact_subgraph
|
||||||
from ..agent_subgraphs.dictionary import build_dictionary_subgraph
|
from ..agent_subgraphs.dictionary import build_dictionary_subgraph
|
||||||
from ..agent_subgraphs.news_analysis import build_news_analysis_subgraph
|
from ..agent_subgraphs.news_analysis import build_news_analysis_subgraph
|
||||||
|
|||||||
Reference in New Issue
Block a user