chore: 更新计算顺序
This commit is contained in:
parent
426af2d75f
commit
eb57e55fca
|
|
@ -76,6 +76,13 @@ class Trainer:
|
||||||
self._call_callbacks("on_epoch_begin", context)
|
self._call_callbacks("on_epoch_begin", context)
|
||||||
|
|
||||||
for batch in context.dataloader:
|
for batch in context.dataloader:
|
||||||
|
if context.iteration % self.train_config.accumulation_steps == 0:
|
||||||
|
# 2. step
|
||||||
|
self._call_callbacks("on_step_begin", context)
|
||||||
|
context.optimizer.step()
|
||||||
|
context.optimizer.zero_grad()
|
||||||
|
self._call_callbacks("on_step_end", context)
|
||||||
|
|
||||||
# 3. batch
|
# 3. batch
|
||||||
self._call_callbacks("on_batch_begin", context)
|
self._call_callbacks("on_batch_begin", context)
|
||||||
loss = context.strategy(batch)
|
loss = context.strategy(batch)
|
||||||
|
|
@ -88,13 +95,6 @@ class Trainer:
|
||||||
|
|
||||||
self._call_callbacks("on_batch_end", context)
|
self._call_callbacks("on_batch_end", context)
|
||||||
|
|
||||||
if context.iteration % self.train_config.accumulation_steps == 0:
|
|
||||||
# 2. step
|
|
||||||
self._call_callbacks("on_step_begin", context)
|
|
||||||
context.optimizer.step()
|
|
||||||
context.optimizer.zero_grad()
|
|
||||||
self._call_callbacks("on_step_end", context)
|
|
||||||
|
|
||||||
self._call_callbacks("on_epoch_end", context)
|
self._call_callbacks("on_epoch_end", context)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue