fix(data/sampler): 修正拼写错误并增强采样器功能

This commit is contained in:
ViperEkura 2025-11-27 19:43:36 +08:00
parent 36b410384b
commit 019bfe4e05
4 changed files with 19 additions and 13 deletions

View File

@ -10,7 +10,7 @@ from khaosz.data.dataset import (
)
from khaosz.data.tokenizer import BpeTokenizer
from khaosz.data.sampler import ResumeableRandomSampler
from khaosz.data.sampler import ResumableDistributedSampler
__all__ = [
"BaseDataset",
@ -22,5 +22,5 @@ __all__ = [
"DatasetLoader",
"load_pkl_files",
"BpeTokenizer",
"ResumeableRandomSampler"
"ResumableDistributedSampler"
]

View File

@ -5,7 +5,7 @@ from torch.utils.data import Dataset, Sampler
from typing import Optional
class ResumeableRandomSampler(Sampler[int]):
class ResumableDistributedSampler(Sampler[int]):
def __init__(
self,
data_source: Dataset,
@ -13,6 +13,7 @@ class ResumeableRandomSampler(Sampler[int]):
start_iter: int=0,
seed: int=42,
drop_last: bool=False,
shuffle: bool=True,
process_group: Optional[dist.ProcessGroup]=None,
):
self.epoch = start_epoch
@ -37,6 +38,8 @@ class ResumeableRandomSampler(Sampler[int]):
self.num_replicas = 1
self.drop_last = drop_last
self.shuffle = shuffle
offset = 0 if drop_last else self.num_replicas - 1
self.num_samples_per_replica = (self.num_samples + offset) // self.num_replicas
self.total_size = self.num_samples_per_replica * self.num_replicas
@ -44,9 +47,12 @@ class ResumeableRandomSampler(Sampler[int]):
self._indices = None
def _get_indices(self):
generator = torch.Generator()
generator.manual_seed(self.seed + self.epoch)
indices = torch.randperm(self.num_samples, generator=generator).tolist()
if self.shuffle:
generator = torch.Generator()
generator.manual_seed(self.seed + self.epoch)
indices = torch.randperm(self.num_samples, generator=generator).tolist()
else:
indices = torch.arange(self.num_samples).tolist()
if not self.drop_last and self.num_samples < self.total_size:
padding_size = self.total_size - len(indices)

View File

@ -3,7 +3,7 @@ from typing import Optional, Self, TYPE_CHECKING
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from khaosz.config import Checkpoint
from khaosz.data import ResumeableRandomSampler
from khaosz.data import ResumableDistributedSampler
from khaosz.trainer.schedule import BaseScheduler, SchedulerFactory
if TYPE_CHECKING:
@ -83,7 +83,7 @@ class TrainContextBuilder:
def with_dataloader(self) -> Self:
# fix: change batch level batch_iter to sample level offset
sampler_offset = self._context.batch_iter * self.trainer.train_config.batch_size
resumeable_sampler = ResumeableRandomSampler(
resumeable_sampler = ResumableDistributedSampler(
data_source=self.trainer.train_config.dataset,
start_epoch=self._context.epoch,
start_iter=sampler_offset,

View File

@ -6,8 +6,8 @@ def test_random_sampler_consistency(random_dataset):
dataset = random_dataset
# Create two samplers with same seed
sampler1 = ResumeableRandomSampler(dataset, seed=42)
sampler2 = ResumeableRandomSampler(dataset, seed=42)
sampler1 = ResumableDistributedSampler(dataset, seed=42)
sampler2 = ResumableDistributedSampler(dataset, seed=42)
indices1 = list(iter(sampler1))
indices2 = list(iter(sampler2))
@ -19,8 +19,8 @@ def test_random_sampler_different_seeds(random_dataset):
dataset = random_dataset
# Create two samplers with different seeds
sampler1 = ResumeableRandomSampler(dataset, seed=42)
sampler2 = ResumeableRandomSampler(dataset, seed=123)
sampler1 = ResumableDistributedSampler(dataset, seed=42)
sampler2 = ResumableDistributedSampler(dataset, seed=123)
indices1 = list(iter(sampler1))
indices2 = list(iter(sampler2))
@ -34,7 +34,7 @@ def test_sampler_across_epochs(random_dataset):
dataset = random_dataset
n = len(dataset)
sampler = ResumeableRandomSampler(dataset, seed=42)
sampler = ResumableDistributedSampler(dataset, seed=42)
# Get indices for first epoch
epoch1_indices = list(iter(sampler))