feat(khaosz/trainer): 更新梯度裁剪回调
This commit is contained in:
parent
05b012820b
commit
0db046f8d9
|
|
@ -60,9 +60,12 @@ class GradientClippingCallback(TrainCallback):
|
|||
"""
|
||||
Gradient clipping callback for trainer.
|
||||
"""
|
||||
def __init__(self, max_grad_norm: float):
|
||||
self.max_grad_norm = max_grad_norm
|
||||
|
||||
def on_step_begin(self, trainer: 'Trainer', context: 'TrainContext'):
|
||||
_ = context
|
||||
clip_grad_norm_(trainer.parameter.model.parameters(), trainer.train_config.max_grad_norm)
|
||||
clip_grad_norm_(trainer.parameter.model.parameters(), self.max_grad_norm)
|
||||
|
||||
|
||||
class SchedulerCallback(TrainCallback):
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ class Trainer:
|
|||
return [
|
||||
ProgressBarCallback(),
|
||||
CheckpointCallback(self.train_config.checkpoint_interval),
|
||||
GradientClippingCallback(),
|
||||
GradientClippingCallback(self.train_config.max_grad_norm),
|
||||
SchedulerCallback(self.schedule_config),
|
||||
]
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue