diff --git a/khaosz/data/dataset.py b/khaosz/data/dataset.py index b769777..ad253ac 100644 --- a/khaosz/data/dataset.py +++ b/khaosz/data/dataset.py @@ -4,7 +4,7 @@ import bisect from abc import ABC, abstractmethod from torch import Tensor from torch.utils.data import Dataset -from khaosz.data.mmap import MmapFileHander +from khaosz.data.mmap import MmapFileHandler from typing import Callable, List, Dict, Literal, Optional, Union Seg = List[Tensor] @@ -74,7 +74,7 @@ class BaseDataset(Dataset, ABC): self.total_samples = None def load(self, load_path: str): - self.segments, self.total_samples = MmapFileHander.load(load_path) + self.segments, self.total_samples = MmapFileHandler.load(load_path) self.fetcher = MultiSegmentFetcher(self.segments) def get_index(self, index: int) -> int: diff --git a/khaosz/data/mmap.py b/khaosz/data/mmap.py index 6920e47..d9d4bd6 100644 --- a/khaosz/data/mmap.py +++ b/khaosz/data/mmap.py @@ -5,14 +5,14 @@ import torch from torch import Tensor from typing import List, Dict, Tuple -class MmapFileHander: +class MmapFileHandler: """ json metadata like this: ``` [ - {"file_name": "file1.bin", "size": 1000, "dtype": "float32", "key": "key1"}, - {"file_name": "file2.bin", "size": 2000, "dtype": "float32", "key": "key2"} + {"file_name": "file1.bin", "size": 1000, "key": "key1"}, + {"file_name": "file2.bin", "size": 2000, "key": "key2"} ... ] ``` @@ -20,29 +20,20 @@ class MmapFileHander: ``` folder_path: - - file_mapper.json + - metadata.json - file1.bin - file2.bin ... - ``` """ - - DTYPE_MAP = { - "float32": torch.float32, - "float64": torch.float64, - "int32": torch.int32, - "int64": torch.int64, - "bool": torch.bool, - } - REVERSE_DTYPE_MAP = {v: k for k, v in DTYPE_MAP.items()} + META_DATA = "metadata.json" @staticmethod def load(root_path: str, shared: bool=True) -> Tuple[Dict[str, List[Tensor]], int]: metadata_list = [] - mmap_shared_group: Dict[str, List[Tensor]] = {} + tensor_group: Dict[str, List[Tensor]] = {} - file_mapper_path = os.path.join(root_path, "file_mapper.json") + file_mapper_path = os.path.join(root_path, MmapFileHandler.META_DATA) if not os.path.exists(file_mapper_path): raise FileNotFoundError(f"File mapper not found: {file_mapper_path}") @@ -50,25 +41,20 @@ class MmapFileHander: metadata_list = json.load(f) for metadata in metadata_list: - file_path = os.path.join(root_path, metadata["file_name"]) - if not os.path.exists(file_path): - raise FileNotFoundError(f"Binary data file not found: {file_path}") - - size = metadata["size"] - dtype = MmapFileHander.DTYPE_MAP[metadata["dtype"]] - segment_key = metadata["key"] - mmap_tensor = torch.from_file(file_path, shared=shared, size=size, dtype=dtype) - - if segment_key not in mmap_shared_group: - mmap_shared_group[segment_key] = [] + file_key = metadata["key"] + file_name = metadata["file_name"] + file_path = os.path.join(root_path, file_name) + elm = torch.load(file_path, map_location="cpu", mmap=shared) - mmap_shared_group[segment_key].append(mmap_tensor) - + if file_key not in tensor_group: + tensor_group[file_key] = [] + tensor_group[file_key].append(elm) + num_samples = sum(metadata["size"] for metadata in metadata_list) num_keys = max(len(set(metadata['key'] for metadata in metadata_list)), 1) sample_per_key = num_samples // num_keys - return mmap_shared_group, sample_per_key + return tensor_group, sample_per_key @staticmethod def save(save_path: str, mmap_shared_group: Dict[str, List[Tensor]]) -> None: @@ -79,18 +65,18 @@ class MmapFileHander: for idx, tensor in enumerate(segment_tensors): try: - with open(os.path.join(save_path, f"{segment_key}_{idx}.bin"), "wb") as f: - f.write(tensor.cpu().numpy().tobytes()) + with open(os.path.join(save_path, f"{segment_key}_{idx}.pt"), "wb") as f: + torch.save(tensor.contiguous().cpu(), f) except Exception as e: raise RuntimeError(f"Error saving tensor: {e}") metadata_list.append({ - "file_name": f"{segment_key}_{idx}.bin", + "file_name": f"{segment_key}_{idx}.pt", "size": tensor.numel(), - "dtype": MmapFileHander.REVERSE_DTYPE_MAP[tensor.dtype], "key": segment_key }) - metadata_path = os.path.join(save_path, "file_mapper.json") + metadata_path = os.path.join(save_path, MmapFileHandler.META_DATA) + with open(metadata_path, "w") as f: - json.dump(metadata_list, f) + json.dump(metadata_list, f) \ No newline at end of file diff --git a/tests/data/test_dataset.py b/tests/data/test_dataset.py index 1c12d21..419c3ad 100644 --- a/tests/data/test_dataset.py +++ b/tests/data/test_dataset.py @@ -17,23 +17,21 @@ def create_mmap_dataset(dir_path, data_dict, dataset_name): for key, tensor in data_dict.items(): # Convert tensor to numpy array and save as binary file - np_array = tensor.numpy() - file_name = f"{key}.bin" + file_name = f"{key}.pt" file_path = os.path.join(dataset_dir, file_name) - # Save as binary file - np_array.tofile(file_path) + with open(file_path, "wb") as f: + torch.save(tensor, f) # Add to file mapper file_mapper.append({ "file_name": file_name, - "size": len(np_array), - "dtype": str(np_array.dtype), + "size": tensor.numel(), "key": key }) # Save file mapper - mapper_path = os.path.join(dataset_dir, "file_mapper.json") + mapper_path = os.path.join(dataset_dir, "metadata.json") with open(mapper_path, "w") as f: json.dump(file_mapper, f, indent=2) @@ -194,7 +192,7 @@ def test_multi_segment_fetcher(base_test_env): dataset_path = create_mmap_dataset(test_dir, dummy_data, "multi_segment_test") # Load the memory mapped files directly - multi_segments, _ = MmapFileHander.load(dataset_path) + multi_segments, _ = MmapFileHandler.load(dataset_path) # Create fetcher fetcher = MultiSegmentFetcher(multi_segments) @@ -218,14 +216,14 @@ def test_multi_segment_fetcher(base_test_env): def test_mmap_file_handler_direct(base_test_env): - """Test MmapFileHander directly without DatasetLoader""" + """Test MmapFileHandler directly without DatasetLoader""" test_dir = base_test_env["test_dir"] # Create test data with multiple segments seq_length1 = 100 seq_length2 = 200 - # Create data in the format expected by MmapFileHander + # Create data in the format expected by MmapFileHandler dummy_data = { "sequence": [ torch.randint(0, 1000, (seq_length1,), dtype=torch.int64), @@ -237,12 +235,12 @@ def test_mmap_file_handler_direct(base_test_env): ] } - # Save data using MmapFileHander + # Save data using MmapFileHandler dataset_dir = os.path.join(test_dir, "mmap_direct_test") - MmapFileHander.save(dataset_dir, dummy_data) + MmapFileHandler.save(dataset_dir, dummy_data) - # Load data using MmapFileHander - loaded_data, num_samples = MmapFileHander.load(dataset_dir) + # Load data using MmapFileHandler + loaded_data, num_samples = MmapFileHandler.load(dataset_dir) # Verify data structure assert set(loaded_data.keys()) == set(dummy_data.keys()) @@ -255,7 +253,7 @@ def test_mmap_file_handler_direct(base_test_env): assert torch.equal(loaded_data[key][i], dummy_data[key][i]) def test_mmap_file_handler_dtypes(base_test_env): - """Test MmapFileHander with different data types""" + """Test MmapFileHandler with different data types""" test_dir = base_test_env["test_dir"] # Create test data with different dtypes @@ -269,10 +267,10 @@ def test_mmap_file_handler_dtypes(base_test_env): # Save data dataset_dir = os.path.join(test_dir, "dtype_test") - MmapFileHander.save(dataset_dir, data) + MmapFileHandler.save(dataset_dir, data) # Load data - loaded_data, _ = MmapFileHander.load(dataset_dir) + loaded_data, _ = MmapFileHandler.load(dataset_dir) # Verify data types for key in data: @@ -280,14 +278,14 @@ def test_mmap_file_handler_dtypes(base_test_env): assert torch.equal(loaded_data[key][0], data[key][0]) def test_mmap_file_handler_error_handling(base_test_env): - """Test MmapFileHander error handling""" + """Test MmapFileHandler error handling""" test_dir = base_test_env["test_dir"] # Test loading without file_mapper.json empty_dir = os.path.join(test_dir, "empty_dir") os.makedirs(empty_dir, exist_ok=True) with pytest.raises(FileNotFoundError): - MmapFileHander.load(empty_dir) + MmapFileHandler.load(empty_dir) # Test loading with invalid file_mapper.json invalid_dir = os.path.join(test_dir, "invalid_dir") @@ -299,4 +297,4 @@ def test_mmap_file_handler_error_handling(base_test_env): # This should raise FileNotFoundError because no binary files exist with pytest.raises(FileNotFoundError): - MmapFileHander.load(invalid_dir) + MmapFileHandler.load(invalid_dir)