chore: 更新计算顺序

This commit is contained in:
ViperEkura 2026-03-30 23:35:22 +08:00
parent 426af2d75f
commit eb57e55fca
1 changed files with 7 additions and 7 deletions

View File

@ -76,6 +76,13 @@ class Trainer:
self._call_callbacks("on_epoch_begin", context) self._call_callbacks("on_epoch_begin", context)
for batch in context.dataloader: for batch in context.dataloader:
if context.iteration % self.train_config.accumulation_steps == 0:
# 2. step
self._call_callbacks("on_step_begin", context)
context.optimizer.step()
context.optimizer.zero_grad()
self._call_callbacks("on_step_end", context)
# 3. batch # 3. batch
self._call_callbacks("on_batch_begin", context) self._call_callbacks("on_batch_begin", context)
loss = context.strategy(batch) loss = context.strategy(batch)
@ -88,13 +95,6 @@ class Trainer:
self._call_callbacks("on_batch_end", context) self._call_callbacks("on_batch_end", context)
if context.iteration % self.train_config.accumulation_steps == 0:
# 2. step
self._call_callbacks("on_step_begin", context)
context.optimizer.step()
context.optimizer.zero_grad()
self._call_callbacks("on_step_end", context)
self._call_callbacks("on_epoch_end", context) self._call_callbacks("on_epoch_end", context)
except Exception as e: except Exception as e: