From 83c08cfbb9890544a18c3d1bc7cce4849a4a6e1f Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Thu, 2 Oct 2025 14:22:38 +0800 Subject: [PATCH] =?UTF-8?q?fix(trainer):=20=E4=BF=AE=E5=A4=8D=E9=9A=8F?= =?UTF-8?q?=E6=9C=BA=E9=87=87=E6=A0=B7=E5=99=A8=E8=BF=AD=E4=BB=A3=E9=87=8D?= =?UTF-8?q?=E7=BD=AE=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/trainer/data_util.py | 6 +++--- khaosz/trainer/trainer.py | 8 ++------ khaosz/trainer/trainer_callback.py | 3 --- 3 files changed, 5 insertions(+), 12 deletions(-) diff --git a/khaosz/trainer/data_util.py b/khaosz/trainer/data_util.py index bd9449d..9d6eb34 100644 --- a/khaosz/trainer/data_util.py +++ b/khaosz/trainer/data_util.py @@ -295,7 +295,8 @@ class RandomSampler(Sampler[int]): if self._indices is None: self._generate_indices() - for i in range(self.current_iter, n): + start = self.current_iter % n + for i in range(start, n): yield self._indices[i] self.current_iter += 1 @@ -303,8 +304,7 @@ class RandomSampler(Sampler[int]): self._indices = None def __len__(self): - n = len(self.data_source) - return n - self.current_iter % n + return len(self.data_source) def state_dict(self): return { diff --git a/khaosz/trainer/trainer.py b/khaosz/trainer/trainer.py index 2350195..4e182d0 100644 --- a/khaosz/trainer/trainer.py +++ b/khaosz/trainer/trainer.py @@ -37,8 +37,6 @@ class Trainer: ] def _set_train_kwargs(self, kwargs: dict): - used_epochs = 0 - used_iters = 0 seed = self.train_config.random_seed sampler = RandomSampler(data_source=self.train_config.dataset, seed=seed) optim = self.train_config.optimizer @@ -59,8 +57,6 @@ class Trainer: if sampler_state: sampler.load_state_dict(sampler_state) - used_epochs = sampler_state.get('epoch', 0) - used_iters = sampler_state.get('iter', 0) if optim_state: optim.load_state_dict(optim_state) @@ -76,8 +72,8 @@ class Trainer: kwargs["dataloader"] = dataloader kwargs["optimizer"] = self.train_config.optimizer - kwargs["epoch"] = used_epochs - kwargs["current_iter"] = used_iters + kwargs["epoch"] = sampler.epoch + kwargs["current_iter"] = sampler.current_iter kwargs["sampler"] = sampler kwargs["checkpoint"] = checkpoint diff --git a/khaosz/trainer/trainer_callback.py b/khaosz/trainer/trainer_callback.py index 59b2823..de898fc 100644 --- a/khaosz/trainer/trainer_callback.py +++ b/khaosz/trainer/trainer_callback.py @@ -103,9 +103,6 @@ class CheckpointCallback(TrainerCallback): checkpoint.sampler_state = random_sampler.state_dict() checkpoint.optim_state = optimizer.state_dict() - checkpoint.sampler_state['epoch'] = kwargs.get('epoch', 0) - checkpoint.sampler_state['current_iter'] = kwargs.get('current_iter', 0) - checkpoint.save(save_path) def on_batch_end(self, trainer: 'Trainer', **kwargs):