fix(mmap): 修复样本数与键值计算逻辑并增强错误处理
This commit is contained in:
parent
701fb9bf78
commit
831933fb66
|
|
@ -19,9 +19,10 @@ class MmapFileHander:
|
|||
files like:
|
||||
|
||||
```
|
||||
file_mapper.json
|
||||
file1.bin
|
||||
file2.bin
|
||||
folder_path:
|
||||
- file_mapper.json
|
||||
- file1.bin
|
||||
- file2.bin
|
||||
...
|
||||
|
||||
```
|
||||
|
|
@ -64,9 +65,8 @@ class MmapFileHander:
|
|||
mmap_shared_group[segment_key].append(mmap_tensor)
|
||||
|
||||
num_samples = sum(metadata["size"] for metadata in metadata_list)
|
||||
num_keys = len(set(metadata['key'] for metadata in metadata_list))
|
||||
|
||||
sample_per_key = num_samples / num_keys
|
||||
num_keys = max(len(set(metadata['key'] for metadata in metadata_list)), 1)
|
||||
sample_per_key = num_samples // num_keys
|
||||
|
||||
return mmap_shared_group, sample_per_key
|
||||
|
||||
|
|
@ -77,15 +77,19 @@ class MmapFileHander:
|
|||
metadata_list = []
|
||||
for segment_key, segment_tensors in mmap_shared_group.items():
|
||||
for idx, tensor in enumerate(segment_tensors):
|
||||
|
||||
try:
|
||||
with open(os.path.join(save_path, f"{segment_key}_{idx}.bin"), "wb") as f:
|
||||
f.write(tensor.cpu().numpy().tobytes())
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error saving tensor: {e}")
|
||||
|
||||
metadata_list.append({
|
||||
"file_name": f"{segment_key}_{idx}.bin",
|
||||
"size": tensor.numel(),
|
||||
"dtype": MmapFileHander.REVERSE_DTYPE_MAP[tensor.dtype],
|
||||
"key": segment_key
|
||||
})
|
||||
file_path = os.path.join(save_path, f"{segment_key}_{idx}.bin")
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(tensor.cpu().numpy().tobytes())
|
||||
|
||||
metadata_path = os.path.join(save_path, "file_mapper.json")
|
||||
with open(metadata_path, "w") as f:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import os
|
||||
import json
|
||||
import pytest
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
|
@ -193,7 +194,7 @@ def test_multi_segment_fetcher(base_test_env):
|
|||
dataset_path = create_mmap_dataset(test_dir, dummy_data, "multi_segment_test")
|
||||
|
||||
# Load the memory mapped files directly
|
||||
multi_segments, _ = load_mmap_files(dataset_path)
|
||||
multi_segments, _ = MmapFileHander.load(dataset_path)
|
||||
|
||||
# Create fetcher
|
||||
fetcher = MultiSegmentFetcher(multi_segments)
|
||||
|
|
@ -214,3 +215,88 @@ def test_multi_segment_fetcher(base_test_env):
|
|||
all_data = fetcher.fetch_data(0, 10)
|
||||
assert "sequence" in all_data
|
||||
assert "mask" in all_data
|
||||
|
||||
|
||||
def test_mmap_file_handler_direct(base_test_env):
|
||||
"""Test MmapFileHander directly without DatasetLoader"""
|
||||
test_dir = base_test_env["test_dir"]
|
||||
|
||||
# Create test data with multiple segments
|
||||
seq_length1 = 100
|
||||
seq_length2 = 200
|
||||
|
||||
# Create data in the format expected by MmapFileHander
|
||||
dummy_data = {
|
||||
"sequence": [
|
||||
torch.randint(0, 1000, (seq_length1,), dtype=torch.int64),
|
||||
torch.randint(0, 1000, (seq_length2,), dtype=torch.int64)
|
||||
],
|
||||
"mask": [
|
||||
torch.ones(seq_length1, dtype=torch.bool),
|
||||
torch.ones(seq_length2, dtype=torch.bool)
|
||||
]
|
||||
}
|
||||
|
||||
# Save data using MmapFileHander
|
||||
dataset_dir = os.path.join(test_dir, "mmap_direct_test")
|
||||
MmapFileHander.save(dataset_dir, dummy_data)
|
||||
|
||||
# Load data using MmapFileHander
|
||||
loaded_data, num_samples = MmapFileHander.load(dataset_dir)
|
||||
|
||||
# Verify data structure
|
||||
assert set(loaded_data.keys()) == set(dummy_data.keys())
|
||||
assert num_samples == seq_length1 + seq_length2 # 300
|
||||
|
||||
# Verify data content
|
||||
for key in dummy_data:
|
||||
assert len(loaded_data[key]) == len(dummy_data[key])
|
||||
for i in range(len(dummy_data[key])):
|
||||
assert torch.equal(loaded_data[key][i], dummy_data[key][i])
|
||||
|
||||
def test_mmap_file_handler_dtypes(base_test_env):
|
||||
"""Test MmapFileHander with different data types"""
|
||||
test_dir = base_test_env["test_dir"]
|
||||
|
||||
# Create test data with different dtypes
|
||||
data = {
|
||||
"float32": [torch.randn(100, dtype=torch.float32)],
|
||||
"float64": [torch.randn(100, dtype=torch.float64)],
|
||||
"int32": [torch.randint(0, 1000, (100,), dtype=torch.int32)],
|
||||
"int64": [torch.randint(0, 1000, (100,), dtype=torch.int64)],
|
||||
"bool": [torch.randint(0, 2, (100,), dtype=torch.bool)]
|
||||
}
|
||||
|
||||
# Save data
|
||||
dataset_dir = os.path.join(test_dir, "dtype_test")
|
||||
MmapFileHander.save(dataset_dir, data)
|
||||
|
||||
# Load data
|
||||
loaded_data, _ = MmapFileHander.load(dataset_dir)
|
||||
|
||||
# Verify data types
|
||||
for key in data:
|
||||
assert loaded_data[key][0].dtype == data[key][0].dtype
|
||||
assert torch.equal(loaded_data[key][0], data[key][0])
|
||||
|
||||
def test_mmap_file_handler_error_handling(base_test_env):
|
||||
"""Test MmapFileHander error handling"""
|
||||
test_dir = base_test_env["test_dir"]
|
||||
|
||||
# Test loading without file_mapper.json
|
||||
empty_dir = os.path.join(test_dir, "empty_dir")
|
||||
os.makedirs(empty_dir, exist_ok=True)
|
||||
with pytest.raises(FileNotFoundError):
|
||||
MmapFileHander.load(empty_dir)
|
||||
|
||||
# Test loading with invalid file_mapper.json
|
||||
invalid_dir = os.path.join(test_dir, "invalid_dir")
|
||||
os.makedirs(invalid_dir, exist_ok=True)
|
||||
|
||||
# Create empty file_mapper.json
|
||||
with open(os.path.join(invalid_dir, "file_mapper.json"), "w") as f:
|
||||
json.dump([{"file_name": "file1.bin", "size": 1000, "dtype": "float32", "key": "key1"}], f)
|
||||
|
||||
# This should raise FileNotFoundError because no binary files exist
|
||||
with pytest.raises(FileNotFoundError):
|
||||
MmapFileHander.load(invalid_dir)
|
||||
Loading…
Reference in New Issue