From 0764cb82963033ca96a8d8aac8c47466b5756e0f Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Tue, 7 Oct 2025 11:43:51 +0800 Subject: [PATCH] =?UTF-8?q?fix(khaosz/trainer/train=5Fcallback):=20?= =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E5=9F=BA=E7=B1=BB=E5=87=BD=E6=95=B0=E5=91=BD?= =?UTF-8?q?=E5=90=8D=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/trainer/train_callback.py | 37 ++++++++++++-------------------- 1 file changed, 14 insertions(+), 23 deletions(-) diff --git a/khaosz/trainer/train_callback.py b/khaosz/trainer/train_callback.py index 7f351ad..dcb868e 100644 --- a/khaosz/trainer/train_callback.py +++ b/khaosz/trainer/train_callback.py @@ -2,7 +2,7 @@ import os from tqdm import tqdm from torch.nn.utils import clip_grad_norm_ from torch.optim.lr_scheduler import LambdaLR -from typing import Optional, TYPE_CHECKING +from typing import Optional, Protocol, TYPE_CHECKING from khaosz.trainer.strategy import ScheduleConfig, SchedulerFactory if TYPE_CHECKING: @@ -10,43 +10,34 @@ if TYPE_CHECKING: from khaosz.trainer.train_context import TrainContext -class TrainCallback: +class TrainCallback(Protocol): """ Callback interface for trainer. - and we use '_' to ignore unused parameters. """ def on_train_begin(self, trainer: 'Trainer', context: 'TrainContext'): """ Called at the beginning of training. """ - _ = trainer, context - def on_train_begin(self, trainer: 'Trainer', context: 'TrainContext'): + def on_train_end(self, trainer: 'Trainer', context: 'TrainContext'): """ Called at the end of training. """ - _ = trainer, context - def on_train_begin(self, trainer: 'Trainer', context: 'TrainContext'): + def on_epoch_begin(self, trainer: 'Trainer', context: 'TrainContext'): """ Called at the beginning of each epoch. """ - _ = trainer, context - def on_train_begin(self, trainer: 'Trainer', context: 'TrainContext'): + def on_epoch_end(self, trainer: 'Trainer', context: 'TrainContext'): """ Called at the end of each epoch. """ - _ = trainer, context - - def on_train_begin(self, trainer: 'Trainer', context: 'TrainContext'): - """ Called at the beginning of each batch. """ - _ = trainer, context - - def on_train_begin(self, trainer: 'Trainer', context: 'TrainContext'): - """ Called at the end of each batch. """ - _ = trainer, context - - def on_train_begin(self, trainer: 'Trainer', context: 'TrainContext'): + + def on_step_begin(self, trainer: 'Trainer', context: 'TrainContext'): """ Called at the beginning of each step. """ - _ = trainer, context - def on_train_begin(self, trainer: 'Trainer', context: 'TrainContext'): + def on_step_end(self, trainer: 'Trainer', context: 'TrainContext'): """ Called at the end of each step.""" - _ = trainer, context + + def on_batch_begin(self, trainer: 'Trainer', context: 'TrainContext'): + """ Called at the beginning of each batch. """ + + def on_batch_end(self, trainer: 'Trainer', context: 'TrainContext'): + """ Called at the end of each batch. """ class ProgressBarCallback(TrainCallback):