diff --git a/khaosz/data/__init__.py b/khaosz/data/__init__.py index 38164f7..44648a8 100644 --- a/khaosz/data/__init__.py +++ b/khaosz/data/__init__.py @@ -1,9 +1,9 @@ from khaosz.data.dataset import ( BaseDataset, - SeqDataset, - DpoDataset, - SftDataset, - PpoDataset, + SEQDataset, + DPODataset, + SFTDataset, + GRPODataset, MultiSegmentFetcher, DatasetLoader ) @@ -13,10 +13,10 @@ from khaosz.data.sampler import ResumableDistributedSampler __all__ = [ "BaseDataset", - "SeqDataset", - "DpoDataset", - "SftDataset", - "PpoDataset", + "SEQDataset", + "SFTDataset", + "DPODataset", + "GRPODataset", "MultiSegmentFetcher", "DatasetLoader", "BpeTokenizer", diff --git a/khaosz/data/dataset.py b/khaosz/data/dataset.py index c15bc54..b4e7fd6 100644 --- a/khaosz/data/dataset.py +++ b/khaosz/data/dataset.py @@ -4,7 +4,7 @@ import bisect from abc import ABC, abstractmethod from torch import Tensor from torch.utils.data import Dataset -from khaosz.data.file import load_h5 +from khaosz.data.serialization import load_h5 from typing import Callable, List, Dict, Literal, Optional, Union @@ -105,7 +105,7 @@ class BaseDataset(Dataset, ABC): return (self.total_samples - 1 - self.window_size) // self.stride + 1 -class SeqDataset(BaseDataset): +class SEQDataset(BaseDataset): def __init__(self, window_size: int, stride: int): super().__init__(window_size, stride) self.fetcher = MultiSegmentFetcher(self.segments) @@ -123,7 +123,7 @@ class SeqDataset(BaseDataset): return {"input_ids": x, "target_ids": y} -class SftDataset(BaseDataset): +class SFTDataset(BaseDataset): def __init__(self, window_size: int, stride: int): super().__init__(window_size, stride) self.fetcher = MultiSegmentFetcher(self.segments) @@ -141,7 +141,7 @@ class SftDataset(BaseDataset): return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask} -class DpoDataset(BaseDataset): +class DPODataset(BaseDataset): def __init__(self, window_size: int, stride: int): super().__init__(window_size, stride) self.fetcher = MultiSegmentFetcher(self.segments) @@ -160,7 +160,7 @@ class DpoDataset(BaseDataset): return {"chosen": chosen, "rejected": rejected, "chosen_mask": chosen_mask, "rejected_mask": rejected_mask} -class PpoDataset(BaseDataset): +class GRPODataset(BaseDataset): def __init__(self, window_size: int, stride: int): super().__init__(window_size, stride) self.fetcher = MultiSegmentFetcher(self.segments) @@ -171,12 +171,12 @@ class PpoDataset(BaseDataset): def __getitem__(self, index: int) -> Dict[str, Tensor]: begin_idx, end_idx = self.get_index(index) - input_ids = self._fetch_data(begin_idx, end_idx, "input_ids"), - actions = self._fetch_data(begin_idx, end_idx, "actions"), - logprobs = self._fetch_data(begin_idx, end_idx, "logprobs"), + prompts = self._fetch_data(begin_idx, end_idx, "prompts"), + responses = self._fetch_data(begin_idx, end_idx, "responses"), + masks = self._fetch_data(begin_idx, end_idx, "masks"), rewards = self._fetch_data(begin_idx, end_idx, "rewards") - return {"input_ids": input_ids, "actions": actions, "logprobs": logprobs, "rewards": rewards} + return {"prompts": prompts, "responses": responses, "masks": masks, "rewards": rewards} class DatasetLoader: @@ -191,9 +191,10 @@ class DatasetLoader: stride = window_size dataset_router: Dict[str, Callable[[int], BaseDataset]] = { - "seq": lambda window_size: SeqDataset(window_size, stride), - "sft": lambda window_size: SftDataset(window_size, stride), - "dpo": lambda window_size: DpoDataset(window_size, stride), + "seq": lambda window_size: SEQDataset(window_size, stride), + "sft": lambda window_size: SFTDataset(window_size, stride), + "dpo": lambda window_size: DPODataset(window_size, stride), + "grpo": lambda window_size: GRPODataset(window_size, stride), } dataset = dataset_router[train_type](window_size) dataset.load(load_path) diff --git a/khaosz/data/file.py b/khaosz/data/file.py deleted file mode 100644 index e6d5535..0000000 --- a/khaosz/data/file.py +++ /dev/null @@ -1,42 +0,0 @@ -import os -import h5py -import torch - -from pathlib import Path -from torch import Tensor -from typing import Dict, List - - -def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]): - os.makedirs(file_path, exist_ok=True) - full_file_path = os.path.join(file_path, f"{file_name}.h5") - with h5py.File(full_file_path, 'w') as f: - for key, tensors in tensor_group.items(): - grp = f.create_group(key) - for idx, tensor in enumerate(tensors): - arr = tensor.cpu().numpy() - grp.create_dataset(f'data_{idx}', data=arr) - -def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]: - tensor_group: Dict[str, List[Tensor]] = {} - - root_path = Path(file_path) - h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5")) - - for h5_file in h5_files: - with h5py.File(h5_file, 'r') as f: - for key in f.keys(): - grp = f[key] - dsets = [] - for dset_name in grp.keys(): - dset = grp[dset_name] - tensor = torch.from_numpy(dset[:]) - if share_memory: - tensor = tensor.share_memory_() - dsets.append(tensor) - - if tensor_group.get(key) is None: - tensor_group[key] = [] - tensor_group[key].extend(dsets) - - return tensor_group \ No newline at end of file diff --git a/khaosz/data/checkpoint.py b/khaosz/data/serialization.py similarity index 52% rename from khaosz/data/checkpoint.py rename to khaosz/data/serialization.py index 272aaf4..3258172 100644 --- a/khaosz/data/checkpoint.py +++ b/khaosz/data/serialization.py @@ -1,11 +1,49 @@ +import os +import h5py +import torch import json import safetensors.torch as st import torch.distributed as dist from pathlib import Path -from typing import Dict, Any +from torch import Tensor +from typing import Any, Dict, List from khaosz.parallel.setup import get_rank +def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]): + os.makedirs(file_path, exist_ok=True) + full_file_path = os.path.join(file_path, f"{file_name}.h5") + with h5py.File(full_file_path, 'w') as f: + for key, tensors in tensor_group.items(): + grp = f.create_group(key) + for idx, tensor in enumerate(tensors): + arr = tensor.cpu().numpy() + grp.create_dataset(f'data_{idx}', data=arr) + +def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]: + tensor_group: Dict[str, List[Tensor]] = {} + + root_path = Path(file_path) + h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5")) + + for h5_file in h5_files: + with h5py.File(h5_file, 'r') as f: + for key in f.keys(): + grp = f[key] + dsets = [] + for dset_name in grp.keys(): + dset = grp[dset_name] + tensor = torch.from_numpy(dset[:]) + if share_memory: + tensor = tensor.share_memory_() + dsets.append(tensor) + + if tensor_group.get(key) is None: + tensor_group[key] = [] + tensor_group[key].extend(dsets) + + return tensor_group + class Checkpoint: def __init__( diff --git a/khaosz/trainer/train_callback.py b/khaosz/trainer/train_callback.py index 5fa9719..6b12b0d 100644 --- a/khaosz/trainer/train_callback.py +++ b/khaosz/trainer/train_callback.py @@ -20,7 +20,7 @@ from khaosz.trainer.metric_util import ( ctx_get_grad_std, ctx_get_grad_nan_num ) -from khaosz.data.checkpoint import Checkpoint +from khaosz.data.serialization import Checkpoint from khaosz.trainer.train_context import TrainContext diff --git a/tests/data/test_dataset.py b/tests/data/test_dataset.py index d12aeb9..a9aac47 100644 --- a/tests/data/test_dataset.py +++ b/tests/data/test_dataset.py @@ -1,7 +1,7 @@ import torch import numpy as np -from khaosz.data.file import save_h5 +from khaosz.data.serialization import save_h5 from khaosz.data.dataset import *