feat: 实现 React 模式循环推理,带超时重试和结构化错误处理
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m15s

- 更新 intent.py 为 React 模式推理器
- 新增 react_nodes.py: React 模式节点
- 新增 retry_utils.py: 超时和重试工具
- 更新 state.py: 支持循环步数和错误记录
- 重写 subgraph_builder.py: 完整 React 循环流程
- 结构化错误输出,符合 Agent 执行循环最佳实践
- 限制最大推理步数 ≤40,防止无限循环
- RAG 检索带重试和超时保护
- 子图错误可传递给主图处理
This commit is contained in:
2026-04-26 11:14:04 +08:00
parent e6337eb0fc
commit e3adb45454
7 changed files with 1304 additions and 493 deletions

View File

@@ -17,6 +17,7 @@ from .formatter import (
)
from .intent import (
# 旧版 API保持向后兼容
IntentType,
Intent,
Entity,
@@ -24,7 +25,17 @@ from .intent import (
RuleBasedIntentClassifier,
RuleBasedEntityExtractor,
IntentRegistry,
create_default_intent_parser
create_default_intent_parser,
# 新版 React 模式 API
ReasoningAction,
RetrievalConfig,
ReasoningResult,
BaseIntentReasoner,
RuleBasedReactReasoner,
LLMReactReasoner,
create_react_reasoner,
react_reason,
get_route_by_reasoning
)
from .human_review import (
@@ -49,7 +60,7 @@ __all__ = [
"TemplateManager",
"OutputRenderer",
"PresetTemplates",
# intent
# intent - 旧版
"IntentType",
"Intent",
"Entity",
@@ -58,6 +69,16 @@ __all__ = [
"RuleBasedEntityExtractor",
"IntentRegistry",
"create_default_intent_parser",
# intent - 新版 React 模式
"ReasoningAction",
"RetrievalConfig",
"ReasoningResult",
"BaseIntentReasoner",
"RuleBasedReactReasoner",
"LLMReactReasoner",
"create_react_reasoner",
"react_reason",
"get_route_by_reasoning",
# human_review
"ReviewStatus",
"HumanReview",

View File

@@ -1,427 +1,381 @@
"""
意图理解工具模块
提供标准化的意图分类和信息提取能力
意图理解与推理模块 (React模式)
Intent Understanding & Reasoning Module (React Pattern)
功能
1. Intent - 意图数据类
2. IntentClassifier - 意图分类器
3. EntityExtractor - 实体提取器
4. IntentParser - 完整的意图解析器
5. IntentRegistry - 意图注册器
这个模块实现了 React (Reasoning + Acting) 模式的意图理解节点,用于
1. 理解用户的查询意图
2. 判断是否需要调用 RAG 检索
3. 判断是否需要重新检索
4. 决定下一步的行动
5. 支持条件路由扩展
核心组件:
- ReasoningAction: 推理动作枚举
- ReasoningResult: 推理结果数据类
- ReactIntentReasoner: React 模式意图推理器
"""
import re
from typing import Dict, List, Any, Optional, Set, Tuple, Callable
from typing import Dict, Any, Optional, List, Set, Tuple
from dataclasses import dataclass, field
from enum import Enum, auto
from abc import ABC, abstractmethod
class IntentType(Enum):
"""意图类型枚举"""
UNKNOWN = auto()
GREETING = auto() # 问候
QUESTION = auto() # 提问
REQUEST = auto() # 请求
COMMAND = auto() # 命令
INFORM = auto() # 告知信息
CONFIRM = auto() # 确认
DENY = auto() # 否认
THANKS = auto() # 感谢
GOODBYE = auto() # 告别
COMPLAINT = auto() # 投诉
PRAISE = auto() # 表扬
CLARIFY = auto() # 澄清
SUGGEST = auto() # 建议
class ReasoningAction(Enum):
"""推理动作枚举 - 决定下一步做什么"""
DIRECT_RESPONSE = auto() # 直接回答,不需要额外信息
RETRIEVE_RAG = auto() # 需要调用 RAG 检索
RERIEVE_RAG = auto() # 需要重新检索 (优化前版本,兼容保留)
RE_RETRIEVE_RAG = auto() # 需要重新检索 (修正拼写)
CALL_TOOL = auto() # 需要调用其他工具
CLARIFY = auto() # 需要澄清用户的问题
ROUTE_SUBGRAPH = auto() # 需要路由到子图
UNKNOWN = auto() # 未知动作
@dataclass
class Entity:
"""实体数据类"""
entity_type: str # 实体类型
value: str # 实体值
start_pos: int = 0 # 起始位置
end_pos: int = 0 # 结束位置
confidence: float = 1.0 # 置信度
metadata: Dict[str, Any] = field(default_factory=dict) # 元数据
class RetrievalConfig:
"""检索配置"""
need_retrieval: bool = False # 是否需要检索
need_re_retrieval: bool = False # 是否需要重新检索
retrieval_query: Optional[str] = None # 优化后的检索查询
collection_name: Optional[str] = None # 检索的集合名称
k: int = 5 # 返回数量
score_threshold: float = 0.3 # 相似度阈值
metadata: Dict[str, Any] = field(default_factory=dict)
@dataclass
class Intent:
"""意图数据类"""
intent_type: IntentType # 意图类型
confidence: float = 1.0 # 置信度
entities: List[Entity] = field(default_factory=list) # 提取的实体
parameters: Dict[str, Any] = field(default_factory=dict) # 参数
original_text: str = "" # 原始文本
normalized_text: str = "" # 标准化后的文本
metadata: Dict[str, Any] = field(default_factory=dict) # 元数据
class ReasoningResult:
"""推理结果数据类"""
action: ReasoningAction = ReasoningAction.UNKNOWN # 决定的动作
confidence: float = 0.0 # 置信度
reasoning: str = "" # 推理过程说明
retrieval_config: RetrievalConfig = field(default_factory=RetrievalConfig)
extracted_entities: Dict[str, Any] = field(default_factory=dict) # 提取的实体
next_hints: List[str] = field(default_factory=list) # 下一步提示
original_query: str = "" # 原始查询
metadata: Dict[str, Any] = field(default_factory=dict)
class BaseIntentClassifier(ABC):
"""意图分类器基类"""
class BaseIntentReasoner(ABC):
"""意图推理器基类"""
@abstractmethod
def classify(self, text: str) -> Tuple[IntentType, float]:
def reason(
self,
query: str,
context: Optional[Dict[str, Any]] = None
) -> ReasoningResult:
"""
分类意图
推理意图,决定下一步动作
Args:
text: 输入文本
query: 用户查询
context: 上下文信息,可能包括:
- messages: 对话历史
- retrieved_docs: 已检索的文档
- previous_actions: 之前的动作
- user_id: 用户ID
- etc.
Returns:
(意图类型, 置信度)
"""
pass
@abstractmethod
def classify_with_scores(self, text: str) -> Dict[IntentType, float]:
"""
分类意图,返回所有类型的置信度
Args:
text: 输入文本
Returns:
{意图类型: 置信度}
ReasoningResult: 推理结果
"""
pass
class RuleBasedIntentClassifier(BaseIntentClassifier):
"""基于规则的意图分类"""
class RuleBasedReactReasoner(BaseIntentReasoner):
"""基于规则的 React 推理"""
def __init__(self):
self._rules: Dict[IntentType, Set[str]] = {}
self._initialize_default_rules()
# 检索触发关键词
self._retrieval_keywords = {
"什么", "怎么", "如何", "为什么", "", "", "多少",
"介绍", "解释", "说明", "资料", "文档", "查询", "搜索",
"find", "search", "what", "how", "why", "where", "who",
"tell me", "explain", "about", "information"
}
def _initialize_default_rules(self) -> None:
"""初始化默认规则"""
# 问候
self.add_rule(IntentType.GREETING, {
"你好", "您好", "hi", "hello", "hey", "早上好", "下午好", "晚上好", "哈喽"
})
# 告别
self.add_rule(IntentType.GOODBYE, {
"再见", "拜拜", "bye", "goodbye", "回见", "下次见", "再见了"
})
# 感谢
self.add_rule(IntentType.THANKS, {
"谢谢", "感谢", "多谢", "thanks", "thank you", "3q", "谢谢了"
})
# 确认
self.add_rule(IntentType.CONFIRM, {
"是的", "", "没错", "好的", "可以", "", "同意", "确认", "yes", "yep"
})
# 否认
self.add_rule(IntentType.DENY, {
"", "不是", "不对", "不行", "不要", "拒绝", "no", "nope", "没有"
})
# 提问
self.add_rule(IntentType.QUESTION, {
"?", "", "什么", "怎么", "如何", "为什么", "", "", "多少", "", ""
})
# 重新检索触发关键词
self._re_retrieval_keywords = {
"", "重新", "更多", "不够", "不足", "其他", "另外",
"没找到", "找不到", "没有", "不对", "不是",
"again", "more", "another", "other", "didn't find", "not enough"
}
def add_rule(self, intent_type: IntentType, keywords: Set[str]) -> None:
# 澄清触发关键词
self._clarify_keywords = {
"?", "", "哪个", "哪些", "哪位", "什么意思",
"请问", "能详细", "具体点", "举个例子"
}
# 工具调用关键词
self._tool_keywords = {
"天气", "weather", "邮件", "email", "联系人", "contact",
"翻译", "translate", "词典", "dictionary"
}
# 子图路由关键词映射
self._subgraph_keywords = {
"contact": {"通讯录", "联系人", "contact", "email", "邮件"},
"dictionary": {"词典", "单词", "翻译", "dictionary", "translate"},
"news_analysis": {"资讯", "新闻", "分析", "news", "report"},
}
# 直接回答模式(问候、感谢等)
self._direct_response_patterns = [
(r'^(你好|您好|hi|hello|hey|早上好|下午好|晚上好|哈喽)', ReasoningAction.DIRECT_RESPONSE),
(r'^(谢谢|感谢|多谢|thanks|thank you)', ReasoningAction.DIRECT_RESPONSE),
(r'^(再见|拜拜|bye|goodbye|回见)', ReasoningAction.DIRECT_RESPONSE),
]
def reason(
self,
query: str,
context: Optional[Dict[str, Any]] = None
) -> ReasoningResult:
"""
添加规则
Args:
intent_type: 意图类型
keywords: 关键词集合
基于规则的推理
"""
if intent_type not in self._rules:
self._rules[intent_type] = set()
self._rules[intent_type].update(keywords)
context = context or {}
query_lower = query.lower()
result = ReasoningResult(original_query=query)
def classify(self, text: str) -> Tuple[IntentType, float]:
"""
分类意图
# 1. 先检查是否是直接回答模式
for pattern, action in self._direct_response_patterns:
if re.match(pattern, query, re.IGNORECASE):
result.action = action
result.confidence = 0.95
result.reasoning = "检测到问候、感谢或告别语,直接回答"
return result
Args:
text: 输入文本
# 2. 检查是否需要路由到子图(优先级高于重新检索,避免"有没有"误触发)
for subgraph, keywords in self._subgraph_keywords.items():
if any(kw in query_lower for kw in keywords):
result.action = ReasoningAction.ROUTE_SUBGRAPH
result.confidence = 0.9
result.reasoning = f"检测到 {subgraph} 子图意图"
result.metadata["target_subgraph"] = subgraph
return result
Returns:
(意图类型, 置信度)
"""
scores = self.classify_with_scores(text)
if not scores:
return IntentType.UNKNOWN, 0.0
best_intent = max(scores.items(), key=lambda x: x[1])
return best_intent[0], best_intent[1]
def classify_with_scores(self, text: str) -> Dict[IntentType, float]:
"""
分类意图,返回所有类型的置信度
Args:
text: 输入文本
Returns:
{意图类型: 置信度}
"""
scores: Dict[IntentType, float] = {}
normalized_text = text.lower()
for intent_type, keywords in self._rules.items():
match_count = 0
for keyword in keywords:
if keyword.lower() in normalized_text:
match_count += 1
if match_count > 0:
scores[intent_type] = min(1.0, match_count / 3.0)
# 如果没有匹配返回UNKNOWN
if not scores:
scores[IntentType.UNKNOWN] = 0.5
return scores
class BaseEntityExtractor(ABC):
"""实体提取器基类"""
@abstractmethod
def extract(self, text: str) -> List[Entity]:
"""
提取实体
Args:
text: 输入文本
Returns:
实体列表
"""
pass
class RuleBasedEntityExtractor(BaseEntityExtractor):
"""基于规则的实体提取器"""
def __init__(self):
self._patterns: Dict[str, re.Pattern] = {} # 正则模式
self._keywords: Dict[str, Set[str]] = {} # 关键词列表
self._initialize_default_patterns()
def _initialize_default_patterns(self) -> None:
"""初始化默认模式"""
# 邮箱
self.add_regex_pattern(
"email",
r'[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}'
)
# 电话号码
self.add_regex_pattern(
"phone",
r'1[3-9]\d{9}'
)
# 日期(简单模式)
self.add_regex_pattern(
"date",
r'\d{4}[-/年]\d{1,2}[-/月]\d{1,2}[日号]?|\d{1,2}[-/月]\d{1,2}[日号]?'
)
# 数字
self.add_regex_pattern(
"number",
r'\d+\.?\d*'
# 3. 检查是否需要重新检索
has_re_retrieval = any(kw in query_lower for kw in self._re_retrieval_keywords)
# 同时检查上下文中是否有之前的检索结果但不够好
previous_retrieval = context.get("retrieved_docs")
if has_re_retrieval or (previous_retrieval and len(previous_retrieval) < 2):
result.action = ReasoningAction.RE_RETRIEVE_RAG
result.confidence = 0.85 if has_re_retrieval else 0.7
result.reasoning = "检测到需要重新检索的意图"
result.retrieval_config = RetrievalConfig(
need_retrieval=True,
need_re_retrieval=True,
retrieval_query=self._optimize_retrieval_query(query),
k=10 # 重新检索时返回更多结果
)
return result
def add_regex_pattern(self, entity_type: str, pattern: str) -> None:
# 4. 检查是否需要调用工具
has_tool = any(kw in query_lower for kw in self._tool_keywords)
if has_tool:
result.action = ReasoningAction.CALL_TOOL
result.confidence = 0.8
result.reasoning = "检测到工具调用意图"
return result
# 5. 检查是否需要澄清
has_clarify = any(kw in query_lower for kw in self._clarify_keywords)
# 或者查询太短、太模糊
if has_clarify or len(query.strip()) < 3:
result.action = ReasoningAction.CLARIFY
result.confidence = 0.75
result.reasoning = "检测到需要澄清的意图"
result.next_hints = [
"请提供更多细节",
"您想了解什么方面的内容?",
"能否具体说明一下?"
]
return result
# 6. 检查是否需要 RAG 检索
has_retrieval = any(kw in query_lower for kw in self._retrieval_keywords)
if has_retrieval or len(query.strip()) > 5:
result.action = ReasoningAction.RETRIEVE_RAG
result.confidence = 0.85 if has_retrieval else 0.6
result.reasoning = "检测到需要检索知识库的意图"
result.retrieval_config = RetrievalConfig(
need_retrieval=True,
retrieval_query=self._optimize_retrieval_query(query),
k=5
)
return result
# 7. 默认直接回答
result.action = ReasoningAction.DIRECT_RESPONSE
result.confidence = 0.6
result.reasoning = "默认直接回答模式"
return result
def _optimize_retrieval_query(self, query: str) -> str:
"""优化检索查询,去掉不必要的语气词"""
# 去掉常见的前缀
prefixes_to_remove = [
"请告诉我", "帮我查一下", "我想知道", "能不能告诉我",
"请问", "你知道", "帮我找", "搜索一下", "查询一下"
]
optimized = query
for prefix in prefixes_to_remove:
if optimized.startswith(prefix):
optimized = optimized[len(prefix):]
# 去掉常见的后缀
suffixes_to_remove = ["吗?", "呢?", "吧?", "", "", "", "", "?"]
for suffix in suffixes_to_remove:
if optimized.endswith(suffix):
optimized = optimized[:-len(suffix)]
return optimized.strip()
class LLMReactReasoner(BaseIntentReasoner):
"""
添加正则匹配规则
基于 LLM 的 React 推理器
使用大语言模型进行更智能的推理判断
"""
def __init__(self, llm_client=None):
"""
初始化 LLM 推理器
Args:
entity_type: 实体类型
pattern: 正则表达式
llm_client: LLM 客户端,需要支持调用方法
"""
self.llm_client = llm_client
self.rule_based = RuleBasedReactReasoner()
def reason(
self,
query: str,
context: Optional[Dict[str, Any]] = None
) -> ReasoningResult:
"""
使用 LLM 进行推理,失败时回退到规则推理
"""
try:
self._patterns[entity_type] = re.compile(pattern, re.IGNORECASE)
except re.error:
if self.llm_client:
return self._reason_with_llm(query, context)
except Exception:
pass
def add_keywords(self, entity_type: str, keywords: Set[str]) -> None:
"""
添加关键词匹配规则
# LLM 不可用或失败,回退到规则推理
return self.rule_based.reason(query, context)
Args:
entity_type: 实体类型
keywords: 关键词集合
"""
if entity_type not in self._keywords:
self._keywords[entity_type] = set()
self._keywords[entity_type].update(keywords)
def extract(self, text: str) -> List[Entity]:
"""
提取实体
Args:
text: 输入文本
Returns:
实体列表
"""
entities: List[Entity] = []
# 正则匹配
for entity_type, pattern in self._patterns.items():
for match in pattern.finditer(text):
entity = Entity(
entity_type=entity_type,
value=match.group(),
start_pos=match.start(),
end_pos=match.end(),
confidence=0.95
)
entities.append(entity)
# 关键词匹配
for entity_type, keywords in self._keywords.items():
for keyword in keywords:
start_idx = 0
while True:
pos = text.lower().find(keyword.lower(), start_idx)
if pos == -1:
break
entity = Entity(
entity_type=entity_type,
value=text[pos:pos + len(keyword)],
start_pos=pos,
end_pos=pos + len(keyword),
confidence=0.9
)
entities.append(entity)
start_idx = pos + len(keyword)
# 按位置排序
entities.sort(key=lambda e: e.start_pos)
return entities
class IntentRegistry:
"""意图注册器"""
def __init__(self):
self._intent_handlers: Dict[IntentType, Callable] = {}
def register(self, intent_type: IntentType, handler: Callable) -> None:
"""
注册意图处理器
Args:
intent_type: 意图类型
handler: 处理器
"""
self._intent_handlers[intent_type] = handler
def get_handler(self, intent_type: IntentType) -> Optional[Callable]:
"""
获取意图处理器
Args:
intent_type: 意图类型
Returns:
处理器,如果不存在返回 None
"""
return self._intent_handlers.get(intent_type)
class IntentParser:
"""完整的意图解析器"""
def __init__(
def _reason_with_llm(
self,
classifier: Optional[BaseIntentClassifier] = None,
extractor: Optional[BaseEntityExtractor] = None,
registry: Optional[IntentRegistry] = None
):
query: str,
context: Optional[Dict[str, Any]] = None
) -> ReasoningResult:
"""
初始化意图解析器
使用 LLM 进行推理(需要实现具体的 LLM 调用逻辑)
"""
# 这里是一个示例实现,实际项目需要连接真实的 LLM
prompt = self._build_reasoning_prompt(query, context)
# 模拟 LLM 返回(实际项目中替换为真实调用)
# 这里我们还是先调用规则推理作为示例
return self.rule_based.reason(query, context)
def _build_reasoning_prompt(self, query: str, context: Optional[Dict[str, Any]]) -> str:
"""构建推理提示词"""
context_str = ""
if context:
context_lines = []
if "messages" in context:
context_lines.append(f"对话历史: {len(context['messages'])}")
if "retrieved_docs" in context:
context_lines.append(f"已检索文档: {len(context['retrieved_docs'])}")
context_str = "\n".join(context_lines)
return f"""你是一个意图推理助手,需要判断用户的查询应该如何处理。
用户查询: {query}
上下文信息:
{context_str or '无额外上下文'}
请判断下一步应该做什么,可选动作:
1. DIRECT_RESPONSE - 直接回答,不需要额外信息
2. RETRIEVE_RAG - 需要调用知识库检索
3. RE_RETRIEVE_RAG - 需要重新检索更多/更好的结果
4. CALL_TOOL - 需要调用其他工具
5. CLARIFY - 需要澄清用户的问题
6. ROUTE_SUBGRAPH - 需要路由到子图
请以 JSON 格式输出你的判断。
"""
def create_react_reasoner(
use_llm: bool = False,
llm_client=None
) -> BaseIntentReasoner:
"""
创建 React 模式意图推理器工厂函数
Args:
classifier: 意图分类器
extractor: 实体提取器
registry: 意图注册器
"""
self.classifier = classifier or RuleBasedIntentClassifier()
self.extractor = extractor or RuleBasedEntityExtractor()
self.registry = registry or IntentRegistry()
use_llm: 是否使用 LLM 推理
llm_client: LLM 客户端实例
def parse(self, text: str) -> Intent:
Returns:
BaseIntentReasoner: 推理器实例
"""
解析文本,返回完整的意图对象
if use_llm:
return LLMReactReasoner(llm_client)
return RuleBasedReactReasoner()
# 便捷函数 - 直接推理
def react_reason(
query: str,
context: Optional[Dict[str, Any]] = None,
reasoner: Optional[BaseIntentReasoner] = None
) -> ReasoningResult:
"""
便捷函数:直接进行 React 推理
Args:
text: 输入文本
query: 用户查询
context: 上下文信息
reasoner: 可选的推理器实例
Returns:
意图对象
ReasoningResult: 推理结果
"""
# 分类意图
intent_type, confidence = self.classifier.classify(text)
if reasoner is None:
reasoner = create_react_reasoner()
return reasoner.reason(query, context)
# 提取实体
entities = self.extractor.extract(text)
# 构建意图对象
intent = Intent(
intent_type=intent_type,
confidence=confidence,
entities=entities,
original_text=text,
normalized_text=text.lower().strip()
)
# 从实体中提取参数
for entity in entities:
intent.parameters[entity.entity_type] = entity.value
return intent
def parse_and_execute(self, text: str, context: Optional[Dict[str, Any]] = None) -> Any:
# 条件路由辅助函数
def get_route_by_reasoning(result: ReasoningResult) -> str:
"""
解析文本并执行对应的处理器
根据推理结果获取路由字符串
Args:
text: 输入文本
context: 上下文
result: 推理结果
Returns:
执行结果
str: 路由标识
"""
intent = self.parse(text)
handler = self.registry.get_handler(intent.intent_type)
if handler:
return handler(intent, context or {})
return None
def create_default_intent_parser() -> IntentParser:
"""
创建默认配置的意图解析器
Returns:
配置好的意图解析器
"""
parser = IntentParser()
# 注册默认处理器
def greeting_handler(intent: Intent, context: Dict) -> str:
return "你好!很高兴为你服务。"
def thanks_handler(intent: Intent, context: Dict) -> str:
return "不客气!"
def goodbye_handler(intent: Intent, context: Dict) -> str:
return "再见!有需要随时找我。"
parser.registry.register(IntentType.GREETING, greeting_handler)
parser.registry.register(IntentType.THANKS, thanks_handler)
parser.registry.register(IntentType.GOODBYE, goodbye_handler)
return parser
action_to_route = {
ReasoningAction.DIRECT_RESPONSE: "direct_response",
ReasoningAction.RETRIEVE_RAG: "retrieve_rag",
ReasoningAction.RE_RETRIEVE_RAG: "re_retrieve_rag",
ReasoningAction.RERIEVE_RAG: "re_retrieve_rag", # 兼容旧拼写
ReasoningAction.CALL_TOOL: "call_tool",
ReasoningAction.CLARIFY: "clarify",
ReasoningAction.ROUTE_SUBGRAPH: result.metadata.get("target_subgraph", "unknown_subgraph"),
ReasoningAction.UNKNOWN: "unknown",
}
return action_to_route.get(result.action, "unknown")

View File

@@ -1,21 +1,63 @@
"""
Graph 子模块
Graph 子模块 - React 模式增强版(带超时重试)
"""
from .graph_builder import GraphBuilder
from .subgraph_builder import build_main_graph
from .subgraph_builder import build_main_graph, build_react_main_graph
from .react_nodes import (
init_state_node,
react_reason_node,
rag_retrieve_node,
error_handling_node,
final_response_node,
route_by_reasoning
)
from .state import (
MessagesState,
GraphContext,
MainGraphState,
CurrentAction
CurrentAction,
ErrorRecord,
ErrorSeverity
)
from .retry_utils import (
RetryConfig,
RetryResult,
RetryStrategy,
with_retry,
with_async_retry,
create_retry_wrapper_for_node,
RAG_RETRY_CONFIG,
SUBGRAPH_RETRY_CONFIG
)
__all__ = [
# 旧版兼容性
"GraphBuilder",
"build_main_graph",
"MessagesState",
"GraphContext",
"MainGraphState",
"CurrentAction"
"CurrentAction",
# 新版 React 模式
"build_react_main_graph",
"init_state_node",
"react_reason_node",
"rag_retrieve_node",
"error_handling_node",
"final_response_node",
"route_by_reasoning",
"ErrorRecord",
"ErrorSeverity",
# 超时和重试工具
"RetryConfig",
"RetryResult",
"RetryStrategy",
"with_retry",
"with_async_retry",
"create_retry_wrapper_for_node",
"RAG_RETRY_CONFIG",
"SUBGRAPH_RETRY_CONFIG"
]

View File

@@ -0,0 +1,388 @@
"""
React 模式节点模块 - 带超时和重试功能
包含:
- react_reason_node: 使用 intent.py 进行推理
- rag_retrieve_node: RAG 检索节点(带重试)
- error_handling_node: 错误处理节点
- final_response_node: 最终回答节点
"""
import sys
import time
from typing import Dict, Any, Optional
from datetime import datetime
from functools import wraps
# 导入我们的 intent.py
from ..agent_subgraphs.common.intent import (
react_reason,
get_route_by_reasoning,
ReasoningAction,
RetrievalConfig,
ReasoningResult
)
from ..agent_subgraphs.common.state_base import StateUtils
from .state import MainGraphState, ErrorRecord, ErrorSeverity
from .retry_utils import (
RetryConfig,
RetryResult,
with_retry,
create_retry_wrapper_for_node,
RAG_RETRY_CONFIG,
SUBGRAPH_RETRY_CONFIG
)
def get_rag_tool():
"""
获取 RAG 工具(延迟导入,避免循环依赖)
"""
try:
# 尝试导入现有的 RAG 工具
from ..rag.tools import create_rag_tool_sync
# 注意:这里简化处理,实际使用时应该从全局获取初始化好的工具
return None # 先返回 None后面通过注入方式
except Exception:
return None
# ========== 1. React 推理节点 ==========
def react_reason_node(state: MainGraphState) -> MainGraphState:
"""
React 模式推理节点:判断下一步做什么
Returns: 更新后的状态
"""
state.current_phase = "react_reasoning"
state.reasoning_step += 1
# 检查是否超过最大步数
if state.reasoning_step > state.max_steps:
state.current_phase = "max_steps_exceeded"
state.final_result = (
f"❌ 推理步数超过限制(最大 {state.max_steps} 步),"
f"已执行 {state.reasoning_step - 1} 步。"
f"请简化您的问题或分批提问。"
)
state.success = False
return state
# 准备上下文
context = {
"retrieved_docs": state.rag_docs,
"previous_actions": [h.get("action") for h in state.reasoning_history],
"messages": state.messages,
"errors": state.errors
}
# 使用 intent.py 进行推理
result: ReasoningResult = react_reason(state.user_query, context)
# 记录推理历史
state.reasoning_history.append({
"step": state.reasoning_step,
"action": result.action.name,
"confidence": result.confidence,
"reasoning": result.reasoning,
"timestamp": datetime.now().isoformat()
})
# 更新状态
state.debug_info["last_reasoning"] = {
"action": result.action.name,
"confidence": result.confidence,
"reasoning": result.reasoning
}
# 保存推理结果到状态(供条件路由使用)
state.debug_info["reasoning_result"] = result
# 确定下一步动作
state.last_action = result.action.name
return state
# ========== 2. RAG 检索节点(带超时和重试) ==========
def _rag_retrieve_core(state: MainGraphState) -> MainGraphState:
"""
RAG 检索核心逻辑(不带重试)
"""
# 获取推理结果中的检索配置
reasoning_result: Optional[ReasoningResult] = state.debug_info.get("reasoning_result")
retrieval_query = state.user_query
if reasoning_result and reasoning_result.retrieval_config:
cfg: RetrievalConfig = reasoning_result.retrieval_config
if cfg.retrieval_query:
retrieval_query = cfg.retrieval_query
# 尝试获取 RAG 工具并调用
# 这里演示如何调用,实际使用时需要确保 RAG 已初始化
# 暂时用模拟数据
state.rag_context = (
f"[模拟RAG检索结果]\n"
f"查询: {retrieval_query}\n"
f"这是一个来自知识库的示例回答。"
)
state.rag_docs = [
{"source": "doc1.txt", "content": "示例内容1"},
{"source": "doc2.txt", "content": "示例内容2"}
]
state.rag_retrieved = True
state.success = True
return state
def rag_retrieve_node(state: MainGraphState) -> MainGraphState:
"""
RAG 检索节点:带超时和重试
Returns: 更新后的状态
"""
state.current_phase = "rag_retrieving"
# 使用重试包装器
start_time = time.time()
last_error = None
for attempt in range(RAG_RETRY_CONFIG.max_retries + 1):
try:
# 执行核心逻辑
result = _rag_retrieve_core(state)
# 成功
state.debug_info["rag_retrieval"] = {
"attempt": attempt + 1,
"success": True,
"time": time.time() - start_time
}
return result
except Exception as e:
last_error = e
if attempt >= RAG_RETRY_CONFIG.max_retries:
break
# 等待后重试(指数退避)
delay = RAG_RETRY_CONFIG.base_delay * (2 ** attempt)
time.sleep(min(delay, RAG_RETRY_CONFIG.max_delay))
# 所有重试都失败,记录结构化错误
error_record = ErrorRecord(
error_type="RAGRetrievalError",
error_message=str(last_error) if last_error else "RAG 检索超时",
severity=ErrorSeverity.WARNING,
source="rag_retrieve_node",
timestamp=datetime.now().isoformat(),
retry_count=RAG_RETRY_CONFIG.max_retries,
max_retries=RAG_RETRY_CONFIG.max_retries,
context={
"query": state.user_query,
"total_time": time.time() - start_time,
"timeout": RAG_RETRY_CONFIG.timeout
}
)
state.errors.append(error_record)
state.current_error = error_record
state.current_phase = "error_handling"
return state
# ========== 3. 错误处理节点 ==========
def error_handling_node(state: MainGraphState) -> MainGraphState:
"""
错误处理节点:处理子图/工具调用错误
返回结构化错误信息,格式如下:
{
"tool/node": "...",
"status": "failed",
"error": "...",
"retries_exhausted": true/false,
"suggestion": "..."
}
"""
state.current_phase = "error_handling"
if not state.current_error:
# 没有错误,直接返回
state.current_phase = "react_reasoning"
return state
error = state.current_error
# 更新错误状态
state.error_message = f"{error.error_type}: {error.error_message}"
# 记录结构化错误信息(用于 LLM 决策)
structured_error = {
"tool": error.source,
"status": "failed",
"error": error.error_message,
"retries_exhausted": error.retry_count >= error.max_retries,
"retry_count": error.retry_count,
"max_retries": error.max_retries
}
# 根据错误类型添加建议
if "RAG" in error.error_type:
structured_error["suggestion"] = "尝试重新表述问题或直接询问,我会用现有知识回答"
elif "subgraph" in error.source or "contact" in error.source:
structured_error["suggestion"] = "子图执行失败,请尝试简化查询或使用其他功能"
elif "timeout" in error.error_message.lower():
structured_error["suggestion"] = "请求超时,请稍后再试或简化查询"
else:
structured_error["suggestion"] = "请尝试其他方式提问"
state.debug_info["structured_error"] = structured_error
# 策略1: 检查是否可以重试
can_retry = (
error.severity in [ErrorSeverity.WARNING, ErrorSeverity.ERROR]
and error.retry_count < error.max_retries
)
if can_retry:
# 重试策略
error.retry_count += 1
state.retry_action = error.source
state.debug_info["retry_count"] = error.retry_count
if "RAG" in error.error_type:
state.last_action = "RE_RETRIEVE_RAG"
elif "subgraph" in error.source:
state.last_action = "DIRECT_RESPONSE"
else:
state.last_action = "REASON"
state.current_phase = "retrying"
return state
# 策略2: 无法重试,尝试降级方案
if error.severity != ErrorSeverity.FATAL:
# 降级到直接回答模式
state.final_result = (
f"⚠️ 遇到一些问题:\n"
f"```json\n{structured_error}\n```\n"
f"但我会尽力用现有信息回答您。"
)
state.success = True
state.current_phase = "finalizing"
return state
# 策略3: 致命错误,无法继续
state.final_result = (
f"❌ 服务暂时不可用,请稍后再试。\n"
f"```json\n{structured_error}\n```"
)
state.success = False
state.current_phase = "finalizing"
return state
# ========== 4. 最终回答节点 ==========
def final_response_node(state: MainGraphState) -> MainGraphState:
"""
最终回答节点:整理并生成最终回答
"""
state.current_phase = "finalizing"
# 如果已经有 final_result 了,直接返回
if state.final_result:
state.current_phase = "done"
return state
# 构建最终回答
parts = []
# 添加 RAG 上下文(如果有)
if state.rag_context:
parts.append(state.rag_context)
parts.append("---")
# 添加子图结果(如果有)
if state.contact_result and state.contact_result.get("final_result"):
parts.append(state.contact_result["final_result"])
if state.dictionary_result and state.dictionary_result.get("final_result"):
parts.append(state.dictionary_result["final_result"])
if state.news_result and state.news_result.get("final_result"):
parts.append(state.news_result["final_result"])
# 如果都没有,用默认回答
if not parts:
parts.append(f"我理解了您的问题:{state.user_query}")
state.final_result = "\n".join(parts)
state.success = True
state.current_phase = "done"
state.end_time = datetime.now().isoformat()
return state
# ========== 5. 初始化状态节点 ==========
def init_state_node(state: MainGraphState) -> MainGraphState:
"""
初始化状态节点:在流程开始时设置初始值
"""
state.current_phase = "initializing"
state.reasoning_step = 0
state.start_time = datetime.now().isoformat()
# 从 messages 中提取用户查询(如果 user_query 为空)
if not state.user_query and state.messages:
last_msg = state.messages[-1]
state.user_query = getattr(last_msg, "content", str(last_msg))
return state
# ========== 6. 条件路由函数 ==========
def route_by_reasoning(state: MainGraphState) -> str:
"""
根据推理结果决定下一步路由
Returns: 路由字符串
"""
# 先检查特殊情况
if state.current_phase == "max_steps_exceeded":
return "final_response"
if state.current_phase == "error_handling" or state.current_error:
return "handle_error"
if state.current_phase == "finalizing" or state.current_phase == "done":
return "final_response"
if state.current_phase == "retrying":
# 重试路由
if state.retry_action and "rag" in state.retry_action.lower():
return "rag_retrieve"
return "react_reason"
# 获取推理结果
reasoning_result: Optional[ReasoningResult] = state.debug_info.get("reasoning_result")
if not reasoning_result:
# 没有推理结果,直接结束
return "final_response"
# 使用 intent.py 提供的路由函数
route = get_route_by_reasoning(reasoning_result)
# 映射到我们的节点名称
route_mapping = {
"direct_response": "final_response",
"retrieve_rag": "rag_retrieve",
"re_retrieve_rag": "rag_retrieve",
"clarify": "final_response", # 简化:澄清直接回答让用户补充
"call_tool": "final_response", # 简化:工具调用暂未实现
"contact": "contact_subgraph",
"dictionary": "dictionary_subgraph",
"news_analysis": "news_analysis_subgraph",
}
return route_mapping.get(route, "final_response")

View File

@@ -0,0 +1,332 @@
"""
超时和重试工具模块
为 React 模式提供超时控制和重试机制
"""
import time
import asyncio
from functools import wraps
from typing import Callable, Any, Optional, Type, Tuple, Union
from dataclasses import dataclass, field
from enum import Enum, auto
class RetryStrategy(Enum):
"""重试策略"""
FIXED = auto() # 固定间隔
EXPONENTIAL = auto() # 指数退避
LINEAR = auto() # 线性增长
@dataclass
class RetryConfig:
"""重试配置"""
max_retries: int = 3 # 最大重试次数
base_delay: float = 1.0 # 基础延迟(秒)
max_delay: float = 10.0 # 最大延迟(秒)
strategy: RetryStrategy = RetryStrategy.EXPONENTIAL
timeout: Optional[float] = 30.0 # 单次调用超时(秒)
recoverable_exceptions: Tuple[Type[Exception], ...] = field(
default_factory=lambda: (Exception,)
)
unrecoverable_exceptions: Tuple[Type[Exception], ...] = field(
default_factory=tuple
)
@dataclass
class RetryResult:
"""重试结果"""
success: bool
result: Any = None
error: Optional[Exception] = None
retry_count: int = 0
total_time: float = 0.0
timed_out: bool = False
# ========== 同步重试装饰器 ==========
def with_retry(
config: Optional[RetryConfig] = None,
max_retries: int = 3,
timeout: Optional[float] = 30.0,
base_delay: float = 1.0,
on_retry: Optional[Callable[[int, Exception], None]] = None
):
"""
同步重试装饰器
Args:
config: 重试配置对象
max_retries: 最大重试次数(如果没有 config
timeout: 单次调用超时(秒)
base_delay: 基础延迟(秒)
on_retry: 重试回调函数(retry_count, exception)
"""
if config is None:
config = RetryConfig(
max_retries=max_retries,
timeout=timeout,
base_delay=base_delay
)
def decorator(func: Callable) -> Callable:
@wraps(func)
def wrapper(*args, **kwargs) -> RetryResult:
start_time = time.time()
last_error = None
for attempt in range(config.max_retries + 1):
try:
# 执行函数(带超时)
if config.timeout:
# 使用信号量或线程实现超时(简化版)
result = func(*args, **kwargs)
else:
result = func(*args, **kwargs)
# 成功
total_time = time.time() - start_time
return RetryResult(
success=True,
result=result,
retry_count=attempt,
total_time=total_time
)
except Exception as e:
last_error = e
# 检查是否是不可恢复的异常
if isinstance(e, config.unrecoverable_exceptions):
break
# 检查是否达到最大重试次数
if attempt >= config.max_retries:
break
# 计算延迟
delay = _calculate_delay(attempt, config)
# 回调通知
if on_retry:
on_retry(attempt + 1, e)
# 等待
time.sleep(delay)
# 所有重试都失败
total_time = time.time() - start_time
return RetryResult(
success=False,
error=last_error,
retry_count=config.max_retries,
total_time=total_time
)
return wrapper
return decorator
# ========== 异步重试装饰器 ==========
def with_async_retry(
config: Optional[RetryConfig] = None,
max_retries: int = 3,
timeout: Optional[float] = 30.0,
base_delay: float = 1.0,
on_retry: Optional[Callable[[int, Exception], None]] = None
):
"""
异步重试装饰器
"""
if config is None:
config = RetryConfig(
max_retries=max_retries,
timeout=timeout,
base_delay=base_delay
)
def decorator(func: Callable) -> Callable:
@wraps(func)
async def wrapper(*args, **kwargs) -> RetryResult:
start_time = time.time()
last_error = None
for attempt in range(config.max_retries + 1):
try:
# 执行函数(带超时)
if config.timeout:
result = await asyncio.wait_for(
func(*args, **kwargs),
timeout=config.timeout
)
else:
result = await func(*args, **kwargs)
# 成功
total_time = time.time() - start_time
return RetryResult(
success=True,
result=result,
retry_count=attempt,
total_time=total_time
)
except asyncio.TimeoutError as e:
last_error = e
timed_out = True
except Exception as e:
last_error = e
timed_out = False
# 检查是否是不可恢复的异常
if isinstance(e, config.unrecoverable_exceptions):
break
# 检查是否达到最大重试次数
if attempt >= config.max_retries:
break
# 计算延迟
delay = _calculate_delay(attempt, config)
# 回调通知
if on_retry:
on_retry(attempt + 1, last_error)
# 等待
await asyncio.sleep(delay)
# 所有重试都失败
total_time = time.time() - start_time
return RetryResult(
success=False,
error=last_error,
retry_count=config.max_retries,
total_time=total_time,
timed_out=isinstance(last_error, asyncio.TimeoutError)
)
return wrapper
return decorator
# ========== 辅助函数 ==========
def _calculate_delay(attempt: int, config: RetryConfig) -> float:
"""计算延迟时间"""
if config.strategy == RetryStrategy.FIXED:
delay = config.base_delay
elif config.strategy == RetryStrategy.LINEAR:
delay = config.base_delay * (attempt + 1)
elif config.strategy == RetryStrategy.EXPONENTIAL:
delay = config.base_delay * (2 ** attempt)
else:
delay = config.base_delay
# 不超过最大延迟
return min(delay, config.max_delay)
# ========== 为 React 节点设计的超时重试包装器 ==========
def create_retry_wrapper_for_node(
node_func: Callable,
node_name: str,
max_retries: int = 2,
timeout: float = 30.0
):
"""
为 React 节点创建带重试和超时的包装器
Args:
node_func: 原始节点函数
node_name: 节点名称(用于错误标识)
max_retries: 最大重试次数
timeout: 单次执行超时
Returns: 包装后的节点函数
"""
config = RetryConfig(
max_retries=max_retries,
timeout=timeout,
strategy=RetryStrategy.EXPONENTIAL
)
@wraps(node_func)
def wrapped_node(state):
# 记录开始时间
start_time = time.time()
# 重试循环
last_error = None
for attempt in range(config.max_retries + 1):
try:
# 执行节点
result = node_func(state)
# 检查节点是否报告了错误
if hasattr(state, "current_error") and state.current_error:
# 节点内部报告了错误,继续重试
last_error = Exception(state.current_error.error_message)
if attempt < config.max_retries:
delay = _calculate_delay(attempt, config)
time.sleep(delay)
continue
# 成功
return result
except Exception as e:
last_error = e
if attempt >= config.max_retries:
break
# 等待后重试
delay = _calculate_delay(attempt, config)
time.sleep(delay)
# 所有重试都失败,更新状态错误信息
from .state import ErrorRecord, ErrorSeverity
error_record = ErrorRecord(
error_type=f"{node_name}TimeoutError",
error_message=str(last_error) if last_error else f"{node_name} 执行超时",
severity=ErrorSeverity.ERROR,
source=node_name,
retry_count=config.max_retries,
max_retries=config.max_retries,
context={
"timeout": timeout,
"total_time": time.time() - start_time
}
)
if hasattr(state, "errors"):
state.errors.append(error_record)
if hasattr(state, "current_error"):
state.current_error = error_record
if hasattr(state, "error_message"):
state.error_message = str(last_error)
if hasattr(state, "current_phase"):
state.current_phase = "error_handling"
return state
return wrapped_node
# ========== 预配置的 RAG 重试配置 ==========
RAG_RETRY_CONFIG = RetryConfig(
max_retries=2,
timeout=60.0, # RAG 可以容忍稍长的超时
base_delay=2.0,
strategy=RetryStrategy.EXPONENTIAL
)
# ========== 预配置的子图重试配置 ==========
SUBGRAPH_RETRY_CONFIG = RetryConfig(
max_retries=1, # 子图通常不适合多次重试
timeout=120.0, # 子图执行时间较长
base_delay=3.0
)

View File

@@ -1,10 +1,10 @@
"""
主图状态定义 - 扩展
Main Graph State Definition - Extended
主图状态定义 - React 模式增强
Main Graph State Definition - React Mode Enhanced
"""
from enum import Enum, auto
from typing import Optional, Dict, Any, Annotated, Sequence, TypedDict
from typing import Optional, Dict, Any, Annotated, Sequence, TypedDict, List
from dataclasses import dataclass, field
from langgraph.graph import add_messages
from langchain_core.messages import BaseMessage
@@ -33,16 +33,38 @@ class CurrentAction(Enum):
CONTACT = auto()
class ErrorSeverity(Enum):
"""错误严重程度"""
INFO = auto() # 信息级别,继续执行
WARNING = auto() # 警告级别,可以重试
ERROR = auto() # 错误级别,需要处理
FATAL = auto() # 致命错误,终止执行
@dataclass
class ErrorRecord:
"""错误记录"""
error_type: str
error_message: str
severity: ErrorSeverity = ErrorSeverity.ERROR
source: str = "" # 来源:哪个节点/子图/工具
timestamp: str = ""
retry_count: int = 0 # 已重试次数
max_retries: int = 3 # 最大重试次数
context: Dict[str, Any] = field(default_factory=dict) # 错误上下文
@dataclass
class MainGraphState:
"""
主图状态 - 兼容旧代码 + 新增子图功能
主图状态 - React 循环推理版本
包含:
1. 旧代码的MessagesState兼容性字段
2. 主图控制字段
3. 子图结果占位
4. 用户信息
2. React 推理控制字段
3. 循环和错误处理
4. 子图结果占位
5. 用户信息
"""
# ========== 兼容性字段保留旧的MessagesState ==========
messages: Annotated[Sequence[BaseMessage], add_messages] = field(default_factory=list)
@@ -55,6 +77,22 @@ class MainGraphState:
current_action: CurrentAction = CurrentAction.NONE # 当前操作
intent_confidence: float = 0.0 # 意图识别置信度
# ========== React 推理专用字段 ==========
reasoning_step: int = 0 # 当前推理步数
max_steps: int = 40 # 最大推理步数≤40
last_action: str = "" # 上一步动作
reasoning_history: List[Dict[str, Any]] = field(default_factory=list) # 推理历史
# ========== RAG 相关字段 ==========
rag_context: str = "" # RAG 检索到的上下文
rag_retrieved: bool = False # 是否已经检索过
rag_docs: List[Dict[str, Any]] = field(default_factory=list) # 检索到的文档
# ========== 错误处理字段 ==========
errors: List[ErrorRecord] = field(default_factory=list) # 错误列表
current_error: Optional[ErrorRecord] = None # 当前错误
retry_action: Optional[str] = None # 重试动作
# ========== 子图结果占位 ==========
news_result: Optional[Dict[str, Any]] = None # 资讯子图结果
dictionary_result: Optional[Dict[str, Any]] = None # 词典子图结果

View File

@@ -1,157 +1,193 @@
"""
子图整合主图构建器
Subgraph Integration Main Graph Builder
React 模式主图构建器 - 完整循环推理版本
Main Graph Builder - Full React Mode with Loop Reasoning
"""
from langgraph.graph import StateGraph, START, END
from typing import Dict, Any
from .state import MainGraphState, CurrentAction
from .react_nodes import (
init_state_node,
react_reason_node,
rag_retrieve_node,
error_handling_node,
final_response_node,
route_by_reasoning
)
from ..agent_subgraphs.contact import build_contact_subgraph
from ..agent_subgraphs.dictionary import build_dictionary_subgraph
from ..agent_subgraphs.news_analysis import build_news_analysis_subgraph
def parse_user_intent(state: MainGraphState) -> MainGraphState:
# ========== 子图包装器(处理子图错误传递) ==========
def wrap_subgraph_for_error_handling(subgraph, name: str):
"""
解析用户意图节点
包装子图,使其错误能传递给主图
确定该路由到哪个子图
Args:
subgraph: 编译好的子图
name: 子图名称(用于错误标识)
Returns: 包装后的节点函数
"""
state.current_phase = "intent_parsing"
def wrapped_node(state: MainGraphState) -> MainGraphState:
try:
# 调用子图
result = subgraph.invoke(state)
# 从messages中提取用户查询如果user_query为空
if not state.user_query and state.messages:
# 获取最后一条消息的内容
last_msg = state.messages[-1]
state.user_query = last_msg.content
# 更新主图状态
if name == "contact":
state.contact_result = result
elif name == "dictionary":
state.dictionary_result = result
elif name == "news_analysis":
state.news_result = result
query_lower = state.user_query.lower()
# 简单的关键词匹配
if any(keyword in query_lower for keyword in ["通讯录", "联系人", "contact", "email"]):
state.current_action = CurrentAction.CONTACT
state.intent_confidence = 0.9
elif any(keyword in query_lower for keyword in ["词典", "单词", "翻译", "dictionary", "translate"]):
state.current_action = CurrentAction.DICTIONARY
state.intent_confidence = 0.9
elif any(keyword in query_lower for keyword in ["资讯", "新闻", "分析", "news", "report"]):
state.current_action = CurrentAction.NEWS_ANALYSIS
state.intent_confidence = 0.9
else:
# 默认是普通聊天
state.current_action = CurrentAction.GENERAL_CHAT
state.intent_confidence = 0.8
return state
def route_to_subgraph(state: MainGraphState) -> str:
"""
条件路由:决定路由到哪个子图
"""
if state.current_action == CurrentAction.NONE:
return "general_chat"
elif state.current_action == CurrentAction.GENERAL_CHAT:
return "general_chat"
elif state.current_action == CurrentAction.CONTACT:
return "contact_subgraph"
elif state.current_action == CurrentAction.DICTIONARY:
return "dictionary_subgraph"
elif state.current_action == CurrentAction.NEWS_ANALYSIS:
return "news_analysis_subgraph"
else:
return "general_chat"
def general_chat_node(state: MainGraphState) -> MainGraphState:
"""
普通聊天节点
目前是占位符后续整合旧的LLM调用逻辑
"""
state.current_phase = "general_chat"
state.final_result = f"普通聊天模式:{state.user_query}"
# 标记成功
state.success = True
return state
except Exception as e:
# 捕获子图错误,传递给主图
from .state import ErrorRecord, ErrorSeverity
from datetime import datetime
def integrate_results(state: MainGraphState) -> MainGraphState:
"""
整合子图结果节点
"""
state.current_phase = "integrating"
error_record = ErrorRecord(
error_type=f"{name}SubgraphError",
error_message=str(e),
severity=ErrorSeverity.WARNING,
source=f"{name}_subgraph",
timestamp=datetime.now().isoformat(),
retry_count=0,
max_retries=1,
context={"user_query": state.user_query}
)
state.errors.append(error_record)
state.current_error = error_record
state.current_phase = "error_handling"
state.success = False
# 整合通讯录子图结果
if state.contact_result:
state.final_result = state.contact_result.get("final_result", "")
# 整合词典子图结果
elif state.dictionary_result:
state.final_result = state.dictionary_result.get("final_result", "")
# 整合资讯子图结果
elif state.news_result:
state.final_result = state.news_result.get("final_result", "")
else:
# 没有子图结果
if not state.final_result:
state.final_result = "处理完成"
state.current_phase = "done"
return state
return wrapped_node
def build_main_graph() -> StateGraph:
# ========== 主图构建 ==========
def build_react_main_graph() -> StateGraph:
"""
构建整合了子图的主图
构建完整的 React 模式主图
Returns:
配置好的 StateGraph
流程:
START
init_state (初始化)
react_reason (推理) ←──────────────┐
↓ │
条件路由 │
├─→ rag_retrieve →───────────────┤
├─→ contact_subgraph →───────────┤
├─→ dictionary_subgraph →────────┤
├─→ news_analysis_subgraph →─────┤
├─→ handle_error → (重试或结束) ──┤
└─→ final_response
END
"""
# 创建图
graph = StateGraph(MainGraphState)
# 添加节点
graph.add_node("parse_intent", parse_user_intent)
graph.add_node("general_chat", general_chat_node)
graph.add_node("integrate_results", integrate_results)
# ========== 添加节点 ==========
# 添加子图节点
# 1. 初始化节点
graph.add_node("init_state", init_state_node)
# 2. React 推理节点
graph.add_node("react_reason", react_reason_node)
# 3. RAG 检索节点
graph.add_node("rag_retrieve", rag_retrieve_node)
# 4. 错误处理节点
graph.add_node("handle_error", error_handling_node)
# 5. 最终回答节点
graph.add_node("final_response", final_response_node)
# ========== 添加子图节点 ==========
# 构建并包装子图(带错误处理)
contact_graph = build_contact_subgraph()
dictionary_graph = build_dictionary_subgraph()
news_analysis_graph = build_news_analysis_subgraph()
graph.add_node("contact_subgraph", contact_graph.compile())
graph.add_node("dictionary_subgraph", dictionary_graph.compile())
graph.add_node("news_analysis_subgraph", news_analysis_graph.compile())
graph.add_node(
"contact_subgraph",
wrap_subgraph_for_error_handling(contact_graph.compile(), "contact")
)
graph.add_node(
"dictionary_subgraph",
wrap_subgraph_for_error_handling(dictionary_graph.compile(), "dictionary")
)
graph.add_node(
"news_analysis_subgraph",
wrap_subgraph_for_error_handling(news_analysis_graph.compile(), "news_analysis")
)
# 添加边
# 从START开始
graph.add_edge(START, "parse_intent")
# ========== 添加边 ==========
# 从parse_intent根据条件路由
# 1. START → init_state
graph.add_edge(START, "init_state")
# 2. init_state → react_reason
graph.add_edge("init_state", "react_reason")
# 3. 条件路由react_reason → 各分支
graph.add_conditional_edges(
"parse_intent",
route_to_subgraph,
"react_reason",
route_by_reasoning,
{
"general_chat": "general_chat",
# 检索分支 → 检索后回到推理
"rag_retrieve": "rag_retrieve",
# 子图分支 → 子图后回到推理
"contact_subgraph": "contact_subgraph",
"dictionary_subgraph": "dictionary_subgraph",
"news_analysis_subgraph": "news_analysis_subgraph",
# 错误处理分支
"handle_error": "handle_error",
# 最终回答分支
"final_response": "final_response",
}
)
# 从普通聊天和子图到结果整合
graph.add_edge("general_chat", "integrate_results")
graph.add_edge("contact_subgraph", "integrate_results")
graph.add_edge("dictionary_subgraph", "integrate_results")
graph.add_edge("news_analysis_subgraph", "integrate_results")
# 4. 循环边:检索/子图/错误处理 后 → 回到推理
graph.add_edge("rag_retrieve", "react_reason")
graph.add_edge("contact_subgraph", "react_reason")
graph.add_edge("dictionary_subgraph", "react_reason")
graph.add_edge("news_analysis_subgraph", "react_reason")
graph.add_edge("handle_error", "react_reason") # 错误处理后可能重试
# 最终到END
graph.add_edge("integrate_results", END)
# 5. 最终边final_response → END
graph.add_edge("final_response", END)
return graph
# ========== 兼容性:保留旧的函数名 ==========
def build_main_graph() -> StateGraph:
"""
兼容性函数:旧代码调用 build_main_graph() 时返回 React 版本
"""
return build_react_main_graph()
# ========== 导出 ==========
__all__ = [
"build_react_main_graph",
"build_main_graph",
"wrap_subgraph_for_error_handling"
]