diff --git a/khaosz/trainer/trainer.py b/khaosz/trainer/trainer.py index 179fdd4..5a5a88e 100644 --- a/khaosz/trainer/trainer.py +++ b/khaosz/trainer/trainer.py @@ -76,6 +76,13 @@ class Trainer: self._call_callbacks("on_epoch_begin", context) for batch in context.dataloader: + if context.iteration % self.train_config.accumulation_steps == 0: + # 2. step + self._call_callbacks("on_step_begin", context) + context.optimizer.step() + context.optimizer.zero_grad() + self._call_callbacks("on_step_end", context) + # 3. batch self._call_callbacks("on_batch_begin", context) loss = context.strategy(batch) @@ -88,13 +95,6 @@ class Trainer: self._call_callbacks("on_batch_end", context) - if context.iteration % self.train_config.accumulation_steps == 0: - # 2. step - self._call_callbacks("on_step_begin", context) - context.optimizer.step() - context.optimizer.zero_grad() - self._call_callbacks("on_step_end", context) - self._call_callbacks("on_epoch_end", context) except Exception as e: