diff --git a/khaosz/trainer/train_callback.py b/khaosz/trainer/train_callback.py index 7f351ad..dcb868e 100644 --- a/khaosz/trainer/train_callback.py +++ b/khaosz/trainer/train_callback.py @@ -2,7 +2,7 @@ import os from tqdm import tqdm from torch.nn.utils import clip_grad_norm_ from torch.optim.lr_scheduler import LambdaLR -from typing import Optional, TYPE_CHECKING +from typing import Optional, Protocol, TYPE_CHECKING from khaosz.trainer.strategy import ScheduleConfig, SchedulerFactory if TYPE_CHECKING: @@ -10,43 +10,34 @@ if TYPE_CHECKING: from khaosz.trainer.train_context import TrainContext -class TrainCallback: +class TrainCallback(Protocol): """ Callback interface for trainer. - and we use '_' to ignore unused parameters. """ def on_train_begin(self, trainer: 'Trainer', context: 'TrainContext'): """ Called at the beginning of training. """ - _ = trainer, context - def on_train_begin(self, trainer: 'Trainer', context: 'TrainContext'): + def on_train_end(self, trainer: 'Trainer', context: 'TrainContext'): """ Called at the end of training. """ - _ = trainer, context - def on_train_begin(self, trainer: 'Trainer', context: 'TrainContext'): + def on_epoch_begin(self, trainer: 'Trainer', context: 'TrainContext'): """ Called at the beginning of each epoch. """ - _ = trainer, context - def on_train_begin(self, trainer: 'Trainer', context: 'TrainContext'): + def on_epoch_end(self, trainer: 'Trainer', context: 'TrainContext'): """ Called at the end of each epoch. """ - _ = trainer, context - - def on_train_begin(self, trainer: 'Trainer', context: 'TrainContext'): - """ Called at the beginning of each batch. """ - _ = trainer, context - - def on_train_begin(self, trainer: 'Trainer', context: 'TrainContext'): - """ Called at the end of each batch. """ - _ = trainer, context - - def on_train_begin(self, trainer: 'Trainer', context: 'TrainContext'): + + def on_step_begin(self, trainer: 'Trainer', context: 'TrainContext'): """ Called at the beginning of each step. """ - _ = trainer, context - def on_train_begin(self, trainer: 'Trainer', context: 'TrainContext'): + def on_step_end(self, trainer: 'Trainer', context: 'TrainContext'): """ Called at the end of each step.""" - _ = trainer, context + + def on_batch_begin(self, trainer: 'Trainer', context: 'TrainContext'): + """ Called at the beginning of each batch. """ + + def on_batch_end(self, trainer: 'Trainer', context: 'TrainContext'): + """ Called at the end of each batch. """ class ProgressBarCallback(TrainCallback):