fix(data/sampler): 增加sampler边界情况处理

This commit is contained in:
ViperEkura 2025-11-27 19:32:40 +08:00
parent 09963a3beb
commit 36b410384b
1 changed files with 14 additions and 3 deletions

View File

@ -12,6 +12,7 @@ class ResumeableRandomSampler(Sampler[int]):
start_epoch: int=0, start_epoch: int=0,
start_iter: int=0, start_iter: int=0,
seed: int=42, seed: int=42,
drop_last: bool=False,
process_group: Optional[dist.ProcessGroup]=None, process_group: Optional[dist.ProcessGroup]=None,
): ):
self.epoch = start_epoch self.epoch = start_epoch
@ -34,7 +35,12 @@ class ResumeableRandomSampler(Sampler[int]):
# single process # single process
self.rank = 0 self.rank = 0
self.num_replicas = 1 self.num_replicas = 1
self.drop_last = drop_last
offset = 0 if drop_last else self.num_replicas - 1
self.num_samples_per_replica = (self.num_samples + offset) // self.num_replicas
self.total_size = self.num_samples_per_replica * self.num_replicas
self._indices = None self._indices = None
def _get_indices(self): def _get_indices(self):
@ -42,8 +48,13 @@ class ResumeableRandomSampler(Sampler[int]):
generator.manual_seed(self.seed + self.epoch) generator.manual_seed(self.seed + self.epoch)
indices = torch.randperm(self.num_samples, generator=generator).tolist() indices = torch.randperm(self.num_samples, generator=generator).tolist()
self.iter = self.iter % self.num_samples if not self.drop_last and self.num_samples < self.total_size:
local_indices = indices[self.rank: self.num_samples: self.num_replicas] padding_size = self.total_size - len(indices)
indices += indices[:padding_size]
local_indices = indices[self.rank:self.total_size:self.num_replicas]
self.iter = self.iter % self.num_samples_per_replica
self._indices = local_indices[self.iter:] self._indices = local_indices[self.iter:]
def __iter__(self): def __iter__(self):