diff --git a/khaosz/trainer/trainer.py b/khaosz/trainer/trainer.py index 9d88704..fefa70f 100644 --- a/khaosz/trainer/trainer.py +++ b/khaosz/trainer/trainer.py @@ -57,22 +57,28 @@ class Trainer: try: self.parameter.model.train() + # 1.epoch for epoch in range(context.epoch, self.train_config.n_epoch): context.epoch = epoch self._call_callbacks('on_epoch_begin', context) for batch in context.dataloader: if context.current_iter % self.train_config.accumulation_steps == 0: + # 2. step self._call_callbacks('on_step_begin', context) self.train_config.optimizer.step() self.train_config.optimizer.zero_grad() self._call_callbacks('on_step_end', context) + # 3. batch self._call_callbacks('on_batch_begin', context) loss = self.train_config.strategy(batch) context.loss = loss.item() context.current_iter += 1 - loss.backward() + + # to make the loss normalized by accumulation steps + normalized_loss = loss / self.train_config.accumulation_steps + normalized_loss.backward() self._call_callbacks('on_batch_end', context)