From c1bf22b6ece9afb063adc8a9c7444c6cbaa80dfd Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Mon, 6 Oct 2025 20:12:08 +0800 Subject: [PATCH] =?UTF-8?q?refactor(khaosz/trainer):=20=E4=BD=BF=E7=94=A8?= =?UTF-8?q?=20TrainContext=20=E6=9B=BF=E4=BB=A3=20kwargs=20=E4=BC=A0?= =?UTF-8?q?=E9=80=92=E8=AE=AD=E7=BB=83=E4=B8=8A=E4=B8=8B=E6=96=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/trainer/train_callback.py | 120 ++++++++++++------------------- khaosz/trainer/train_context.py | 8 +-- khaosz/trainer/trainer.py | 3 +- tests/test_callbacks.py | 6 +- 4 files changed, 55 insertions(+), 82 deletions(-) diff --git a/khaosz/trainer/train_callback.py b/khaosz/trainer/train_callback.py index bf8d40f..7f351ad 100644 --- a/khaosz/trainer/train_callback.py +++ b/khaosz/trainer/train_callback.py @@ -1,16 +1,13 @@ import os -import torch.optim as optim - from tqdm import tqdm from torch.nn.utils import clip_grad_norm_ from torch.optim.lr_scheduler import LambdaLR -from typing import Optional, cast, TYPE_CHECKING -from khaosz.core.parameter import Checkpoint -from khaosz.trainer.data_util import RandomSampler +from typing import Optional, TYPE_CHECKING from khaosz.trainer.strategy import ScheduleConfig, SchedulerFactory if TYPE_CHECKING: from khaosz.trainer.trainer import Trainer + from khaosz.trainer.train_context import TrainContext class TrainCallback: @@ -19,37 +16,37 @@ class TrainCallback: and we use '_' to ignore unused parameters. """ - def on_train_begin(self, trainer: 'Trainer', **kwargs): + def on_train_begin(self, trainer: 'Trainer', context: 'TrainContext'): """ Called at the beginning of training. """ - _ = trainer, kwargs + _ = trainer, context - def on_train_end(self, trainer: 'Trainer', **kwargs): + def on_train_begin(self, trainer: 'Trainer', context: 'TrainContext'): """ Called at the end of training. """ - _ = trainer, kwargs + _ = trainer, context - def on_epoch_begin(self, trainer: 'Trainer', **kwargs): + def on_train_begin(self, trainer: 'Trainer', context: 'TrainContext'): """ Called at the beginning of each epoch. """ - _ = trainer, kwargs + _ = trainer, context - def on_epoch_end(self, trainer: 'Trainer', **kwargs): + def on_train_begin(self, trainer: 'Trainer', context: 'TrainContext'): """ Called at the end of each epoch. """ - _ = trainer, kwargs + _ = trainer, context - def on_batch_begin(self, trainer: 'Trainer', **kwargs): + def on_train_begin(self, trainer: 'Trainer', context: 'TrainContext'): """ Called at the beginning of each batch. """ - _ = trainer, kwargs + _ = trainer, context - def on_batch_end(self, trainer: 'Trainer', **kwargs): + def on_train_begin(self, trainer: 'Trainer', context: 'TrainContext'): """ Called at the end of each batch. """ - _ = trainer, kwargs + _ = trainer, context - def on_step_begin(self, trainer: 'Trainer', **kwargs): + def on_train_begin(self, trainer: 'Trainer', context: 'TrainContext'): """ Called at the beginning of each step. """ - _ = trainer, kwargs + _ = trainer, context - def on_step_end(self, trainer: 'Trainer', **kwargs): + def on_train_begin(self, trainer: 'Trainer', context: 'TrainContext'): """ Called at the end of each step.""" - _ = trainer, kwargs + _ = trainer, context class ProgressBarCallback(TrainCallback): @@ -59,27 +56,23 @@ class ProgressBarCallback(TrainCallback): def __init__(self): self.progress_bar: tqdm = None - def on_epoch_begin(self, trainer: 'Trainer', **kwargs): - epoch = kwargs.get('epoch') - dataloader = kwargs.get('dataloader') + def on_epoch_begin(self, trainer: 'Trainer', context: 'TrainContext'): self.progress_bar = tqdm( - dataloader, - desc=f"Epoch {epoch+1}/{trainer.train_config.n_epoch}", + context.dataloader, + desc=f"Epoch {context.epoch+1}/{trainer.train_config.n_epoch}", dynamic_ncols=True ) - def on_batch_end(self, trainer: 'Trainer', **kwargs): + def on_batch_end(self, trainer: 'Trainer', context: 'TrainContext'): _ = trainer - loss = kwargs.get('loss') - optimizer = cast(optim.Optimizer, kwargs.get('optimizer')) self.progress_bar.set_postfix({ - "loss": f"{loss:.4f}", - "lr": f"{optimizer.param_groups[-1]['lr']:.2e}" + "loss": f"{context.loss:.4f}", + "lr": f"{context.optimizer.param_groups[-1]['lr']:.2e}" }) self.progress_bar.update(1) - def on_epoch_end(self, trainer: 'Trainer', **kwargs): - _ = trainer, kwargs + def on_epoch_end(self, trainer: 'Trainer', context: 'TrainContext'): + _ = trainer, context if self.progress_bar: self.progress_bar.close() @@ -92,46 +85,31 @@ class CheckpointCallback(TrainCallback): self.checkpoint_interval = checkpoint_interval self.last_ckpt_iter = 0 - @staticmethod - def _save_checkpoint(trainer: 'Trainer', **kwargs): - current_iter = kwargs.get('current_iter') - random_sampler = cast(RandomSampler, kwargs.get('sampler')) - optimizer = cast(optim.Optimizer, kwargs.get('optimizer')) - checkpoint = cast(Checkpoint, kwargs.get('checkpoint')) - - save_path = os.path.join(trainer.train_config.checkpoint_dir, f"iter_{current_iter}") - checkpoint.sampler_state = random_sampler.state_dict() - checkpoint.optim_state = optimizer.state_dict() - - checkpoint.save(save_path) + def _save_checkpoint(self, trainer: 'Trainer', context: 'TrainContext'): + save_path = os.path.join(trainer.train_config.checkpoint_dir, f"iter_{context.current_iter}") + context.checkpoint.sampler_state = context.sampler.state_dict() + context.checkpoint.optimizer_state = context.optimizer.state_dict() + context.checkpoint.save(save_path) + self.last_ckpt_iter = context.current_iter - def on_batch_end(self, trainer: 'Trainer', **kwargs): - current_iter = kwargs.get('current_iter') - checkpoint = cast(Checkpoint, kwargs.get('checkpoint')) - loss = kwargs.get('loss') - checkpoint.loss_list.append(loss) + def on_batch_end(self, trainer: 'Trainer', context: 'TrainContext'): + context.checkpoint.loss_list.append(context.loss) - if current_iter - self.last_ckpt_iter >= self.checkpoint_interval: - CheckpointCallback._save_checkpoint(trainer, **kwargs) - self.last_ckpt_iter = current_iter - - def on_train_end(self, trainer: 'Trainer', **kwargs): - current_iter = kwargs.get('current_iter') - if current_iter != self.last_ckpt_iter: - CheckpointCallback._save_checkpoint(trainer, **kwargs) - self.last_ckpt_iter = current_iter + if context.current_iter - self.last_ckpt_iter >= self.checkpoint_interval: + self._save_checkpoint(trainer, context) + + def on_train_end(self, trainer: 'Trainer', context: 'TrainContext'): + if context.current_iter != self.last_ckpt_iter: + self._save_checkpoint(trainer, context) class GradientClippingCallback(TrainCallback): """ Gradient clipping callback for trainer. """ - def on_step_begin(self, trainer: 'Trainer', **kwargs): - _ = kwargs - clip_grad_norm_( - trainer.parameter.model.parameters(), - trainer.train_config.max_grad_norm - ) + def on_step_begin(self, trainer: 'Trainer', context: 'TrainContext'): + _ = context + clip_grad_norm_(trainer.parameter.model.parameters(), trainer.train_config.max_grad_norm) class SchedulerCallback(TrainCallback): @@ -141,10 +119,8 @@ class SchedulerCallback(TrainCallback): def __init__(self, schedule_config: ScheduleConfig): self.schedule_config = schedule_config self.scheduler: Optional[LambdaLR] = None - self.current_iter = 0 - def on_train_begin(self, trainer: 'Trainer', **kwargs): - self.current_iter = kwargs.get('current_iter') + def on_train_begin(self, trainer: 'Trainer', context: 'TrainContext'): for group in trainer.train_config.optimizer.param_groups: if "initial_lr" not in group: @@ -158,12 +134,10 @@ class SchedulerCallback(TrainCallback): self.scheduler = LambdaLR( trainer.train_config.optimizer, lambda_scheduler_fn, - last_epoch=self.current_iter - 1 + last_epoch=context.current_iter - 1 ) - def on_batch_end(self, trainer: 'Trainer', **kwargs): - _ = trainer, kwargs - + def on_batch_end(self, trainer: 'Trainer', context: 'TrainContext'): + _ = trainer, context if self.scheduler: self.scheduler.step() - self.current_iter += 1 diff --git a/khaosz/trainer/train_context.py b/khaosz/trainer/train_context.py index 1324ccb..50145e4 100644 --- a/khaosz/trainer/train_context.py +++ b/khaosz/trainer/train_context.py @@ -44,7 +44,7 @@ class TrainContextBuilder: tokenizer=self.trainer.parameter.tokenizer, config=self.trainer.parameter.config, sampler_state=None, - optim_state=None, + optimizer_state=None, loss_list=[] ) self._context.checkpoint = checkpoint @@ -72,13 +72,13 @@ class TrainContextBuilder: def with_optimizer(self) -> Self: optimizer = self.trainer.train_config.optimizer - if self._context.checkpoint and self._context.checkpoint.optim_state: - optimizer.load_state_dict(self._context.checkpoint.optim_state) + if self._context.checkpoint and self._context.checkpoint.optimizer_state: + optimizer.load_state_dict(self._context.checkpoint.optimizer_state) self._context.optimizer = optimizer if self._context.checkpoint: - self._context.checkpoint.optim_state = optimizer.state_dict() + self._context.checkpoint.optimizer_state = optimizer.state_dict() return self diff --git a/khaosz/trainer/trainer.py b/khaosz/trainer/trainer.py index 0228459..9d88704 100644 --- a/khaosz/trainer/trainer.py +++ b/khaosz/trainer/trainer.py @@ -45,11 +45,10 @@ class Trainer: .build()) def _call_callbacks(self, method_name: str, context: TrainContext): - kwargs = context.asdict() for callback in self.callbacks: method = getattr(callback, method_name, None) if method: - method(self, **kwargs) + method(self, context) def train(self, checkpoint: Optional[Checkpoint] = None) -> Checkpoint: context = self._build_train_context(checkpoint) diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index d48d548..bb4e62b 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -28,13 +28,13 @@ def test_callback_integration(base_test_env, random_dataset): callback_calls = [] class TrackingCallback(TrainCallback): - def on_train_begin(self, trainer, **kwargs): + def on_train_begin(self, trainer, context): callback_calls.append('on_train_begin') - def on_batch_end(self, trainer, **kwargs): + def on_batch_end(self, trainer, context): callback_calls.append('on_batch_end') - def on_epoch_end(self, trainer, **kwargs): + def on_epoch_end(self, trainer, context): callback_calls.append('on_epoch_end') train_config.strategy = StrategyFactory.load(base_test_env["model"], "seq", base_test_env["device"])