fix: 修复梯度平均问题
This commit is contained in:
parent
c4feab96fe
commit
e55b57d771
|
|
@ -86,8 +86,7 @@ class Trainer:
|
|||
context.iteration += 1
|
||||
|
||||
# to make the loss normalized by accumulation steps
|
||||
stand_batch = self.train_config.accumulation_steps * self.train_config.nprocs
|
||||
stand_loss = loss / stand_batch
|
||||
stand_loss = loss / self.train_config.accumulation_steps
|
||||
stand_loss.backward()
|
||||
|
||||
self._call_callbacks('on_batch_end', context)
|
||||
|
|
|
|||
Loading…
Reference in New Issue