fix(data/sampler): 修正拼写错误并增强采样器功能
This commit is contained in:
parent
36b410384b
commit
019bfe4e05
|
|
@ -10,7 +10,7 @@ from khaosz.data.dataset import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from khaosz.data.tokenizer import BpeTokenizer
|
from khaosz.data.tokenizer import BpeTokenizer
|
||||||
from khaosz.data.sampler import ResumeableRandomSampler
|
from khaosz.data.sampler import ResumableDistributedSampler
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseDataset",
|
"BaseDataset",
|
||||||
|
|
@ -22,5 +22,5 @@ __all__ = [
|
||||||
"DatasetLoader",
|
"DatasetLoader",
|
||||||
"load_pkl_files",
|
"load_pkl_files",
|
||||||
"BpeTokenizer",
|
"BpeTokenizer",
|
||||||
"ResumeableRandomSampler"
|
"ResumableDistributedSampler"
|
||||||
]
|
]
|
||||||
|
|
@ -5,7 +5,7 @@ from torch.utils.data import Dataset, Sampler
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
class ResumeableRandomSampler(Sampler[int]):
|
class ResumableDistributedSampler(Sampler[int]):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
data_source: Dataset,
|
data_source: Dataset,
|
||||||
|
|
@ -13,6 +13,7 @@ class ResumeableRandomSampler(Sampler[int]):
|
||||||
start_iter: int=0,
|
start_iter: int=0,
|
||||||
seed: int=42,
|
seed: int=42,
|
||||||
drop_last: bool=False,
|
drop_last: bool=False,
|
||||||
|
shuffle: bool=True,
|
||||||
process_group: Optional[dist.ProcessGroup]=None,
|
process_group: Optional[dist.ProcessGroup]=None,
|
||||||
):
|
):
|
||||||
self.epoch = start_epoch
|
self.epoch = start_epoch
|
||||||
|
|
@ -37,6 +38,8 @@ class ResumeableRandomSampler(Sampler[int]):
|
||||||
self.num_replicas = 1
|
self.num_replicas = 1
|
||||||
|
|
||||||
self.drop_last = drop_last
|
self.drop_last = drop_last
|
||||||
|
self.shuffle = shuffle
|
||||||
|
|
||||||
offset = 0 if drop_last else self.num_replicas - 1
|
offset = 0 if drop_last else self.num_replicas - 1
|
||||||
self.num_samples_per_replica = (self.num_samples + offset) // self.num_replicas
|
self.num_samples_per_replica = (self.num_samples + offset) // self.num_replicas
|
||||||
self.total_size = self.num_samples_per_replica * self.num_replicas
|
self.total_size = self.num_samples_per_replica * self.num_replicas
|
||||||
|
|
@ -44,9 +47,12 @@ class ResumeableRandomSampler(Sampler[int]):
|
||||||
self._indices = None
|
self._indices = None
|
||||||
|
|
||||||
def _get_indices(self):
|
def _get_indices(self):
|
||||||
|
if self.shuffle:
|
||||||
generator = torch.Generator()
|
generator = torch.Generator()
|
||||||
generator.manual_seed(self.seed + self.epoch)
|
generator.manual_seed(self.seed + self.epoch)
|
||||||
indices = torch.randperm(self.num_samples, generator=generator).tolist()
|
indices = torch.randperm(self.num_samples, generator=generator).tolist()
|
||||||
|
else:
|
||||||
|
indices = torch.arange(self.num_samples).tolist()
|
||||||
|
|
||||||
if not self.drop_last and self.num_samples < self.total_size:
|
if not self.drop_last and self.num_samples < self.total_size:
|
||||||
padding_size = self.total_size - len(indices)
|
padding_size = self.total_size - len(indices)
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ from typing import Optional, Self, TYPE_CHECKING
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from khaosz.config import Checkpoint
|
from khaosz.config import Checkpoint
|
||||||
from khaosz.data import ResumeableRandomSampler
|
from khaosz.data import ResumableDistributedSampler
|
||||||
from khaosz.trainer.schedule import BaseScheduler, SchedulerFactory
|
from khaosz.trainer.schedule import BaseScheduler, SchedulerFactory
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|
@ -83,7 +83,7 @@ class TrainContextBuilder:
|
||||||
def with_dataloader(self) -> Self:
|
def with_dataloader(self) -> Self:
|
||||||
# fix: change batch level batch_iter to sample level offset
|
# fix: change batch level batch_iter to sample level offset
|
||||||
sampler_offset = self._context.batch_iter * self.trainer.train_config.batch_size
|
sampler_offset = self._context.batch_iter * self.trainer.train_config.batch_size
|
||||||
resumeable_sampler = ResumeableRandomSampler(
|
resumeable_sampler = ResumableDistributedSampler(
|
||||||
data_source=self.trainer.train_config.dataset,
|
data_source=self.trainer.train_config.dataset,
|
||||||
start_epoch=self._context.epoch,
|
start_epoch=self._context.epoch,
|
||||||
start_iter=sampler_offset,
|
start_iter=sampler_offset,
|
||||||
|
|
|
||||||
|
|
@ -6,8 +6,8 @@ def test_random_sampler_consistency(random_dataset):
|
||||||
dataset = random_dataset
|
dataset = random_dataset
|
||||||
|
|
||||||
# Create two samplers with same seed
|
# Create two samplers with same seed
|
||||||
sampler1 = ResumeableRandomSampler(dataset, seed=42)
|
sampler1 = ResumableDistributedSampler(dataset, seed=42)
|
||||||
sampler2 = ResumeableRandomSampler(dataset, seed=42)
|
sampler2 = ResumableDistributedSampler(dataset, seed=42)
|
||||||
|
|
||||||
indices1 = list(iter(sampler1))
|
indices1 = list(iter(sampler1))
|
||||||
indices2 = list(iter(sampler2))
|
indices2 = list(iter(sampler2))
|
||||||
|
|
@ -19,8 +19,8 @@ def test_random_sampler_different_seeds(random_dataset):
|
||||||
dataset = random_dataset
|
dataset = random_dataset
|
||||||
|
|
||||||
# Create two samplers with different seeds
|
# Create two samplers with different seeds
|
||||||
sampler1 = ResumeableRandomSampler(dataset, seed=42)
|
sampler1 = ResumableDistributedSampler(dataset, seed=42)
|
||||||
sampler2 = ResumeableRandomSampler(dataset, seed=123)
|
sampler2 = ResumableDistributedSampler(dataset, seed=123)
|
||||||
|
|
||||||
indices1 = list(iter(sampler1))
|
indices1 = list(iter(sampler1))
|
||||||
indices2 = list(iter(sampler2))
|
indices2 = list(iter(sampler2))
|
||||||
|
|
@ -34,7 +34,7 @@ def test_sampler_across_epochs(random_dataset):
|
||||||
dataset = random_dataset
|
dataset = random_dataset
|
||||||
n = len(dataset)
|
n = len(dataset)
|
||||||
|
|
||||||
sampler = ResumeableRandomSampler(dataset, seed=42)
|
sampler = ResumableDistributedSampler(dataset, seed=42)
|
||||||
|
|
||||||
# Get indices for first epoch
|
# Get indices for first epoch
|
||||||
epoch1_indices = list(iter(sampler))
|
epoch1_indices = list(iter(sampler))
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue