From b2f3fefa1babe30a8ec7ac63071c7a1cc86c59af Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Mon, 29 Sep 2025 12:48:01 +0800 Subject: [PATCH] =?UTF-8?q?feat(callback):=20=E4=B8=BA=20TrainerCallback?= =?UTF-8?q?=20=E5=8F=8A=E5=85=B6=E5=AD=90=E7=B1=BB=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=E6=96=87=E6=A1=A3=E5=AD=97=E7=AC=A6=E4=B8=B2=E5=92=8C=E6=9C=AA?= =?UTF-8?q?=E4=BD=BF=E7=94=A8=E5=8F=82=E6=95=B0=E5=8D=A0=E4=BD=8D=E7=AC=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/trainer/callback.py | 60 ++++++++++++++++++++++++++++++++------ 1 file changed, 51 insertions(+), 9 deletions(-) diff --git a/khaosz/trainer/callback.py b/khaosz/trainer/callback.py index 4a6ef96..4e55311 100644 --- a/khaosz/trainer/callback.py +++ b/khaosz/trainer/callback.py @@ -6,32 +6,66 @@ from typing import cast class TrainerCallback: + """ + Callback interface for trainer. + and we use '_' to ignore unused parameters. + """ + def on_train_begin(self, trainer: 'Trainer', **kwargs): - pass + """ + Called at the beginning of training. + """ + _ = trainer, kwargs def on_train_end(self, trainer: 'Trainer', **kwargs): - pass + """ + Called at the end of training. + """ + _ = trainer, kwargs def on_epoch_begin(self, trainer: 'Trainer', **kwargs): - pass + """ + Called at the beginning of each epoch. + """ + _ = trainer, kwargs def on_epoch_end(self, trainer: 'Trainer', **kwargs): - pass + """ + Called at the end of each epoch. + """ + _ = trainer, kwargs def on_batch_begin(self, trainer: 'Trainer', **kwargs): - pass + """ + Called at the beginning of each batch. + """ + _ = trainer, kwargs def on_batch_end(self, trainer: 'Trainer', **kwargs): - pass + """ + Called at the end of each batch. + """ + _ = trainer, kwargs def on_step_begin(self, trainer: 'Trainer', **kwargs): - pass + """ + Called at the beginning of each step. + """ + + _ = trainer, kwargs def on_step_end(self, trainer: 'Trainer', **kwargs): - pass + """ + Called at the end of each step. + """ + + _ = trainer, kwargs class ProgressBarCallback(TrainerCallback): + """ + Progress bar callback for trainer. + """ def __init__(self): self.progress_bar: tqdm = None @@ -53,16 +87,21 @@ class ProgressBarCallback(TrainerCallback): self.progress_bar.update(1) def on_epoch_end(self, trainer: 'Trainer', **kwargs): + _ = trainer, kwargs if self.progress_bar: self.progress_bar.close() class CheckpointCallback(TrainerCallback): + """ + Checkpoint callback for trainer. + """ def __init__(self, checkpoint_interval: int): self.checkpoint_interval = checkpoint_interval self.last_ckpt_iter = 0 def on_train_begin(self, trainer: 'Trainer', **kwargs): + _ = trainer checkpoint = cast(Checkpoint, kwargs.get('checkpoint')) self.last_ckpt_iter = len(checkpoint.loss_list) @@ -80,8 +119,11 @@ class CheckpointCallback(TrainerCallback): class GradientClippingCallback(TrainerCallback): - + """ + Gradient clipping callback for trainer. + """ def on_step_begin(self, trainer: 'Trainer', **kwargs): + _ = kwargs clip_grad_norm_( trainer.checkpoint.model.parameters(), trainer.train_config.max_grad_norm