diff --git a/khaosz/trainer/data_util.py b/khaosz/trainer/data_util.py index bd9449d..9d6eb34 100644 --- a/khaosz/trainer/data_util.py +++ b/khaosz/trainer/data_util.py @@ -295,7 +295,8 @@ class RandomSampler(Sampler[int]): if self._indices is None: self._generate_indices() - for i in range(self.current_iter, n): + start = self.current_iter % n + for i in range(start, n): yield self._indices[i] self.current_iter += 1 @@ -303,8 +304,7 @@ class RandomSampler(Sampler[int]): self._indices = None def __len__(self): - n = len(self.data_source) - return n - self.current_iter % n + return len(self.data_source) def state_dict(self): return { diff --git a/khaosz/trainer/trainer.py b/khaosz/trainer/trainer.py index 2350195..4e182d0 100644 --- a/khaosz/trainer/trainer.py +++ b/khaosz/trainer/trainer.py @@ -37,8 +37,6 @@ class Trainer: ] def _set_train_kwargs(self, kwargs: dict): - used_epochs = 0 - used_iters = 0 seed = self.train_config.random_seed sampler = RandomSampler(data_source=self.train_config.dataset, seed=seed) optim = self.train_config.optimizer @@ -59,8 +57,6 @@ class Trainer: if sampler_state: sampler.load_state_dict(sampler_state) - used_epochs = sampler_state.get('epoch', 0) - used_iters = sampler_state.get('iter', 0) if optim_state: optim.load_state_dict(optim_state) @@ -76,8 +72,8 @@ class Trainer: kwargs["dataloader"] = dataloader kwargs["optimizer"] = self.train_config.optimizer - kwargs["epoch"] = used_epochs - kwargs["current_iter"] = used_iters + kwargs["epoch"] = sampler.epoch + kwargs["current_iter"] = sampler.current_iter kwargs["sampler"] = sampler kwargs["checkpoint"] = checkpoint diff --git a/khaosz/trainer/trainer_callback.py b/khaosz/trainer/trainer_callback.py index 59b2823..de898fc 100644 --- a/khaosz/trainer/trainer_callback.py +++ b/khaosz/trainer/trainer_callback.py @@ -103,9 +103,6 @@ class CheckpointCallback(TrainerCallback): checkpoint.sampler_state = random_sampler.state_dict() checkpoint.optim_state = optimizer.state_dict() - checkpoint.sampler_state['epoch'] = kwargs.get('epoch', 0) - checkpoint.sampler_state['current_iter'] = kwargs.get('current_iter', 0) - checkpoint.save(save_path) def on_batch_end(self, trainer: 'Trainer', **kwargs):