refactor: 重构数据模块中的数据集类命名和文件结构
This commit is contained in:
parent
0f518473af
commit
50f76cd7c7
|
|
@ -1,9 +1,9 @@
|
|||
from khaosz.data.dataset import (
|
||||
BaseDataset,
|
||||
SeqDataset,
|
||||
DpoDataset,
|
||||
SftDataset,
|
||||
PpoDataset,
|
||||
SEQDataset,
|
||||
DPODataset,
|
||||
SFTDataset,
|
||||
GRPODataset,
|
||||
MultiSegmentFetcher,
|
||||
DatasetLoader
|
||||
)
|
||||
|
|
@ -13,10 +13,10 @@ from khaosz.data.sampler import ResumableDistributedSampler
|
|||
|
||||
__all__ = [
|
||||
"BaseDataset",
|
||||
"SeqDataset",
|
||||
"DpoDataset",
|
||||
"SftDataset",
|
||||
"PpoDataset",
|
||||
"SEQDataset",
|
||||
"SFTDataset",
|
||||
"DPODataset",
|
||||
"GRPODataset",
|
||||
"MultiSegmentFetcher",
|
||||
"DatasetLoader",
|
||||
"BpeTokenizer",
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import bisect
|
|||
from abc import ABC, abstractmethod
|
||||
from torch import Tensor
|
||||
from torch.utils.data import Dataset
|
||||
from khaosz.data.file import load_h5
|
||||
from khaosz.data.serialization import load_h5
|
||||
from typing import Callable, List, Dict, Literal, Optional, Union
|
||||
|
||||
|
||||
|
|
@ -105,7 +105,7 @@ class BaseDataset(Dataset, ABC):
|
|||
return (self.total_samples - 1 - self.window_size) // self.stride + 1
|
||||
|
||||
|
||||
class SeqDataset(BaseDataset):
|
||||
class SEQDataset(BaseDataset):
|
||||
def __init__(self, window_size: int, stride: int):
|
||||
super().__init__(window_size, stride)
|
||||
self.fetcher = MultiSegmentFetcher(self.segments)
|
||||
|
|
@ -123,7 +123,7 @@ class SeqDataset(BaseDataset):
|
|||
return {"input_ids": x, "target_ids": y}
|
||||
|
||||
|
||||
class SftDataset(BaseDataset):
|
||||
class SFTDataset(BaseDataset):
|
||||
def __init__(self, window_size: int, stride: int):
|
||||
super().__init__(window_size, stride)
|
||||
self.fetcher = MultiSegmentFetcher(self.segments)
|
||||
|
|
@ -141,7 +141,7 @@ class SftDataset(BaseDataset):
|
|||
return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask}
|
||||
|
||||
|
||||
class DpoDataset(BaseDataset):
|
||||
class DPODataset(BaseDataset):
|
||||
def __init__(self, window_size: int, stride: int):
|
||||
super().__init__(window_size, stride)
|
||||
self.fetcher = MultiSegmentFetcher(self.segments)
|
||||
|
|
@ -160,7 +160,7 @@ class DpoDataset(BaseDataset):
|
|||
return {"chosen": chosen, "rejected": rejected, "chosen_mask": chosen_mask, "rejected_mask": rejected_mask}
|
||||
|
||||
|
||||
class PpoDataset(BaseDataset):
|
||||
class GRPODataset(BaseDataset):
|
||||
def __init__(self, window_size: int, stride: int):
|
||||
super().__init__(window_size, stride)
|
||||
self.fetcher = MultiSegmentFetcher(self.segments)
|
||||
|
|
@ -171,12 +171,12 @@ class PpoDataset(BaseDataset):
|
|||
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
||||
begin_idx, end_idx = self.get_index(index)
|
||||
|
||||
input_ids = self._fetch_data(begin_idx, end_idx, "input_ids"),
|
||||
actions = self._fetch_data(begin_idx, end_idx, "actions"),
|
||||
logprobs = self._fetch_data(begin_idx, end_idx, "logprobs"),
|
||||
prompts = self._fetch_data(begin_idx, end_idx, "prompts"),
|
||||
responses = self._fetch_data(begin_idx, end_idx, "responses"),
|
||||
masks = self._fetch_data(begin_idx, end_idx, "masks"),
|
||||
rewards = self._fetch_data(begin_idx, end_idx, "rewards")
|
||||
|
||||
return {"input_ids": input_ids, "actions": actions, "logprobs": logprobs, "rewards": rewards}
|
||||
return {"prompts": prompts, "responses": responses, "masks": masks, "rewards": rewards}
|
||||
|
||||
|
||||
class DatasetLoader:
|
||||
|
|
@ -191,9 +191,10 @@ class DatasetLoader:
|
|||
stride = window_size
|
||||
|
||||
dataset_router: Dict[str, Callable[[int], BaseDataset]] = {
|
||||
"seq": lambda window_size: SeqDataset(window_size, stride),
|
||||
"sft": lambda window_size: SftDataset(window_size, stride),
|
||||
"dpo": lambda window_size: DpoDataset(window_size, stride),
|
||||
"seq": lambda window_size: SEQDataset(window_size, stride),
|
||||
"sft": lambda window_size: SFTDataset(window_size, stride),
|
||||
"dpo": lambda window_size: DPODataset(window_size, stride),
|
||||
"grpo": lambda window_size: GRPODataset(window_size, stride),
|
||||
}
|
||||
dataset = dataset_router[train_type](window_size)
|
||||
dataset.load(load_path)
|
||||
|
|
|
|||
|
|
@ -1,42 +0,0 @@
|
|||
import os
|
||||
import h5py
|
||||
import torch
|
||||
|
||||
from pathlib import Path
|
||||
from torch import Tensor
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
|
||||
os.makedirs(file_path, exist_ok=True)
|
||||
full_file_path = os.path.join(file_path, f"{file_name}.h5")
|
||||
with h5py.File(full_file_path, 'w') as f:
|
||||
for key, tensors in tensor_group.items():
|
||||
grp = f.create_group(key)
|
||||
for idx, tensor in enumerate(tensors):
|
||||
arr = tensor.cpu().numpy()
|
||||
grp.create_dataset(f'data_{idx}', data=arr)
|
||||
|
||||
def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
|
||||
tensor_group: Dict[str, List[Tensor]] = {}
|
||||
|
||||
root_path = Path(file_path)
|
||||
h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5"))
|
||||
|
||||
for h5_file in h5_files:
|
||||
with h5py.File(h5_file, 'r') as f:
|
||||
for key in f.keys():
|
||||
grp = f[key]
|
||||
dsets = []
|
||||
for dset_name in grp.keys():
|
||||
dset = grp[dset_name]
|
||||
tensor = torch.from_numpy(dset[:])
|
||||
if share_memory:
|
||||
tensor = tensor.share_memory_()
|
||||
dsets.append(tensor)
|
||||
|
||||
if tensor_group.get(key) is None:
|
||||
tensor_group[key] = []
|
||||
tensor_group[key].extend(dsets)
|
||||
|
||||
return tensor_group
|
||||
|
|
@ -1,11 +1,49 @@
|
|||
import os
|
||||
import h5py
|
||||
import torch
|
||||
import json
|
||||
import safetensors.torch as st
|
||||
import torch.distributed as dist
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any
|
||||
from torch import Tensor
|
||||
from typing import Any, Dict, List
|
||||
from khaosz.parallel.setup import get_rank
|
||||
|
||||
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
|
||||
os.makedirs(file_path, exist_ok=True)
|
||||
full_file_path = os.path.join(file_path, f"{file_name}.h5")
|
||||
with h5py.File(full_file_path, 'w') as f:
|
||||
for key, tensors in tensor_group.items():
|
||||
grp = f.create_group(key)
|
||||
for idx, tensor in enumerate(tensors):
|
||||
arr = tensor.cpu().numpy()
|
||||
grp.create_dataset(f'data_{idx}', data=arr)
|
||||
|
||||
def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
|
||||
tensor_group: Dict[str, List[Tensor]] = {}
|
||||
|
||||
root_path = Path(file_path)
|
||||
h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5"))
|
||||
|
||||
for h5_file in h5_files:
|
||||
with h5py.File(h5_file, 'r') as f:
|
||||
for key in f.keys():
|
||||
grp = f[key]
|
||||
dsets = []
|
||||
for dset_name in grp.keys():
|
||||
dset = grp[dset_name]
|
||||
tensor = torch.from_numpy(dset[:])
|
||||
if share_memory:
|
||||
tensor = tensor.share_memory_()
|
||||
dsets.append(tensor)
|
||||
|
||||
if tensor_group.get(key) is None:
|
||||
tensor_group[key] = []
|
||||
tensor_group[key].extend(dsets)
|
||||
|
||||
return tensor_group
|
||||
|
||||
|
||||
class Checkpoint:
|
||||
def __init__(
|
||||
|
|
@ -20,7 +20,7 @@ from khaosz.trainer.metric_util import (
|
|||
ctx_get_grad_std,
|
||||
ctx_get_grad_nan_num
|
||||
)
|
||||
from khaosz.data.checkpoint import Checkpoint
|
||||
from khaosz.data.serialization import Checkpoint
|
||||
from khaosz.trainer.train_context import TrainContext
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
|
||||
from khaosz.data.file import save_h5
|
||||
from khaosz.data.serialization import save_h5
|
||||
from khaosz.data.dataset import *
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue