diff --git a/khaosz/data/__init__.py b/khaosz/data/__init__.py index 9a3c5ae..cdb4362 100644 --- a/khaosz/data/__init__.py +++ b/khaosz/data/__init__.py @@ -1,16 +1,16 @@ -from khaosz.data.data_util import ( +from khaosz.data.dataset import ( BaseDataset, SeqDataset, DpoDataset, SftDataset, PpoDataset, MutiSegmentFetcher, - ResumeableRandomSampler, DatasetLoader, load_pkl_files, ) from khaosz.data.tokenizer import BpeTokenizer +from khaosz.data.sampler import ResumeableRandomSampler __all__ = [ "BaseDataset", @@ -19,8 +19,8 @@ __all__ = [ "SftDataset", "PpoDataset", "MutiSegmentFetcher", - "ResumeableRandomSampler", "DatasetLoader", "load_pkl_files", - "BpeTokenizer" + "BpeTokenizer", + "ResumeableRandomSampler" ] \ No newline at end of file diff --git a/khaosz/data/data_util.py b/khaosz/data/dataset.py similarity index 87% rename from khaosz/data/data_util.py rename to khaosz/data/dataset.py index d3608ad..42bb5dd 100644 --- a/khaosz/data/data_util.py +++ b/khaosz/data/dataset.py @@ -1,9 +1,10 @@ import torch import bisect import pickle as pkl + from abc import ABC, abstractmethod from torch import Tensor -from torch.utils.data import Dataset, Sampler +from torch.utils.data import Dataset from typing import Callable, List, Dict, Literal, Optional, Union MutiSeg = Dict[str, List[Tensor]] @@ -217,40 +218,3 @@ class DatasetLoader: dataset.load(load_path) return dataset - - -class ResumeableRandomSampler(Sampler[int]): - def __init__(self, data_source, start_epoch=0, start_iter=0, seed=42): - self.num_samples = len(data_source) - self.epoch = start_epoch - self.iter = start_iter - - generator = torch.Generator() - generator.manual_seed(seed) - - # consume previous epochs - for _ in range(start_epoch): - torch.randperm(self.num_samples, generator=generator) - - self.generator = generator - self._indices = None - - def _get_indices(self): - current_epoch_indices = torch.randperm(self.num_samples, generator=self.generator).tolist() - self._indices = current_epoch_indices[self.iter % self.num_samples:] - - def __iter__(self): - if self._indices is None: - self._get_indices() - - for i in self._indices: - self.iter += 1 - yield i - - self.epoch += 1 - self._indices = None - - def __len__(self): - if self._indices is None: - self._get_indices() - return len(self._indices) \ No newline at end of file diff --git a/khaosz/data/sampler.py b/khaosz/data/sampler.py new file mode 100644 index 0000000..f628885 --- /dev/null +++ b/khaosz/data/sampler.py @@ -0,0 +1,63 @@ +import torch +import torch.distributed as dist + +from torch.utils.data import Dataset, Sampler +from typing import Optional + + +class ResumeableRandomSampler(Sampler[int]): + def __init__( + self, + data_source: Dataset, + start_epoch: int=0, + start_iter: int=0, + seed: int=42, + process_group: Optional[dist.ProcessGroup]=None, + ): + self.epoch = start_epoch + self.iter = start_iter + self.seed = seed + self.num_samples = len(data_source) + + if process_group is not None: + # input process group + self.rank = dist.get_rank(process_group) + self.num_replicas = dist.get_world_size(process_group) + + elif dist.is_available() and dist.is_initialized(): + # use default process group + process_group = dist.group.WORLD + self.rank = dist.get_rank() + self.num_replicas = dist.get_world_size() + + else: + # single process + self.rank = 0 + self.num_replicas = 1 + + self._indices = None + + def _get_indices(self): + generator = torch.Generator() + generator.manual_seed(self.seed + self.epoch) + indices = torch.randperm(self.num_samples, generator=generator).tolist() + + self.iter = self.iter % self.num_samples + local_indices = indices[self.rank: self.num_samples: self.num_replicas] + self._indices = local_indices[self.iter:] + + def __iter__(self): + if self._indices is None: + self._get_indices() + + for i in self._indices: + self.iter += 1 + yield i + + self.epoch += 1 + self._indices = None + + def __len__(self): + if self._indices is None: + self._get_indices() + return len(self._indices) \ No newline at end of file diff --git a/tests/test_dataset_loader.py b/tests/test_dataset_loader.py index a5b54e8..a6e282c 100644 --- a/tests/test_dataset_loader.py +++ b/tests/test_dataset_loader.py @@ -4,7 +4,7 @@ import pickle import numpy as np from khaosz.trainer import * -from khaosz.data.data_util import * +from khaosz.data.dataset import * def test_dataset_loader_random_paths(base_test_env): diff --git a/tests/test_sampler.py b/tests/test_sampler.py index 3ff1568..c6ce6cf 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -1,5 +1,5 @@ from khaosz.trainer import * -from khaosz.data.data_util import * +from khaosz.data import * def test_random_sampler_consistency(random_dataset): """Test RandomSampler produces consistent results with same seed""" diff --git a/tests/test_train_config.py b/tests/test_train_config.py index 972f256..3cd2112 100644 --- a/tests/test_train_config.py +++ b/tests/test_train_config.py @@ -4,7 +4,7 @@ import numpy as np from khaosz.config import * from khaosz.trainer import * -from khaosz.data.data_util import * +from khaosz.data.dataset import * def test_different_batch_sizes(base_test_env, random_dataset): """Test training with different batch sizes""" diff --git a/tests/test_train_strategy.py b/tests/test_train_strategy.py index 102a9da..d97c254 100644 --- a/tests/test_train_strategy.py +++ b/tests/test_train_strategy.py @@ -4,7 +4,7 @@ import pytest from khaosz.config import * from khaosz.trainer.schedule import * -from khaosz.data.data_util import * +from khaosz.data.dataset import * def test_schedule_factory_random_configs():