feat(callback): 为 TrainerCallback 及其子类添加文档字符串和未使用参数占位符
This commit is contained in:
parent
e52803ddc3
commit
b2f3fefa1b
|
|
@ -6,32 +6,66 @@ from typing import cast
|
|||
|
||||
|
||||
class TrainerCallback:
|
||||
"""
|
||||
Callback interface for trainer.
|
||||
and we use '_' to ignore unused parameters.
|
||||
"""
|
||||
|
||||
def on_train_begin(self, trainer: 'Trainer', **kwargs):
|
||||
pass
|
||||
"""
|
||||
Called at the beginning of training.
|
||||
"""
|
||||
_ = trainer, kwargs
|
||||
|
||||
def on_train_end(self, trainer: 'Trainer', **kwargs):
|
||||
pass
|
||||
"""
|
||||
Called at the end of training.
|
||||
"""
|
||||
_ = trainer, kwargs
|
||||
|
||||
def on_epoch_begin(self, trainer: 'Trainer', **kwargs):
|
||||
pass
|
||||
"""
|
||||
Called at the beginning of each epoch.
|
||||
"""
|
||||
_ = trainer, kwargs
|
||||
|
||||
def on_epoch_end(self, trainer: 'Trainer', **kwargs):
|
||||
pass
|
||||
"""
|
||||
Called at the end of each epoch.
|
||||
"""
|
||||
_ = trainer, kwargs
|
||||
|
||||
def on_batch_begin(self, trainer: 'Trainer', **kwargs):
|
||||
pass
|
||||
"""
|
||||
Called at the beginning of each batch.
|
||||
"""
|
||||
_ = trainer, kwargs
|
||||
|
||||
def on_batch_end(self, trainer: 'Trainer', **kwargs):
|
||||
pass
|
||||
"""
|
||||
Called at the end of each batch.
|
||||
"""
|
||||
_ = trainer, kwargs
|
||||
|
||||
def on_step_begin(self, trainer: 'Trainer', **kwargs):
|
||||
pass
|
||||
"""
|
||||
Called at the beginning of each step.
|
||||
"""
|
||||
|
||||
_ = trainer, kwargs
|
||||
|
||||
def on_step_end(self, trainer: 'Trainer', **kwargs):
|
||||
pass
|
||||
"""
|
||||
Called at the end of each step.
|
||||
"""
|
||||
|
||||
_ = trainer, kwargs
|
||||
|
||||
|
||||
class ProgressBarCallback(TrainerCallback):
|
||||
"""
|
||||
Progress bar callback for trainer.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.progress_bar: tqdm = None
|
||||
|
||||
|
|
@ -53,16 +87,21 @@ class ProgressBarCallback(TrainerCallback):
|
|||
self.progress_bar.update(1)
|
||||
|
||||
def on_epoch_end(self, trainer: 'Trainer', **kwargs):
|
||||
_ = trainer, kwargs
|
||||
if self.progress_bar:
|
||||
self.progress_bar.close()
|
||||
|
||||
|
||||
class CheckpointCallback(TrainerCallback):
|
||||
"""
|
||||
Checkpoint callback for trainer.
|
||||
"""
|
||||
def __init__(self, checkpoint_interval: int):
|
||||
self.checkpoint_interval = checkpoint_interval
|
||||
self.last_ckpt_iter = 0
|
||||
|
||||
def on_train_begin(self, trainer: 'Trainer', **kwargs):
|
||||
_ = trainer
|
||||
checkpoint = cast(Checkpoint, kwargs.get('checkpoint'))
|
||||
self.last_ckpt_iter = len(checkpoint.loss_list)
|
||||
|
||||
|
|
@ -80,8 +119,11 @@ class CheckpointCallback(TrainerCallback):
|
|||
|
||||
|
||||
class GradientClippingCallback(TrainerCallback):
|
||||
|
||||
"""
|
||||
Gradient clipping callback for trainer.
|
||||
"""
|
||||
def on_step_begin(self, trainer: 'Trainer', **kwargs):
|
||||
_ = kwargs
|
||||
clip_grad_norm_(
|
||||
trainer.checkpoint.model.parameters(),
|
||||
trainer.train_config.max_grad_norm
|
||||
|
|
|
|||
Loading…
Reference in New Issue