fix(data/sampler): 修正拼写错误并增强采样器功能
This commit is contained in:
parent
36b410384b
commit
019bfe4e05
|
|
@ -10,7 +10,7 @@ from khaosz.data.dataset import (
|
|||
)
|
||||
|
||||
from khaosz.data.tokenizer import BpeTokenizer
|
||||
from khaosz.data.sampler import ResumeableRandomSampler
|
||||
from khaosz.data.sampler import ResumableDistributedSampler
|
||||
|
||||
__all__ = [
|
||||
"BaseDataset",
|
||||
|
|
@ -22,5 +22,5 @@ __all__ = [
|
|||
"DatasetLoader",
|
||||
"load_pkl_files",
|
||||
"BpeTokenizer",
|
||||
"ResumeableRandomSampler"
|
||||
"ResumableDistributedSampler"
|
||||
]
|
||||
|
|
@ -5,7 +5,7 @@ from torch.utils.data import Dataset, Sampler
|
|||
from typing import Optional
|
||||
|
||||
|
||||
class ResumeableRandomSampler(Sampler[int]):
|
||||
class ResumableDistributedSampler(Sampler[int]):
|
||||
def __init__(
|
||||
self,
|
||||
data_source: Dataset,
|
||||
|
|
@ -13,6 +13,7 @@ class ResumeableRandomSampler(Sampler[int]):
|
|||
start_iter: int=0,
|
||||
seed: int=42,
|
||||
drop_last: bool=False,
|
||||
shuffle: bool=True,
|
||||
process_group: Optional[dist.ProcessGroup]=None,
|
||||
):
|
||||
self.epoch = start_epoch
|
||||
|
|
@ -37,6 +38,8 @@ class ResumeableRandomSampler(Sampler[int]):
|
|||
self.num_replicas = 1
|
||||
|
||||
self.drop_last = drop_last
|
||||
self.shuffle = shuffle
|
||||
|
||||
offset = 0 if drop_last else self.num_replicas - 1
|
||||
self.num_samples_per_replica = (self.num_samples + offset) // self.num_replicas
|
||||
self.total_size = self.num_samples_per_replica * self.num_replicas
|
||||
|
|
@ -44,9 +47,12 @@ class ResumeableRandomSampler(Sampler[int]):
|
|||
self._indices = None
|
||||
|
||||
def _get_indices(self):
|
||||
if self.shuffle:
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(self.seed + self.epoch)
|
||||
indices = torch.randperm(self.num_samples, generator=generator).tolist()
|
||||
else:
|
||||
indices = torch.arange(self.num_samples).tolist()
|
||||
|
||||
if not self.drop_last and self.num_samples < self.total_size:
|
||||
padding_size = self.total_size - len(indices)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from typing import Optional, Self, TYPE_CHECKING
|
|||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader
|
||||
from khaosz.config import Checkpoint
|
||||
from khaosz.data import ResumeableRandomSampler
|
||||
from khaosz.data import ResumableDistributedSampler
|
||||
from khaosz.trainer.schedule import BaseScheduler, SchedulerFactory
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -83,7 +83,7 @@ class TrainContextBuilder:
|
|||
def with_dataloader(self) -> Self:
|
||||
# fix: change batch level batch_iter to sample level offset
|
||||
sampler_offset = self._context.batch_iter * self.trainer.train_config.batch_size
|
||||
resumeable_sampler = ResumeableRandomSampler(
|
||||
resumeable_sampler = ResumableDistributedSampler(
|
||||
data_source=self.trainer.train_config.dataset,
|
||||
start_epoch=self._context.epoch,
|
||||
start_iter=sampler_offset,
|
||||
|
|
|
|||
|
|
@ -6,8 +6,8 @@ def test_random_sampler_consistency(random_dataset):
|
|||
dataset = random_dataset
|
||||
|
||||
# Create two samplers with same seed
|
||||
sampler1 = ResumeableRandomSampler(dataset, seed=42)
|
||||
sampler2 = ResumeableRandomSampler(dataset, seed=42)
|
||||
sampler1 = ResumableDistributedSampler(dataset, seed=42)
|
||||
sampler2 = ResumableDistributedSampler(dataset, seed=42)
|
||||
|
||||
indices1 = list(iter(sampler1))
|
||||
indices2 = list(iter(sampler2))
|
||||
|
|
@ -19,8 +19,8 @@ def test_random_sampler_different_seeds(random_dataset):
|
|||
dataset = random_dataset
|
||||
|
||||
# Create two samplers with different seeds
|
||||
sampler1 = ResumeableRandomSampler(dataset, seed=42)
|
||||
sampler2 = ResumeableRandomSampler(dataset, seed=123)
|
||||
sampler1 = ResumableDistributedSampler(dataset, seed=42)
|
||||
sampler2 = ResumableDistributedSampler(dataset, seed=123)
|
||||
|
||||
indices1 = list(iter(sampler1))
|
||||
indices2 = list(iter(sampler2))
|
||||
|
|
@ -34,7 +34,7 @@ def test_sampler_across_epochs(random_dataset):
|
|||
dataset = random_dataset
|
||||
n = len(dataset)
|
||||
|
||||
sampler = ResumeableRandomSampler(dataset, seed=42)
|
||||
sampler = ResumableDistributedSampler(dataset, seed=42)
|
||||
|
||||
# Get indices for first epoch
|
||||
epoch1_indices = list(iter(sampler))
|
||||
|
|
|
|||
Loading…
Reference in New Issue