From 648e4e177be6733b76b8df0a38bac6f65abf2d4f Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Mon, 29 Sep 2025 13:18:44 +0800 Subject: [PATCH] =?UTF-8?q?feat(khaosz/trainer):=20=E6=B7=BB=E5=8A=A0Sched?= =?UTF-8?q?ulerCallback=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/trainer/__init__.py | 15 ++++++++++++++- khaosz/trainer/callback.py | 37 +++++++++++++++++++++++++++++++++++-- khaosz/trainer/trainer.py | 23 +++++++++-------------- 3 files changed, 58 insertions(+), 17 deletions(-) diff --git a/khaosz/trainer/__init__.py b/khaosz/trainer/__init__.py index 6d40793..2954dbb 100644 --- a/khaosz/trainer/__init__.py +++ b/khaosz/trainer/__init__.py @@ -7,13 +7,26 @@ from khaosz.trainer.strategy import ( StrategyFactory, SchedulerFactory ) +from khaosz.trainer.callback import ( + ProgressBarCallback, + CheckpointCallback, + TrainerCallback, + SchedulerCallback +) __all__ = [ + # strategy "DatasetLoader", "Trainer", "TrainConfig", "CosineScheduleConfig", "SgdrScheduleConfig", "StrategyFactory", - "SchedulerFactory" + "SchedulerFactory", + + # callback + "ProgressBarCallback", + "CheckpointCallback", + "TrainerCallback", + "SchedulerCallback", ] \ No newline at end of file diff --git a/khaosz/trainer/callback.py b/khaosz/trainer/callback.py index 05ba1ba..e7cdf25 100644 --- a/khaosz/trainer/callback.py +++ b/khaosz/trainer/callback.py @@ -1,7 +1,9 @@ from tqdm import tqdm from khaosz.core.parameter import Checkpoint from torch.nn.utils import clip_grad_norm_ -from typing import cast, TYPE_CHECKING +from torch.optim.lr_scheduler import LambdaLR +from typing import Optional, cast, TYPE_CHECKING +from khaosz.trainer.strategy import ScheduleConfig, SchedulerFactory if TYPE_CHECKING: from khaosz.trainer.trainer import Trainer @@ -129,4 +131,35 @@ class GradientClippingCallback(TrainerCallback): clip_grad_norm_( trainer.checkpoint.model.parameters(), trainer.train_config.max_grad_norm - ) \ No newline at end of file + ) + + +class SchedulerCallback(TrainerCallback): + """ + Scheduler callback for trainer. + """ + def __init__(self, schedule_config: ScheduleConfig): + self.schedule_config = schedule_config + self.scheduler: Optional[LambdaLR] = None + self.current_iter = 0 + + def on_train_begin(self, trainer: 'Trainer', **kwargs): + checkpoint = cast(Checkpoint, kwargs.get('checkpoint')) + self.current_iter = len(checkpoint.loss_list) + + lambda_scheduler_fn = SchedulerFactory.load_schedule_fn( + **self.schedule_config.get_kwargs() + ) + + self.scheduler = LambdaLR( + trainer.train_config.optimizer, + lambda_scheduler_fn, + last_epoch=self.current_iter - 1 + ) + + def on_step_end(self, trainer: 'Trainer', **kwargs): + _ = trainer, kwargs + + if self.scheduler: + self.scheduler.step() + self.current_iter += 1 diff --git a/khaosz/trainer/trainer.py b/khaosz/trainer/trainer.py index 973a097..a2a26b6 100644 --- a/khaosz/trainer/trainer.py +++ b/khaosz/trainer/trainer.py @@ -1,12 +1,17 @@ import os import torch from typing import Optional, List -from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader, RandomSampler from khaosz.core import ModelParameter, Checkpoint -from khaosz.trainer.strategy import SchedulerFactory, TrainConfig, ScheduleConfig -from khaosz.trainer.callback import TrainerCallback, ProgressBarCallback, CheckpointCallback, GradientClippingCallback +from khaosz.trainer.strategy import TrainConfig, ScheduleConfig +from khaosz.trainer.callback import ( + TrainerCallback, + ProgressBarCallback, + CheckpointCallback, + GradientClippingCallback, + SchedulerCallback +) class Trainer: @@ -32,6 +37,7 @@ class Trainer: ProgressBarCallback(), CheckpointCallback(self.train_config.checkpoint_interval), GradientClippingCallback(), + SchedulerCallback(self.schedule_config), ] def _create_dataloader(self) -> DataLoader: @@ -72,16 +78,6 @@ class Trainer: for group in self.train_config.optimizer.param_groups: if "initial_lr" not in group: group["initial_lr"] = group["lr"] - - lambda_scheduler_fn = SchedulerFactory.load_schedule_fn( - **self.schedule_config.get_kwargs() - ) - - scheduler = LambdaLR( - self.train_config.optimizer, - lambda_scheduler_fn, - last_epoch=current_iter - 1 if train_checkpoint else -1 - ) reamining_steps = self.train_config.n_epoch - current_iter total_steps = len(self.train_config.dataset) // self.train_config.batch_size @@ -114,7 +110,6 @@ class Trainer: self._call_callbacks('on_step_end', current_iter=current_iter) current_iter += 1 - scheduler.step() self._call_callbacks('on_epoch_end', epoch=epoch, loss_list=self.checkpoint.loss_list)