Files
ailine/backend/app/agent_subgraphs/common/state_base.py
root d05a57948c
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m20s
refactor: 所有子图使用公共工具,避免重复造轮子
2026-04-25 20:02:20 +08:00

126 lines
3.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
状态基类工具模块
提供类型安全的 LangGraph 状态基类和常用状态操作工具
功能:
1. BaseState - 基础状态基类包含通用字段消息、token统计、耗时等
2. StateUtils - 状态操作工具类,提供常用的状态访问和修改方法
3. TypedStateBuilder - 类型化状态构建器,支持链式创建自定义状态
4. StateValidation - 状态验证工具,确保状态完整性
"""
from typing import Dict, Any, List, Optional
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum, auto
class Phase(Enum):
"""执行阶段枚举"""
INIT = auto()
INTENT_PARSING = auto()
EXECUTING = auto()
FORMATTING = auto()
COMPLETED = auto()
ERROR = auto()
@dataclass
class TokenUsage:
"""Token 使用统计"""
prompt_tokens: int = 0
completion_tokens: int = 0
total_tokens: int = 0
def add(self, other: 'TokenUsage') -> 'TokenUsage':
"""累加另一个统计"""
return TokenUsage(
prompt_tokens=self.prompt_tokens + other.prompt_tokens,
completion_tokens=self.completion_tokens + other.completion_tokens,
total_tokens=self.total_tokens + other.total_tokens
)
@dataclass
class BaseState:
"""
基础状态基类
所有子图的状态都应继承此类
"""
# 核心字段
user_query: str = ""
user_id: str = "default"
thread_id: Optional[str] = None
# 执行阶段
current_phase: Phase = Phase.INIT
phase_history: List[Phase] = field(default_factory=list)
# 结果
final_result: str = ""
success: bool = True
error_message: str = ""
# 统计
token_usage: TokenUsage = field(default_factory=TokenUsage)
start_time: Optional[datetime] = None
end_time: Optional[datetime] = None
# 元数据
metadata: Dict[str, Any] = field(default_factory=dict)
def __post_init__(self):
"""初始化后调用"""
if self.start_time is None:
self.start_time = datetime.now()
if not self.phase_history:
self.phase_history.append(self.current_phase)
def transition_to(self, phase: Phase) -> None:
"""转换到新阶段"""
self.current_phase = phase
self.phase_history.append(phase)
def complete(self, result: str, success: bool = True) -> None:
"""完成执行"""
self.final_result = result
self.success = success
self.end_time = datetime.now()
self.transition_to(Phase.COMPLETED)
def fail(self, error: str) -> None:
"""执行失败"""
self.error_message = error
self.success = False
self.end_time = datetime.now()
self.transition_to(Phase.ERROR)
@property
def elapsed_time(self) -> float:
"""获取耗时(秒)"""
if self.start_time and self.end_time:
return (self.end_time - self.start_time).total_seconds()
return 0.0
class StateUtils:
"""状态操作工具类"""
@staticmethod
def merge_metadata(base: Dict[str, Any], overlay: Dict[str, Any]) -> Dict[str, Any]:
"""合并元数据"""
result = base.copy()
result.update(overlay)
return result
@staticmethod
def create_snapshot(state: BaseState) -> Dict[str, Any]:
"""创建状态快照(用于调试)"""
return {
"user_query": state.user_query,
"user_id": state.user_id,
"current_phase": state.current_phase.name,
"success": state.success,
"elapsed_time": state.elapsed_time
}