diff --git a/khaosz/config/__init__.py b/khaosz/config/__init__.py index 392bdc3..eae4607 100644 --- a/khaosz/config/__init__.py +++ b/khaosz/config/__init__.py @@ -1,12 +1,18 @@ from khaosz.config.model_config import TransformerConfig from khaosz.config.param_config import BaseModelIO, ModelParameter, Checkpoint, ParameterLoader +from khaosz.config.schedule_config import ScheduleConfig, CosineScheduleConfig, SgdrScheduleConfig from khaosz.config.train_config import TrainConfig + __all__ = [ "BaseModelIO", "ModelParameter", "Checkpoint", "ParameterLoader", "TransformerConfig", - "TrainConfig" + "TrainConfig", + + "ScheduleConfig", + "CosineScheduleConfig", + "SgdrScheduleConfig", ] \ No newline at end of file diff --git a/khaosz/config/schedule_config.py b/khaosz/config/schedule_config.py new file mode 100644 index 0000000..82b99c6 --- /dev/null +++ b/khaosz/config/schedule_config.py @@ -0,0 +1,87 @@ +from typing import Any, Literal, Dict +from abc import ABC, abstractmethod +from dataclasses import dataclass, field + + +@dataclass +class ScheduleConfig(ABC): + schedule_type: str = field( + default="cosine", + metadata={ + "help": "Type of learning rate schedule.", + "choices": ["cosine", "sgdr"] + } + ) + warmup_steps: int = field( + default=1000, + metadata={"help": "Number of warmup steps."} + ) + min_rate: float = field( + default=0.05, + metadata={"help": "Minimum learning rate multiplier."} + ) + + @abstractmethod + def get_kwargs(self) -> Dict[str, Any]: + raise NotImplementedError + + def validate(self) -> None: + """Validate configuration parameters.""" + if self.warmup_steps < 0: + raise ValueError(f"warmup_steps must be non-negative, got {self.warmup_steps}") + if not 0 <= self.min_rate <= 1: + raise ValueError(f"min_rate must be between 0 and 1, got {self.min_rate}") + + +@dataclass +class CosineScheduleConfig(ScheduleConfig): + total_steps: int = field( + default=None, + metadata={"help": "Total training steps for cosine schedule."} + ) + schedule_type: Literal["cosine"] = "cosine" + + def get_kwargs(self) -> Dict[str, Any]: + if self.total_steps is None: + raise ValueError("total_steps must be specified for cosine schedule") + + return { + "schedule_type": self.schedule_type, + "warmup_steps": self.warmup_steps, + "lr_decay_steps": self.total_steps - self.warmup_steps, + "min_rate": self.min_rate + } + + def validate(self) -> None: + super().validate() + if self.total_steps is not None and self.total_steps <= self.warmup_steps: + raise ValueError(f"total_steps ({self.total_steps}) must be greater than warmup_steps ({self.warmup_steps})") + + +@dataclass +class SgdrScheduleConfig(ScheduleConfig): + cycle_length: int = field( + default=1000, + metadata={"help": "Length of the first cycle in steps."} + ) + t_mult: int = field( + default=2, + metadata={"help": "Multiplier for cycle length growth."} + ) + schedule_type: Literal["sgdr"] = "sgdr" + + def get_kwargs(self) -> Dict[str, Any]: + return { + "schedule_type": self.schedule_type, + "warmup_steps": self.warmup_steps, + "cycle_length": self.cycle_length, + "min_rate": self.min_rate, + "t_mult": self.t_mult + } + + def validate(self) -> None: + super().validate() + if self.cycle_length <= 0: + raise ValueError(f"cycle_length must be positive, got {self.cycle_length}") + if self.t_mult < 1: + raise ValueError(f"t_mult must be >= 1, got {self.t_mult}") \ No newline at end of file diff --git a/khaosz/trainer/__init__.py b/khaosz/trainer/__init__.py index e264dfe..d856750 100644 --- a/khaosz/trainer/__init__.py +++ b/khaosz/trainer/__init__.py @@ -1,10 +1,7 @@ from khaosz.trainer.trainer import Trainer -from khaosz.trainer.strategy import ( - CosineScheduleConfig, - SgdrScheduleConfig, - StrategyFactory, - SchedulerFactory -) +from khaosz.trainer.strategy import StrategyFactory +from khaosz.trainer.schedule import SchedulerFactory + from khaosz.trainer.train_callback import ( TrainCallback, ProgressBarCallback, @@ -15,16 +12,18 @@ from khaosz.trainer.train_callback import ( ) __all__ = [ + # trainer "Trainer", + + # factory "StrategyFactory", - "CosineScheduleConfig", - "SgdrScheduleConfig", "SchedulerFactory", # callback "TrainCallback", "ProgressBarCallback", "CheckpointCallback", + "TrainCallback", "SchedulerCallback", "StepMonitorCallback" ] \ No newline at end of file diff --git a/khaosz/trainer/schedule.py b/khaosz/trainer/schedule.py new file mode 100644 index 0000000..2309b04 --- /dev/null +++ b/khaosz/trainer/schedule.py @@ -0,0 +1,164 @@ +import math +from abc import abstractmethod, ABC +from typing import Any, Dict, List +from torch.optim.lr_scheduler import LRScheduler +from khaosz.config.schedule_config import ScheduleConfig + + +class BaseScheduler(LRScheduler, ABC): + """ + Base scheduler class for all other schedulers. + """ + + def __init__(self, optimizer, last_epoch: int = -1): + super().__init__(optimizer, last_epoch) + + @abstractmethod + def get_lr(self) -> List[float]: + raise NotImplementedError + + def state_dict(self) -> Dict[str, Any]: + return super().state_dict() + + def load_state_dict(self, state_dict: Dict[str, Any]): + super().load_state_dict(state_dict) + + +class CosineScheduler(BaseScheduler): + """ + Cosine decay scheduler with warmup, implemented as PyTorch LRScheduler. + """ + + def __init__( + self, + optimizer, + warmup_steps: int, + lr_decay_steps: int, + min_rate: float = 0.05, + last_epoch: int = -1 + ): + self.warmup_steps = warmup_steps + self.lr_decay_steps = lr_decay_steps + self.min_rate = min_rate + self.total_steps = warmup_steps + lr_decay_steps + super().__init__(optimizer, last_epoch) + + + def get_lr(self) -> List[float]: + # warmup + if self.last_epoch < self.warmup_steps: + warmup_factor = max(self.min_rate, self.last_epoch / self.warmup_steps) + return [base_lr * warmup_factor for base_lr in self.base_lrs] + + # cosine decay + decay_progress = (self.last_epoch - self.warmup_steps) / self.lr_decay_steps + decay_progress = min(decay_progress, 1.0) + cosine_decay = 0.5 * (1.0 + math.cos(math.pi * decay_progress)) + decay_factor = max(self.min_rate, cosine_decay) + return [base_lr * decay_factor for base_lr in self.base_lrs] + + def state_dict(self): + state = super().state_dict() + state.update({ + 'warmup_steps': self.warmup_steps, + 'lr_decay_steps': self.lr_decay_steps, + 'min_rate': self.min_rate, + 'total_steps': self.total_steps, + }) + return state + + def load_state_dict(self, state_dict): + self.warmup_steps = state_dict.pop('warmup_steps') + self.lr_decay_steps = state_dict.pop('lr_decay_steps') + self.min_rate = state_dict.pop('min_rate') + self.total_steps = state_dict.pop('total_steps') + super().load_state_dict(state_dict) + + +class SGDRScheduler(BaseScheduler): + """ + SGDR (Stochastic Gradient Descent with Warm Restarts) scheduler, + """ + + def __init__( + self, + optimizer, + warmup_steps: int, + cycle_length: int, + min_rate: float = 0.05, + t_mult: int = 2, + last_epoch: int = -1, + ): + self.warmup_steps = warmup_steps + self.cycle_length = cycle_length + self.min_rate = min_rate + self.t_mult = t_mult + + super().__init__(optimizer, last_epoch) + + + def get_lr(self): + # warmup + if self.last_epoch < self.warmup_steps: + warmup_factor = max(self.min_rate, self.last_epoch / self.warmup_steps) + return [base_lr * warmup_factor for base_lr in self.base_lrs] + + # SGDR + steps_since_warmup = self.last_epoch - self.warmup_steps + + # 1. Calculate current cycle and position within cycle + current_cycle_length = self.cycle_length + total_cycles_length = 0 + cycle_num = 0 + + while total_cycles_length + current_cycle_length <= steps_since_warmup: + total_cycles_length += current_cycle_length + current_cycle_length *= self.t_mult + cycle_num += 1 + + steps_in_cycle = steps_since_warmup - total_cycles_length + + # 2. Cosine annealing within the current cycle + cosine_factor = 0.5 * (1 + math.cos(math.pi * steps_in_cycle / current_cycle_length)) + learning_rate_factor = self.min_rate + (1 - self.min_rate) * cosine_factor + + return [base_lr * learning_rate_factor for base_lr in self.base_lrs] + + def state_dict(self): + """Returns the state of the scheduler as a dict.""" + state = super().state_dict() + state.update({ + 'warmup_steps': self.warmup_steps, + 'cycle_length': self.cycle_length, + 'min_rate': self.min_rate, + 't_mult': self.t_mult + }) + return state + + def load_state_dict(self, state_dict): + """Loads the scheduler's state.""" + self.warmup_steps = state_dict.pop('warmup_steps') + self.cycle_length = state_dict.pop('cycle_length') + self.min_rate = state_dict.pop('min_rate') + self.t_mult = state_dict.pop('t_mult') + super().load_state_dict(state_dict) + + + +class SchedulerFactory: + """ + Factory class for creating learning rate schedulers. + """ + + @staticmethod + def load_scheduler(optimizer, scedule_config: ScheduleConfig) -> BaseScheduler: + kwargs = scedule_config.get_kwargs() + schedule_type = kwargs.pop("schedule_type") + + if schedule_type == "cosine": + return CosineScheduler(optimizer, **kwargs) + elif schedule_type == "sgdr": + return SGDRScheduler(optimizer, **kwargs) + else: + raise ValueError(f"Unsupported schedule type: {schedule_type}") + \ No newline at end of file diff --git a/khaosz/trainer/strategy.py b/khaosz/trainer/strategy.py index 81cb15d..0771a55 100644 --- a/khaosz/trainer/strategy.py +++ b/khaosz/trainer/strategy.py @@ -1,13 +1,11 @@ import copy -import math import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor -from typing import Any, Literal, Tuple, Callable, Dict, Union +from typing import Any, Tuple, Callable, Dict, Union from abc import ABC, abstractmethod -from dataclasses import dataclass, field def get_logprobs(model:nn.Module, input_ids: Tensor, mask: Tensor, pad_token_id: int): @@ -167,181 +165,4 @@ class StrategyFactory: ) } strategy = train_strategy[train_type]() - return strategy - - -@dataclass -class ScheduleConfig(ABC): - schedule_type: str = field( - default="cosine", - metadata={ - "help": "Type of learning rate schedule.", - "choices": ["cosine", "sgdr"] - } - ) - warmup_steps: int = field( - default=1000, - metadata={"help": "Number of warmup steps."} - ) - min_rate: float = field( - default=0.05, - metadata={"help": "Minimum learning rate multiplier."} - ) - - @abstractmethod - def get_kwargs(self) -> Dict[str, Any]: - raise NotImplementedError - - def validate(self) -> None: - """Validate configuration parameters.""" - if self.warmup_steps < 0: - raise ValueError(f"warmup_steps must be non-negative, got {self.warmup_steps}") - if not 0 <= self.min_rate <= 1: - raise ValueError(f"min_rate must be between 0 and 1, got {self.min_rate}") - - -@dataclass -class CosineScheduleConfig(ScheduleConfig): - total_steps: int = field( - default=None, - metadata={"help": "Total training steps for cosine schedule."} - ) - schedule_type: Literal["cosine"] = "cosine" - - def get_kwargs(self) -> Dict[str, Any]: - if self.total_steps is None: - raise ValueError("total_steps must be specified for cosine schedule") - - return { - "schedule_type": self.schedule_type, - "warmup_steps": self.warmup_steps, - "lr_decay_steps": self.total_steps - self.warmup_steps, - "min_rate": self.min_rate - } - - def validate(self) -> None: - super().validate() - if self.total_steps is not None and self.total_steps <= self.warmup_steps: - raise ValueError(f"total_steps ({self.total_steps}) must be greater than warmup_steps ({self.warmup_steps})") - - -@dataclass -class SgdrScheduleConfig(ScheduleConfig): - cycle_length: int = field( - default=1000, - metadata={"help": "Length of the first cycle in steps."} - ) - t_mult: int = field( - default=2, - metadata={"help": "Multiplier for cycle length growth."} - ) - schedule_type: Literal["sgdr"] = "sgdr" - - def get_kwargs(self) -> Dict[str, Any]: - return { - "schedule_type": self.schedule_type, - "warmup_steps": self.warmup_steps, - "cycle_length": self.cycle_length, - "min_rate": self.min_rate, - "t_mult": self.t_mult - } - - def validate(self) -> None: - super().validate() - if self.cycle_length <= 0: - raise ValueError(f"cycle_length must be positive, got {self.cycle_length}") - if self.t_mult < 1: - raise ValueError(f"t_mult must be >= 1, got {self.t_mult}") - - -class SchedulerFactory: - """Factory for creating learning rate schedule functions.""" - - @staticmethod - def get_sgdr_schedule( - warmup_steps: int, - cycle_length: int, - min_rate: float = 0.05, - t_mult: int = 2 - ) -> Callable[[int], float]: - """ - Create SGDR (Stochastic Gradient Descent with Warm Restarts) schedule. - - Args: - warmup_steps: Number of warmup steps - cycle_length: Length of the first cycle - min_rate: Minimum learning rate multiplier - t_mult: Cycle length multiplier - - Returns: - Schedule function that takes current step and returns LR multiplier - """ - - def sgdr_schedule(current_step: int) -> float: - # Warmup phase - if current_step < warmup_steps: - return max(min_rate, current_step / warmup_steps) - - # SGDR phase - steps_since_warmup = current_step - warmup_steps - - # Find current cycle and position within cycle - cycle_start = 0 - current_cycle_length = cycle_length - cycle_index = 0 - - while steps_since_warmup >= cycle_start + current_cycle_length: - cycle_start += current_cycle_length - current_cycle_length *= t_mult - cycle_index += 1 - - position_in_cycle = steps_since_warmup - cycle_start - progress = position_in_cycle / current_cycle_length - - # Cosine annealing within cycle - return max(min_rate, 0.5 * (1 + math.cos(math.pi * progress))) - - return sgdr_schedule - - @staticmethod - def get_cosine_schedule( - warmup_steps: int, - lr_decay_steps: int, - min_rate: float = 0.05 - ) -> Callable[[int], float]: - """ - Create cosine decay schedule with warmup. - - Args: - warmup_steps: Number of warmup steps - lr_decay_steps: Number of steps for cosine decay after warmup - min_rate: Minimum learning rate multiplier - - Returns: - Schedule function that takes current step and returns LR multiplier - """ - - def cosine_schedule(current_step: int) -> float: - if current_step < warmup_steps: - # Linear warmup - return max(min_rate, current_step / warmup_steps) - else: - # Cosine decay - decay_progress = (current_step - warmup_steps) / lr_decay_steps - decay_progress = min(decay_progress, 1.0) # Clamp at 1.0 - return max(min_rate, 0.5 * (1.0 + math.cos(math.pi * decay_progress))) - - return cosine_schedule - - @staticmethod - def load_schedule_fn(scedule_config: ScheduleConfig) -> Callable[[int], float]: - kwargs = scedule_config.get_kwargs() - schedule_type = kwargs.pop("schedule_type") - - if schedule_type == "cosine": - return SchedulerFactory.get_cosine_schedule(**kwargs) - elif schedule_type == "sgdr": - return SchedulerFactory.get_sgdr_schedule(**kwargs) - else: - raise ValueError(f"Unsupported schedule type: {schedule_type}") - \ No newline at end of file + return strategy \ No newline at end of file diff --git a/khaosz/trainer/train_callback.py b/khaosz/trainer/train_callback.py index 3ca86bb..73dc53d 100644 --- a/khaosz/trainer/train_callback.py +++ b/khaosz/trainer/train_callback.py @@ -8,7 +8,7 @@ from torch.nn.utils import clip_grad_norm_ from torch.optim.lr_scheduler import LambdaLR from typing import List, Optional, Protocol, TYPE_CHECKING -from khaosz.trainer.strategy import ScheduleConfig, SchedulerFactory +from khaosz.config import ScheduleConfig from khaosz.trainer.metric_util import ( grad_max, grad_min, @@ -78,17 +78,8 @@ class SchedulerCallback(TrainCallback): for group in trainer.train_config.optimizer.param_groups: if "initial_lr" not in group: group["initial_lr"] = group["lr"] - - self.schedule_config.validate() - lambda_scheduler_fn = SchedulerFactory.load_schedule_fn( - self.schedule_config - ) - self.scheduler = LambdaLR( - trainer.train_config.optimizer, - lambda_scheduler_fn, - last_epoch=context.current_iter - 1 - ) + self.scheduler = context.scheduler def on_batch_end(self, trainer: 'Trainer', context: 'TrainContext'): _ = trainer, context @@ -106,8 +97,8 @@ class CheckpointCallback(TrainCallback): def _save_checkpoint(self, trainer: 'Trainer', context: 'TrainContext'): save_path = os.path.join(trainer.train_config.checkpoint_dir, f"iter_{context.current_iter}") - # context.checkpoint.scheduler_state = context.sampler.state_dict() context.checkpoint.optimizer_state = context.optimizer.state_dict() + context.checkpoint.scheduler_state = context.scheduler.state_dict() context.checkpoint.save(save_path) self.last_ckpt_iter = context.current_iter diff --git a/khaosz/trainer/train_context.py b/khaosz/trainer/train_context.py index d9b8a9c..b1ed4e0 100644 --- a/khaosz/trainer/train_context.py +++ b/khaosz/trainer/train_context.py @@ -1,10 +1,10 @@ from dataclasses import dataclass, field, fields from typing import Optional, Self, TYPE_CHECKING from torch.optim import Optimizer -from torch.optim.lr_scheduler import LRScheduler from torch.utils.data import DataLoader -from khaosz.config.param_config import Checkpoint -from khaosz.data.data_util import ResumeableRandomSampler +from khaosz.config import Checkpoint +from khaosz.data import ResumeableRandomSampler +from khaosz.trainer.schedule import BaseScheduler, SchedulerFactory if TYPE_CHECKING: from khaosz.trainer.trainer import Trainer @@ -14,7 +14,7 @@ if TYPE_CHECKING: class TrainContext: dataloader: DataLoader = field(default=None) optimizer: Optimizer = field(default=None) - scheduler: LRScheduler = field(default=None) + scheduler: BaseScheduler = field(default=None) checkpoint: Checkpoint = field(default=None) epoch: int = field(default=0) current_iter: int = field(default=0) @@ -54,8 +54,20 @@ class TrainContextBuilder: return self def with_scheduler(self) -> Self: - return self + # the build order has any problem ? + optimizer = self.trainer.train_config.optimizer + schedule_config = self.trainer.schedule_config + scheduler = SchedulerFactory.load_scheduler(optimizer, schedule_config) + if self._context.checkpoint and self._context.checkpoint.scheduler_state: + scheduler.load_state_dict(self._context.checkpoint.scheduler_state) + + self._context.scheduler = scheduler + + if self._context.checkpoint: + self._context.checkpoint.scheduler_state = scheduler.state_dict() + + return self def with_dataloader(self) -> Self: resumeable_sampler = ResumeableRandomSampler( diff --git a/khaosz/trainer/trainer.py b/khaosz/trainer/trainer.py index 14878d0..0241ded 100644 --- a/khaosz/trainer/trainer.py +++ b/khaosz/trainer/trainer.py @@ -1,9 +1,11 @@ import logging from typing import Optional, List - -from khaosz.config import ModelParameter, Checkpoint -from khaosz.trainer.strategy import ScheduleConfig -from khaosz.config.train_config import TrainConfig +from khaosz.config import ( + ModelParameter, + Checkpoint, + ScheduleConfig, + TrainConfig +) from khaosz.trainer.train_callback import ( TrainCallback, ProgressBarCallback, @@ -15,6 +17,7 @@ from khaosz.trainer.train_context import TrainContext, TrainContextBuilder logger = logging.getLogger(__name__) + class Trainer: def __init__( self, @@ -36,7 +39,7 @@ class Trainer: SchedulerCallback(self.schedule_config), ] - def _build_train_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext: + def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext: return (TrainContextBuilder(self) .with_checkpoint(checkpoint) .with_optimizer() @@ -51,8 +54,7 @@ class Trainer: method(self, context) def train(self, checkpoint: Optional[Checkpoint] = None) -> Checkpoint: - context = self._build_train_context(checkpoint) - + context = self._build_context(checkpoint) self._call_callbacks('on_train_begin', context) try: diff --git a/train.py b/train.py index 6249b91..93e3c88 100644 --- a/train.py +++ b/train.py @@ -3,9 +3,9 @@ import argparse import torch from torch.optim import AdamW -from khaosz.core import ParameterLoader, Checkpoint -from khaosz.trainer import Trainer, DatasetLoader, TrainConfig, CosineScheduleConfig -from khaosz.trainer import StrategyFactory +from khaosz.config import ParameterLoader, Checkpoint, TrainConfig, CosineScheduleConfig +from khaosz.trainer import Trainer, StrategyFactory +from khaosz.data import DatasetLoader PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))