chore: 更新计算顺序
This commit is contained in:
parent
426af2d75f
commit
eb57e55fca
|
|
@ -76,6 +76,13 @@ class Trainer:
|
|||
self._call_callbacks("on_epoch_begin", context)
|
||||
|
||||
for batch in context.dataloader:
|
||||
if context.iteration % self.train_config.accumulation_steps == 0:
|
||||
# 2. step
|
||||
self._call_callbacks("on_step_begin", context)
|
||||
context.optimizer.step()
|
||||
context.optimizer.zero_grad()
|
||||
self._call_callbacks("on_step_end", context)
|
||||
|
||||
# 3. batch
|
||||
self._call_callbacks("on_batch_begin", context)
|
||||
loss = context.strategy(batch)
|
||||
|
|
@ -88,13 +95,6 @@ class Trainer:
|
|||
|
||||
self._call_callbacks("on_batch_end", context)
|
||||
|
||||
if context.iteration % self.train_config.accumulation_steps == 0:
|
||||
# 2. step
|
||||
self._call_callbacks("on_step_begin", context)
|
||||
context.optimizer.step()
|
||||
context.optimizer.zero_grad()
|
||||
self._call_callbacks("on_step_end", context)
|
||||
|
||||
self._call_callbacks("on_epoch_end", context)
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
|||
Loading…
Reference in New Issue