fix(khaosz/trainer/train_callback): 修复基类函数命名错误
This commit is contained in:
parent
57cd7b921e
commit
0764cb8296
|
|
@ -2,7 +2,7 @@ import os
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from torch.nn.utils import clip_grad_norm_
|
from torch.nn.utils import clip_grad_norm_
|
||||||
from torch.optim.lr_scheduler import LambdaLR
|
from torch.optim.lr_scheduler import LambdaLR
|
||||||
from typing import Optional, TYPE_CHECKING
|
from typing import Optional, Protocol, TYPE_CHECKING
|
||||||
from khaosz.trainer.strategy import ScheduleConfig, SchedulerFactory
|
from khaosz.trainer.strategy import ScheduleConfig, SchedulerFactory
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|
@ -10,43 +10,34 @@ if TYPE_CHECKING:
|
||||||
from khaosz.trainer.train_context import TrainContext
|
from khaosz.trainer.train_context import TrainContext
|
||||||
|
|
||||||
|
|
||||||
class TrainCallback:
|
class TrainCallback(Protocol):
|
||||||
"""
|
"""
|
||||||
Callback interface for trainer.
|
Callback interface for trainer.
|
||||||
and we use '_' to ignore unused parameters.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def on_train_begin(self, trainer: 'Trainer', context: 'TrainContext'):
|
def on_train_begin(self, trainer: 'Trainer', context: 'TrainContext'):
|
||||||
""" Called at the beginning of training. """
|
""" Called at the beginning of training. """
|
||||||
_ = trainer, context
|
|
||||||
|
|
||||||
def on_train_begin(self, trainer: 'Trainer', context: 'TrainContext'):
|
def on_train_end(self, trainer: 'Trainer', context: 'TrainContext'):
|
||||||
""" Called at the end of training. """
|
""" Called at the end of training. """
|
||||||
_ = trainer, context
|
|
||||||
|
|
||||||
def on_train_begin(self, trainer: 'Trainer', context: 'TrainContext'):
|
def on_epoch_begin(self, trainer: 'Trainer', context: 'TrainContext'):
|
||||||
""" Called at the beginning of each epoch. """
|
""" Called at the beginning of each epoch. """
|
||||||
_ = trainer, context
|
|
||||||
|
|
||||||
def on_train_begin(self, trainer: 'Trainer', context: 'TrainContext'):
|
def on_epoch_end(self, trainer: 'Trainer', context: 'TrainContext'):
|
||||||
""" Called at the end of each epoch. """
|
""" Called at the end of each epoch. """
|
||||||
_ = trainer, context
|
|
||||||
|
def on_step_begin(self, trainer: 'Trainer', context: 'TrainContext'):
|
||||||
def on_train_begin(self, trainer: 'Trainer', context: 'TrainContext'):
|
|
||||||
""" Called at the beginning of each batch. """
|
|
||||||
_ = trainer, context
|
|
||||||
|
|
||||||
def on_train_begin(self, trainer: 'Trainer', context: 'TrainContext'):
|
|
||||||
""" Called at the end of each batch. """
|
|
||||||
_ = trainer, context
|
|
||||||
|
|
||||||
def on_train_begin(self, trainer: 'Trainer', context: 'TrainContext'):
|
|
||||||
""" Called at the beginning of each step. """
|
""" Called at the beginning of each step. """
|
||||||
_ = trainer, context
|
|
||||||
|
|
||||||
def on_train_begin(self, trainer: 'Trainer', context: 'TrainContext'):
|
def on_step_end(self, trainer: 'Trainer', context: 'TrainContext'):
|
||||||
""" Called at the end of each step."""
|
""" Called at the end of each step."""
|
||||||
_ = trainer, context
|
|
||||||
|
def on_batch_begin(self, trainer: 'Trainer', context: 'TrainContext'):
|
||||||
|
""" Called at the beginning of each batch. """
|
||||||
|
|
||||||
|
def on_batch_end(self, trainer: 'Trainer', context: 'TrainContext'):
|
||||||
|
""" Called at the end of each batch. """
|
||||||
|
|
||||||
|
|
||||||
class ProgressBarCallback(TrainCallback):
|
class ProgressBarCallback(TrainCallback):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue