From 25ec56a1f589225c598cbacd118e6a13c1b14812 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sun, 28 Sep 2025 14:38:02 +0800 Subject: [PATCH] =?UTF-8?q?fix(trainer):=20=E4=BF=AE=E5=A4=8D=E8=AE=AD?= =?UTF-8?q?=E7=BB=83=E5=99=A8=E6=81=A2=E5=A4=8D=E6=A3=80=E6=9F=A5=E7=82=B9?= =?UTF-8?q?=E6=97=B6=E7=9A=84=E5=AD=A6=E4=B9=A0=E7=8E=87=E5=88=9D=E5=A7=8B?= =?UTF-8?q?=E5=8C=96=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/trainer/trainer.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/khaosz/trainer/trainer.py b/khaosz/trainer/trainer.py index 285f5c0..d906232 100644 --- a/khaosz/trainer/trainer.py +++ b/khaosz/trainer/trainer.py @@ -45,10 +45,16 @@ class Trainer: if train_checkpoint: self.checkpoint = train_checkpoint train_config.optimizer.load_state_dict(train_checkpoint.optim_state) - + + self.checkpoint.optim_state = train_config.optimizer.state_dict() loss_list = self.checkpoint.loss_list current_iter = len(self.checkpoint.loss_list) last_ckpt_iter = current_iter + + for group in train_config.optimizer.param_groups: + if "initial_lr" not in group: + group["initial_lr"] = group["lr"] + lambda_scheduler_fn = SchedulerFactory.load_schedule_fn( **schedule_config.get_kwargs()