diff --git a/khaosz/trainer/trainer.py b/khaosz/trainer/trainer.py index 823511c..89faf26 100644 --- a/khaosz/trainer/trainer.py +++ b/khaosz/trainer/trainer.py @@ -86,8 +86,7 @@ class Trainer: context.iteration += 1 # to make the loss normalized by accumulation steps - stand_batch = self.train_config.accumulation_steps * self.train_config.nprocs - stand_loss = loss / stand_batch + stand_loss = loss / self.train_config.accumulation_steps stand_loss.backward() self._call_callbacks('on_batch_end', context)