feat(khaosz/trainer): 更新梯度裁剪回调
This commit is contained in:
parent
05b012820b
commit
0db046f8d9
|
|
@ -60,9 +60,12 @@ class GradientClippingCallback(TrainCallback):
|
||||||
"""
|
"""
|
||||||
Gradient clipping callback for trainer.
|
Gradient clipping callback for trainer.
|
||||||
"""
|
"""
|
||||||
|
def __init__(self, max_grad_norm: float):
|
||||||
|
self.max_grad_norm = max_grad_norm
|
||||||
|
|
||||||
def on_step_begin(self, trainer: 'Trainer', context: 'TrainContext'):
|
def on_step_begin(self, trainer: 'Trainer', context: 'TrainContext'):
|
||||||
_ = context
|
_ = context
|
||||||
clip_grad_norm_(trainer.parameter.model.parameters(), trainer.train_config.max_grad_norm)
|
clip_grad_norm_(trainer.parameter.model.parameters(), self.max_grad_norm)
|
||||||
|
|
||||||
|
|
||||||
class SchedulerCallback(TrainCallback):
|
class SchedulerCallback(TrainCallback):
|
||||||
|
|
|
||||||
|
|
@ -35,7 +35,7 @@ class Trainer:
|
||||||
return [
|
return [
|
||||||
ProgressBarCallback(),
|
ProgressBarCallback(),
|
||||||
CheckpointCallback(self.train_config.checkpoint_interval),
|
CheckpointCallback(self.train_config.checkpoint_interval),
|
||||||
GradientClippingCallback(),
|
GradientClippingCallback(self.train_config.max_grad_norm),
|
||||||
SchedulerCallback(self.schedule_config),
|
SchedulerCallback(self.schedule_config),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue