""" 状态基类工具模块 提供类型安全的 LangGraph 状态基类和常用状态操作工具 功能: 1. BaseState - 基础状态基类,包含通用字段(消息、token统计、耗时等) 2. StateUtils - 状态操作工具类,提供常用的状态访问和修改方法 3. TypedStateBuilder - 类型化状态构建器,支持链式创建自定义状态 4. StateValidation - 状态验证工具,确保状态完整性 """ from typing import Dict, Any, List, Optional from dataclasses import dataclass, field from datetime import datetime from enum import Enum, auto class Phase(Enum): """执行阶段枚举""" INIT = auto() INTENT_PARSING = auto() EXECUTING = auto() FORMATTING = auto() COMPLETED = auto() ERROR = auto() @dataclass class TokenUsage: """Token 使用统计""" prompt_tokens: int = 0 completion_tokens: int = 0 total_tokens: int = 0 def add(self, other: 'TokenUsage') -> 'TokenUsage': """累加另一个统计""" return TokenUsage( prompt_tokens=self.prompt_tokens + other.prompt_tokens, completion_tokens=self.completion_tokens + other.completion_tokens, total_tokens=self.total_tokens + other.total_tokens ) @dataclass class BaseState: """ 基础状态基类 所有子图的状态都应继承此类 """ # 核心字段 user_query: str = "" user_id: str = "default" thread_id: Optional[str] = None # 执行阶段 current_phase: Phase = Phase.INIT phase_history: List[Phase] = field(default_factory=list) # 结果 final_result: str = "" success: bool = True error_message: str = "" # 统计 token_usage: TokenUsage = field(default_factory=TokenUsage) start_time: Optional[datetime] = None end_time: Optional[datetime] = None # 元数据 metadata: Dict[str, Any] = field(default_factory=dict) def __post_init__(self): """初始化后调用""" if self.start_time is None: self.start_time = datetime.now() if not self.phase_history: self.phase_history.append(self.current_phase) def transition_to(self, phase: Phase) -> None: """转换到新阶段""" self.current_phase = phase self.phase_history.append(phase) def complete(self, result: str, success: bool = True) -> None: """完成执行""" self.final_result = result self.success = success self.end_time = datetime.now() self.transition_to(Phase.COMPLETED) def fail(self, error: str) -> None: """执行失败""" self.error_message = error self.success = False self.end_time = datetime.now() self.transition_to(Phase.ERROR) @property def elapsed_time(self) -> float: """获取耗时(秒)""" if self.start_time and self.end_time: return (self.end_time - self.start_time).total_seconds() return 0.0 class StateUtils: """状态操作工具类""" @staticmethod def merge_metadata(base: Dict[str, Any], overlay: Dict[str, Any]) -> Dict[str, Any]: """合并元数据""" result = base.copy() result.update(overlay) return result @staticmethod def create_snapshot(state: BaseState) -> Dict[str, Any]: """创建状态快照(用于调试)""" return { "user_query": state.user_query, "user_id": state.user_id, "current_phase": state.current_phase.name, "success": state.success, "elapsed_time": state.elapsed_time }