From 288e2c3da6e36ffedecdc430b827d01835599ac8 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Tue, 30 Sep 2025 17:57:55 +0800 Subject: [PATCH] =?UTF-8?q?fix(khaosz/trainer):=20=E8=B0=83=E6=95=B4?= =?UTF-8?q?=E8=AE=AD=E7=BB=83=E5=BE=AA=E7=8E=AF=E4=B8=AD=E5=9B=9E=E8=B0=83?= =?UTF-8?q?=E8=B0=83=E7=94=A8=E9=A1=BA=E5=BA=8F=E5=B9=B6=E5=A2=9E=E5=BC=BA?= =?UTF-8?q?=E5=BC=82=E5=B8=B8=E6=97=A5=E5=BF=97=E8=AE=B0=E5=BD=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/trainer/trainer.py | 24 ++++++++++++++---------- khaosz/trainer/trainer_callback.py | 6 ++++-- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/khaosz/trainer/trainer.py b/khaosz/trainer/trainer.py index 9a0005f..2350195 100644 --- a/khaosz/trainer/trainer.py +++ b/khaosz/trainer/trainer.py @@ -1,3 +1,4 @@ +import logging from typing import Optional, List, cast from torch.utils.data import DataLoader @@ -12,6 +13,7 @@ from khaosz.trainer.trainer_callback import ( SchedulerCallback ) +logger = logging.getLogger(__name__) class Trainer: def __init__( @@ -115,12 +117,6 @@ class Trainer: train_kwargs["epoch"] = epoch self._call_callbacks('on_epoch_begin', **train_kwargs) for batch in dataloader: - # batch - self._call_callbacks('on_batch_begin', **train_kwargs) - loss = self.train_config.strategy(batch) - loss.backward() - train_kwargs["loss"] = loss.item() - self._call_callbacks('on_batch_end', **train_kwargs) if train_kwargs["current_iter"] % self.train_config.accumulation_steps == 0: # step @@ -128,13 +124,21 @@ class Trainer: self.train_config.optimizer.step() self.train_config.optimizer.zero_grad() self._call_callbacks('on_step_end', **train_kwargs) - + + # batch + self._call_callbacks('on_batch_begin', **train_kwargs) + loss = self.train_config.strategy(batch) + train_kwargs["loss"] = loss.item() train_kwargs["current_iter"] += 1 - + loss.backward() + + self._call_callbacks('on_batch_end', **train_kwargs) + self._call_callbacks('on_epoch_end', **train_kwargs) - + except Exception as e: - raise e + logger.error(f"Training failed: {str(e)}", exc_info=True) + raise finally: self._call_callbacks('on_train_end', **train_kwargs) return checkpoint \ No newline at end of file diff --git a/khaosz/trainer/trainer_callback.py b/khaosz/trainer/trainer_callback.py index e1a9ec8..cb4b954 100644 --- a/khaosz/trainer/trainer_callback.py +++ b/khaosz/trainer/trainer_callback.py @@ -87,10 +87,12 @@ class ProgressBarCallback(TrainerCallback): ) def on_batch_end(self, trainer: 'Trainer', **kwargs): + _ = trainer loss = kwargs.get('loss') + optimizer = cast(optim.Optimizer, kwargs.get('optimizer')) self.progress_bar.set_postfix({ "loss": f"{loss:.4f}", - "lr": f"{trainer.train_config.optimizer.param_groups[0]['lr']:.2e}" + "lr": f"{optimizer.param_groups[-1]['lr']:.2e}" }) self.progress_bar.update(1) @@ -180,7 +182,7 @@ class SchedulerCallback(TrainerCallback): last_epoch=self.current_iter - 1 ) - def on_step_end(self, trainer: 'Trainer', **kwargs): + def on_batch_end(self, trainer: 'Trainer', **kwargs): _ = trainer, kwargs if self.scheduler: