fix(trainer): 修复训练器恢复检查点时的学习率初始化问题

This commit is contained in:
ViperEkura 2025-09-28 14:38:02 +08:00
parent c8a38743a4
commit 25ec56a1f5
1 changed files with 7 additions and 1 deletions

View File

@ -45,10 +45,16 @@ class Trainer:
if train_checkpoint: if train_checkpoint:
self.checkpoint = train_checkpoint self.checkpoint = train_checkpoint
train_config.optimizer.load_state_dict(train_checkpoint.optim_state) train_config.optimizer.load_state_dict(train_checkpoint.optim_state)
self.checkpoint.optim_state = train_config.optimizer.state_dict()
loss_list = self.checkpoint.loss_list loss_list = self.checkpoint.loss_list
current_iter = len(self.checkpoint.loss_list) current_iter = len(self.checkpoint.loss_list)
last_ckpt_iter = current_iter last_ckpt_iter = current_iter
for group in train_config.optimizer.param_groups:
if "initial_lr" not in group:
group["initial_lr"] = group["lr"]
lambda_scheduler_fn = SchedulerFactory.load_schedule_fn( lambda_scheduler_fn = SchedulerFactory.load_schedule_fn(
**schedule_config.get_kwargs() **schedule_config.get_kwargs()