From 89211c16f66e134cd69e7828d254e42a7726bbf9 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Mon, 29 Sep 2025 13:38:46 +0800 Subject: [PATCH] =?UTF-8?q?fix(khaosz/trainer):=20=E5=B0=86=E4=BF=9D?= =?UTF-8?q?=E5=AD=98=E6=A3=80=E6=9F=A5=E7=82=B9=E9=80=BB=E8=BE=91=E7=A7=BB?= =?UTF-8?q?=E8=87=B3CheckpointCallback?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/trainer/callback.py | 12 ++++++++++-- khaosz/trainer/trainer.py | 10 +--------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/khaosz/trainer/callback.py b/khaosz/trainer/callback.py index e7cdf25..bb0a3ea 100644 --- a/khaosz/trainer/callback.py +++ b/khaosz/trainer/callback.py @@ -1,3 +1,4 @@ +import os from tqdm import tqdm from khaosz.core.parameter import Checkpoint from torch.nn.utils import clip_grad_norm_ @@ -104,6 +105,13 @@ class CheckpointCallback(TrainerCallback): self.checkpoint_interval = checkpoint_interval self.last_ckpt_iter = 0 + @staticmethod + def _save_checkpoint(trainer: 'Trainer'): + current_iter = len(trainer.checkpoint.loss_list) + save_path = os.path.join(trainer.train_config.checkpoint_dir, f"iter_{current_iter}") + trainer.checkpoint.optim_state = trainer.train_config.optimizer.state_dict() + trainer.checkpoint.save(save_path) + def on_train_begin(self, trainer: 'Trainer', **kwargs): _ = trainer checkpoint = cast(Checkpoint, kwargs.get('checkpoint')) @@ -112,14 +120,14 @@ class CheckpointCallback(TrainerCallback): def on_batch_end(self, trainer: 'Trainer', **kwargs): current_iter = kwargs.get('current_iter') if current_iter - self.last_ckpt_iter >= self.checkpoint_interval: - trainer._save_checkpoint() + CheckpointCallback._save_checkpoint(trainer) self.last_ckpt_iter = current_iter def on_train_end(self, trainer: 'Trainer', **kwargs): checkpoint = cast(Checkpoint, kwargs.get('checkpoint')) current_iter = len(checkpoint.loss_list) if current_iter != self.last_ckpt_iter: - trainer._save_checkpoint() + CheckpointCallback._save_checkpoint(trainer) class GradientClippingCallback(TrainerCallback): diff --git a/khaosz/trainer/trainer.py b/khaosz/trainer/trainer.py index a2a26b6..3d44968 100644 --- a/khaosz/trainer/trainer.py +++ b/khaosz/trainer/trainer.py @@ -1,4 +1,3 @@ -import os import torch from typing import Optional, List from torch.utils.data import DataLoader, RandomSampler @@ -29,7 +28,6 @@ class Trainer: ) self.train_config = train_config self.schedule_config = schedule_config - self.callbacks = callbacks or self._get_default_callbacks() def _get_default_callbacks(self) -> List[TrainerCallback]: @@ -49,19 +47,13 @@ class Trainer: batch_size=self.train_config.batch_size, sampler=sampler ) - - def _save_checkpoint(self): - current_iter = len(self.checkpoint.loss_list) - save_path = os.path.join(self.train_config.checkpoint_dir, f"iter_{current_iter}") - self.checkpoint.optim_state = self.train_config.optimizer.state_dict() - self.checkpoint.save(save_path) def _call_callbacks(self, method_name: str, **kwargs): for callback in self.callbacks: method = getattr(callback, method_name, None) if method: method(self, **kwargs) - + def train( self, train_checkpoint: Optional[Checkpoint] = None