fix(trainer): 修复随机采样器迭代重置问题
This commit is contained in:
parent
e43a5b9b66
commit
83c08cfbb9
|
|
@ -295,7 +295,8 @@ class RandomSampler(Sampler[int]):
|
|||
if self._indices is None:
|
||||
self._generate_indices()
|
||||
|
||||
for i in range(self.current_iter, n):
|
||||
start = self.current_iter % n
|
||||
for i in range(start, n):
|
||||
yield self._indices[i]
|
||||
self.current_iter += 1
|
||||
|
||||
|
|
@ -303,8 +304,7 @@ class RandomSampler(Sampler[int]):
|
|||
self._indices = None
|
||||
|
||||
def __len__(self):
|
||||
n = len(self.data_source)
|
||||
return n - self.current_iter % n
|
||||
return len(self.data_source)
|
||||
|
||||
def state_dict(self):
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -37,8 +37,6 @@ class Trainer:
|
|||
]
|
||||
|
||||
def _set_train_kwargs(self, kwargs: dict):
|
||||
used_epochs = 0
|
||||
used_iters = 0
|
||||
seed = self.train_config.random_seed
|
||||
sampler = RandomSampler(data_source=self.train_config.dataset, seed=seed)
|
||||
optim = self.train_config.optimizer
|
||||
|
|
@ -59,8 +57,6 @@ class Trainer:
|
|||
|
||||
if sampler_state:
|
||||
sampler.load_state_dict(sampler_state)
|
||||
used_epochs = sampler_state.get('epoch', 0)
|
||||
used_iters = sampler_state.get('iter', 0)
|
||||
|
||||
if optim_state:
|
||||
optim.load_state_dict(optim_state)
|
||||
|
|
@ -76,8 +72,8 @@ class Trainer:
|
|||
|
||||
kwargs["dataloader"] = dataloader
|
||||
kwargs["optimizer"] = self.train_config.optimizer
|
||||
kwargs["epoch"] = used_epochs
|
||||
kwargs["current_iter"] = used_iters
|
||||
kwargs["epoch"] = sampler.epoch
|
||||
kwargs["current_iter"] = sampler.current_iter
|
||||
kwargs["sampler"] = sampler
|
||||
kwargs["checkpoint"] = checkpoint
|
||||
|
||||
|
|
|
|||
|
|
@ -103,9 +103,6 @@ class CheckpointCallback(TrainerCallback):
|
|||
checkpoint.sampler_state = random_sampler.state_dict()
|
||||
checkpoint.optim_state = optimizer.state_dict()
|
||||
|
||||
checkpoint.sampler_state['epoch'] = kwargs.get('epoch', 0)
|
||||
checkpoint.sampler_state['current_iter'] = kwargs.get('current_iter', 0)
|
||||
|
||||
checkpoint.save(save_path)
|
||||
|
||||
def on_batch_end(self, trainer: 'Trainer', **kwargs):
|
||||
|
|
|
|||
Loading…
Reference in New Issue