feat(khaosz/trainer): 改进训练循环中的损失归一化处理
This commit is contained in:
parent
c1bf22b6ec
commit
57cd7b921e
|
|
@ -57,22 +57,28 @@ class Trainer:
|
|||
|
||||
try:
|
||||
self.parameter.model.train()
|
||||
# 1.epoch
|
||||
for epoch in range(context.epoch, self.train_config.n_epoch):
|
||||
context.epoch = epoch
|
||||
self._call_callbacks('on_epoch_begin', context)
|
||||
|
||||
for batch in context.dataloader:
|
||||
if context.current_iter % self.train_config.accumulation_steps == 0:
|
||||
# 2. step
|
||||
self._call_callbacks('on_step_begin', context)
|
||||
self.train_config.optimizer.step()
|
||||
self.train_config.optimizer.zero_grad()
|
||||
self._call_callbacks('on_step_end', context)
|
||||
|
||||
# 3. batch
|
||||
self._call_callbacks('on_batch_begin', context)
|
||||
loss = self.train_config.strategy(batch)
|
||||
context.loss = loss.item()
|
||||
context.current_iter += 1
|
||||
loss.backward()
|
||||
|
||||
# to make the loss normalized by accumulation steps
|
||||
normalized_loss = loss / self.train_config.accumulation_steps
|
||||
normalized_loss.backward()
|
||||
|
||||
self._call_callbacks('on_batch_end', context)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue