fix(trainer): 修复训练器恢复检查点时的学习率初始化问题
This commit is contained in:
parent
c8a38743a4
commit
25ec56a1f5
|
|
@ -46,10 +46,16 @@ class Trainer:
|
|||
self.checkpoint = train_checkpoint
|
||||
train_config.optimizer.load_state_dict(train_checkpoint.optim_state)
|
||||
|
||||
self.checkpoint.optim_state = train_config.optimizer.state_dict()
|
||||
loss_list = self.checkpoint.loss_list
|
||||
current_iter = len(self.checkpoint.loss_list)
|
||||
last_ckpt_iter = current_iter
|
||||
|
||||
for group in train_config.optimizer.param_groups:
|
||||
if "initial_lr" not in group:
|
||||
group["initial_lr"] = group["lr"]
|
||||
|
||||
|
||||
lambda_scheduler_fn = SchedulerFactory.load_schedule_fn(
|
||||
**schedule_config.get_kwargs()
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue