2026-04-25 13:24:50 +08:00
|
|
|
|
"""
|
|
|
|
|
|
状态基类工具模块
|
|
|
|
|
|
提供类型安全的 LangGraph 状态基类和常用状态操作工具
|
|
|
|
|
|
|
|
|
|
|
|
功能:
|
|
|
|
|
|
1. BaseState - 基础状态基类,包含通用字段(消息、token统计、耗时等)
|
|
|
|
|
|
2. StateUtils - 状态操作工具类,提供常用的状态访问和修改方法
|
|
|
|
|
|
3. TypedStateBuilder - 类型化状态构建器,支持链式创建自定义状态
|
|
|
|
|
|
4. StateValidation - 状态验证工具,确保状态完整性
|
|
|
|
|
|
"""
|
2026-04-25 20:02:20 +08:00
|
|
|
|
|
|
|
|
|
|
from typing import Dict, Any, List, Optional
|
|
|
|
|
|
from dataclasses import dataclass, field
|
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
|
from enum import Enum, auto
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Phase(Enum):
|
|
|
|
|
|
"""执行阶段枚举"""
|
|
|
|
|
|
INIT = auto()
|
|
|
|
|
|
INTENT_PARSING = auto()
|
|
|
|
|
|
EXECUTING = auto()
|
|
|
|
|
|
FORMATTING = auto()
|
|
|
|
|
|
COMPLETED = auto()
|
|
|
|
|
|
ERROR = auto()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
|
class TokenUsage:
|
|
|
|
|
|
"""Token 使用统计"""
|
|
|
|
|
|
prompt_tokens: int = 0
|
|
|
|
|
|
completion_tokens: int = 0
|
|
|
|
|
|
total_tokens: int = 0
|
|
|
|
|
|
|
|
|
|
|
|
def add(self, other: 'TokenUsage') -> 'TokenUsage':
|
|
|
|
|
|
"""累加另一个统计"""
|
|
|
|
|
|
return TokenUsage(
|
|
|
|
|
|
prompt_tokens=self.prompt_tokens + other.prompt_tokens,
|
|
|
|
|
|
completion_tokens=self.completion_tokens + other.completion_tokens,
|
|
|
|
|
|
total_tokens=self.total_tokens + other.total_tokens
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
|
class BaseState:
|
|
|
|
|
|
"""
|
|
|
|
|
|
基础状态基类
|
|
|
|
|
|
所有子图的状态都应继承此类
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 核心字段
|
|
|
|
|
|
user_query: str = ""
|
|
|
|
|
|
user_id: str = "default"
|
|
|
|
|
|
thread_id: Optional[str] = None
|
|
|
|
|
|
|
|
|
|
|
|
# 执行阶段
|
|
|
|
|
|
current_phase: Phase = Phase.INIT
|
|
|
|
|
|
phase_history: List[Phase] = field(default_factory=list)
|
|
|
|
|
|
|
|
|
|
|
|
# 结果
|
|
|
|
|
|
final_result: str = ""
|
|
|
|
|
|
success: bool = True
|
|
|
|
|
|
error_message: str = ""
|
|
|
|
|
|
|
|
|
|
|
|
# 统计
|
|
|
|
|
|
token_usage: TokenUsage = field(default_factory=TokenUsage)
|
|
|
|
|
|
start_time: Optional[datetime] = None
|
|
|
|
|
|
end_time: Optional[datetime] = None
|
|
|
|
|
|
|
|
|
|
|
|
# 元数据
|
|
|
|
|
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
|
|
|
|
|
|
|
|
|
|
def __post_init__(self):
|
|
|
|
|
|
"""初始化后调用"""
|
|
|
|
|
|
if self.start_time is None:
|
|
|
|
|
|
self.start_time = datetime.now()
|
|
|
|
|
|
if not self.phase_history:
|
|
|
|
|
|
self.phase_history.append(self.current_phase)
|
|
|
|
|
|
|
|
|
|
|
|
def transition_to(self, phase: Phase) -> None:
|
|
|
|
|
|
"""转换到新阶段"""
|
|
|
|
|
|
self.current_phase = phase
|
|
|
|
|
|
self.phase_history.append(phase)
|
|
|
|
|
|
|
|
|
|
|
|
def complete(self, result: str, success: bool = True) -> None:
|
|
|
|
|
|
"""完成执行"""
|
|
|
|
|
|
self.final_result = result
|
|
|
|
|
|
self.success = success
|
|
|
|
|
|
self.end_time = datetime.now()
|
|
|
|
|
|
self.transition_to(Phase.COMPLETED)
|
|
|
|
|
|
|
|
|
|
|
|
def fail(self, error: str) -> None:
|
|
|
|
|
|
"""执行失败"""
|
|
|
|
|
|
self.error_message = error
|
|
|
|
|
|
self.success = False
|
|
|
|
|
|
self.end_time = datetime.now()
|
|
|
|
|
|
self.transition_to(Phase.ERROR)
|
|
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
|
def elapsed_time(self) -> float:
|
|
|
|
|
|
"""获取耗时(秒)"""
|
|
|
|
|
|
if self.start_time and self.end_time:
|
|
|
|
|
|
return (self.end_time - self.start_time).total_seconds()
|
|
|
|
|
|
return 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StateUtils:
|
|
|
|
|
|
"""状态操作工具类"""
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def merge_metadata(base: Dict[str, Any], overlay: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
|
|
|
"""合并元数据"""
|
|
|
|
|
|
result = base.copy()
|
|
|
|
|
|
result.update(overlay)
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def create_snapshot(state: BaseState) -> Dict[str, Any]:
|
|
|
|
|
|
"""创建状态快照(用于调试)"""
|
|
|
|
|
|
return {
|
|
|
|
|
|
"user_query": state.user_query,
|
|
|
|
|
|
"user_id": state.user_id,
|
|
|
|
|
|
"current_phase": state.current_phase.name,
|
|
|
|
|
|
"success": state.success,
|
|
|
|
|
|
"elapsed_time": state.elapsed_time
|
|
|
|
|
|
}
|