diff --git a/khaosz/trainer/__init__.py b/khaosz/trainer/__init__.py index 2954dbb..39470f0 100644 --- a/khaosz/trainer/__init__.py +++ b/khaosz/trainer/__init__.py @@ -1,4 +1,4 @@ -from khaosz.trainer.dataset import DatasetLoader +from khaosz.trainer.data_util import DatasetLoader from khaosz.trainer.trainer import Trainer from khaosz.trainer.strategy import ( TrainConfig, @@ -7,7 +7,7 @@ from khaosz.trainer.strategy import ( StrategyFactory, SchedulerFactory ) -from khaosz.trainer.callback import ( +from khaosz.trainer.trainer_callback import ( ProgressBarCallback, CheckpointCallback, TrainerCallback, diff --git a/khaosz/trainer/dataset.py b/khaosz/trainer/data_util.py similarity index 85% rename from khaosz/trainer/dataset.py rename to khaosz/trainer/data_util.py index 3cf466e..ed19708 100644 --- a/khaosz/trainer/dataset.py +++ b/khaosz/trainer/data_util.py @@ -3,7 +3,7 @@ import bisect import pickle as pkl from abc import ABC, abstractmethod from torch import Tensor -from torch.utils.data import Dataset +from torch.utils.data import Dataset, Sampler from typing import Callable, List, Dict, Literal, Union MutiSeg = Dict[str, List[Tensor]] @@ -264,4 +264,58 @@ class DatasetLoader: dataset = dataset_router[train_type](max_len, device) dataset.load(load_path) - return dataset \ No newline at end of file + return dataset + + +class RandomSampler(Sampler[int]): + def __init__(self, data_source, generator=None, seed=42): + self.data_source = data_source + self.seed = seed + self.epoch = 0 + self.current_index = 0 + self._indices = None + + if generator is None: + self.generator = torch.Generator() + self.generator.manual_seed(seed) + else: + self.generator = generator + + def _generate_indices(self): + n = len(self.data_source) + self._indices = torch.randperm(n, generator=self.generator).tolist() + + def __iter__(self): + n = len(self.data_source) + + if self._indices is None: + self._generate_indices() + + for i in range(self.current_index, n): + yield self._indices[i] + + self.epoch += 1 + self.current_index = 0 + self._indices = None + + def __len__(self): + return len(self.data_source) - self.current_index + + def state_dict(self): + return { + 'epoch': self.epoch, + 'current_index': self.current_index, + 'seed': self.seed, + 'generator_state': self.generator.get_state() if self.generator else None, + 'indices': self._indices + } + + def load_state_dict(self, state_dict): + self.epoch = state_dict['epoch'] + self.current_index = state_dict['current_index'] + self.seed = state_dict['seed'] + + if self.generator and state_dict['generator_state'] is not None: + self.generator.set_state(state_dict['generator_state']) + + self._indices = state_dict['indices'] \ No newline at end of file diff --git a/khaosz/trainer/trainer.py b/khaosz/trainer/trainer.py index f374a65..6f8a46f 100644 --- a/khaosz/trainer/trainer.py +++ b/khaosz/trainer/trainer.py @@ -1,11 +1,12 @@ import torch import itertools from typing import Optional, List -from torch.utils.data import DataLoader, RandomSampler +from torch.utils.data import DataLoader from khaosz.core import ModelParameter, Checkpoint +from khaosz.trainer.data_util import RandomSampler from khaosz.trainer.strategy import TrainConfig, ScheduleConfig -from khaosz.trainer.callback import ( +from khaosz.trainer.trainer_callback import ( TrainerCallback, ProgressBarCallback, CheckpointCallback, @@ -42,7 +43,11 @@ class Trainer: def _create_dataloader(self, start_index: int = 0) -> DataLoader: seed = self.train_config.random_seed generator = torch.Generator().manual_seed(seed) - sampler = RandomSampler(self.train_config.dataset, generator=generator) + sampler = RandomSampler( + self.train_config.dataset, + generator=generator, + seed=seed + ) dataloader = DataLoader( self.train_config.dataset, batch_size=self.train_config.batch_size, diff --git a/khaosz/trainer/callback.py b/khaosz/trainer/trainer_callback.py similarity index 100% rename from khaosz/trainer/callback.py rename to khaosz/trainer/trainer_callback.py