fix(trainer): 修复随机采样器迭代重置问题
This commit is contained in:
parent
e43a5b9b66
commit
83c08cfbb9
|
|
@ -295,7 +295,8 @@ class RandomSampler(Sampler[int]):
|
||||||
if self._indices is None:
|
if self._indices is None:
|
||||||
self._generate_indices()
|
self._generate_indices()
|
||||||
|
|
||||||
for i in range(self.current_iter, n):
|
start = self.current_iter % n
|
||||||
|
for i in range(start, n):
|
||||||
yield self._indices[i]
|
yield self._indices[i]
|
||||||
self.current_iter += 1
|
self.current_iter += 1
|
||||||
|
|
||||||
|
|
@ -303,8 +304,7 @@ class RandomSampler(Sampler[int]):
|
||||||
self._indices = None
|
self._indices = None
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
n = len(self.data_source)
|
return len(self.data_source)
|
||||||
return n - self.current_iter % n
|
|
||||||
|
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
return {
|
return {
|
||||||
|
|
|
||||||
|
|
@ -37,8 +37,6 @@ class Trainer:
|
||||||
]
|
]
|
||||||
|
|
||||||
def _set_train_kwargs(self, kwargs: dict):
|
def _set_train_kwargs(self, kwargs: dict):
|
||||||
used_epochs = 0
|
|
||||||
used_iters = 0
|
|
||||||
seed = self.train_config.random_seed
|
seed = self.train_config.random_seed
|
||||||
sampler = RandomSampler(data_source=self.train_config.dataset, seed=seed)
|
sampler = RandomSampler(data_source=self.train_config.dataset, seed=seed)
|
||||||
optim = self.train_config.optimizer
|
optim = self.train_config.optimizer
|
||||||
|
|
@ -59,8 +57,6 @@ class Trainer:
|
||||||
|
|
||||||
if sampler_state:
|
if sampler_state:
|
||||||
sampler.load_state_dict(sampler_state)
|
sampler.load_state_dict(sampler_state)
|
||||||
used_epochs = sampler_state.get('epoch', 0)
|
|
||||||
used_iters = sampler_state.get('iter', 0)
|
|
||||||
|
|
||||||
if optim_state:
|
if optim_state:
|
||||||
optim.load_state_dict(optim_state)
|
optim.load_state_dict(optim_state)
|
||||||
|
|
@ -76,8 +72,8 @@ class Trainer:
|
||||||
|
|
||||||
kwargs["dataloader"] = dataloader
|
kwargs["dataloader"] = dataloader
|
||||||
kwargs["optimizer"] = self.train_config.optimizer
|
kwargs["optimizer"] = self.train_config.optimizer
|
||||||
kwargs["epoch"] = used_epochs
|
kwargs["epoch"] = sampler.epoch
|
||||||
kwargs["current_iter"] = used_iters
|
kwargs["current_iter"] = sampler.current_iter
|
||||||
kwargs["sampler"] = sampler
|
kwargs["sampler"] = sampler
|
||||||
kwargs["checkpoint"] = checkpoint
|
kwargs["checkpoint"] = checkpoint
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -103,9 +103,6 @@ class CheckpointCallback(TrainerCallback):
|
||||||
checkpoint.sampler_state = random_sampler.state_dict()
|
checkpoint.sampler_state = random_sampler.state_dict()
|
||||||
checkpoint.optim_state = optimizer.state_dict()
|
checkpoint.optim_state = optimizer.state_dict()
|
||||||
|
|
||||||
checkpoint.sampler_state['epoch'] = kwargs.get('epoch', 0)
|
|
||||||
checkpoint.sampler_state['current_iter'] = kwargs.get('current_iter', 0)
|
|
||||||
|
|
||||||
checkpoint.save(save_path)
|
checkpoint.save(save_path)
|
||||||
|
|
||||||
def on_batch_end(self, trainer: 'Trainer', **kwargs):
|
def on_batch_end(self, trainer: 'Trainer', **kwargs):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue