diff --git a/khaosz/data/__init__.py b/khaosz/data/__init__.py index 8882a1b..4d1930c 100644 --- a/khaosz/data/__init__.py +++ b/khaosz/data/__init__.py @@ -4,7 +4,7 @@ from khaosz.data.dataset import ( DpoDataset, SftDataset, PpoDataset, - MutiSegmentFetcher, + MultiSegmentFetcher, DatasetLoader, load_pkl_files, ) @@ -18,7 +18,7 @@ __all__ = [ "DpoDataset", "SftDataset", "PpoDataset", - "MutiSegmentFetcher", + "MultiSegmentFetcher", "DatasetLoader", "load_pkl_files", "BpeTokenizer", diff --git a/khaosz/data/dataset.py b/khaosz/data/dataset.py index 42bb5dd..8bb1b28 100644 --- a/khaosz/data/dataset.py +++ b/khaosz/data/dataset.py @@ -1,34 +1,80 @@ +import os +import json import torch import bisect -import pickle as pkl from abc import ABC, abstractmethod from torch import Tensor from torch.utils.data import Dataset -from typing import Callable, List, Dict, Literal, Optional, Union - -MutiSeg = Dict[str, List[Tensor]] -Seg = Dict[str, Tensor] +from typing import Callable, List, Dict, Literal, Optional, Tuple, Union -def load_pkl_files(paths: List[str]): - segments: MutiSeg = {} - total_samples = 0 +Seg = List[Tensor] +MultiSeg = Dict[str, Seg] - for path in paths: - with open(path, "rb") as f: - pkl_file: Seg = pkl.load(f) - for key, value in pkl_file.items(): - if key not in segments: - segments[key] = [] - segments[key].append(value) - first_key = list(pkl_file.keys())[0] - total_samples += pkl_file[first_key].numel() + +def load_mmap_files(root_path: str, shared: bool=True) -> Tuple[MultiSeg, int]: + """Load memory-mapped binary files as torch tensors. - return segments, total_samples + Loads configuration from file_mapper.json in the specified directory, then loads + corresponding binary files as memory-mapped tensors. Returns tensors grouped by key + and total number of elements. + + Args: + root_path: Root directory path containing file_mapper.json and binary files + shared: Whether to load tensors in shared mode. If True, tensors can be + shared between processes + + Raises: + FileNotFoundError: If file_mapper.json or any binary file in config is missing + KeyError: If dtype in config is not in supported DTYPE_MAP + json.JSONDecodeError: If config file is not valid JSON + + Returns: + Tuple containing: + - MultiSeg: Dictionary of tensors grouped by key, structure: {key: [tensor1, tensor2, ...]} + - int: Total number of elements across all tensors + """ + + DTYPE_MAP = { + "float32": torch.float32, + "float64": torch.float64, + "int32": torch.int32, + "int64": torch.int64, + "bool": torch.bool, + } + + metadata_list = [] + mmap_shared_group: MultiSeg = {} + + file_mapper_path = os.path.join(root_path, "file_mapper.json") + if not os.path.exists(file_mapper_path): + raise FileNotFoundError(f"File mapper not found: {file_mapper_path}") + + with open(file_mapper_path, "r") as f: + metadata_list = json.load(f) + + num_samples = sum(metadata["size"] for metadata in metadata_list) + + for metadata in metadata_list: + file_path = os.path.join(root_path, metadata["file_name"]) + if not os.path.exists(file_path): + raise FileNotFoundError(f"Binary data file not found: {file_path}") + + size = metadata["size"] + dtype = DTYPE_MAP[metadata["dtype"]] + segment_key = metadata["key"] + mmap_tensor = torch.from_file(file_path, shared=shared, size=size, dtype=dtype) + + if segment_key not in mmap_shared_group: + mmap_shared_group[segment_key] = [] + + mmap_shared_group[segment_key].append(mmap_tensor) + + return mmap_shared_group, num_samples class BaseSegmentFetcher: - def __init__(self, segments: List[Tensor]): + def __init__(self, segments: Seg): self.segments = segments self.cum_lengths = [] total = 0 @@ -58,8 +104,8 @@ class BaseSegmentFetcher: return torch.cat(result_segments, dim=0) -class MutiSegmentFetcher: - def __init__(self, muti_segments: MutiSeg): +class MultiSegmentFetcher: + def __init__(self, muti_segments: MultiSeg): self.muti_keys = list(muti_segments.keys()) self.muti_fetchers = { key: BaseSegmentFetcher(segments) @@ -82,29 +128,17 @@ class MutiSegmentFetcher: class BaseDataset(Dataset, ABC): - def __init__(self, window_size: int, stride: int): + def __init__(self, window_size: int, stride: int, share_memory: bool=False): super().__init__() - self.segments: MutiSeg = {} + self.segments: MultiSeg = {} self.window_size = window_size self.stride = stride self.total_samples = None - def save(self, save_path: str): - keys = list(self.segments.keys()) - if not keys: - return - - first_item = self.segments[keys[0]] - segment_size = len(first_item) - - for i in range(segment_size): - formated_segment = {key: self.segments[key][i] for key in keys} - pkl.dump(formated_segment, open(f"{save_path}_{i}.pkl", "wb")) - def load(self, load_path: Union[str, List[str]]): paths = [load_path] if isinstance(load_path, str) else load_path - self.segments, self.total_samples = load_pkl_files(paths) - self.fetcher = MutiSegmentFetcher(self.segments) + self.segments, self.total_samples = load_mmap_files(paths) + self.fetcher = MultiSegmentFetcher(self.segments) def get_index(self, index: int) -> int: begin_idx = min(index * self.stride, self.total_samples - self.window_size - 1) @@ -126,7 +160,7 @@ class BaseDataset(Dataset, ABC): class SeqDataset(BaseDataset): def __init__(self, window_size: int, stride: int): super().__init__(window_size, stride) - self.fetcher = MutiSegmentFetcher(self.segments) + self.fetcher = MultiSegmentFetcher(self.segments) def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor: return self.fetcher.key_fetch(begin_idx, end_idx, "sequence") @@ -144,7 +178,7 @@ class SeqDataset(BaseDataset): class SftDataset(BaseDataset): def __init__(self, window_size: int, stride: int): super().__init__(window_size, stride) - self.fetcher = MutiSegmentFetcher(self.segments) + self.fetcher = MultiSegmentFetcher(self.segments) def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: return self.fetcher.key_fetch(begin_idx, end_idx, key) @@ -162,7 +196,7 @@ class SftDataset(BaseDataset): class DpoDataset(BaseDataset): def __init__(self, window_size: int, stride: int): super().__init__(window_size, stride) - self.fetcher = MutiSegmentFetcher(self.segments) + self.fetcher = MultiSegmentFetcher(self.segments) def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: return self.fetcher.key_fetch(begin_idx, end_idx, key) @@ -181,7 +215,7 @@ class DpoDataset(BaseDataset): class PpoDataset(BaseDataset): def __init__(self, window_size: int, stride: int): super().__init__(window_size, stride) - self.fetcher = MutiSegmentFetcher(self.segments) + self.fetcher = MultiSegmentFetcher(self.segments) def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: return self.fetcher.key_fetch(begin_idx, end_idx, key)