diff --git a/khaosz/data/dataset.py b/khaosz/data/dataset.py index 4a18373..b769777 100644 --- a/khaosz/data/dataset.py +++ b/khaosz/data/dataset.py @@ -1,93 +1,16 @@ -import os -import json import torch import bisect from abc import ABC, abstractmethod from torch import Tensor from torch.utils.data import Dataset -from typing import Callable, List, Dict, Literal, Optional, Tuple, Union +from khaosz.data.mmap import MmapFileHander +from typing import Callable, List, Dict, Literal, Optional, Union Seg = List[Tensor] MultiSeg = Dict[str, Seg] -def load_mmap_files(root_path: str, shared: bool=True) -> Tuple[MultiSeg, int]: - """Load memory-mapped binary files as torch tensors. - - Loads configuration from file_mapper.json in the specified directory, then loads - corresponding binary files as memory-mapped tensors. Returns tensors grouped by key - and total number of elements. - - json metadata like this: - - ``` - [ - { - "file_name": "file1.bin", - "size": 1000, - "dtype": "float32", - "key": "key1" - }, - ... - ] - ``` - - Args: - root_path: Root directory path containing file_mapper.json and binary files - shared: Whether to load tensors in shared mode. If True, tensors can be - shared between processes - - Raises: - FileNotFoundError: If file_mapper.json or any binary file in config is missing - KeyError: If dtype in config is not in supported DTYPE_MAP - json.JSONDecodeError: If config file is not valid JSON - - Returns: - Tuple containing: - - MultiSeg: Dictionary of tensors grouped by key, structure: {key: [tensor1, tensor2, ...]} - - int: Total number of elements across all tensors - """ - - DTYPE_MAP = { - "float32": torch.float32, - "float64": torch.float64, - "int32": torch.int32, - "int64": torch.int64, - "bool": torch.bool, - } - - metadata_list = [] - mmap_shared_group: MultiSeg = {} - - file_mapper_path = os.path.join(root_path, "file_mapper.json") - if not os.path.exists(file_mapper_path): - raise FileNotFoundError(f"File mapper not found: {file_mapper_path}") - - with open(file_mapper_path, "r") as f: - metadata_list = json.load(f) - - for metadata in metadata_list: - file_path = os.path.join(root_path, metadata["file_name"]) - if not os.path.exists(file_path): - raise FileNotFoundError(f"Binary data file not found: {file_path}") - - size = metadata["size"] - dtype = DTYPE_MAP[metadata["dtype"]] - segment_key = metadata["key"] - mmap_tensor = torch.from_file(file_path, shared=shared, size=size, dtype=dtype) - - if segment_key not in mmap_shared_group: - mmap_shared_group[segment_key] = [] - - mmap_shared_group[segment_key].append(mmap_tensor) - - num_samples = sum(metadata["size"] for metadata in metadata_list - if segment_key == metadata["key"]) - - return mmap_shared_group, num_samples - - class BaseSegmentFetcher: def __init__(self, segments: Seg): self.segments = segments @@ -151,7 +74,7 @@ class BaseDataset(Dataset, ABC): self.total_samples = None def load(self, load_path: str): - self.segments, self.total_samples = load_mmap_files(load_path) + self.segments, self.total_samples = MmapFileHander.load(load_path) self.fetcher = MultiSegmentFetcher(self.segments) def get_index(self, index: int) -> int: diff --git a/khaosz/data/mmap.py b/khaosz/data/mmap.py new file mode 100644 index 0000000..842a78b --- /dev/null +++ b/khaosz/data/mmap.py @@ -0,0 +1,92 @@ +import os +import json +import torch + +from torch import Tensor +from typing import List, Dict, Tuple + +class MmapFileHander: + """ + json metadata like this: + + ``` + [ + {"file_name": "file1.bin", "size": 1000, "dtype": "float32", "key": "key1"}, + {"file_name": "file2.bin", "size": 2000, "dtype": "float32", "key": "key2"} + ... + ] + ``` + files like: + + ``` + file_mapper.json + file1.bin + file2.bin + ... + + ``` + """ + + DTYPE_MAP = { + "float32": torch.float32, + "float64": torch.float64, + "int32": torch.int32, + "int64": torch.int64, + "bool": torch.bool, + } + REVERSE_DTYPE_MAP = {v: k for k, v in DTYPE_MAP.items()} + + @staticmethod + def load(root_path: str, shared: bool=True) -> Tuple[Dict[str, List[Tensor]], int]: + metadata_list = [] + mmap_shared_group: Dict[str, List[Tensor]] = {} + + file_mapper_path = os.path.join(root_path, "file_mapper.json") + if not os.path.exists(file_mapper_path): + raise FileNotFoundError(f"File mapper not found: {file_mapper_path}") + + with open(file_mapper_path, "r") as f: + metadata_list = json.load(f) + + for metadata in metadata_list: + file_path = os.path.join(root_path, metadata["file_name"]) + if not os.path.exists(file_path): + raise FileNotFoundError(f"Binary data file not found: {file_path}") + + size = metadata["size"] + dtype = MmapFileHander.DTYPE_MAP[metadata["dtype"]] + segment_key = metadata["key"] + mmap_tensor = torch.from_file(file_path, shared=shared, size=size, dtype=dtype) + + if segment_key not in mmap_shared_group: + mmap_shared_group[segment_key] = [] + + mmap_shared_group[segment_key].append(mmap_tensor) + + num_samples = sum(metadata["size"] for metadata in metadata_list) + num_keys = len(set(metadata['key'] for metadata in metadata_list)) + + sample_per_key = num_samples / num_keys + + return mmap_shared_group, sample_per_key + + @staticmethod + def save(save_path: str, mmap_shared_group: Dict[str, List[Tensor]]) -> None: + os.makedirs(save_path, exist_ok=True) + + metadata_list = [] + for segment_key, segment_tensors in mmap_shared_group.items(): + for idx, tensor in enumerate(segment_tensors): + metadata_list.append({ + "file_name": f"{segment_key}_{idx}.bin", + "size": tensor.numel(), + "dtype": MmapFileHander.REVERSE_DTYPE_MAP[tensor.dtype], + "key": segment_key + }) + file_path = os.path.join(save_path, f"{segment_key}_{idx}.bin") + with open(file_path, "wb") as f: + f.write(tensor.cpu().numpy().tobytes()) + + metadata_path = os.path.join(save_path, "file_mapper.json") + with open(metadata_path, "w") as f: + json.dump(metadata_list, f)