From 831933fb6618326d46e5571fd7e3560961d2b903 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Mon, 15 Dec 2025 09:27:29 +0800 Subject: [PATCH] =?UTF-8?q?fix(mmap):=20=E4=BF=AE=E5=A4=8D=E6=A0=B7?= =?UTF-8?q?=E6=9C=AC=E6=95=B0=E4=B8=8E=E9=94=AE=E5=80=BC=E8=AE=A1=E7=AE=97?= =?UTF-8?q?=E9=80=BB=E8=BE=91=E5=B9=B6=E5=A2=9E=E5=BC=BA=E9=94=99=E8=AF=AF?= =?UTF-8?q?=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/data/mmap.py | 22 +++-- ...test_dataset_loader.py => test_dataset.py} | 90 ++++++++++++++++++- 2 files changed, 101 insertions(+), 11 deletions(-) rename tests/{test_dataset_loader.py => test_dataset.py} (68%) diff --git a/khaosz/data/mmap.py b/khaosz/data/mmap.py index 842a78b..6920e47 100644 --- a/khaosz/data/mmap.py +++ b/khaosz/data/mmap.py @@ -19,9 +19,10 @@ class MmapFileHander: files like: ``` - file_mapper.json - file1.bin - file2.bin + folder_path: + - file_mapper.json + - file1.bin + - file2.bin ... ``` @@ -64,9 +65,8 @@ class MmapFileHander: mmap_shared_group[segment_key].append(mmap_tensor) num_samples = sum(metadata["size"] for metadata in metadata_list) - num_keys = len(set(metadata['key'] for metadata in metadata_list)) - - sample_per_key = num_samples / num_keys + num_keys = max(len(set(metadata['key'] for metadata in metadata_list)), 1) + sample_per_key = num_samples // num_keys return mmap_shared_group, sample_per_key @@ -77,15 +77,19 @@ class MmapFileHander: metadata_list = [] for segment_key, segment_tensors in mmap_shared_group.items(): for idx, tensor in enumerate(segment_tensors): + + try: + with open(os.path.join(save_path, f"{segment_key}_{idx}.bin"), "wb") as f: + f.write(tensor.cpu().numpy().tobytes()) + except Exception as e: + raise RuntimeError(f"Error saving tensor: {e}") + metadata_list.append({ "file_name": f"{segment_key}_{idx}.bin", "size": tensor.numel(), "dtype": MmapFileHander.REVERSE_DTYPE_MAP[tensor.dtype], "key": segment_key }) - file_path = os.path.join(save_path, f"{segment_key}_{idx}.bin") - with open(file_path, "wb") as f: - f.write(tensor.cpu().numpy().tobytes()) metadata_path = os.path.join(save_path, "file_mapper.json") with open(metadata_path, "w") as f: diff --git a/tests/test_dataset_loader.py b/tests/test_dataset.py similarity index 68% rename from tests/test_dataset_loader.py rename to tests/test_dataset.py index 3187fe4..1c12d21 100644 --- a/tests/test_dataset_loader.py +++ b/tests/test_dataset.py @@ -1,5 +1,6 @@ import os import json +import pytest import torch import numpy as np @@ -193,7 +194,7 @@ def test_multi_segment_fetcher(base_test_env): dataset_path = create_mmap_dataset(test_dir, dummy_data, "multi_segment_test") # Load the memory mapped files directly - multi_segments, _ = load_mmap_files(dataset_path) + multi_segments, _ = MmapFileHander.load(dataset_path) # Create fetcher fetcher = MultiSegmentFetcher(multi_segments) @@ -213,4 +214,89 @@ def test_multi_segment_fetcher(base_test_env): # Test fetching all keys all_data = fetcher.fetch_data(0, 10) assert "sequence" in all_data - assert "mask" in all_data \ No newline at end of file + assert "mask" in all_data + + +def test_mmap_file_handler_direct(base_test_env): + """Test MmapFileHander directly without DatasetLoader""" + test_dir = base_test_env["test_dir"] + + # Create test data with multiple segments + seq_length1 = 100 + seq_length2 = 200 + + # Create data in the format expected by MmapFileHander + dummy_data = { + "sequence": [ + torch.randint(0, 1000, (seq_length1,), dtype=torch.int64), + torch.randint(0, 1000, (seq_length2,), dtype=torch.int64) + ], + "mask": [ + torch.ones(seq_length1, dtype=torch.bool), + torch.ones(seq_length2, dtype=torch.bool) + ] + } + + # Save data using MmapFileHander + dataset_dir = os.path.join(test_dir, "mmap_direct_test") + MmapFileHander.save(dataset_dir, dummy_data) + + # Load data using MmapFileHander + loaded_data, num_samples = MmapFileHander.load(dataset_dir) + + # Verify data structure + assert set(loaded_data.keys()) == set(dummy_data.keys()) + assert num_samples == seq_length1 + seq_length2 # 300 + + # Verify data content + for key in dummy_data: + assert len(loaded_data[key]) == len(dummy_data[key]) + for i in range(len(dummy_data[key])): + assert torch.equal(loaded_data[key][i], dummy_data[key][i]) + +def test_mmap_file_handler_dtypes(base_test_env): + """Test MmapFileHander with different data types""" + test_dir = base_test_env["test_dir"] + + # Create test data with different dtypes + data = { + "float32": [torch.randn(100, dtype=torch.float32)], + "float64": [torch.randn(100, dtype=torch.float64)], + "int32": [torch.randint(0, 1000, (100,), dtype=torch.int32)], + "int64": [torch.randint(0, 1000, (100,), dtype=torch.int64)], + "bool": [torch.randint(0, 2, (100,), dtype=torch.bool)] + } + + # Save data + dataset_dir = os.path.join(test_dir, "dtype_test") + MmapFileHander.save(dataset_dir, data) + + # Load data + loaded_data, _ = MmapFileHander.load(dataset_dir) + + # Verify data types + for key in data: + assert loaded_data[key][0].dtype == data[key][0].dtype + assert torch.equal(loaded_data[key][0], data[key][0]) + +def test_mmap_file_handler_error_handling(base_test_env): + """Test MmapFileHander error handling""" + test_dir = base_test_env["test_dir"] + + # Test loading without file_mapper.json + empty_dir = os.path.join(test_dir, "empty_dir") + os.makedirs(empty_dir, exist_ok=True) + with pytest.raises(FileNotFoundError): + MmapFileHander.load(empty_dir) + + # Test loading with invalid file_mapper.json + invalid_dir = os.path.join(test_dir, "invalid_dir") + os.makedirs(invalid_dir, exist_ok=True) + + # Create empty file_mapper.json + with open(os.path.join(invalid_dir, "file_mapper.json"), "w") as f: + json.dump([{"file_name": "file1.bin", "size": 1000, "dtype": "float32", "key": "key1"}], f) + + # This should raise FileNotFoundError because no binary files exist + with pytest.raises(FileNotFoundError): + MmapFileHander.load(invalid_dir)