fix(mmap): 修复样本数与键值计算逻辑并增强错误处理

This commit is contained in:
ViperEkura 2025-12-15 09:27:29 +08:00
parent 701fb9bf78
commit 831933fb66
2 changed files with 101 additions and 11 deletions

View File

@ -19,9 +19,10 @@ class MmapFileHander:
files like:
```
file_mapper.json
file1.bin
file2.bin
folder_path:
- file_mapper.json
- file1.bin
- file2.bin
...
```
@ -64,9 +65,8 @@ class MmapFileHander:
mmap_shared_group[segment_key].append(mmap_tensor)
num_samples = sum(metadata["size"] for metadata in metadata_list)
num_keys = len(set(metadata['key'] for metadata in metadata_list))
sample_per_key = num_samples / num_keys
num_keys = max(len(set(metadata['key'] for metadata in metadata_list)), 1)
sample_per_key = num_samples // num_keys
return mmap_shared_group, sample_per_key
@ -77,15 +77,19 @@ class MmapFileHander:
metadata_list = []
for segment_key, segment_tensors in mmap_shared_group.items():
for idx, tensor in enumerate(segment_tensors):
try:
with open(os.path.join(save_path, f"{segment_key}_{idx}.bin"), "wb") as f:
f.write(tensor.cpu().numpy().tobytes())
except Exception as e:
raise RuntimeError(f"Error saving tensor: {e}")
metadata_list.append({
"file_name": f"{segment_key}_{idx}.bin",
"size": tensor.numel(),
"dtype": MmapFileHander.REVERSE_DTYPE_MAP[tensor.dtype],
"key": segment_key
})
file_path = os.path.join(save_path, f"{segment_key}_{idx}.bin")
with open(file_path, "wb") as f:
f.write(tensor.cpu().numpy().tobytes())
metadata_path = os.path.join(save_path, "file_mapper.json")
with open(metadata_path, "w") as f:

View File

@ -1,5 +1,6 @@
import os
import json
import pytest
import torch
import numpy as np
@ -193,7 +194,7 @@ def test_multi_segment_fetcher(base_test_env):
dataset_path = create_mmap_dataset(test_dir, dummy_data, "multi_segment_test")
# Load the memory mapped files directly
multi_segments, _ = load_mmap_files(dataset_path)
multi_segments, _ = MmapFileHander.load(dataset_path)
# Create fetcher
fetcher = MultiSegmentFetcher(multi_segments)
@ -213,4 +214,89 @@ def test_multi_segment_fetcher(base_test_env):
# Test fetching all keys
all_data = fetcher.fetch_data(0, 10)
assert "sequence" in all_data
assert "mask" in all_data
assert "mask" in all_data
def test_mmap_file_handler_direct(base_test_env):
"""Test MmapFileHander directly without DatasetLoader"""
test_dir = base_test_env["test_dir"]
# Create test data with multiple segments
seq_length1 = 100
seq_length2 = 200
# Create data in the format expected by MmapFileHander
dummy_data = {
"sequence": [
torch.randint(0, 1000, (seq_length1,), dtype=torch.int64),
torch.randint(0, 1000, (seq_length2,), dtype=torch.int64)
],
"mask": [
torch.ones(seq_length1, dtype=torch.bool),
torch.ones(seq_length2, dtype=torch.bool)
]
}
# Save data using MmapFileHander
dataset_dir = os.path.join(test_dir, "mmap_direct_test")
MmapFileHander.save(dataset_dir, dummy_data)
# Load data using MmapFileHander
loaded_data, num_samples = MmapFileHander.load(dataset_dir)
# Verify data structure
assert set(loaded_data.keys()) == set(dummy_data.keys())
assert num_samples == seq_length1 + seq_length2 # 300
# Verify data content
for key in dummy_data:
assert len(loaded_data[key]) == len(dummy_data[key])
for i in range(len(dummy_data[key])):
assert torch.equal(loaded_data[key][i], dummy_data[key][i])
def test_mmap_file_handler_dtypes(base_test_env):
"""Test MmapFileHander with different data types"""
test_dir = base_test_env["test_dir"]
# Create test data with different dtypes
data = {
"float32": [torch.randn(100, dtype=torch.float32)],
"float64": [torch.randn(100, dtype=torch.float64)],
"int32": [torch.randint(0, 1000, (100,), dtype=torch.int32)],
"int64": [torch.randint(0, 1000, (100,), dtype=torch.int64)],
"bool": [torch.randint(0, 2, (100,), dtype=torch.bool)]
}
# Save data
dataset_dir = os.path.join(test_dir, "dtype_test")
MmapFileHander.save(dataset_dir, data)
# Load data
loaded_data, _ = MmapFileHander.load(dataset_dir)
# Verify data types
for key in data:
assert loaded_data[key][0].dtype == data[key][0].dtype
assert torch.equal(loaded_data[key][0], data[key][0])
def test_mmap_file_handler_error_handling(base_test_env):
"""Test MmapFileHander error handling"""
test_dir = base_test_env["test_dir"]
# Test loading without file_mapper.json
empty_dir = os.path.join(test_dir, "empty_dir")
os.makedirs(empty_dir, exist_ok=True)
with pytest.raises(FileNotFoundError):
MmapFileHander.load(empty_dir)
# Test loading with invalid file_mapper.json
invalid_dir = os.path.join(test_dir, "invalid_dir")
os.makedirs(invalid_dir, exist_ok=True)
# Create empty file_mapper.json
with open(os.path.join(invalid_dir, "file_mapper.json"), "w") as f:
json.dump([{"file_name": "file1.bin", "size": 1000, "dtype": "float32", "key": "key1"}], f)
# This should raise FileNotFoundError because no binary files exist
with pytest.raises(FileNotFoundError):
MmapFileHander.load(invalid_dir)