refactor(data): 重构MmapFileHandler类并改进数据加载机制

This commit is contained in:
ViperEkura 2026-01-11 19:37:28 +08:00
parent 9dab96c31f
commit 7dfa5cc0ac
3 changed files with 42 additions and 58 deletions

View File

@ -4,7 +4,7 @@ import bisect
from abc import ABC, abstractmethod
from torch import Tensor
from torch.utils.data import Dataset
from khaosz.data.mmap import MmapFileHander
from khaosz.data.mmap import MmapFileHandler
from typing import Callable, List, Dict, Literal, Optional, Union
Seg = List[Tensor]
@ -74,7 +74,7 @@ class BaseDataset(Dataset, ABC):
self.total_samples = None
def load(self, load_path: str):
self.segments, self.total_samples = MmapFileHander.load(load_path)
self.segments, self.total_samples = MmapFileHandler.load(load_path)
self.fetcher = MultiSegmentFetcher(self.segments)
def get_index(self, index: int) -> int:

View File

@ -5,14 +5,14 @@ import torch
from torch import Tensor
from typing import List, Dict, Tuple
class MmapFileHander:
class MmapFileHandler:
"""
json metadata like this:
```
[
{"file_name": "file1.bin", "size": 1000, "dtype": "float32", "key": "key1"},
{"file_name": "file2.bin", "size": 2000, "dtype": "float32", "key": "key2"}
{"file_name": "file1.bin", "size": 1000, "key": "key1"},
{"file_name": "file2.bin", "size": 2000, "key": "key2"}
...
]
```
@ -20,29 +20,20 @@ class MmapFileHander:
```
folder_path:
- file_mapper.json
- metadata.json
- file1.bin
- file2.bin
...
```
"""
DTYPE_MAP = {
"float32": torch.float32,
"float64": torch.float64,
"int32": torch.int32,
"int64": torch.int64,
"bool": torch.bool,
}
REVERSE_DTYPE_MAP = {v: k for k, v in DTYPE_MAP.items()}
META_DATA = "metadata.json"
@staticmethod
def load(root_path: str, shared: bool=True) -> Tuple[Dict[str, List[Tensor]], int]:
metadata_list = []
mmap_shared_group: Dict[str, List[Tensor]] = {}
tensor_group: Dict[str, List[Tensor]] = {}
file_mapper_path = os.path.join(root_path, "file_mapper.json")
file_mapper_path = os.path.join(root_path, MmapFileHandler.META_DATA)
if not os.path.exists(file_mapper_path):
raise FileNotFoundError(f"File mapper not found: {file_mapper_path}")
@ -50,25 +41,20 @@ class MmapFileHander:
metadata_list = json.load(f)
for metadata in metadata_list:
file_path = os.path.join(root_path, metadata["file_name"])
if not os.path.exists(file_path):
raise FileNotFoundError(f"Binary data file not found: {file_path}")
file_key = metadata["key"]
file_name = metadata["file_name"]
file_path = os.path.join(root_path, file_name)
elm = torch.load(file_path, map_location="cpu", mmap=shared)
size = metadata["size"]
dtype = MmapFileHander.DTYPE_MAP[metadata["dtype"]]
segment_key = metadata["key"]
mmap_tensor = torch.from_file(file_path, shared=shared, size=size, dtype=dtype)
if segment_key not in mmap_shared_group:
mmap_shared_group[segment_key] = []
mmap_shared_group[segment_key].append(mmap_tensor)
if file_key not in tensor_group:
tensor_group[file_key] = []
tensor_group[file_key].append(elm)
num_samples = sum(metadata["size"] for metadata in metadata_list)
num_keys = max(len(set(metadata['key'] for metadata in metadata_list)), 1)
sample_per_key = num_samples // num_keys
return mmap_shared_group, sample_per_key
return tensor_group, sample_per_key
@staticmethod
def save(save_path: str, mmap_shared_group: Dict[str, List[Tensor]]) -> None:
@ -79,18 +65,18 @@ class MmapFileHander:
for idx, tensor in enumerate(segment_tensors):
try:
with open(os.path.join(save_path, f"{segment_key}_{idx}.bin"), "wb") as f:
f.write(tensor.cpu().numpy().tobytes())
with open(os.path.join(save_path, f"{segment_key}_{idx}.pt"), "wb") as f:
torch.save(tensor.contiguous().cpu(), f)
except Exception as e:
raise RuntimeError(f"Error saving tensor: {e}")
metadata_list.append({
"file_name": f"{segment_key}_{idx}.bin",
"file_name": f"{segment_key}_{idx}.pt",
"size": tensor.numel(),
"dtype": MmapFileHander.REVERSE_DTYPE_MAP[tensor.dtype],
"key": segment_key
})
metadata_path = os.path.join(save_path, "file_mapper.json")
metadata_path = os.path.join(save_path, MmapFileHandler.META_DATA)
with open(metadata_path, "w") as f:
json.dump(metadata_list, f)

View File

@ -17,23 +17,21 @@ def create_mmap_dataset(dir_path, data_dict, dataset_name):
for key, tensor in data_dict.items():
# Convert tensor to numpy array and save as binary file
np_array = tensor.numpy()
file_name = f"{key}.bin"
file_name = f"{key}.pt"
file_path = os.path.join(dataset_dir, file_name)
# Save as binary file
np_array.tofile(file_path)
with open(file_path, "wb") as f:
torch.save(tensor, f)
# Add to file mapper
file_mapper.append({
"file_name": file_name,
"size": len(np_array),
"dtype": str(np_array.dtype),
"size": tensor.numel(),
"key": key
})
# Save file mapper
mapper_path = os.path.join(dataset_dir, "file_mapper.json")
mapper_path = os.path.join(dataset_dir, "metadata.json")
with open(mapper_path, "w") as f:
json.dump(file_mapper, f, indent=2)
@ -194,7 +192,7 @@ def test_multi_segment_fetcher(base_test_env):
dataset_path = create_mmap_dataset(test_dir, dummy_data, "multi_segment_test")
# Load the memory mapped files directly
multi_segments, _ = MmapFileHander.load(dataset_path)
multi_segments, _ = MmapFileHandler.load(dataset_path)
# Create fetcher
fetcher = MultiSegmentFetcher(multi_segments)
@ -218,14 +216,14 @@ def test_multi_segment_fetcher(base_test_env):
def test_mmap_file_handler_direct(base_test_env):
"""Test MmapFileHander directly without DatasetLoader"""
"""Test MmapFileHandler directly without DatasetLoader"""
test_dir = base_test_env["test_dir"]
# Create test data with multiple segments
seq_length1 = 100
seq_length2 = 200
# Create data in the format expected by MmapFileHander
# Create data in the format expected by MmapFileHandler
dummy_data = {
"sequence": [
torch.randint(0, 1000, (seq_length1,), dtype=torch.int64),
@ -237,12 +235,12 @@ def test_mmap_file_handler_direct(base_test_env):
]
}
# Save data using MmapFileHander
# Save data using MmapFileHandler
dataset_dir = os.path.join(test_dir, "mmap_direct_test")
MmapFileHander.save(dataset_dir, dummy_data)
MmapFileHandler.save(dataset_dir, dummy_data)
# Load data using MmapFileHander
loaded_data, num_samples = MmapFileHander.load(dataset_dir)
# Load data using MmapFileHandler
loaded_data, num_samples = MmapFileHandler.load(dataset_dir)
# Verify data structure
assert set(loaded_data.keys()) == set(dummy_data.keys())
@ -255,7 +253,7 @@ def test_mmap_file_handler_direct(base_test_env):
assert torch.equal(loaded_data[key][i], dummy_data[key][i])
def test_mmap_file_handler_dtypes(base_test_env):
"""Test MmapFileHander with different data types"""
"""Test MmapFileHandler with different data types"""
test_dir = base_test_env["test_dir"]
# Create test data with different dtypes
@ -269,10 +267,10 @@ def test_mmap_file_handler_dtypes(base_test_env):
# Save data
dataset_dir = os.path.join(test_dir, "dtype_test")
MmapFileHander.save(dataset_dir, data)
MmapFileHandler.save(dataset_dir, data)
# Load data
loaded_data, _ = MmapFileHander.load(dataset_dir)
loaded_data, _ = MmapFileHandler.load(dataset_dir)
# Verify data types
for key in data:
@ -280,14 +278,14 @@ def test_mmap_file_handler_dtypes(base_test_env):
assert torch.equal(loaded_data[key][0], data[key][0])
def test_mmap_file_handler_error_handling(base_test_env):
"""Test MmapFileHander error handling"""
"""Test MmapFileHandler error handling"""
test_dir = base_test_env["test_dir"]
# Test loading without file_mapper.json
empty_dir = os.path.join(test_dir, "empty_dir")
os.makedirs(empty_dir, exist_ok=True)
with pytest.raises(FileNotFoundError):
MmapFileHander.load(empty_dir)
MmapFileHandler.load(empty_dir)
# Test loading with invalid file_mapper.json
invalid_dir = os.path.join(test_dir, "invalid_dir")
@ -299,4 +297,4 @@ def test_mmap_file_handler_error_handling(base_test_env):
# This should raise FileNotFoundError because no binary files exist
with pytest.raises(FileNotFoundError):
MmapFileHander.load(invalid_dir)
MmapFileHandler.load(invalid_dir)