diff --git a/khaosz/trainer/callback.py b/khaosz/trainer/callback.py index bb0a3ea..6505833 100644 --- a/khaosz/trainer/callback.py +++ b/khaosz/trainer/callback.py @@ -154,9 +154,9 @@ class SchedulerCallback(TrainerCallback): def on_train_begin(self, trainer: 'Trainer', **kwargs): checkpoint = cast(Checkpoint, kwargs.get('checkpoint')) self.current_iter = len(checkpoint.loss_list) - + self.schedule_config.validate() lambda_scheduler_fn = SchedulerFactory.load_schedule_fn( - **self.schedule_config.get_kwargs() + self.schedule_config ) self.scheduler = LambdaLR( diff --git a/khaosz/trainer/strategy.py b/khaosz/trainer/strategy.py index eadaf1f..93a2fea 100644 --- a/khaosz/trainer/strategy.py +++ b/khaosz/trainer/strategy.py @@ -386,24 +386,10 @@ class SchedulerFactory: return max(min_rate, 0.5 * (1.0 + math.cos(math.pi * decay_progress))) return cosine_schedule - + @staticmethod - def create_schedule(config: ScheduleConfig) -> Callable[[int], float]: - """ - Create schedule from configuration. - - Args: - config: Schedule configuration instance - - Returns: - Schedule function - """ - config.validate() - kwargs = config.get_kwargs() - return SchedulerFactory.load_schedule_fn(**kwargs) - - @staticmethod - def load_schedule_fn(**kwargs) -> Callable[[int], float]: + def load_schedule_fn(scedule_config: ScheduleConfig) -> Callable[[int], float]: + kwargs = scedule_config.get_kwargs() schedule_type = kwargs.pop("schedule_type") if schedule_type == "cosine":