fix(trainer): 修复随机采样器迭代重置问题

This commit is contained in:
ViperEkura 2025-10-02 14:22:38 +08:00
parent e43a5b9b66
commit 83c08cfbb9
3 changed files with 5 additions and 12 deletions

View File

@ -295,7 +295,8 @@ class RandomSampler(Sampler[int]):
if self._indices is None: if self._indices is None:
self._generate_indices() self._generate_indices()
for i in range(self.current_iter, n): start = self.current_iter % n
for i in range(start, n):
yield self._indices[i] yield self._indices[i]
self.current_iter += 1 self.current_iter += 1
@ -303,8 +304,7 @@ class RandomSampler(Sampler[int]):
self._indices = None self._indices = None
def __len__(self): def __len__(self):
n = len(self.data_source) return len(self.data_source)
return n - self.current_iter % n
def state_dict(self): def state_dict(self):
return { return {

View File

@ -37,8 +37,6 @@ class Trainer:
] ]
def _set_train_kwargs(self, kwargs: dict): def _set_train_kwargs(self, kwargs: dict):
used_epochs = 0
used_iters = 0
seed = self.train_config.random_seed seed = self.train_config.random_seed
sampler = RandomSampler(data_source=self.train_config.dataset, seed=seed) sampler = RandomSampler(data_source=self.train_config.dataset, seed=seed)
optim = self.train_config.optimizer optim = self.train_config.optimizer
@ -59,8 +57,6 @@ class Trainer:
if sampler_state: if sampler_state:
sampler.load_state_dict(sampler_state) sampler.load_state_dict(sampler_state)
used_epochs = sampler_state.get('epoch', 0)
used_iters = sampler_state.get('iter', 0)
if optim_state: if optim_state:
optim.load_state_dict(optim_state) optim.load_state_dict(optim_state)
@ -76,8 +72,8 @@ class Trainer:
kwargs["dataloader"] = dataloader kwargs["dataloader"] = dataloader
kwargs["optimizer"] = self.train_config.optimizer kwargs["optimizer"] = self.train_config.optimizer
kwargs["epoch"] = used_epochs kwargs["epoch"] = sampler.epoch
kwargs["current_iter"] = used_iters kwargs["current_iter"] = sampler.current_iter
kwargs["sampler"] = sampler kwargs["sampler"] = sampler
kwargs["checkpoint"] = checkpoint kwargs["checkpoint"] = checkpoint

View File

@ -103,9 +103,6 @@ class CheckpointCallback(TrainerCallback):
checkpoint.sampler_state = random_sampler.state_dict() checkpoint.sampler_state = random_sampler.state_dict()
checkpoint.optim_state = optimizer.state_dict() checkpoint.optim_state = optimizer.state_dict()
checkpoint.sampler_state['epoch'] = kwargs.get('epoch', 0)
checkpoint.sampler_state['current_iter'] = kwargs.get('current_iter', 0)
checkpoint.save(save_path) checkpoint.save(save_path)
def on_batch_end(self, trainer: 'Trainer', **kwargs): def on_batch_end(self, trainer: 'Trainer', **kwargs):