refactor(data): 重构MmapFileHandler类并改进数据加载机制
This commit is contained in:
parent
9dab96c31f
commit
7dfa5cc0ac
|
|
@ -4,7 +4,7 @@ import bisect
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from khaosz.data.mmap import MmapFileHander
|
from khaosz.data.mmap import MmapFileHandler
|
||||||
from typing import Callable, List, Dict, Literal, Optional, Union
|
from typing import Callable, List, Dict, Literal, Optional, Union
|
||||||
|
|
||||||
Seg = List[Tensor]
|
Seg = List[Tensor]
|
||||||
|
|
@ -74,7 +74,7 @@ class BaseDataset(Dataset, ABC):
|
||||||
self.total_samples = None
|
self.total_samples = None
|
||||||
|
|
||||||
def load(self, load_path: str):
|
def load(self, load_path: str):
|
||||||
self.segments, self.total_samples = MmapFileHander.load(load_path)
|
self.segments, self.total_samples = MmapFileHandler.load(load_path)
|
||||||
self.fetcher = MultiSegmentFetcher(self.segments)
|
self.fetcher = MultiSegmentFetcher(self.segments)
|
||||||
|
|
||||||
def get_index(self, index: int) -> int:
|
def get_index(self, index: int) -> int:
|
||||||
|
|
|
||||||
|
|
@ -5,14 +5,14 @@ import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from typing import List, Dict, Tuple
|
from typing import List, Dict, Tuple
|
||||||
|
|
||||||
class MmapFileHander:
|
class MmapFileHandler:
|
||||||
"""
|
"""
|
||||||
json metadata like this:
|
json metadata like this:
|
||||||
|
|
||||||
```
|
```
|
||||||
[
|
[
|
||||||
{"file_name": "file1.bin", "size": 1000, "dtype": "float32", "key": "key1"},
|
{"file_name": "file1.bin", "size": 1000, "key": "key1"},
|
||||||
{"file_name": "file2.bin", "size": 2000, "dtype": "float32", "key": "key2"}
|
{"file_name": "file2.bin", "size": 2000, "key": "key2"}
|
||||||
...
|
...
|
||||||
]
|
]
|
||||||
```
|
```
|
||||||
|
|
@ -20,29 +20,20 @@ class MmapFileHander:
|
||||||
|
|
||||||
```
|
```
|
||||||
folder_path:
|
folder_path:
|
||||||
- file_mapper.json
|
- metadata.json
|
||||||
- file1.bin
|
- file1.bin
|
||||||
- file2.bin
|
- file2.bin
|
||||||
...
|
...
|
||||||
|
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
META_DATA = "metadata.json"
|
||||||
DTYPE_MAP = {
|
|
||||||
"float32": torch.float32,
|
|
||||||
"float64": torch.float64,
|
|
||||||
"int32": torch.int32,
|
|
||||||
"int64": torch.int64,
|
|
||||||
"bool": torch.bool,
|
|
||||||
}
|
|
||||||
REVERSE_DTYPE_MAP = {v: k for k, v in DTYPE_MAP.items()}
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load(root_path: str, shared: bool=True) -> Tuple[Dict[str, List[Tensor]], int]:
|
def load(root_path: str, shared: bool=True) -> Tuple[Dict[str, List[Tensor]], int]:
|
||||||
metadata_list = []
|
metadata_list = []
|
||||||
mmap_shared_group: Dict[str, List[Tensor]] = {}
|
tensor_group: Dict[str, List[Tensor]] = {}
|
||||||
|
|
||||||
file_mapper_path = os.path.join(root_path, "file_mapper.json")
|
file_mapper_path = os.path.join(root_path, MmapFileHandler.META_DATA)
|
||||||
if not os.path.exists(file_mapper_path):
|
if not os.path.exists(file_mapper_path):
|
||||||
raise FileNotFoundError(f"File mapper not found: {file_mapper_path}")
|
raise FileNotFoundError(f"File mapper not found: {file_mapper_path}")
|
||||||
|
|
||||||
|
|
@ -50,25 +41,20 @@ class MmapFileHander:
|
||||||
metadata_list = json.load(f)
|
metadata_list = json.load(f)
|
||||||
|
|
||||||
for metadata in metadata_list:
|
for metadata in metadata_list:
|
||||||
file_path = os.path.join(root_path, metadata["file_name"])
|
file_key = metadata["key"]
|
||||||
if not os.path.exists(file_path):
|
file_name = metadata["file_name"]
|
||||||
raise FileNotFoundError(f"Binary data file not found: {file_path}")
|
file_path = os.path.join(root_path, file_name)
|
||||||
|
elm = torch.load(file_path, map_location="cpu", mmap=shared)
|
||||||
|
|
||||||
size = metadata["size"]
|
if file_key not in tensor_group:
|
||||||
dtype = MmapFileHander.DTYPE_MAP[metadata["dtype"]]
|
tensor_group[file_key] = []
|
||||||
segment_key = metadata["key"]
|
tensor_group[file_key].append(elm)
|
||||||
mmap_tensor = torch.from_file(file_path, shared=shared, size=size, dtype=dtype)
|
|
||||||
|
|
||||||
if segment_key not in mmap_shared_group:
|
|
||||||
mmap_shared_group[segment_key] = []
|
|
||||||
|
|
||||||
mmap_shared_group[segment_key].append(mmap_tensor)
|
|
||||||
|
|
||||||
num_samples = sum(metadata["size"] for metadata in metadata_list)
|
num_samples = sum(metadata["size"] for metadata in metadata_list)
|
||||||
num_keys = max(len(set(metadata['key'] for metadata in metadata_list)), 1)
|
num_keys = max(len(set(metadata['key'] for metadata in metadata_list)), 1)
|
||||||
sample_per_key = num_samples // num_keys
|
sample_per_key = num_samples // num_keys
|
||||||
|
|
||||||
return mmap_shared_group, sample_per_key
|
return tensor_group, sample_per_key
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def save(save_path: str, mmap_shared_group: Dict[str, List[Tensor]]) -> None:
|
def save(save_path: str, mmap_shared_group: Dict[str, List[Tensor]]) -> None:
|
||||||
|
|
@ -79,18 +65,18 @@ class MmapFileHander:
|
||||||
for idx, tensor in enumerate(segment_tensors):
|
for idx, tensor in enumerate(segment_tensors):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(os.path.join(save_path, f"{segment_key}_{idx}.bin"), "wb") as f:
|
with open(os.path.join(save_path, f"{segment_key}_{idx}.pt"), "wb") as f:
|
||||||
f.write(tensor.cpu().numpy().tobytes())
|
torch.save(tensor.contiguous().cpu(), f)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Error saving tensor: {e}")
|
raise RuntimeError(f"Error saving tensor: {e}")
|
||||||
|
|
||||||
metadata_list.append({
|
metadata_list.append({
|
||||||
"file_name": f"{segment_key}_{idx}.bin",
|
"file_name": f"{segment_key}_{idx}.pt",
|
||||||
"size": tensor.numel(),
|
"size": tensor.numel(),
|
||||||
"dtype": MmapFileHander.REVERSE_DTYPE_MAP[tensor.dtype],
|
|
||||||
"key": segment_key
|
"key": segment_key
|
||||||
})
|
})
|
||||||
|
|
||||||
metadata_path = os.path.join(save_path, "file_mapper.json")
|
metadata_path = os.path.join(save_path, MmapFileHandler.META_DATA)
|
||||||
|
|
||||||
with open(metadata_path, "w") as f:
|
with open(metadata_path, "w") as f:
|
||||||
json.dump(metadata_list, f)
|
json.dump(metadata_list, f)
|
||||||
|
|
@ -17,23 +17,21 @@ def create_mmap_dataset(dir_path, data_dict, dataset_name):
|
||||||
|
|
||||||
for key, tensor in data_dict.items():
|
for key, tensor in data_dict.items():
|
||||||
# Convert tensor to numpy array and save as binary file
|
# Convert tensor to numpy array and save as binary file
|
||||||
np_array = tensor.numpy()
|
file_name = f"{key}.pt"
|
||||||
file_name = f"{key}.bin"
|
|
||||||
file_path = os.path.join(dataset_dir, file_name)
|
file_path = os.path.join(dataset_dir, file_name)
|
||||||
|
|
||||||
# Save as binary file
|
with open(file_path, "wb") as f:
|
||||||
np_array.tofile(file_path)
|
torch.save(tensor, f)
|
||||||
|
|
||||||
# Add to file mapper
|
# Add to file mapper
|
||||||
file_mapper.append({
|
file_mapper.append({
|
||||||
"file_name": file_name,
|
"file_name": file_name,
|
||||||
"size": len(np_array),
|
"size": tensor.numel(),
|
||||||
"dtype": str(np_array.dtype),
|
|
||||||
"key": key
|
"key": key
|
||||||
})
|
})
|
||||||
|
|
||||||
# Save file mapper
|
# Save file mapper
|
||||||
mapper_path = os.path.join(dataset_dir, "file_mapper.json")
|
mapper_path = os.path.join(dataset_dir, "metadata.json")
|
||||||
with open(mapper_path, "w") as f:
|
with open(mapper_path, "w") as f:
|
||||||
json.dump(file_mapper, f, indent=2)
|
json.dump(file_mapper, f, indent=2)
|
||||||
|
|
||||||
|
|
@ -194,7 +192,7 @@ def test_multi_segment_fetcher(base_test_env):
|
||||||
dataset_path = create_mmap_dataset(test_dir, dummy_data, "multi_segment_test")
|
dataset_path = create_mmap_dataset(test_dir, dummy_data, "multi_segment_test")
|
||||||
|
|
||||||
# Load the memory mapped files directly
|
# Load the memory mapped files directly
|
||||||
multi_segments, _ = MmapFileHander.load(dataset_path)
|
multi_segments, _ = MmapFileHandler.load(dataset_path)
|
||||||
|
|
||||||
# Create fetcher
|
# Create fetcher
|
||||||
fetcher = MultiSegmentFetcher(multi_segments)
|
fetcher = MultiSegmentFetcher(multi_segments)
|
||||||
|
|
@ -218,14 +216,14 @@ def test_multi_segment_fetcher(base_test_env):
|
||||||
|
|
||||||
|
|
||||||
def test_mmap_file_handler_direct(base_test_env):
|
def test_mmap_file_handler_direct(base_test_env):
|
||||||
"""Test MmapFileHander directly without DatasetLoader"""
|
"""Test MmapFileHandler directly without DatasetLoader"""
|
||||||
test_dir = base_test_env["test_dir"]
|
test_dir = base_test_env["test_dir"]
|
||||||
|
|
||||||
# Create test data with multiple segments
|
# Create test data with multiple segments
|
||||||
seq_length1 = 100
|
seq_length1 = 100
|
||||||
seq_length2 = 200
|
seq_length2 = 200
|
||||||
|
|
||||||
# Create data in the format expected by MmapFileHander
|
# Create data in the format expected by MmapFileHandler
|
||||||
dummy_data = {
|
dummy_data = {
|
||||||
"sequence": [
|
"sequence": [
|
||||||
torch.randint(0, 1000, (seq_length1,), dtype=torch.int64),
|
torch.randint(0, 1000, (seq_length1,), dtype=torch.int64),
|
||||||
|
|
@ -237,12 +235,12 @@ def test_mmap_file_handler_direct(base_test_env):
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
# Save data using MmapFileHander
|
# Save data using MmapFileHandler
|
||||||
dataset_dir = os.path.join(test_dir, "mmap_direct_test")
|
dataset_dir = os.path.join(test_dir, "mmap_direct_test")
|
||||||
MmapFileHander.save(dataset_dir, dummy_data)
|
MmapFileHandler.save(dataset_dir, dummy_data)
|
||||||
|
|
||||||
# Load data using MmapFileHander
|
# Load data using MmapFileHandler
|
||||||
loaded_data, num_samples = MmapFileHander.load(dataset_dir)
|
loaded_data, num_samples = MmapFileHandler.load(dataset_dir)
|
||||||
|
|
||||||
# Verify data structure
|
# Verify data structure
|
||||||
assert set(loaded_data.keys()) == set(dummy_data.keys())
|
assert set(loaded_data.keys()) == set(dummy_data.keys())
|
||||||
|
|
@ -255,7 +253,7 @@ def test_mmap_file_handler_direct(base_test_env):
|
||||||
assert torch.equal(loaded_data[key][i], dummy_data[key][i])
|
assert torch.equal(loaded_data[key][i], dummy_data[key][i])
|
||||||
|
|
||||||
def test_mmap_file_handler_dtypes(base_test_env):
|
def test_mmap_file_handler_dtypes(base_test_env):
|
||||||
"""Test MmapFileHander with different data types"""
|
"""Test MmapFileHandler with different data types"""
|
||||||
test_dir = base_test_env["test_dir"]
|
test_dir = base_test_env["test_dir"]
|
||||||
|
|
||||||
# Create test data with different dtypes
|
# Create test data with different dtypes
|
||||||
|
|
@ -269,10 +267,10 @@ def test_mmap_file_handler_dtypes(base_test_env):
|
||||||
|
|
||||||
# Save data
|
# Save data
|
||||||
dataset_dir = os.path.join(test_dir, "dtype_test")
|
dataset_dir = os.path.join(test_dir, "dtype_test")
|
||||||
MmapFileHander.save(dataset_dir, data)
|
MmapFileHandler.save(dataset_dir, data)
|
||||||
|
|
||||||
# Load data
|
# Load data
|
||||||
loaded_data, _ = MmapFileHander.load(dataset_dir)
|
loaded_data, _ = MmapFileHandler.load(dataset_dir)
|
||||||
|
|
||||||
# Verify data types
|
# Verify data types
|
||||||
for key in data:
|
for key in data:
|
||||||
|
|
@ -280,14 +278,14 @@ def test_mmap_file_handler_dtypes(base_test_env):
|
||||||
assert torch.equal(loaded_data[key][0], data[key][0])
|
assert torch.equal(loaded_data[key][0], data[key][0])
|
||||||
|
|
||||||
def test_mmap_file_handler_error_handling(base_test_env):
|
def test_mmap_file_handler_error_handling(base_test_env):
|
||||||
"""Test MmapFileHander error handling"""
|
"""Test MmapFileHandler error handling"""
|
||||||
test_dir = base_test_env["test_dir"]
|
test_dir = base_test_env["test_dir"]
|
||||||
|
|
||||||
# Test loading without file_mapper.json
|
# Test loading without file_mapper.json
|
||||||
empty_dir = os.path.join(test_dir, "empty_dir")
|
empty_dir = os.path.join(test_dir, "empty_dir")
|
||||||
os.makedirs(empty_dir, exist_ok=True)
|
os.makedirs(empty_dir, exist_ok=True)
|
||||||
with pytest.raises(FileNotFoundError):
|
with pytest.raises(FileNotFoundError):
|
||||||
MmapFileHander.load(empty_dir)
|
MmapFileHandler.load(empty_dir)
|
||||||
|
|
||||||
# Test loading with invalid file_mapper.json
|
# Test loading with invalid file_mapper.json
|
||||||
invalid_dir = os.path.join(test_dir, "invalid_dir")
|
invalid_dir = os.path.join(test_dir, "invalid_dir")
|
||||||
|
|
@ -299,4 +297,4 @@ def test_mmap_file_handler_error_handling(base_test_env):
|
||||||
|
|
||||||
# This should raise FileNotFoundError because no binary files exist
|
# This should raise FileNotFoundError because no binary files exist
|
||||||
with pytest.raises(FileNotFoundError):
|
with pytest.raises(FileNotFoundError):
|
||||||
MmapFileHander.load(invalid_dir)
|
MmapFileHandler.load(invalid_dir)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue