From 57cd7b921ed0b7e96c62c3c9328911ef4b919771 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Mon, 6 Oct 2025 20:17:47 +0800 Subject: [PATCH] =?UTF-8?q?feat(khaosz/trainer):=20=E6=94=B9=E8=BF=9B?= =?UTF-8?q?=E8=AE=AD=E7=BB=83=E5=BE=AA=E7=8E=AF=E4=B8=AD=E7=9A=84=E6=8D=9F?= =?UTF-8?q?=E5=A4=B1=E5=BD=92=E4=B8=80=E5=8C=96=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/trainer/trainer.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/khaosz/trainer/trainer.py b/khaosz/trainer/trainer.py index 9d88704..fefa70f 100644 --- a/khaosz/trainer/trainer.py +++ b/khaosz/trainer/trainer.py @@ -57,22 +57,28 @@ class Trainer: try: self.parameter.model.train() + # 1.epoch for epoch in range(context.epoch, self.train_config.n_epoch): context.epoch = epoch self._call_callbacks('on_epoch_begin', context) for batch in context.dataloader: if context.current_iter % self.train_config.accumulation_steps == 0: + # 2. step self._call_callbacks('on_step_begin', context) self.train_config.optimizer.step() self.train_config.optimizer.zero_grad() self._call_callbacks('on_step_end', context) + # 3. batch self._call_callbacks('on_batch_begin', context) loss = self.train_config.strategy(batch) context.loss = loss.item() context.current_iter += 1 - loss.backward() + + # to make the loss normalized by accumulation steps + normalized_loss = loss / self.train_config.accumulation_steps + normalized_loss.backward() self._call_callbacks('on_batch_end', context)