AstrAI/khaosz/trainer/schedule.py

229 lines
7.8 KiB
Python

"""Learning rate scheduler implementations with factory pattern."""
import math
from abc import abstractmethod, ABC
from typing import Any, Dict, List, Type
from torch.optim.lr_scheduler import LRScheduler
from khaosz.config.schedule_config import ScheduleConfig
class BaseScheduler(LRScheduler, ABC):
"""Base scheduler class for all other schedulers."""
def __init__(self, optimizer, last_epoch: int = -1):
super().__init__(optimizer, last_epoch)
@abstractmethod
def get_lr(self) -> List[float]:
"""Calculate the current learning rate."""
raise NotImplementedError
def state_dict(self) -> Dict[str, Any]:
return super().state_dict()
def load_state_dict(self, state_dict: Dict[str, Any]):
super().load_state_dict(state_dict)
class SchedulerFactory:
"""Factory class for creating learning rate schedulers.
Supports decorator-based registration for extensible scheduler types.
Also supports creation from ScheduleConfig objects.
Example usage:
@SchedulerFactory.register("custom")
class CustomScheduler(BaseScheduler):
...
scheduler = SchedulerFactory.create(optimizer, "custom", **kwargs)
# Or from config
config = CosineScheduleConfig(total_steps=10000)
scheduler = SchedulerFactory.load(optimizer, config)
"""
SCHEDULER_MAP: Dict[str, Type[BaseScheduler]] = {}
@classmethod
def register(cls, name: str):
"""Decorator to register a new scheduler class.
Args:
name: Registration name for the scheduler
Returns:
Decorator function that registers the scheduler class
"""
def decorator(scheduler_cls: Type[BaseScheduler]) -> Type[BaseScheduler]:
if not issubclass(scheduler_cls, BaseScheduler):
raise TypeError(f"{scheduler_cls.__name__} must inherit from BaseScheduler")
cls.SCHEDULER_MAP[name] = scheduler_cls
return scheduler_cls
return decorator
@classmethod
def create(cls, optimizer, schedule_type: str, **kwargs) -> BaseScheduler:
"""Create a scheduler instance by type name.
Args:
optimizer: PyTorch optimizer
schedule_type: Type of scheduler ("cosine", "sgdr")
**kwargs: Arguments passed to the scheduler constructor
Returns:
Scheduler instance
Raises:
ValueError: If schedule_type is not supported
"""
if schedule_type not in cls.SCHEDULER_MAP:
raise ValueError(
f"Unknown schedule type: '{schedule_type}'. "
f"Supported types: {sorted(cls.SCHEDULER_MAP.keys())}"
)
scheduler_cls = cls.SCHEDULER_MAP[schedule_type]
return scheduler_cls(optimizer, **kwargs)
@staticmethod
def load(optimizer, schedule_config: ScheduleConfig) -> BaseScheduler:
"""Create a scheduler from a ScheduleConfig object.
Args:
optimizer: PyTorch optimizer
schedule_config: ScheduleConfig instance
Returns:
Scheduler instance
"""
kwargs = schedule_config.get_kwargs()
schedule_type = kwargs.pop("schedule_type")
return SchedulerFactory.create(optimizer, schedule_type, **kwargs)
@classmethod
def available_types(cls) -> list:
"""Return list of registered scheduler type names."""
return list(cls.SCHEDULER_MAP.keys())
# ============== Scheduler Classes ==============
# All scheduler classes are registered at class definition time using the decorator
@SchedulerFactory.register("cosine")
class CosineScheduler(BaseScheduler):
"""Cosine decay scheduler with warmup, implemented as PyTorch LRScheduler."""
def __init__(
self,
optimizer,
warmup_steps: int,
lr_decay_steps: int,
min_rate: float = 0.05,
last_epoch: int = -1
):
self.warmup_steps = warmup_steps
self.lr_decay_steps = lr_decay_steps
self.min_rate = min_rate
self.total_steps = warmup_steps + lr_decay_steps
super().__init__(optimizer, last_epoch)
def get_lr(self) -> List[float]:
# warmup
if self.last_epoch < self.warmup_steps:
warmup_factor = max(self.min_rate, self.last_epoch / self.warmup_steps)
return [base_lr * warmup_factor for base_lr in self.base_lrs]
# cosine decay
decay_progress = (self.last_epoch - self.warmup_steps) / self.lr_decay_steps
decay_progress = min(decay_progress, 1.0)
cosine_decay = 0.5 * (1.0 + math.cos(math.pi * decay_progress))
decay_factor = max(self.min_rate, cosine_decay)
return [base_lr * decay_factor for base_lr in self.base_lrs]
def state_dict(self):
state = super().state_dict()
state.update({
'warmup_steps': self.warmup_steps,
'lr_decay_steps': self.lr_decay_steps,
'min_rate': self.min_rate,
'total_steps': self.total_steps,
})
return state
def load_state_dict(self, state_dict):
self.warmup_steps = state_dict.pop('warmup_steps')
self.lr_decay_steps = state_dict.pop('lr_decay_steps')
self.min_rate = state_dict.pop('min_rate')
self.total_steps = state_dict.pop('total_steps')
super().load_state_dict(state_dict)
@SchedulerFactory.register("sgdr")
class SGDRScheduler(BaseScheduler):
"""SGDR (Stochastic Gradient Descent with Warm Restarts) scheduler."""
def __init__(
self,
optimizer,
warmup_steps: int,
cycle_length: int,
min_rate: float = 0.05,
t_mult: int = 2,
last_epoch: int = -1,
):
self.warmup_steps = warmup_steps
self.cycle_length = cycle_length
self.min_rate = min_rate
self.t_mult = t_mult
super().__init__(optimizer, last_epoch)
def get_lr(self):
# warmup
if self.last_epoch < self.warmup_steps:
warmup_factor = max(self.min_rate, self.last_epoch / self.warmup_steps)
return [base_lr * warmup_factor for base_lr in self.base_lrs]
# SGDR
steps_since_warmup = self.last_epoch - self.warmup_steps
# 1. Calculate current cycle and position within cycle
current_cycle_length = self.cycle_length
total_cycles_length = 0
cycle_num = 0
while total_cycles_length + current_cycle_length <= steps_since_warmup:
total_cycles_length += current_cycle_length
current_cycle_length *= self.t_mult
cycle_num += 1
steps_in_cycle = steps_since_warmup - total_cycles_length
# 2. Cosine annealing within the current cycle
cosine_factor = 0.5 * (1 + math.cos(math.pi * steps_in_cycle / current_cycle_length))
learning_rate_factor = self.min_rate + (1 - self.min_rate) * cosine_factor
return [base_lr * learning_rate_factor for base_lr in self.base_lrs]
def state_dict(self):
"""Returns the state of the scheduler as a dict."""
state = super().state_dict()
state.update({
'warmup_steps': self.warmup_steps,
'cycle_length': self.cycle_length,
'min_rate': self.min_rate,
't_mult': self.t_mult
})
return state
def load_state_dict(self, state_dict):
"""Loads the scheduler's state."""
self.warmup_steps = state_dict.pop('warmup_steps')
self.cycle_length = state_dict.pop('cycle_length')
self.min_rate = state_dict.pop('min_rate')
self.t_mult = state_dict.pop('t_mult')
super().load_state_dict(state_dict)