diff --git a/khaosz/trainer/train_callback.py b/khaosz/trainer/train_callback.py index 1673221..4805e37 100644 --- a/khaosz/trainer/train_callback.py +++ b/khaosz/trainer/train_callback.py @@ -60,9 +60,12 @@ class GradientClippingCallback(TrainCallback): """ Gradient clipping callback for trainer. """ + def __init__(self, max_grad_norm: float): + self.max_grad_norm = max_grad_norm + def on_step_begin(self, trainer: 'Trainer', context: 'TrainContext'): _ = context - clip_grad_norm_(trainer.parameter.model.parameters(), trainer.train_config.max_grad_norm) + clip_grad_norm_(trainer.parameter.model.parameters(), self.max_grad_norm) class SchedulerCallback(TrainCallback): diff --git a/khaosz/trainer/trainer.py b/khaosz/trainer/trainer.py index 1565584..be28e8c 100644 --- a/khaosz/trainer/trainer.py +++ b/khaosz/trainer/trainer.py @@ -35,7 +35,7 @@ class Trainer: return [ ProgressBarCallback(), CheckpointCallback(self.train_config.checkpoint_interval), - GradientClippingCallback(), + GradientClippingCallback(self.train_config.max_grad_norm), SchedulerCallback(self.schedule_config), ]