fix(khaosz/trainer): 调整训练循环中回调调用顺序并增强异常日志记录

This commit is contained in:
ViperEkura 2025-09-30 17:57:55 +08:00
parent dd6a9e4ede
commit 288e2c3da6
2 changed files with 18 additions and 12 deletions

View File

@ -1,3 +1,4 @@
import logging
from typing import Optional, List, cast
from torch.utils.data import DataLoader
@ -12,6 +13,7 @@ from khaosz.trainer.trainer_callback import (
SchedulerCallback
)
logger = logging.getLogger(__name__)
class Trainer:
def __init__(
@ -115,12 +117,6 @@ class Trainer:
train_kwargs["epoch"] = epoch
self._call_callbacks('on_epoch_begin', **train_kwargs)
for batch in dataloader:
# batch
self._call_callbacks('on_batch_begin', **train_kwargs)
loss = self.train_config.strategy(batch)
loss.backward()
train_kwargs["loss"] = loss.item()
self._call_callbacks('on_batch_end', **train_kwargs)
if train_kwargs["current_iter"] % self.train_config.accumulation_steps == 0:
# step
@ -128,13 +124,21 @@ class Trainer:
self.train_config.optimizer.step()
self.train_config.optimizer.zero_grad()
self._call_callbacks('on_step_end', **train_kwargs)
# batch
self._call_callbacks('on_batch_begin', **train_kwargs)
loss = self.train_config.strategy(batch)
train_kwargs["loss"] = loss.item()
train_kwargs["current_iter"] += 1
loss.backward()
self._call_callbacks('on_batch_end', **train_kwargs)
self._call_callbacks('on_epoch_end', **train_kwargs)
except Exception as e:
raise e
logger.error(f"Training failed: {str(e)}", exc_info=True)
raise
finally:
self._call_callbacks('on_train_end', **train_kwargs)
return checkpoint

View File

@ -87,10 +87,12 @@ class ProgressBarCallback(TrainerCallback):
)
def on_batch_end(self, trainer: 'Trainer', **kwargs):
_ = trainer
loss = kwargs.get('loss')
optimizer = cast(optim.Optimizer, kwargs.get('optimizer'))
self.progress_bar.set_postfix({
"loss": f"{loss:.4f}",
"lr": f"{trainer.train_config.optimizer.param_groups[0]['lr']:.2e}"
"lr": f"{optimizer.param_groups[-1]['lr']:.2e}"
})
self.progress_bar.update(1)
@ -180,7 +182,7 @@ class SchedulerCallback(TrainerCallback):
last_epoch=self.current_iter - 1
)
def on_step_end(self, trainer: 'Trainer', **kwargs):
def on_batch_end(self, trainer: 'Trainer', **kwargs):
_ = trainer, kwargs
if self.scheduler: