diff --git a/khaosz/trainer/trainer.py b/khaosz/trainer/trainer.py index 2392da0..25e34db 100644 --- a/khaosz/trainer/trainer.py +++ b/khaosz/trainer/trainer.py @@ -1,7 +1,8 @@ + import os import torch - -from typing import Optional +from abc import abstractmethod +from typing import Optional, List, override from torch.nn.utils import clip_grad_norm_ from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader, RandomSampler @@ -11,12 +12,106 @@ from khaosz.core import ModelParameter, Checkpoint from khaosz.trainer.strategy import SchedulerFactory, TrainConfig, ScheduleConfig +class TrainerCallback: + @abstractmethod + def on_train_begin(self, trainer: 'Trainer', **kwargs): + pass + + @abstractmethod + def on_train_end(self, trainer: 'Trainer', **kwargs): + pass + + @abstractmethod + def on_epoch_begin(self, trainer: 'Trainer', **kwargs): + pass + + @abstractmethod + def on_epoch_end(self, trainer: 'Trainer', **kwargs): + pass + + @abstractmethod + def on_batch_begin(self, trainer: 'Trainer', **kwargs): + pass + + @abstractmethod + def on_batch_end(self, trainer: 'Trainer', **kwargs): + pass + + @abstractmethod + def on_step_begin(self, trainer: 'Trainer', **kwargs): + pass + + @abstractmethod + def on_step_end(self, trainer: 'Trainer', **kwargs): + pass + + + +class ProgressBarCallback(TrainerCallback): + def __init__(self): + self.progress_bar: tqdm = None + + def on_epoch_begin(self, trainer: 'Trainer', **kwargs): + epoch = kwargs.get('epoch') + dataloader = trainer._create_dataloader() + self.progress_bar = tqdm( + dataloader, + desc=f"Epoch {epoch+1}/{trainer.train_config.n_epoch}", + dynamic_ncols=True + ) + + def on_batch_end(self, trainer: 'Trainer', **kwargs): + loss = kwargs.get('loss') + self.progress_bar.set_postfix({ + "loss": f"{loss:.4f}", + "lr": f"{trainer.train_config.optimizer.param_groups[0]['lr']:.2e}" + }) + self.progress_bar.update(1) + + + def on_epoch_end(self, trainer: 'Trainer', **kwargs): + if self.progress_bar: + self.progress_bar.close() + + +class CheckpointCallback(TrainerCallback): + def __init__(self, checkpoint_interval: int): + self.checkpoint_interval = checkpoint_interval + self.last_ckpt_iter = 0 + + def on_train_begin(self, trainer: 'Trainer', **kwargs): + checkpoint = kwargs.get('checkpoint') + self.last_ckpt_iter = len(checkpoint.loss_list) + + def on_batch_end(self, trainer: 'Trainer', **kwargs): + current_iter = kwargs.get('current_iter') + if current_iter - self.last_ckpt_iter >= self.checkpoint_interval: + trainer._save_checkpoint() + self.last_ckpt_iter = current_iter + + def on_train_end(self, trainer: 'Trainer', **kwargs): + checkpoint = kwargs.get('checkpoint') + current_iter = len(checkpoint.loss_list) + if current_iter != self.last_ckpt_iter: + trainer._save_checkpoint() + + +class GradientClippingCallback(TrainerCallback): + + def on_step_begin(self, trainer: 'Trainer', **kwargs): + clip_grad_norm_( + trainer.checkpoint.model.parameters(), + trainer.train_config.max_grad_norm + ) + + class Trainer: def __init__( self, parameter: ModelParameter, train_config: TrainConfig, - schedule_config: ScheduleConfig + schedule_config: ScheduleConfig, + callbacks: Optional[List[TrainerCallback]] = None ): self.checkpoint = Checkpoint( model=parameter.model, @@ -26,16 +121,37 @@ class Trainer: self.train_config = train_config self.schedule_config = schedule_config - def save_checkpoint( - self, - loss_list: list, - ): - current_iter = len(loss_list) + self.callbacks = callbacks or self._get_default_callbacks() + + def _get_default_callbacks(self) -> List[TrainerCallback]: + return [ + ProgressBarCallback(), + CheckpointCallback(self.train_config.checkpoint_interval), + GradientClippingCallback(), + ] + + def _create_dataloader(self) -> DataLoader: + seed = self.train_config.random_seed + generator = torch.Generator().manual_seed(seed) + sampler = RandomSampler(self.train_config.dataset, generator=generator) + return DataLoader( + self.train_config.dataset, + batch_size=self.train_config.batch_size, + sampler=sampler + ) + + def _save_checkpoint(self): + current_iter = len(self.checkpoint.loss_list) save_path = os.path.join(self.train_config.checkpoint_dir, f"iter_{current_iter}") - self.checkpoint.loss_list = loss_list self.checkpoint.optim_state = self.train_config.optimizer.state_dict() self.checkpoint.save(save_path) + def _call_callbacks(self, method_name: str, **kwargs): + for callback in self.callbacks: + method = getattr(callback, method_name, None) + if method: + method(self, **kwargs) + def train( self, train_checkpoint: Optional[Checkpoint] = None @@ -47,16 +163,13 @@ class Trainer: self.train_config.optimizer.load_state_dict(train_checkpoint.optim_state) self.checkpoint.optim_state = self.train_config.optimizer.state_dict() - loss_list = self.checkpoint.loss_list current_iter = len(self.checkpoint.loss_list) - last_ckpt_iter = current_iter for group in self.train_config.optimizer.param_groups: if "initial_lr" not in group: group["initial_lr"] = group["lr"] - - lambda_scheduler_fn = SchedulerFactory.load_schedule_fn( + lambda_scheduler_fn = SchedulerFactory.load_schedule_fn( **self.schedule_config.get_kwargs() ) @@ -66,52 +179,42 @@ class Trainer: last_epoch=current_iter - 1 if train_checkpoint else -1 ) - seed = self.train_config.random_seed - generator = torch.Generator().manual_seed(seed) - sampler = RandomSampler(self.train_config.dataset, generator=generator) - remaining_epochs = self.train_config.n_epoch - current_iter // ( - len(self.train_config.dataset) // self.train_config.batch_size) + reamining_steps = self.train_config.n_epoch - current_iter + total_steps = len(self.train_config.dataset) // self.train_config.batch_size + remaining_epochs = (reamining_steps + total_steps - 1) // total_steps + # train + self._call_callbacks('on_train_begin', checkpoint=self.checkpoint) - for epoch in range(remaining_epochs): - self.checkpoint.model.train() - dataloader = DataLoader( - self.train_config.dataset, - batch_size=self.train_config.batch_size, - sampler=sampler - ) - progress_bar = tqdm( - dataloader, - desc=f"Epoch {epoch+1}/{self.train_config.n_epoch}", - dynamic_ncols=True - ) - for batch in progress_bar: - #forward - loss = self.train_config.strategy(batch) - loss_list.append(loss.item()) - #backward - loss.backward() - #step - if current_iter % self.train_config.accumulation_steps == 0: - clip_grad_norm_( - self.checkpoint.model.parameters(), - self.train_config.max_grad_norm - ) - self.train_config.optimizer.step() - self.train_config.optimizer.zero_grad() + try: + for epoch in range(remaining_epochs): + self.checkpoint.model.train() + + # epoch + self._call_callbacks('on_epoch_begin', epoch=epoch) + + dataloader = self._create_dataloader() + + for batch in dataloader: + # batch + self._call_callbacks('on_batch_begin', batch=batch) + loss = self.train_config.strategy(batch) + self.checkpoint.loss_list.append(loss.item()) + loss.backward() + self._call_callbacks('on_batch_end', batch=batch, loss=loss.item(), current_iter=current_iter) - current_iter += 1 - scheduler.step() - progress_bar.set_postfix({ - "loss": f"{loss.item():.4f}", - "lr": f"{self.train_config.optimizer.param_groups[0]['lr']:.2e}" - }) - #save checkpotint - if current_iter - last_ckpt_iter >= self.train_config.checkpoint_interval: - self.save_checkpoint(loss_list) - last_ckpt_iter = current_iter - - if current_iter != last_ckpt_iter: - self.save_checkpoint(loss_list) - last_ckpt_iter = current_iter + if current_iter % self.train_config.accumulation_steps == 0: + # step + self._call_callbacks('on_step_begin', current_iter=current_iter) + self.train_config.optimizer.step() + self.train_config.optimizer.zero_grad() + self._call_callbacks('on_step_end', current_iter=current_iter) + + current_iter += 1 + scheduler.step() + + self._call_callbacks('on_epoch_end', epoch=epoch, loss_list=self.checkpoint.loss_list) + + finally: + self._call_callbacks('on_train_end', checkpoint=self.checkpoint) return self.checkpoint \ No newline at end of file