fix(data/sampler): 修正拼写错误并增强采样器功能

This commit is contained in:
ViperEkura 2025-11-27 19:43:36 +08:00
parent 36b410384b
commit 019bfe4e05
4 changed files with 19 additions and 13 deletions

View File

@ -10,7 +10,7 @@ from khaosz.data.dataset import (
) )
from khaosz.data.tokenizer import BpeTokenizer from khaosz.data.tokenizer import BpeTokenizer
from khaosz.data.sampler import ResumeableRandomSampler from khaosz.data.sampler import ResumableDistributedSampler
__all__ = [ __all__ = [
"BaseDataset", "BaseDataset",
@ -22,5 +22,5 @@ __all__ = [
"DatasetLoader", "DatasetLoader",
"load_pkl_files", "load_pkl_files",
"BpeTokenizer", "BpeTokenizer",
"ResumeableRandomSampler" "ResumableDistributedSampler"
] ]

View File

@ -5,7 +5,7 @@ from torch.utils.data import Dataset, Sampler
from typing import Optional from typing import Optional
class ResumeableRandomSampler(Sampler[int]): class ResumableDistributedSampler(Sampler[int]):
def __init__( def __init__(
self, self,
data_source: Dataset, data_source: Dataset,
@ -13,6 +13,7 @@ class ResumeableRandomSampler(Sampler[int]):
start_iter: int=0, start_iter: int=0,
seed: int=42, seed: int=42,
drop_last: bool=False, drop_last: bool=False,
shuffle: bool=True,
process_group: Optional[dist.ProcessGroup]=None, process_group: Optional[dist.ProcessGroup]=None,
): ):
self.epoch = start_epoch self.epoch = start_epoch
@ -37,6 +38,8 @@ class ResumeableRandomSampler(Sampler[int]):
self.num_replicas = 1 self.num_replicas = 1
self.drop_last = drop_last self.drop_last = drop_last
self.shuffle = shuffle
offset = 0 if drop_last else self.num_replicas - 1 offset = 0 if drop_last else self.num_replicas - 1
self.num_samples_per_replica = (self.num_samples + offset) // self.num_replicas self.num_samples_per_replica = (self.num_samples + offset) // self.num_replicas
self.total_size = self.num_samples_per_replica * self.num_replicas self.total_size = self.num_samples_per_replica * self.num_replicas
@ -44,9 +47,12 @@ class ResumeableRandomSampler(Sampler[int]):
self._indices = None self._indices = None
def _get_indices(self): def _get_indices(self):
if self.shuffle:
generator = torch.Generator() generator = torch.Generator()
generator.manual_seed(self.seed + self.epoch) generator.manual_seed(self.seed + self.epoch)
indices = torch.randperm(self.num_samples, generator=generator).tolist() indices = torch.randperm(self.num_samples, generator=generator).tolist()
else:
indices = torch.arange(self.num_samples).tolist()
if not self.drop_last and self.num_samples < self.total_size: if not self.drop_last and self.num_samples < self.total_size:
padding_size = self.total_size - len(indices) padding_size = self.total_size - len(indices)

View File

@ -3,7 +3,7 @@ from typing import Optional, Self, TYPE_CHECKING
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from khaosz.config import Checkpoint from khaosz.config import Checkpoint
from khaosz.data import ResumeableRandomSampler from khaosz.data import ResumableDistributedSampler
from khaosz.trainer.schedule import BaseScheduler, SchedulerFactory from khaosz.trainer.schedule import BaseScheduler, SchedulerFactory
if TYPE_CHECKING: if TYPE_CHECKING:
@ -83,7 +83,7 @@ class TrainContextBuilder:
def with_dataloader(self) -> Self: def with_dataloader(self) -> Self:
# fix: change batch level batch_iter to sample level offset # fix: change batch level batch_iter to sample level offset
sampler_offset = self._context.batch_iter * self.trainer.train_config.batch_size sampler_offset = self._context.batch_iter * self.trainer.train_config.batch_size
resumeable_sampler = ResumeableRandomSampler( resumeable_sampler = ResumableDistributedSampler(
data_source=self.trainer.train_config.dataset, data_source=self.trainer.train_config.dataset,
start_epoch=self._context.epoch, start_epoch=self._context.epoch,
start_iter=sampler_offset, start_iter=sampler_offset,

View File

@ -6,8 +6,8 @@ def test_random_sampler_consistency(random_dataset):
dataset = random_dataset dataset = random_dataset
# Create two samplers with same seed # Create two samplers with same seed
sampler1 = ResumeableRandomSampler(dataset, seed=42) sampler1 = ResumableDistributedSampler(dataset, seed=42)
sampler2 = ResumeableRandomSampler(dataset, seed=42) sampler2 = ResumableDistributedSampler(dataset, seed=42)
indices1 = list(iter(sampler1)) indices1 = list(iter(sampler1))
indices2 = list(iter(sampler2)) indices2 = list(iter(sampler2))
@ -19,8 +19,8 @@ def test_random_sampler_different_seeds(random_dataset):
dataset = random_dataset dataset = random_dataset
# Create two samplers with different seeds # Create two samplers with different seeds
sampler1 = ResumeableRandomSampler(dataset, seed=42) sampler1 = ResumableDistributedSampler(dataset, seed=42)
sampler2 = ResumeableRandomSampler(dataset, seed=123) sampler2 = ResumableDistributedSampler(dataset, seed=123)
indices1 = list(iter(sampler1)) indices1 = list(iter(sampler1))
indices2 = list(iter(sampler2)) indices2 = list(iter(sampler2))
@ -34,7 +34,7 @@ def test_sampler_across_epochs(random_dataset):
dataset = random_dataset dataset = random_dataset
n = len(dataset) n = len(dataset)
sampler = ResumeableRandomSampler(dataset, seed=42) sampler = ResumableDistributedSampler(dataset, seed=42)
# Get indices for first epoch # Get indices for first epoch
epoch1_indices = list(iter(sampler)) epoch1_indices = list(iter(sampler))