150 lines
4.6 KiB
Python
150 lines
4.6 KiB
Python
from typing import Any, Dict, Type
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass, field
|
|
|
|
|
|
@dataclass
|
|
class ScheduleConfig(ABC):
|
|
"""Base configuration class for learning rate schedulers.
|
|
|
|
Provides common validation and interface for all schedule types.
|
|
"""
|
|
|
|
schedule_type: str = field(
|
|
default="cosine",
|
|
metadata={
|
|
"help": "Type of learning rate schedule.",
|
|
"choices": ["cosine", "sgdr"],
|
|
},
|
|
)
|
|
warmup_steps: int = field(
|
|
default=1000, metadata={"help": "Number of warmup steps."}
|
|
)
|
|
min_rate: float = field(
|
|
default=0.05, metadata={"help": "Minimum learning rate multiplier."}
|
|
)
|
|
|
|
@abstractmethod
|
|
def get_kwargs(self) -> Dict[str, Any]:
|
|
"""Get configuration kwargs for scheduler creation."""
|
|
raise NotImplementedError
|
|
|
|
def validate(self) -> None:
|
|
"""Validate configuration parameters."""
|
|
if self.warmup_steps < 0:
|
|
raise ValueError(
|
|
f"warmup_steps must be non-negative, got {self.warmup_steps}"
|
|
)
|
|
if not 0 <= self.min_rate <= 1:
|
|
raise ValueError(f"min_rate must be between 0 and 1, got {self.min_rate}")
|
|
|
|
|
|
@dataclass
|
|
class CosineScheduleConfig(ScheduleConfig):
|
|
"""Cosine annealing learning rate schedule configuration."""
|
|
|
|
total_steps: int = field(
|
|
default=None, metadata={"help": "Total training steps for cosine schedule."}
|
|
)
|
|
|
|
def __post_init__(self) -> None:
|
|
self.schedule_type = "cosine"
|
|
self.validate()
|
|
|
|
def get_kwargs(self) -> Dict[str, Any]:
|
|
if self.total_steps is None:
|
|
raise ValueError("total_steps must be specified for cosine schedule")
|
|
|
|
return {
|
|
"schedule_type": self.schedule_type,
|
|
"warmup_steps": self.warmup_steps,
|
|
"lr_decay_steps": self.total_steps - self.warmup_steps,
|
|
"min_rate": self.min_rate,
|
|
}
|
|
|
|
def validate(self) -> None:
|
|
super().validate()
|
|
if self.total_steps is not None and self.total_steps <= self.warmup_steps:
|
|
raise ValueError(
|
|
f"total_steps ({self.total_steps}) must be greater than warmup_steps ({self.warmup_steps})"
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class SGDRScheduleConfig(ScheduleConfig):
|
|
"""Stochastic Gradient Descent with Warm Restarts schedule configuration."""
|
|
|
|
cycle_length: int = field(
|
|
default=1000, metadata={"help": "Length of the first cycle in steps."}
|
|
)
|
|
t_mult: int = field(
|
|
default=2, metadata={"help": "Multiplier for cycle length growth."}
|
|
)
|
|
|
|
def __post_init__(self) -> None:
|
|
self.schedule_type = "sgdr"
|
|
self.validate()
|
|
|
|
def get_kwargs(self) -> Dict[str, Any]:
|
|
return {
|
|
"schedule_type": self.schedule_type,
|
|
"warmup_steps": self.warmup_steps,
|
|
"cycle_length": self.cycle_length,
|
|
"min_rate": self.min_rate,
|
|
"t_mult": self.t_mult,
|
|
}
|
|
|
|
def validate(self) -> None:
|
|
super().validate()
|
|
if self.cycle_length <= 0:
|
|
raise ValueError(f"cycle_length must be positive, got {self.cycle_length}")
|
|
if self.t_mult < 1:
|
|
raise ValueError(f"t_mult must be >= 1, got {self.t_mult}")
|
|
|
|
|
|
class ScheduleConfigFactory:
|
|
"""Factory class for creating ScheduleConfig instances.
|
|
|
|
Supports both direct instantiation and factory creation methods.
|
|
|
|
Example usage:
|
|
# Direct creation
|
|
config = CosineScheduleConfig(total_steps=10000)
|
|
|
|
# Factory method
|
|
config = ScheduleConfigFactory.create("cosine", total_steps=10000)
|
|
"""
|
|
|
|
CONFIG_MAP: Dict[str, Type[ScheduleConfig]] = {
|
|
"cosine": CosineScheduleConfig,
|
|
"sgdr": SGDRScheduleConfig,
|
|
}
|
|
|
|
@classmethod
|
|
def create(cls, schedule_type: str, **kwargs) -> ScheduleConfig:
|
|
"""Create a schedule config instance.
|
|
|
|
Args:
|
|
schedule_type: Type of schedule ("cosine", "sgdr")
|
|
**kwargs: Arguments passed to the config constructor
|
|
|
|
Returns:
|
|
ScheduleConfig instance
|
|
|
|
Raises:
|
|
ValueError: If schedule_type is not supported
|
|
"""
|
|
if schedule_type not in cls.CONFIG_MAP:
|
|
raise ValueError(
|
|
f"Unknown schedule type: '{schedule_type}'. "
|
|
f"Supported types: {sorted(cls.CONFIG_MAP.keys())}"
|
|
)
|
|
|
|
config_cls = cls.CONFIG_MAP[schedule_type]
|
|
return config_cls(**kwargs)
|
|
|
|
@classmethod
|
|
def available_types(cls) -> list:
|
|
"""Return list of available schedule type names."""
|
|
return list(cls.CONFIG_MAP.keys())
|