refactor: 重构数据模块中的数据集类命名和文件结构
This commit is contained in:
parent
0f518473af
commit
50f76cd7c7
|
|
@ -1,9 +1,9 @@
|
||||||
from khaosz.data.dataset import (
|
from khaosz.data.dataset import (
|
||||||
BaseDataset,
|
BaseDataset,
|
||||||
SeqDataset,
|
SEQDataset,
|
||||||
DpoDataset,
|
DPODataset,
|
||||||
SftDataset,
|
SFTDataset,
|
||||||
PpoDataset,
|
GRPODataset,
|
||||||
MultiSegmentFetcher,
|
MultiSegmentFetcher,
|
||||||
DatasetLoader
|
DatasetLoader
|
||||||
)
|
)
|
||||||
|
|
@ -13,10 +13,10 @@ from khaosz.data.sampler import ResumableDistributedSampler
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseDataset",
|
"BaseDataset",
|
||||||
"SeqDataset",
|
"SEQDataset",
|
||||||
"DpoDataset",
|
"SFTDataset",
|
||||||
"SftDataset",
|
"DPODataset",
|
||||||
"PpoDataset",
|
"GRPODataset",
|
||||||
"MultiSegmentFetcher",
|
"MultiSegmentFetcher",
|
||||||
"DatasetLoader",
|
"DatasetLoader",
|
||||||
"BpeTokenizer",
|
"BpeTokenizer",
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import bisect
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from khaosz.data.file import load_h5
|
from khaosz.data.serialization import load_h5
|
||||||
from typing import Callable, List, Dict, Literal, Optional, Union
|
from typing import Callable, List, Dict, Literal, Optional, Union
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -105,7 +105,7 @@ class BaseDataset(Dataset, ABC):
|
||||||
return (self.total_samples - 1 - self.window_size) // self.stride + 1
|
return (self.total_samples - 1 - self.window_size) // self.stride + 1
|
||||||
|
|
||||||
|
|
||||||
class SeqDataset(BaseDataset):
|
class SEQDataset(BaseDataset):
|
||||||
def __init__(self, window_size: int, stride: int):
|
def __init__(self, window_size: int, stride: int):
|
||||||
super().__init__(window_size, stride)
|
super().__init__(window_size, stride)
|
||||||
self.fetcher = MultiSegmentFetcher(self.segments)
|
self.fetcher = MultiSegmentFetcher(self.segments)
|
||||||
|
|
@ -123,7 +123,7 @@ class SeqDataset(BaseDataset):
|
||||||
return {"input_ids": x, "target_ids": y}
|
return {"input_ids": x, "target_ids": y}
|
||||||
|
|
||||||
|
|
||||||
class SftDataset(BaseDataset):
|
class SFTDataset(BaseDataset):
|
||||||
def __init__(self, window_size: int, stride: int):
|
def __init__(self, window_size: int, stride: int):
|
||||||
super().__init__(window_size, stride)
|
super().__init__(window_size, stride)
|
||||||
self.fetcher = MultiSegmentFetcher(self.segments)
|
self.fetcher = MultiSegmentFetcher(self.segments)
|
||||||
|
|
@ -141,7 +141,7 @@ class SftDataset(BaseDataset):
|
||||||
return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask}
|
return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask}
|
||||||
|
|
||||||
|
|
||||||
class DpoDataset(BaseDataset):
|
class DPODataset(BaseDataset):
|
||||||
def __init__(self, window_size: int, stride: int):
|
def __init__(self, window_size: int, stride: int):
|
||||||
super().__init__(window_size, stride)
|
super().__init__(window_size, stride)
|
||||||
self.fetcher = MultiSegmentFetcher(self.segments)
|
self.fetcher = MultiSegmentFetcher(self.segments)
|
||||||
|
|
@ -160,7 +160,7 @@ class DpoDataset(BaseDataset):
|
||||||
return {"chosen": chosen, "rejected": rejected, "chosen_mask": chosen_mask, "rejected_mask": rejected_mask}
|
return {"chosen": chosen, "rejected": rejected, "chosen_mask": chosen_mask, "rejected_mask": rejected_mask}
|
||||||
|
|
||||||
|
|
||||||
class PpoDataset(BaseDataset):
|
class GRPODataset(BaseDataset):
|
||||||
def __init__(self, window_size: int, stride: int):
|
def __init__(self, window_size: int, stride: int):
|
||||||
super().__init__(window_size, stride)
|
super().__init__(window_size, stride)
|
||||||
self.fetcher = MultiSegmentFetcher(self.segments)
|
self.fetcher = MultiSegmentFetcher(self.segments)
|
||||||
|
|
@ -171,12 +171,12 @@ class PpoDataset(BaseDataset):
|
||||||
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
||||||
begin_idx, end_idx = self.get_index(index)
|
begin_idx, end_idx = self.get_index(index)
|
||||||
|
|
||||||
input_ids = self._fetch_data(begin_idx, end_idx, "input_ids"),
|
prompts = self._fetch_data(begin_idx, end_idx, "prompts"),
|
||||||
actions = self._fetch_data(begin_idx, end_idx, "actions"),
|
responses = self._fetch_data(begin_idx, end_idx, "responses"),
|
||||||
logprobs = self._fetch_data(begin_idx, end_idx, "logprobs"),
|
masks = self._fetch_data(begin_idx, end_idx, "masks"),
|
||||||
rewards = self._fetch_data(begin_idx, end_idx, "rewards")
|
rewards = self._fetch_data(begin_idx, end_idx, "rewards")
|
||||||
|
|
||||||
return {"input_ids": input_ids, "actions": actions, "logprobs": logprobs, "rewards": rewards}
|
return {"prompts": prompts, "responses": responses, "masks": masks, "rewards": rewards}
|
||||||
|
|
||||||
|
|
||||||
class DatasetLoader:
|
class DatasetLoader:
|
||||||
|
|
@ -191,9 +191,10 @@ class DatasetLoader:
|
||||||
stride = window_size
|
stride = window_size
|
||||||
|
|
||||||
dataset_router: Dict[str, Callable[[int], BaseDataset]] = {
|
dataset_router: Dict[str, Callable[[int], BaseDataset]] = {
|
||||||
"seq": lambda window_size: SeqDataset(window_size, stride),
|
"seq": lambda window_size: SEQDataset(window_size, stride),
|
||||||
"sft": lambda window_size: SftDataset(window_size, stride),
|
"sft": lambda window_size: SFTDataset(window_size, stride),
|
||||||
"dpo": lambda window_size: DpoDataset(window_size, stride),
|
"dpo": lambda window_size: DPODataset(window_size, stride),
|
||||||
|
"grpo": lambda window_size: GRPODataset(window_size, stride),
|
||||||
}
|
}
|
||||||
dataset = dataset_router[train_type](window_size)
|
dataset = dataset_router[train_type](window_size)
|
||||||
dataset.load(load_path)
|
dataset.load(load_path)
|
||||||
|
|
|
||||||
|
|
@ -1,42 +0,0 @@
|
||||||
import os
|
|
||||||
import h5py
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
from torch import Tensor
|
|
||||||
from typing import Dict, List
|
|
||||||
|
|
||||||
|
|
||||||
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
|
|
||||||
os.makedirs(file_path, exist_ok=True)
|
|
||||||
full_file_path = os.path.join(file_path, f"{file_name}.h5")
|
|
||||||
with h5py.File(full_file_path, 'w') as f:
|
|
||||||
for key, tensors in tensor_group.items():
|
|
||||||
grp = f.create_group(key)
|
|
||||||
for idx, tensor in enumerate(tensors):
|
|
||||||
arr = tensor.cpu().numpy()
|
|
||||||
grp.create_dataset(f'data_{idx}', data=arr)
|
|
||||||
|
|
||||||
def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
|
|
||||||
tensor_group: Dict[str, List[Tensor]] = {}
|
|
||||||
|
|
||||||
root_path = Path(file_path)
|
|
||||||
h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5"))
|
|
||||||
|
|
||||||
for h5_file in h5_files:
|
|
||||||
with h5py.File(h5_file, 'r') as f:
|
|
||||||
for key in f.keys():
|
|
||||||
grp = f[key]
|
|
||||||
dsets = []
|
|
||||||
for dset_name in grp.keys():
|
|
||||||
dset = grp[dset_name]
|
|
||||||
tensor = torch.from_numpy(dset[:])
|
|
||||||
if share_memory:
|
|
||||||
tensor = tensor.share_memory_()
|
|
||||||
dsets.append(tensor)
|
|
||||||
|
|
||||||
if tensor_group.get(key) is None:
|
|
||||||
tensor_group[key] = []
|
|
||||||
tensor_group[key].extend(dsets)
|
|
||||||
|
|
||||||
return tensor_group
|
|
||||||
|
|
@ -1,11 +1,49 @@
|
||||||
|
import os
|
||||||
|
import h5py
|
||||||
|
import torch
|
||||||
import json
|
import json
|
||||||
import safetensors.torch as st
|
import safetensors.torch as st
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Any
|
from torch import Tensor
|
||||||
|
from typing import Any, Dict, List
|
||||||
from khaosz.parallel.setup import get_rank
|
from khaosz.parallel.setup import get_rank
|
||||||
|
|
||||||
|
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
|
||||||
|
os.makedirs(file_path, exist_ok=True)
|
||||||
|
full_file_path = os.path.join(file_path, f"{file_name}.h5")
|
||||||
|
with h5py.File(full_file_path, 'w') as f:
|
||||||
|
for key, tensors in tensor_group.items():
|
||||||
|
grp = f.create_group(key)
|
||||||
|
for idx, tensor in enumerate(tensors):
|
||||||
|
arr = tensor.cpu().numpy()
|
||||||
|
grp.create_dataset(f'data_{idx}', data=arr)
|
||||||
|
|
||||||
|
def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
|
||||||
|
tensor_group: Dict[str, List[Tensor]] = {}
|
||||||
|
|
||||||
|
root_path = Path(file_path)
|
||||||
|
h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5"))
|
||||||
|
|
||||||
|
for h5_file in h5_files:
|
||||||
|
with h5py.File(h5_file, 'r') as f:
|
||||||
|
for key in f.keys():
|
||||||
|
grp = f[key]
|
||||||
|
dsets = []
|
||||||
|
for dset_name in grp.keys():
|
||||||
|
dset = grp[dset_name]
|
||||||
|
tensor = torch.from_numpy(dset[:])
|
||||||
|
if share_memory:
|
||||||
|
tensor = tensor.share_memory_()
|
||||||
|
dsets.append(tensor)
|
||||||
|
|
||||||
|
if tensor_group.get(key) is None:
|
||||||
|
tensor_group[key] = []
|
||||||
|
tensor_group[key].extend(dsets)
|
||||||
|
|
||||||
|
return tensor_group
|
||||||
|
|
||||||
|
|
||||||
class Checkpoint:
|
class Checkpoint:
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
@ -20,7 +20,7 @@ from khaosz.trainer.metric_util import (
|
||||||
ctx_get_grad_std,
|
ctx_get_grad_std,
|
||||||
ctx_get_grad_nan_num
|
ctx_get_grad_nan_num
|
||||||
)
|
)
|
||||||
from khaosz.data.checkpoint import Checkpoint
|
from khaosz.data.serialization import Checkpoint
|
||||||
from khaosz.trainer.train_context import TrainContext
|
from khaosz.trainer.train_context import TrainContext
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from khaosz.data.file import save_h5
|
from khaosz.data.serialization import save_h5
|
||||||
from khaosz.data.dataset import *
|
from khaosz.data.dataset import *
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue