feat(khaosz/trainer): 改进训练循环中的损失归一化处理

This commit is contained in:
ViperEkura 2025-10-06 20:17:47 +08:00
parent c1bf22b6ec
commit 57cd7b921e
1 changed files with 7 additions and 1 deletions

View File

@ -57,22 +57,28 @@ class Trainer:
try:
self.parameter.model.train()
# 1.epoch
for epoch in range(context.epoch, self.train_config.n_epoch):
context.epoch = epoch
self._call_callbacks('on_epoch_begin', context)
for batch in context.dataloader:
if context.current_iter % self.train_config.accumulation_steps == 0:
# 2. step
self._call_callbacks('on_step_begin', context)
self.train_config.optimizer.step()
self.train_config.optimizer.zero_grad()
self._call_callbacks('on_step_end', context)
# 3. batch
self._call_callbacks('on_batch_begin', context)
loss = self.train_config.strategy(batch)
context.loss = loss.item()
context.current_iter += 1
loss.backward()
# to make the loss normalized by accumulation steps
normalized_loss = loss / self.train_config.accumulation_steps
normalized_loss.backward()
self._call_callbacks('on_batch_end', context)