refactor: 重构数据模块中的数据集类命名和文件结构

This commit is contained in:
ViperEkura 2026-03-19 22:37:32 +08:00
parent 0f518473af
commit 50f76cd7c7
6 changed files with 62 additions and 65 deletions

View File

@ -1,9 +1,9 @@
from khaosz.data.dataset import ( from khaosz.data.dataset import (
BaseDataset, BaseDataset,
SeqDataset, SEQDataset,
DpoDataset, DPODataset,
SftDataset, SFTDataset,
PpoDataset, GRPODataset,
MultiSegmentFetcher, MultiSegmentFetcher,
DatasetLoader DatasetLoader
) )
@ -13,10 +13,10 @@ from khaosz.data.sampler import ResumableDistributedSampler
__all__ = [ __all__ = [
"BaseDataset", "BaseDataset",
"SeqDataset", "SEQDataset",
"DpoDataset", "SFTDataset",
"SftDataset", "DPODataset",
"PpoDataset", "GRPODataset",
"MultiSegmentFetcher", "MultiSegmentFetcher",
"DatasetLoader", "DatasetLoader",
"BpeTokenizer", "BpeTokenizer",

View File

@ -4,7 +4,7 @@ import bisect
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from torch import Tensor from torch import Tensor
from torch.utils.data import Dataset from torch.utils.data import Dataset
from khaosz.data.file import load_h5 from khaosz.data.serialization import load_h5
from typing import Callable, List, Dict, Literal, Optional, Union from typing import Callable, List, Dict, Literal, Optional, Union
@ -105,7 +105,7 @@ class BaseDataset(Dataset, ABC):
return (self.total_samples - 1 - self.window_size) // self.stride + 1 return (self.total_samples - 1 - self.window_size) // self.stride + 1
class SeqDataset(BaseDataset): class SEQDataset(BaseDataset):
def __init__(self, window_size: int, stride: int): def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride) super().__init__(window_size, stride)
self.fetcher = MultiSegmentFetcher(self.segments) self.fetcher = MultiSegmentFetcher(self.segments)
@ -123,7 +123,7 @@ class SeqDataset(BaseDataset):
return {"input_ids": x, "target_ids": y} return {"input_ids": x, "target_ids": y}
class SftDataset(BaseDataset): class SFTDataset(BaseDataset):
def __init__(self, window_size: int, stride: int): def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride) super().__init__(window_size, stride)
self.fetcher = MultiSegmentFetcher(self.segments) self.fetcher = MultiSegmentFetcher(self.segments)
@ -141,7 +141,7 @@ class SftDataset(BaseDataset):
return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask} return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask}
class DpoDataset(BaseDataset): class DPODataset(BaseDataset):
def __init__(self, window_size: int, stride: int): def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride) super().__init__(window_size, stride)
self.fetcher = MultiSegmentFetcher(self.segments) self.fetcher = MultiSegmentFetcher(self.segments)
@ -160,7 +160,7 @@ class DpoDataset(BaseDataset):
return {"chosen": chosen, "rejected": rejected, "chosen_mask": chosen_mask, "rejected_mask": rejected_mask} return {"chosen": chosen, "rejected": rejected, "chosen_mask": chosen_mask, "rejected_mask": rejected_mask}
class PpoDataset(BaseDataset): class GRPODataset(BaseDataset):
def __init__(self, window_size: int, stride: int): def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride) super().__init__(window_size, stride)
self.fetcher = MultiSegmentFetcher(self.segments) self.fetcher = MultiSegmentFetcher(self.segments)
@ -171,12 +171,12 @@ class PpoDataset(BaseDataset):
def __getitem__(self, index: int) -> Dict[str, Tensor]: def __getitem__(self, index: int) -> Dict[str, Tensor]:
begin_idx, end_idx = self.get_index(index) begin_idx, end_idx = self.get_index(index)
input_ids = self._fetch_data(begin_idx, end_idx, "input_ids"), prompts = self._fetch_data(begin_idx, end_idx, "prompts"),
actions = self._fetch_data(begin_idx, end_idx, "actions"), responses = self._fetch_data(begin_idx, end_idx, "responses"),
logprobs = self._fetch_data(begin_idx, end_idx, "logprobs"), masks = self._fetch_data(begin_idx, end_idx, "masks"),
rewards = self._fetch_data(begin_idx, end_idx, "rewards") rewards = self._fetch_data(begin_idx, end_idx, "rewards")
return {"input_ids": input_ids, "actions": actions, "logprobs": logprobs, "rewards": rewards} return {"prompts": prompts, "responses": responses, "masks": masks, "rewards": rewards}
class DatasetLoader: class DatasetLoader:
@ -191,9 +191,10 @@ class DatasetLoader:
stride = window_size stride = window_size
dataset_router: Dict[str, Callable[[int], BaseDataset]] = { dataset_router: Dict[str, Callable[[int], BaseDataset]] = {
"seq": lambda window_size: SeqDataset(window_size, stride), "seq": lambda window_size: SEQDataset(window_size, stride),
"sft": lambda window_size: SftDataset(window_size, stride), "sft": lambda window_size: SFTDataset(window_size, stride),
"dpo": lambda window_size: DpoDataset(window_size, stride), "dpo": lambda window_size: DPODataset(window_size, stride),
"grpo": lambda window_size: GRPODataset(window_size, stride),
} }
dataset = dataset_router[train_type](window_size) dataset = dataset_router[train_type](window_size)
dataset.load(load_path) dataset.load(load_path)

View File

@ -1,42 +0,0 @@
import os
import h5py
import torch
from pathlib import Path
from torch import Tensor
from typing import Dict, List
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
os.makedirs(file_path, exist_ok=True)
full_file_path = os.path.join(file_path, f"{file_name}.h5")
with h5py.File(full_file_path, 'w') as f:
for key, tensors in tensor_group.items():
grp = f.create_group(key)
for idx, tensor in enumerate(tensors):
arr = tensor.cpu().numpy()
grp.create_dataset(f'data_{idx}', data=arr)
def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
tensor_group: Dict[str, List[Tensor]] = {}
root_path = Path(file_path)
h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5"))
for h5_file in h5_files:
with h5py.File(h5_file, 'r') as f:
for key in f.keys():
grp = f[key]
dsets = []
for dset_name in grp.keys():
dset = grp[dset_name]
tensor = torch.from_numpy(dset[:])
if share_memory:
tensor = tensor.share_memory_()
dsets.append(tensor)
if tensor_group.get(key) is None:
tensor_group[key] = []
tensor_group[key].extend(dsets)
return tensor_group

View File

@ -1,11 +1,49 @@
import os
import h5py
import torch
import json import json
import safetensors.torch as st import safetensors.torch as st
import torch.distributed as dist import torch.distributed as dist
from pathlib import Path from pathlib import Path
from typing import Dict, Any from torch import Tensor
from typing import Any, Dict, List
from khaosz.parallel.setup import get_rank from khaosz.parallel.setup import get_rank
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
os.makedirs(file_path, exist_ok=True)
full_file_path = os.path.join(file_path, f"{file_name}.h5")
with h5py.File(full_file_path, 'w') as f:
for key, tensors in tensor_group.items():
grp = f.create_group(key)
for idx, tensor in enumerate(tensors):
arr = tensor.cpu().numpy()
grp.create_dataset(f'data_{idx}', data=arr)
def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
tensor_group: Dict[str, List[Tensor]] = {}
root_path = Path(file_path)
h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5"))
for h5_file in h5_files:
with h5py.File(h5_file, 'r') as f:
for key in f.keys():
grp = f[key]
dsets = []
for dset_name in grp.keys():
dset = grp[dset_name]
tensor = torch.from_numpy(dset[:])
if share_memory:
tensor = tensor.share_memory_()
dsets.append(tensor)
if tensor_group.get(key) is None:
tensor_group[key] = []
tensor_group[key].extend(dsets)
return tensor_group
class Checkpoint: class Checkpoint:
def __init__( def __init__(

View File

@ -20,7 +20,7 @@ from khaosz.trainer.metric_util import (
ctx_get_grad_std, ctx_get_grad_std,
ctx_get_grad_nan_num ctx_get_grad_nan_num
) )
from khaosz.data.checkpoint import Checkpoint from khaosz.data.serialization import Checkpoint
from khaosz.trainer.train_context import TrainContext from khaosz.trainer.train_context import TrainContext

View File

@ -1,7 +1,7 @@
import torch import torch
import numpy as np import numpy as np
from khaosz.data.file import save_h5 from khaosz.data.serialization import save_h5
from khaosz.data.dataset import * from khaosz.data.dataset import *