diff --git a/khaosz/data/__init__.py b/khaosz/data/__init__.py index cdb4362..8882a1b 100644 --- a/khaosz/data/__init__.py +++ b/khaosz/data/__init__.py @@ -10,7 +10,7 @@ from khaosz.data.dataset import ( ) from khaosz.data.tokenizer import BpeTokenizer -from khaosz.data.sampler import ResumeableRandomSampler +from khaosz.data.sampler import ResumableDistributedSampler __all__ = [ "BaseDataset", @@ -22,5 +22,5 @@ __all__ = [ "DatasetLoader", "load_pkl_files", "BpeTokenizer", - "ResumeableRandomSampler" + "ResumableDistributedSampler" ] \ No newline at end of file diff --git a/khaosz/data/sampler.py b/khaosz/data/sampler.py index 0769f66..4a7ddbc 100644 --- a/khaosz/data/sampler.py +++ b/khaosz/data/sampler.py @@ -5,7 +5,7 @@ from torch.utils.data import Dataset, Sampler from typing import Optional -class ResumeableRandomSampler(Sampler[int]): +class ResumableDistributedSampler(Sampler[int]): def __init__( self, data_source: Dataset, @@ -13,6 +13,7 @@ class ResumeableRandomSampler(Sampler[int]): start_iter: int=0, seed: int=42, drop_last: bool=False, + shuffle: bool=True, process_group: Optional[dist.ProcessGroup]=None, ): self.epoch = start_epoch @@ -37,6 +38,8 @@ class ResumeableRandomSampler(Sampler[int]): self.num_replicas = 1 self.drop_last = drop_last + self.shuffle = shuffle + offset = 0 if drop_last else self.num_replicas - 1 self.num_samples_per_replica = (self.num_samples + offset) // self.num_replicas self.total_size = self.num_samples_per_replica * self.num_replicas @@ -44,9 +47,12 @@ class ResumeableRandomSampler(Sampler[int]): self._indices = None def _get_indices(self): - generator = torch.Generator() - generator.manual_seed(self.seed + self.epoch) - indices = torch.randperm(self.num_samples, generator=generator).tolist() + if self.shuffle: + generator = torch.Generator() + generator.manual_seed(self.seed + self.epoch) + indices = torch.randperm(self.num_samples, generator=generator).tolist() + else: + indices = torch.arange(self.num_samples).tolist() if not self.drop_last and self.num_samples < self.total_size: padding_size = self.total_size - len(indices) diff --git a/khaosz/trainer/train_context.py b/khaosz/trainer/train_context.py index f0009b7..77a8199 100644 --- a/khaosz/trainer/train_context.py +++ b/khaosz/trainer/train_context.py @@ -3,7 +3,7 @@ from typing import Optional, Self, TYPE_CHECKING from torch.optim import Optimizer from torch.utils.data import DataLoader from khaosz.config import Checkpoint -from khaosz.data import ResumeableRandomSampler +from khaosz.data import ResumableDistributedSampler from khaosz.trainer.schedule import BaseScheduler, SchedulerFactory if TYPE_CHECKING: @@ -83,7 +83,7 @@ class TrainContextBuilder: def with_dataloader(self) -> Self: # fix: change batch level batch_iter to sample level offset sampler_offset = self._context.batch_iter * self.trainer.train_config.batch_size - resumeable_sampler = ResumeableRandomSampler( + resumeable_sampler = ResumableDistributedSampler( data_source=self.trainer.train_config.dataset, start_epoch=self._context.epoch, start_iter=sampler_offset, diff --git a/tests/test_sampler.py b/tests/test_sampler.py index c6ce6cf..3876ebd 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -6,8 +6,8 @@ def test_random_sampler_consistency(random_dataset): dataset = random_dataset # Create two samplers with same seed - sampler1 = ResumeableRandomSampler(dataset, seed=42) - sampler2 = ResumeableRandomSampler(dataset, seed=42) + sampler1 = ResumableDistributedSampler(dataset, seed=42) + sampler2 = ResumableDistributedSampler(dataset, seed=42) indices1 = list(iter(sampler1)) indices2 = list(iter(sampler2)) @@ -19,8 +19,8 @@ def test_random_sampler_different_seeds(random_dataset): dataset = random_dataset # Create two samplers with different seeds - sampler1 = ResumeableRandomSampler(dataset, seed=42) - sampler2 = ResumeableRandomSampler(dataset, seed=123) + sampler1 = ResumableDistributedSampler(dataset, seed=42) + sampler2 = ResumableDistributedSampler(dataset, seed=123) indices1 = list(iter(sampler1)) indices2 = list(iter(sampler2)) @@ -34,7 +34,7 @@ def test_sampler_across_epochs(random_dataset): dataset = random_dataset n = len(dataset) - sampler = ResumeableRandomSampler(dataset, seed=42) + sampler = ResumableDistributedSampler(dataset, seed=42) # Get indices for first epoch epoch1_indices = list(iter(sampler))