fix(khaosz/trainer/train_callback): 修复基类函数命名错误

This commit is contained in:
ViperEkura 2025-10-07 11:43:51 +08:00
parent 57cd7b921e
commit 0764cb8296
1 changed files with 14 additions and 23 deletions

View File

@ -2,7 +2,7 @@ import os
from tqdm import tqdm
from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import LambdaLR
from typing import Optional, TYPE_CHECKING
from typing import Optional, Protocol, TYPE_CHECKING
from khaosz.trainer.strategy import ScheduleConfig, SchedulerFactory
if TYPE_CHECKING:
@ -10,43 +10,34 @@ if TYPE_CHECKING:
from khaosz.trainer.train_context import TrainContext
class TrainCallback:
class TrainCallback(Protocol):
"""
Callback interface for trainer.
and we use '_' to ignore unused parameters.
"""
def on_train_begin(self, trainer: 'Trainer', context: 'TrainContext'):
""" Called at the beginning of training. """
_ = trainer, context
def on_train_begin(self, trainer: 'Trainer', context: 'TrainContext'):
def on_train_end(self, trainer: 'Trainer', context: 'TrainContext'):
""" Called at the end of training. """
_ = trainer, context
def on_train_begin(self, trainer: 'Trainer', context: 'TrainContext'):
def on_epoch_begin(self, trainer: 'Trainer', context: 'TrainContext'):
""" Called at the beginning of each epoch. """
_ = trainer, context
def on_train_begin(self, trainer: 'Trainer', context: 'TrainContext'):
def on_epoch_end(self, trainer: 'Trainer', context: 'TrainContext'):
""" Called at the end of each epoch. """
_ = trainer, context
def on_train_begin(self, trainer: 'Trainer', context: 'TrainContext'):
""" Called at the beginning of each batch. """
_ = trainer, context
def on_train_begin(self, trainer: 'Trainer', context: 'TrainContext'):
""" Called at the end of each batch. """
_ = trainer, context
def on_train_begin(self, trainer: 'Trainer', context: 'TrainContext'):
def on_step_begin(self, trainer: 'Trainer', context: 'TrainContext'):
""" Called at the beginning of each step. """
_ = trainer, context
def on_train_begin(self, trainer: 'Trainer', context: 'TrainContext'):
def on_step_end(self, trainer: 'Trainer', context: 'TrainContext'):
""" Called at the end of each step."""
_ = trainer, context
def on_batch_begin(self, trainer: 'Trainer', context: 'TrainContext'):
""" Called at the beginning of each batch. """
def on_batch_end(self, trainer: 'Trainer', context: 'TrainContext'):
""" Called at the end of each batch. """
class ProgressBarCallback(TrainCallback):