diff --git a/khaosz/trainer/callback.py b/khaosz/trainer/callback.py new file mode 100644 index 0000000..4a6ef96 --- /dev/null +++ b/khaosz/trainer/callback.py @@ -0,0 +1,88 @@ +from tqdm import tqdm +from khaosz.core.parameter import Checkpoint +from khaosz.trainer.trainer import Trainer +from torch.nn.utils import clip_grad_norm_ +from typing import cast + + +class TrainerCallback: + def on_train_begin(self, trainer: 'Trainer', **kwargs): + pass + + def on_train_end(self, trainer: 'Trainer', **kwargs): + pass + + def on_epoch_begin(self, trainer: 'Trainer', **kwargs): + pass + + def on_epoch_end(self, trainer: 'Trainer', **kwargs): + pass + + def on_batch_begin(self, trainer: 'Trainer', **kwargs): + pass + + def on_batch_end(self, trainer: 'Trainer', **kwargs): + pass + + def on_step_begin(self, trainer: 'Trainer', **kwargs): + pass + + def on_step_end(self, trainer: 'Trainer', **kwargs): + pass + + +class ProgressBarCallback(TrainerCallback): + def __init__(self): + self.progress_bar: tqdm = None + + def on_epoch_begin(self, trainer: 'Trainer', **kwargs): + epoch = kwargs.get('epoch') + dataloader = trainer._create_dataloader() + self.progress_bar = tqdm( + dataloader, + desc=f"Epoch {epoch+1}/{trainer.train_config.n_epoch}", + dynamic_ncols=True + ) + + def on_batch_end(self, trainer: 'Trainer', **kwargs): + loss = kwargs.get('loss') + self.progress_bar.set_postfix({ + "loss": f"{loss:.4f}", + "lr": f"{trainer.train_config.optimizer.param_groups[0]['lr']:.2e}" + }) + self.progress_bar.update(1) + + def on_epoch_end(self, trainer: 'Trainer', **kwargs): + if self.progress_bar: + self.progress_bar.close() + + +class CheckpointCallback(TrainerCallback): + def __init__(self, checkpoint_interval: int): + self.checkpoint_interval = checkpoint_interval + self.last_ckpt_iter = 0 + + def on_train_begin(self, trainer: 'Trainer', **kwargs): + checkpoint = cast(Checkpoint, kwargs.get('checkpoint')) + self.last_ckpt_iter = len(checkpoint.loss_list) + + def on_batch_end(self, trainer: 'Trainer', **kwargs): + current_iter = kwargs.get('current_iter') + if current_iter - self.last_ckpt_iter >= self.checkpoint_interval: + trainer._save_checkpoint() + self.last_ckpt_iter = current_iter + + def on_train_end(self, trainer: 'Trainer', **kwargs): + checkpoint = cast(Checkpoint, kwargs.get('checkpoint')) + current_iter = len(checkpoint.loss_list) + if current_iter != self.last_ckpt_iter: + trainer._save_checkpoint() + + +class GradientClippingCallback(TrainerCallback): + + def on_step_begin(self, trainer: 'Trainer', **kwargs): + clip_grad_norm_( + trainer.checkpoint.model.parameters(), + trainer.train_config.max_grad_norm + ) \ No newline at end of file diff --git a/khaosz/trainer/trainer.py b/khaosz/trainer/trainer.py index 25e34db..973a097 100644 --- a/khaosz/trainer/trainer.py +++ b/khaosz/trainer/trainer.py @@ -1,108 +1,12 @@ - import os import torch -from abc import abstractmethod -from typing import Optional, List, override -from torch.nn.utils import clip_grad_norm_ +from typing import Optional, List from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader, RandomSampler -from tqdm import tqdm from khaosz.core import ModelParameter, Checkpoint from khaosz.trainer.strategy import SchedulerFactory, TrainConfig, ScheduleConfig - - -class TrainerCallback: - @abstractmethod - def on_train_begin(self, trainer: 'Trainer', **kwargs): - pass - - @abstractmethod - def on_train_end(self, trainer: 'Trainer', **kwargs): - pass - - @abstractmethod - def on_epoch_begin(self, trainer: 'Trainer', **kwargs): - pass - - @abstractmethod - def on_epoch_end(self, trainer: 'Trainer', **kwargs): - pass - - @abstractmethod - def on_batch_begin(self, trainer: 'Trainer', **kwargs): - pass - - @abstractmethod - def on_batch_end(self, trainer: 'Trainer', **kwargs): - pass - - @abstractmethod - def on_step_begin(self, trainer: 'Trainer', **kwargs): - pass - - @abstractmethod - def on_step_end(self, trainer: 'Trainer', **kwargs): - pass - - - -class ProgressBarCallback(TrainerCallback): - def __init__(self): - self.progress_bar: tqdm = None - - def on_epoch_begin(self, trainer: 'Trainer', **kwargs): - epoch = kwargs.get('epoch') - dataloader = trainer._create_dataloader() - self.progress_bar = tqdm( - dataloader, - desc=f"Epoch {epoch+1}/{trainer.train_config.n_epoch}", - dynamic_ncols=True - ) - - def on_batch_end(self, trainer: 'Trainer', **kwargs): - loss = kwargs.get('loss') - self.progress_bar.set_postfix({ - "loss": f"{loss:.4f}", - "lr": f"{trainer.train_config.optimizer.param_groups[0]['lr']:.2e}" - }) - self.progress_bar.update(1) - - - def on_epoch_end(self, trainer: 'Trainer', **kwargs): - if self.progress_bar: - self.progress_bar.close() - - -class CheckpointCallback(TrainerCallback): - def __init__(self, checkpoint_interval: int): - self.checkpoint_interval = checkpoint_interval - self.last_ckpt_iter = 0 - - def on_train_begin(self, trainer: 'Trainer', **kwargs): - checkpoint = kwargs.get('checkpoint') - self.last_ckpt_iter = len(checkpoint.loss_list) - - def on_batch_end(self, trainer: 'Trainer', **kwargs): - current_iter = kwargs.get('current_iter') - if current_iter - self.last_ckpt_iter >= self.checkpoint_interval: - trainer._save_checkpoint() - self.last_ckpt_iter = current_iter - - def on_train_end(self, trainer: 'Trainer', **kwargs): - checkpoint = kwargs.get('checkpoint') - current_iter = len(checkpoint.loss_list) - if current_iter != self.last_ckpt_iter: - trainer._save_checkpoint() - - -class GradientClippingCallback(TrainerCallback): - - def on_step_begin(self, trainer: 'Trainer', **kwargs): - clip_grad_norm_( - trainer.checkpoint.model.parameters(), - trainer.train_config.max_grad_norm - ) +from khaosz.trainer.callback import TrainerCallback, ProgressBarCallback, CheckpointCallback, GradientClippingCallback class Trainer: @@ -214,6 +118,9 @@ class Trainer: self._call_callbacks('on_epoch_end', epoch=epoch, loss_list=self.checkpoint.loss_list) + except Exception as e: + raise e + finally: self._call_callbacks('on_train_end', checkpoint=self.checkpoint)