fix(mmap): 修复样本数与键值计算逻辑并增强错误处理
This commit is contained in:
parent
701fb9bf78
commit
831933fb66
|
|
@ -19,9 +19,10 @@ class MmapFileHander:
|
||||||
files like:
|
files like:
|
||||||
|
|
||||||
```
|
```
|
||||||
file_mapper.json
|
folder_path:
|
||||||
file1.bin
|
- file_mapper.json
|
||||||
file2.bin
|
- file1.bin
|
||||||
|
- file2.bin
|
||||||
...
|
...
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
@ -64,9 +65,8 @@ class MmapFileHander:
|
||||||
mmap_shared_group[segment_key].append(mmap_tensor)
|
mmap_shared_group[segment_key].append(mmap_tensor)
|
||||||
|
|
||||||
num_samples = sum(metadata["size"] for metadata in metadata_list)
|
num_samples = sum(metadata["size"] for metadata in metadata_list)
|
||||||
num_keys = len(set(metadata['key'] for metadata in metadata_list))
|
num_keys = max(len(set(metadata['key'] for metadata in metadata_list)), 1)
|
||||||
|
sample_per_key = num_samples // num_keys
|
||||||
sample_per_key = num_samples / num_keys
|
|
||||||
|
|
||||||
return mmap_shared_group, sample_per_key
|
return mmap_shared_group, sample_per_key
|
||||||
|
|
||||||
|
|
@ -77,15 +77,19 @@ class MmapFileHander:
|
||||||
metadata_list = []
|
metadata_list = []
|
||||||
for segment_key, segment_tensors in mmap_shared_group.items():
|
for segment_key, segment_tensors in mmap_shared_group.items():
|
||||||
for idx, tensor in enumerate(segment_tensors):
|
for idx, tensor in enumerate(segment_tensors):
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(os.path.join(save_path, f"{segment_key}_{idx}.bin"), "wb") as f:
|
||||||
|
f.write(tensor.cpu().numpy().tobytes())
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Error saving tensor: {e}")
|
||||||
|
|
||||||
metadata_list.append({
|
metadata_list.append({
|
||||||
"file_name": f"{segment_key}_{idx}.bin",
|
"file_name": f"{segment_key}_{idx}.bin",
|
||||||
"size": tensor.numel(),
|
"size": tensor.numel(),
|
||||||
"dtype": MmapFileHander.REVERSE_DTYPE_MAP[tensor.dtype],
|
"dtype": MmapFileHander.REVERSE_DTYPE_MAP[tensor.dtype],
|
||||||
"key": segment_key
|
"key": segment_key
|
||||||
})
|
})
|
||||||
file_path = os.path.join(save_path, f"{segment_key}_{idx}.bin")
|
|
||||||
with open(file_path, "wb") as f:
|
|
||||||
f.write(tensor.cpu().numpy().tobytes())
|
|
||||||
|
|
||||||
metadata_path = os.path.join(save_path, "file_mapper.json")
|
metadata_path = os.path.join(save_path, "file_mapper.json")
|
||||||
with open(metadata_path, "w") as f:
|
with open(metadata_path, "w") as f:
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
@ -193,7 +194,7 @@ def test_multi_segment_fetcher(base_test_env):
|
||||||
dataset_path = create_mmap_dataset(test_dir, dummy_data, "multi_segment_test")
|
dataset_path = create_mmap_dataset(test_dir, dummy_data, "multi_segment_test")
|
||||||
|
|
||||||
# Load the memory mapped files directly
|
# Load the memory mapped files directly
|
||||||
multi_segments, _ = load_mmap_files(dataset_path)
|
multi_segments, _ = MmapFileHander.load(dataset_path)
|
||||||
|
|
||||||
# Create fetcher
|
# Create fetcher
|
||||||
fetcher = MultiSegmentFetcher(multi_segments)
|
fetcher = MultiSegmentFetcher(multi_segments)
|
||||||
|
|
@ -213,4 +214,89 @@ def test_multi_segment_fetcher(base_test_env):
|
||||||
# Test fetching all keys
|
# Test fetching all keys
|
||||||
all_data = fetcher.fetch_data(0, 10)
|
all_data = fetcher.fetch_data(0, 10)
|
||||||
assert "sequence" in all_data
|
assert "sequence" in all_data
|
||||||
assert "mask" in all_data
|
assert "mask" in all_data
|
||||||
|
|
||||||
|
|
||||||
|
def test_mmap_file_handler_direct(base_test_env):
|
||||||
|
"""Test MmapFileHander directly without DatasetLoader"""
|
||||||
|
test_dir = base_test_env["test_dir"]
|
||||||
|
|
||||||
|
# Create test data with multiple segments
|
||||||
|
seq_length1 = 100
|
||||||
|
seq_length2 = 200
|
||||||
|
|
||||||
|
# Create data in the format expected by MmapFileHander
|
||||||
|
dummy_data = {
|
||||||
|
"sequence": [
|
||||||
|
torch.randint(0, 1000, (seq_length1,), dtype=torch.int64),
|
||||||
|
torch.randint(0, 1000, (seq_length2,), dtype=torch.int64)
|
||||||
|
],
|
||||||
|
"mask": [
|
||||||
|
torch.ones(seq_length1, dtype=torch.bool),
|
||||||
|
torch.ones(seq_length2, dtype=torch.bool)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
# Save data using MmapFileHander
|
||||||
|
dataset_dir = os.path.join(test_dir, "mmap_direct_test")
|
||||||
|
MmapFileHander.save(dataset_dir, dummy_data)
|
||||||
|
|
||||||
|
# Load data using MmapFileHander
|
||||||
|
loaded_data, num_samples = MmapFileHander.load(dataset_dir)
|
||||||
|
|
||||||
|
# Verify data structure
|
||||||
|
assert set(loaded_data.keys()) == set(dummy_data.keys())
|
||||||
|
assert num_samples == seq_length1 + seq_length2 # 300
|
||||||
|
|
||||||
|
# Verify data content
|
||||||
|
for key in dummy_data:
|
||||||
|
assert len(loaded_data[key]) == len(dummy_data[key])
|
||||||
|
for i in range(len(dummy_data[key])):
|
||||||
|
assert torch.equal(loaded_data[key][i], dummy_data[key][i])
|
||||||
|
|
||||||
|
def test_mmap_file_handler_dtypes(base_test_env):
|
||||||
|
"""Test MmapFileHander with different data types"""
|
||||||
|
test_dir = base_test_env["test_dir"]
|
||||||
|
|
||||||
|
# Create test data with different dtypes
|
||||||
|
data = {
|
||||||
|
"float32": [torch.randn(100, dtype=torch.float32)],
|
||||||
|
"float64": [torch.randn(100, dtype=torch.float64)],
|
||||||
|
"int32": [torch.randint(0, 1000, (100,), dtype=torch.int32)],
|
||||||
|
"int64": [torch.randint(0, 1000, (100,), dtype=torch.int64)],
|
||||||
|
"bool": [torch.randint(0, 2, (100,), dtype=torch.bool)]
|
||||||
|
}
|
||||||
|
|
||||||
|
# Save data
|
||||||
|
dataset_dir = os.path.join(test_dir, "dtype_test")
|
||||||
|
MmapFileHander.save(dataset_dir, data)
|
||||||
|
|
||||||
|
# Load data
|
||||||
|
loaded_data, _ = MmapFileHander.load(dataset_dir)
|
||||||
|
|
||||||
|
# Verify data types
|
||||||
|
for key in data:
|
||||||
|
assert loaded_data[key][0].dtype == data[key][0].dtype
|
||||||
|
assert torch.equal(loaded_data[key][0], data[key][0])
|
||||||
|
|
||||||
|
def test_mmap_file_handler_error_handling(base_test_env):
|
||||||
|
"""Test MmapFileHander error handling"""
|
||||||
|
test_dir = base_test_env["test_dir"]
|
||||||
|
|
||||||
|
# Test loading without file_mapper.json
|
||||||
|
empty_dir = os.path.join(test_dir, "empty_dir")
|
||||||
|
os.makedirs(empty_dir, exist_ok=True)
|
||||||
|
with pytest.raises(FileNotFoundError):
|
||||||
|
MmapFileHander.load(empty_dir)
|
||||||
|
|
||||||
|
# Test loading with invalid file_mapper.json
|
||||||
|
invalid_dir = os.path.join(test_dir, "invalid_dir")
|
||||||
|
os.makedirs(invalid_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Create empty file_mapper.json
|
||||||
|
with open(os.path.join(invalid_dir, "file_mapper.json"), "w") as f:
|
||||||
|
json.dump([{"file_name": "file1.bin", "size": 1000, "dtype": "float32", "key": "key1"}], f)
|
||||||
|
|
||||||
|
# This should raise FileNotFoundError because no binary files exist
|
||||||
|
with pytest.raises(FileNotFoundError):
|
||||||
|
MmapFileHander.load(invalid_dir)
|
||||||
Loading…
Reference in New Issue