From 019bfe4e0510471dd55aad971dcfa057bb915425 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Thu, 27 Nov 2025 19:43:36 +0800 Subject: [PATCH] =?UTF-8?q?fix(data/sampler):=20=E4=BF=AE=E6=AD=A3?= =?UTF-8?q?=E6=8B=BC=E5=86=99=E9=94=99=E8=AF=AF=E5=B9=B6=E5=A2=9E=E5=BC=BA?= =?UTF-8?q?=E9=87=87=E6=A0=B7=E5=99=A8=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/data/__init__.py | 4 ++-- khaosz/data/sampler.py | 14 ++++++++++---- khaosz/trainer/train_context.py | 4 ++-- tests/test_sampler.py | 10 +++++----- 4 files changed, 19 insertions(+), 13 deletions(-) diff --git a/khaosz/data/__init__.py b/khaosz/data/__init__.py index cdb4362..8882a1b 100644 --- a/khaosz/data/__init__.py +++ b/khaosz/data/__init__.py @@ -10,7 +10,7 @@ from khaosz.data.dataset import ( ) from khaosz.data.tokenizer import BpeTokenizer -from khaosz.data.sampler import ResumeableRandomSampler +from khaosz.data.sampler import ResumableDistributedSampler __all__ = [ "BaseDataset", @@ -22,5 +22,5 @@ __all__ = [ "DatasetLoader", "load_pkl_files", "BpeTokenizer", - "ResumeableRandomSampler" + "ResumableDistributedSampler" ] \ No newline at end of file diff --git a/khaosz/data/sampler.py b/khaosz/data/sampler.py index 0769f66..4a7ddbc 100644 --- a/khaosz/data/sampler.py +++ b/khaosz/data/sampler.py @@ -5,7 +5,7 @@ from torch.utils.data import Dataset, Sampler from typing import Optional -class ResumeableRandomSampler(Sampler[int]): +class ResumableDistributedSampler(Sampler[int]): def __init__( self, data_source: Dataset, @@ -13,6 +13,7 @@ class ResumeableRandomSampler(Sampler[int]): start_iter: int=0, seed: int=42, drop_last: bool=False, + shuffle: bool=True, process_group: Optional[dist.ProcessGroup]=None, ): self.epoch = start_epoch @@ -37,6 +38,8 @@ class ResumeableRandomSampler(Sampler[int]): self.num_replicas = 1 self.drop_last = drop_last + self.shuffle = shuffle + offset = 0 if drop_last else self.num_replicas - 1 self.num_samples_per_replica = (self.num_samples + offset) // self.num_replicas self.total_size = self.num_samples_per_replica * self.num_replicas @@ -44,9 +47,12 @@ class ResumeableRandomSampler(Sampler[int]): self._indices = None def _get_indices(self): - generator = torch.Generator() - generator.manual_seed(self.seed + self.epoch) - indices = torch.randperm(self.num_samples, generator=generator).tolist() + if self.shuffle: + generator = torch.Generator() + generator.manual_seed(self.seed + self.epoch) + indices = torch.randperm(self.num_samples, generator=generator).tolist() + else: + indices = torch.arange(self.num_samples).tolist() if not self.drop_last and self.num_samples < self.total_size: padding_size = self.total_size - len(indices) diff --git a/khaosz/trainer/train_context.py b/khaosz/trainer/train_context.py index f0009b7..77a8199 100644 --- a/khaosz/trainer/train_context.py +++ b/khaosz/trainer/train_context.py @@ -3,7 +3,7 @@ from typing import Optional, Self, TYPE_CHECKING from torch.optim import Optimizer from torch.utils.data import DataLoader from khaosz.config import Checkpoint -from khaosz.data import ResumeableRandomSampler +from khaosz.data import ResumableDistributedSampler from khaosz.trainer.schedule import BaseScheduler, SchedulerFactory if TYPE_CHECKING: @@ -83,7 +83,7 @@ class TrainContextBuilder: def with_dataloader(self) -> Self: # fix: change batch level batch_iter to sample level offset sampler_offset = self._context.batch_iter * self.trainer.train_config.batch_size - resumeable_sampler = ResumeableRandomSampler( + resumeable_sampler = ResumableDistributedSampler( data_source=self.trainer.train_config.dataset, start_epoch=self._context.epoch, start_iter=sampler_offset, diff --git a/tests/test_sampler.py b/tests/test_sampler.py index c6ce6cf..3876ebd 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -6,8 +6,8 @@ def test_random_sampler_consistency(random_dataset): dataset = random_dataset # Create two samplers with same seed - sampler1 = ResumeableRandomSampler(dataset, seed=42) - sampler2 = ResumeableRandomSampler(dataset, seed=42) + sampler1 = ResumableDistributedSampler(dataset, seed=42) + sampler2 = ResumableDistributedSampler(dataset, seed=42) indices1 = list(iter(sampler1)) indices2 = list(iter(sampler2)) @@ -19,8 +19,8 @@ def test_random_sampler_different_seeds(random_dataset): dataset = random_dataset # Create two samplers with different seeds - sampler1 = ResumeableRandomSampler(dataset, seed=42) - sampler2 = ResumeableRandomSampler(dataset, seed=123) + sampler1 = ResumableDistributedSampler(dataset, seed=42) + sampler2 = ResumableDistributedSampler(dataset, seed=123) indices1 = list(iter(sampler1)) indices2 = list(iter(sampler2)) @@ -34,7 +34,7 @@ def test_sampler_across_epochs(random_dataset): dataset = random_dataset n = len(dataset) - sampler = ResumeableRandomSampler(dataset, seed=42) + sampler = ResumableDistributedSampler(dataset, seed=42) # Get indices for first epoch epoch1_indices = list(iter(sampler))