From eb57e55fcac966860b9e7dedc1657827a86c9a6e Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Mon, 30 Mar 2026 23:35:22 +0800 Subject: [PATCH] =?UTF-8?q?chore:=20=E6=9B=B4=E6=96=B0=E8=AE=A1=E7=AE=97?= =?UTF-8?q?=E9=A1=BA=E5=BA=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/trainer/trainer.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/khaosz/trainer/trainer.py b/khaosz/trainer/trainer.py index 179fdd4..5a5a88e 100644 --- a/khaosz/trainer/trainer.py +++ b/khaosz/trainer/trainer.py @@ -76,6 +76,13 @@ class Trainer: self._call_callbacks("on_epoch_begin", context) for batch in context.dataloader: + if context.iteration % self.train_config.accumulation_steps == 0: + # 2. step + self._call_callbacks("on_step_begin", context) + context.optimizer.step() + context.optimizer.zero_grad() + self._call_callbacks("on_step_end", context) + # 3. batch self._call_callbacks("on_batch_begin", context) loss = context.strategy(batch) @@ -88,13 +95,6 @@ class Trainer: self._call_callbacks("on_batch_end", context) - if context.iteration % self.train_config.accumulation_steps == 0: - # 2. step - self._call_callbacks("on_step_begin", context) - context.optimizer.step() - context.optimizer.zero_grad() - self._call_callbacks("on_step_end", context) - self._call_callbacks("on_epoch_end", context) except Exception as e: