diff --git a/khaosz/trainer/callback.py b/khaosz/trainer/callback.py index 4a6ef96..4e55311 100644 --- a/khaosz/trainer/callback.py +++ b/khaosz/trainer/callback.py @@ -6,32 +6,66 @@ from typing import cast class TrainerCallback: + """ + Callback interface for trainer. + and we use '_' to ignore unused parameters. + """ + def on_train_begin(self, trainer: 'Trainer', **kwargs): - pass + """ + Called at the beginning of training. + """ + _ = trainer, kwargs def on_train_end(self, trainer: 'Trainer', **kwargs): - pass + """ + Called at the end of training. + """ + _ = trainer, kwargs def on_epoch_begin(self, trainer: 'Trainer', **kwargs): - pass + """ + Called at the beginning of each epoch. + """ + _ = trainer, kwargs def on_epoch_end(self, trainer: 'Trainer', **kwargs): - pass + """ + Called at the end of each epoch. + """ + _ = trainer, kwargs def on_batch_begin(self, trainer: 'Trainer', **kwargs): - pass + """ + Called at the beginning of each batch. + """ + _ = trainer, kwargs def on_batch_end(self, trainer: 'Trainer', **kwargs): - pass + """ + Called at the end of each batch. + """ + _ = trainer, kwargs def on_step_begin(self, trainer: 'Trainer', **kwargs): - pass + """ + Called at the beginning of each step. + """ + + _ = trainer, kwargs def on_step_end(self, trainer: 'Trainer', **kwargs): - pass + """ + Called at the end of each step. + """ + + _ = trainer, kwargs class ProgressBarCallback(TrainerCallback): + """ + Progress bar callback for trainer. + """ def __init__(self): self.progress_bar: tqdm = None @@ -53,16 +87,21 @@ class ProgressBarCallback(TrainerCallback): self.progress_bar.update(1) def on_epoch_end(self, trainer: 'Trainer', **kwargs): + _ = trainer, kwargs if self.progress_bar: self.progress_bar.close() class CheckpointCallback(TrainerCallback): + """ + Checkpoint callback for trainer. + """ def __init__(self, checkpoint_interval: int): self.checkpoint_interval = checkpoint_interval self.last_ckpt_iter = 0 def on_train_begin(self, trainer: 'Trainer', **kwargs): + _ = trainer checkpoint = cast(Checkpoint, kwargs.get('checkpoint')) self.last_ckpt_iter = len(checkpoint.loss_list) @@ -80,8 +119,11 @@ class CheckpointCallback(TrainerCallback): class GradientClippingCallback(TrainerCallback): - + """ + Gradient clipping callback for trainer. + """ def on_step_begin(self, trainer: 'Trainer', **kwargs): + _ = kwargs clip_grad_norm_( trainer.checkpoint.model.parameters(), trainer.train_config.max_grad_norm