refactor: 重构数据模块中的数据集类命名和文件结构

This commit is contained in:
ViperEkura 2026-03-19 22:37:32 +08:00
parent 0f518473af
commit 50f76cd7c7
6 changed files with 62 additions and 65 deletions

View File

@ -1,9 +1,9 @@
from khaosz.data.dataset import (
BaseDataset,
SeqDataset,
DpoDataset,
SftDataset,
PpoDataset,
SEQDataset,
DPODataset,
SFTDataset,
GRPODataset,
MultiSegmentFetcher,
DatasetLoader
)
@ -13,10 +13,10 @@ from khaosz.data.sampler import ResumableDistributedSampler
__all__ = [
"BaseDataset",
"SeqDataset",
"DpoDataset",
"SftDataset",
"PpoDataset",
"SEQDataset",
"SFTDataset",
"DPODataset",
"GRPODataset",
"MultiSegmentFetcher",
"DatasetLoader",
"BpeTokenizer",

View File

@ -4,7 +4,7 @@ import bisect
from abc import ABC, abstractmethod
from torch import Tensor
from torch.utils.data import Dataset
from khaosz.data.file import load_h5
from khaosz.data.serialization import load_h5
from typing import Callable, List, Dict, Literal, Optional, Union
@ -105,7 +105,7 @@ class BaseDataset(Dataset, ABC):
return (self.total_samples - 1 - self.window_size) // self.stride + 1
class SeqDataset(BaseDataset):
class SEQDataset(BaseDataset):
def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride)
self.fetcher = MultiSegmentFetcher(self.segments)
@ -123,7 +123,7 @@ class SeqDataset(BaseDataset):
return {"input_ids": x, "target_ids": y}
class SftDataset(BaseDataset):
class SFTDataset(BaseDataset):
def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride)
self.fetcher = MultiSegmentFetcher(self.segments)
@ -141,7 +141,7 @@ class SftDataset(BaseDataset):
return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask}
class DpoDataset(BaseDataset):
class DPODataset(BaseDataset):
def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride)
self.fetcher = MultiSegmentFetcher(self.segments)
@ -160,7 +160,7 @@ class DpoDataset(BaseDataset):
return {"chosen": chosen, "rejected": rejected, "chosen_mask": chosen_mask, "rejected_mask": rejected_mask}
class PpoDataset(BaseDataset):
class GRPODataset(BaseDataset):
def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride)
self.fetcher = MultiSegmentFetcher(self.segments)
@ -171,12 +171,12 @@ class PpoDataset(BaseDataset):
def __getitem__(self, index: int) -> Dict[str, Tensor]:
begin_idx, end_idx = self.get_index(index)
input_ids = self._fetch_data(begin_idx, end_idx, "input_ids"),
actions = self._fetch_data(begin_idx, end_idx, "actions"),
logprobs = self._fetch_data(begin_idx, end_idx, "logprobs"),
prompts = self._fetch_data(begin_idx, end_idx, "prompts"),
responses = self._fetch_data(begin_idx, end_idx, "responses"),
masks = self._fetch_data(begin_idx, end_idx, "masks"),
rewards = self._fetch_data(begin_idx, end_idx, "rewards")
return {"input_ids": input_ids, "actions": actions, "logprobs": logprobs, "rewards": rewards}
return {"prompts": prompts, "responses": responses, "masks": masks, "rewards": rewards}
class DatasetLoader:
@ -191,9 +191,10 @@ class DatasetLoader:
stride = window_size
dataset_router: Dict[str, Callable[[int], BaseDataset]] = {
"seq": lambda window_size: SeqDataset(window_size, stride),
"sft": lambda window_size: SftDataset(window_size, stride),
"dpo": lambda window_size: DpoDataset(window_size, stride),
"seq": lambda window_size: SEQDataset(window_size, stride),
"sft": lambda window_size: SFTDataset(window_size, stride),
"dpo": lambda window_size: DPODataset(window_size, stride),
"grpo": lambda window_size: GRPODataset(window_size, stride),
}
dataset = dataset_router[train_type](window_size)
dataset.load(load_path)

View File

@ -1,42 +0,0 @@
import os
import h5py
import torch
from pathlib import Path
from torch import Tensor
from typing import Dict, List
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
os.makedirs(file_path, exist_ok=True)
full_file_path = os.path.join(file_path, f"{file_name}.h5")
with h5py.File(full_file_path, 'w') as f:
for key, tensors in tensor_group.items():
grp = f.create_group(key)
for idx, tensor in enumerate(tensors):
arr = tensor.cpu().numpy()
grp.create_dataset(f'data_{idx}', data=arr)
def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
tensor_group: Dict[str, List[Tensor]] = {}
root_path = Path(file_path)
h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5"))
for h5_file in h5_files:
with h5py.File(h5_file, 'r') as f:
for key in f.keys():
grp = f[key]
dsets = []
for dset_name in grp.keys():
dset = grp[dset_name]
tensor = torch.from_numpy(dset[:])
if share_memory:
tensor = tensor.share_memory_()
dsets.append(tensor)
if tensor_group.get(key) is None:
tensor_group[key] = []
tensor_group[key].extend(dsets)
return tensor_group

View File

@ -1,11 +1,49 @@
import os
import h5py
import torch
import json
import safetensors.torch as st
import torch.distributed as dist
from pathlib import Path
from typing import Dict, Any
from torch import Tensor
from typing import Any, Dict, List
from khaosz.parallel.setup import get_rank
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
os.makedirs(file_path, exist_ok=True)
full_file_path = os.path.join(file_path, f"{file_name}.h5")
with h5py.File(full_file_path, 'w') as f:
for key, tensors in tensor_group.items():
grp = f.create_group(key)
for idx, tensor in enumerate(tensors):
arr = tensor.cpu().numpy()
grp.create_dataset(f'data_{idx}', data=arr)
def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
tensor_group: Dict[str, List[Tensor]] = {}
root_path = Path(file_path)
h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5"))
for h5_file in h5_files:
with h5py.File(h5_file, 'r') as f:
for key in f.keys():
grp = f[key]
dsets = []
for dset_name in grp.keys():
dset = grp[dset_name]
tensor = torch.from_numpy(dset[:])
if share_memory:
tensor = tensor.share_memory_()
dsets.append(tensor)
if tensor_group.get(key) is None:
tensor_group[key] = []
tensor_group[key].extend(dsets)
return tensor_group
class Checkpoint:
def __init__(

View File

@ -20,7 +20,7 @@ from khaosz.trainer.metric_util import (
ctx_get_grad_std,
ctx_get_grad_nan_num
)
from khaosz.data.checkpoint import Checkpoint
from khaosz.data.serialization import Checkpoint
from khaosz.trainer.train_context import TrainContext

View File

@ -1,7 +1,7 @@
import torch
import numpy as np
from khaosz.data.file import save_h5
from khaosz.data.serialization import save_h5
from khaosz.data.dataset import *