feat(khaosz/trainer): 改进训练循环中的损失归一化处理
This commit is contained in:
parent
c1bf22b6ec
commit
57cd7b921e
|
|
@ -57,22 +57,28 @@ class Trainer:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.parameter.model.train()
|
self.parameter.model.train()
|
||||||
|
# 1.epoch
|
||||||
for epoch in range(context.epoch, self.train_config.n_epoch):
|
for epoch in range(context.epoch, self.train_config.n_epoch):
|
||||||
context.epoch = epoch
|
context.epoch = epoch
|
||||||
self._call_callbacks('on_epoch_begin', context)
|
self._call_callbacks('on_epoch_begin', context)
|
||||||
|
|
||||||
for batch in context.dataloader:
|
for batch in context.dataloader:
|
||||||
if context.current_iter % self.train_config.accumulation_steps == 0:
|
if context.current_iter % self.train_config.accumulation_steps == 0:
|
||||||
|
# 2. step
|
||||||
self._call_callbacks('on_step_begin', context)
|
self._call_callbacks('on_step_begin', context)
|
||||||
self.train_config.optimizer.step()
|
self.train_config.optimizer.step()
|
||||||
self.train_config.optimizer.zero_grad()
|
self.train_config.optimizer.zero_grad()
|
||||||
self._call_callbacks('on_step_end', context)
|
self._call_callbacks('on_step_end', context)
|
||||||
|
|
||||||
|
# 3. batch
|
||||||
self._call_callbacks('on_batch_begin', context)
|
self._call_callbacks('on_batch_begin', context)
|
||||||
loss = self.train_config.strategy(batch)
|
loss = self.train_config.strategy(batch)
|
||||||
context.loss = loss.item()
|
context.loss = loss.item()
|
||||||
context.current_iter += 1
|
context.current_iter += 1
|
||||||
loss.backward()
|
|
||||||
|
# to make the loss normalized by accumulation steps
|
||||||
|
normalized_loss = loss / self.train_config.accumulation_steps
|
||||||
|
normalized_loss.backward()
|
||||||
|
|
||||||
self._call_callbacks('on_batch_end', context)
|
self._call_callbacks('on_batch_end', context)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue