From 36b410384b487d3696330ddef866e4d71d39f5d2 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Thu, 27 Nov 2025 19:32:40 +0800 Subject: [PATCH] =?UTF-8?q?fix(data/sampler):=20=E5=A2=9E=E5=8A=A0sampler?= =?UTF-8?q?=E8=BE=B9=E7=95=8C=E6=83=85=E5=86=B5=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/data/sampler.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/khaosz/data/sampler.py b/khaosz/data/sampler.py index f628885..0769f66 100644 --- a/khaosz/data/sampler.py +++ b/khaosz/data/sampler.py @@ -12,6 +12,7 @@ class ResumeableRandomSampler(Sampler[int]): start_epoch: int=0, start_iter: int=0, seed: int=42, + drop_last: bool=False, process_group: Optional[dist.ProcessGroup]=None, ): self.epoch = start_epoch @@ -34,7 +35,12 @@ class ResumeableRandomSampler(Sampler[int]): # single process self.rank = 0 self.num_replicas = 1 - + + self.drop_last = drop_last + offset = 0 if drop_last else self.num_replicas - 1 + self.num_samples_per_replica = (self.num_samples + offset) // self.num_replicas + self.total_size = self.num_samples_per_replica * self.num_replicas + self._indices = None def _get_indices(self): @@ -42,8 +48,13 @@ class ResumeableRandomSampler(Sampler[int]): generator.manual_seed(self.seed + self.epoch) indices = torch.randperm(self.num_samples, generator=generator).tolist() - self.iter = self.iter % self.num_samples - local_indices = indices[self.rank: self.num_samples: self.num_replicas] + if not self.drop_last and self.num_samples < self.total_size: + padding_size = self.total_size - len(indices) + indices += indices[:padding_size] + + local_indices = indices[self.rank:self.total_size:self.num_replicas] + + self.iter = self.iter % self.num_samples_per_replica self._indices = local_indices[self.iter:] def __iter__(self):