refactor(data): 重构数据模块结构并优化可恢复采样器实现

This commit is contained in:
ViperEkura 2025-11-27 18:16:35 +08:00
parent 5daf63a7a4
commit 09963a3beb
7 changed files with 73 additions and 46 deletions

View File

@ -1,16 +1,16 @@
from khaosz.data.data_util import ( from khaosz.data.dataset import (
BaseDataset, BaseDataset,
SeqDataset, SeqDataset,
DpoDataset, DpoDataset,
SftDataset, SftDataset,
PpoDataset, PpoDataset,
MutiSegmentFetcher, MutiSegmentFetcher,
ResumeableRandomSampler,
DatasetLoader, DatasetLoader,
load_pkl_files, load_pkl_files,
) )
from khaosz.data.tokenizer import BpeTokenizer from khaosz.data.tokenizer import BpeTokenizer
from khaosz.data.sampler import ResumeableRandomSampler
__all__ = [ __all__ = [
"BaseDataset", "BaseDataset",
@ -19,8 +19,8 @@ __all__ = [
"SftDataset", "SftDataset",
"PpoDataset", "PpoDataset",
"MutiSegmentFetcher", "MutiSegmentFetcher",
"ResumeableRandomSampler",
"DatasetLoader", "DatasetLoader",
"load_pkl_files", "load_pkl_files",
"BpeTokenizer" "BpeTokenizer",
"ResumeableRandomSampler"
] ]

View File

@ -1,9 +1,10 @@
import torch import torch
import bisect import bisect
import pickle as pkl import pickle as pkl
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from torch import Tensor from torch import Tensor
from torch.utils.data import Dataset, Sampler from torch.utils.data import Dataset
from typing import Callable, List, Dict, Literal, Optional, Union from typing import Callable, List, Dict, Literal, Optional, Union
MutiSeg = Dict[str, List[Tensor]] MutiSeg = Dict[str, List[Tensor]]
@ -217,40 +218,3 @@ class DatasetLoader:
dataset.load(load_path) dataset.load(load_path)
return dataset return dataset
class ResumeableRandomSampler(Sampler[int]):
def __init__(self, data_source, start_epoch=0, start_iter=0, seed=42):
self.num_samples = len(data_source)
self.epoch = start_epoch
self.iter = start_iter
generator = torch.Generator()
generator.manual_seed(seed)
# consume previous epochs
for _ in range(start_epoch):
torch.randperm(self.num_samples, generator=generator)
self.generator = generator
self._indices = None
def _get_indices(self):
current_epoch_indices = torch.randperm(self.num_samples, generator=self.generator).tolist()
self._indices = current_epoch_indices[self.iter % self.num_samples:]
def __iter__(self):
if self._indices is None:
self._get_indices()
for i in self._indices:
self.iter += 1
yield i
self.epoch += 1
self._indices = None
def __len__(self):
if self._indices is None:
self._get_indices()
return len(self._indices)

63
khaosz/data/sampler.py Normal file
View File

@ -0,0 +1,63 @@
import torch
import torch.distributed as dist
from torch.utils.data import Dataset, Sampler
from typing import Optional
class ResumeableRandomSampler(Sampler[int]):
def __init__(
self,
data_source: Dataset,
start_epoch: int=0,
start_iter: int=0,
seed: int=42,
process_group: Optional[dist.ProcessGroup]=None,
):
self.epoch = start_epoch
self.iter = start_iter
self.seed = seed
self.num_samples = len(data_source)
if process_group is not None:
# input process group
self.rank = dist.get_rank(process_group)
self.num_replicas = dist.get_world_size(process_group)
elif dist.is_available() and dist.is_initialized():
# use default process group
process_group = dist.group.WORLD
self.rank = dist.get_rank()
self.num_replicas = dist.get_world_size()
else:
# single process
self.rank = 0
self.num_replicas = 1
self._indices = None
def _get_indices(self):
generator = torch.Generator()
generator.manual_seed(self.seed + self.epoch)
indices = torch.randperm(self.num_samples, generator=generator).tolist()
self.iter = self.iter % self.num_samples
local_indices = indices[self.rank: self.num_samples: self.num_replicas]
self._indices = local_indices[self.iter:]
def __iter__(self):
if self._indices is None:
self._get_indices()
for i in self._indices:
self.iter += 1
yield i
self.epoch += 1
self._indices = None
def __len__(self):
if self._indices is None:
self._get_indices()
return len(self._indices)

View File

@ -4,7 +4,7 @@ import pickle
import numpy as np import numpy as np
from khaosz.trainer import * from khaosz.trainer import *
from khaosz.data.data_util import * from khaosz.data.dataset import *
def test_dataset_loader_random_paths(base_test_env): def test_dataset_loader_random_paths(base_test_env):

View File

@ -1,5 +1,5 @@
from khaosz.trainer import * from khaosz.trainer import *
from khaosz.data.data_util import * from khaosz.data import *
def test_random_sampler_consistency(random_dataset): def test_random_sampler_consistency(random_dataset):
"""Test RandomSampler produces consistent results with same seed""" """Test RandomSampler produces consistent results with same seed"""

View File

@ -4,7 +4,7 @@ import numpy as np
from khaosz.config import * from khaosz.config import *
from khaosz.trainer import * from khaosz.trainer import *
from khaosz.data.data_util import * from khaosz.data.dataset import *
def test_different_batch_sizes(base_test_env, random_dataset): def test_different_batch_sizes(base_test_env, random_dataset):
"""Test training with different batch sizes""" """Test training with different batch sizes"""

View File

@ -4,7 +4,7 @@ import pytest
from khaosz.config import * from khaosz.config import *
from khaosz.trainer.schedule import * from khaosz.trainer.schedule import *
from khaosz.data.data_util import * from khaosz.data.dataset import *
def test_schedule_factory_random_configs(): def test_schedule_factory_random_configs():