diff --git a/khaosz/trainer/trainer.py b/khaosz/trainer/trainer.py index 9a0005f..2350195 100644 --- a/khaosz/trainer/trainer.py +++ b/khaosz/trainer/trainer.py @@ -1,3 +1,4 @@ +import logging from typing import Optional, List, cast from torch.utils.data import DataLoader @@ -12,6 +13,7 @@ from khaosz.trainer.trainer_callback import ( SchedulerCallback ) +logger = logging.getLogger(__name__) class Trainer: def __init__( @@ -115,12 +117,6 @@ class Trainer: train_kwargs["epoch"] = epoch self._call_callbacks('on_epoch_begin', **train_kwargs) for batch in dataloader: - # batch - self._call_callbacks('on_batch_begin', **train_kwargs) - loss = self.train_config.strategy(batch) - loss.backward() - train_kwargs["loss"] = loss.item() - self._call_callbacks('on_batch_end', **train_kwargs) if train_kwargs["current_iter"] % self.train_config.accumulation_steps == 0: # step @@ -128,13 +124,21 @@ class Trainer: self.train_config.optimizer.step() self.train_config.optimizer.zero_grad() self._call_callbacks('on_step_end', **train_kwargs) - + + # batch + self._call_callbacks('on_batch_begin', **train_kwargs) + loss = self.train_config.strategy(batch) + train_kwargs["loss"] = loss.item() train_kwargs["current_iter"] += 1 - + loss.backward() + + self._call_callbacks('on_batch_end', **train_kwargs) + self._call_callbacks('on_epoch_end', **train_kwargs) - + except Exception as e: - raise e + logger.error(f"Training failed: {str(e)}", exc_info=True) + raise finally: self._call_callbacks('on_train_end', **train_kwargs) return checkpoint \ No newline at end of file diff --git a/khaosz/trainer/trainer_callback.py b/khaosz/trainer/trainer_callback.py index e1a9ec8..cb4b954 100644 --- a/khaosz/trainer/trainer_callback.py +++ b/khaosz/trainer/trainer_callback.py @@ -87,10 +87,12 @@ class ProgressBarCallback(TrainerCallback): ) def on_batch_end(self, trainer: 'Trainer', **kwargs): + _ = trainer loss = kwargs.get('loss') + optimizer = cast(optim.Optimizer, kwargs.get('optimizer')) self.progress_bar.set_postfix({ "loss": f"{loss:.4f}", - "lr": f"{trainer.train_config.optimizer.param_groups[0]['lr']:.2e}" + "lr": f"{optimizer.param_groups[-1]['lr']:.2e}" }) self.progress_bar.update(1) @@ -180,7 +182,7 @@ class SchedulerCallback(TrainerCallback): last_epoch=self.current_iter - 1 ) - def on_step_end(self, trainer: 'Trainer', **kwargs): + def on_batch_end(self, trainer: 'Trainer', **kwargs): _ = trainer, kwargs if self.scheduler: