fix: 修复梯度平均问题

This commit is contained in:
ViperEkura 2026-03-13 23:00:26 +08:00
parent c4feab96fe
commit e55b57d771
1 changed files with 1 additions and 2 deletions

View File

@ -86,8 +86,7 @@ class Trainer:
context.iteration += 1
# to make the loss normalized by accumulation steps
stand_batch = self.train_config.accumulation_steps * self.train_config.nprocs
stand_loss = loss / stand_batch
stand_loss = loss / self.train_config.accumulation_steps
stand_loss.backward()
self._call_callbacks('on_batch_end', context)