diff --git a/khaosz/trainer/callback.py b/khaosz/trainer/callback.py index e7cdf25..bb0a3ea 100644 --- a/khaosz/trainer/callback.py +++ b/khaosz/trainer/callback.py @@ -1,3 +1,4 @@ +import os from tqdm import tqdm from khaosz.core.parameter import Checkpoint from torch.nn.utils import clip_grad_norm_ @@ -104,6 +105,13 @@ class CheckpointCallback(TrainerCallback): self.checkpoint_interval = checkpoint_interval self.last_ckpt_iter = 0 + @staticmethod + def _save_checkpoint(trainer: 'Trainer'): + current_iter = len(trainer.checkpoint.loss_list) + save_path = os.path.join(trainer.train_config.checkpoint_dir, f"iter_{current_iter}") + trainer.checkpoint.optim_state = trainer.train_config.optimizer.state_dict() + trainer.checkpoint.save(save_path) + def on_train_begin(self, trainer: 'Trainer', **kwargs): _ = trainer checkpoint = cast(Checkpoint, kwargs.get('checkpoint')) @@ -112,14 +120,14 @@ class CheckpointCallback(TrainerCallback): def on_batch_end(self, trainer: 'Trainer', **kwargs): current_iter = kwargs.get('current_iter') if current_iter - self.last_ckpt_iter >= self.checkpoint_interval: - trainer._save_checkpoint() + CheckpointCallback._save_checkpoint(trainer) self.last_ckpt_iter = current_iter def on_train_end(self, trainer: 'Trainer', **kwargs): checkpoint = cast(Checkpoint, kwargs.get('checkpoint')) current_iter = len(checkpoint.loss_list) if current_iter != self.last_ckpt_iter: - trainer._save_checkpoint() + CheckpointCallback._save_checkpoint(trainer) class GradientClippingCallback(TrainerCallback): diff --git a/khaosz/trainer/trainer.py b/khaosz/trainer/trainer.py index a2a26b6..3d44968 100644 --- a/khaosz/trainer/trainer.py +++ b/khaosz/trainer/trainer.py @@ -1,4 +1,3 @@ -import os import torch from typing import Optional, List from torch.utils.data import DataLoader, RandomSampler @@ -29,7 +28,6 @@ class Trainer: ) self.train_config = train_config self.schedule_config = schedule_config - self.callbacks = callbacks or self._get_default_callbacks() def _get_default_callbacks(self) -> List[TrainerCallback]: @@ -49,19 +47,13 @@ class Trainer: batch_size=self.train_config.batch_size, sampler=sampler ) - - def _save_checkpoint(self): - current_iter = len(self.checkpoint.loss_list) - save_path = os.path.join(self.train_config.checkpoint_dir, f"iter_{current_iter}") - self.checkpoint.optim_state = self.train_config.optimizer.state_dict() - self.checkpoint.save(save_path) def _call_callbacks(self, method_name: str, **kwargs): for callback in self.callbacks: method = getattr(callback, method_name, None) if method: method(self, **kwargs) - + def train( self, train_checkpoint: Optional[Checkpoint] = None