diff --git a/khaosz/trainer/trainer.py b/khaosz/trainer/trainer.py index 285f5c0..d906232 100644 --- a/khaosz/trainer/trainer.py +++ b/khaosz/trainer/trainer.py @@ -45,10 +45,16 @@ class Trainer: if train_checkpoint: self.checkpoint = train_checkpoint train_config.optimizer.load_state_dict(train_checkpoint.optim_state) - + + self.checkpoint.optim_state = train_config.optimizer.state_dict() loss_list = self.checkpoint.loss_list current_iter = len(self.checkpoint.loss_list) last_ckpt_iter = current_iter + + for group in train_config.optimizer.param_groups: + if "initial_lr" not in group: + group["initial_lr"] = group["lr"] + lambda_scheduler_fn = SchedulerFactory.load_schedule_fn( **schedule_config.get_kwargs()