refactor(data): 重构MmapFileHandler类并改进数据加载机制

This commit is contained in:
ViperEkura 2026-01-11 19:37:28 +08:00
parent 9dab96c31f
commit 7dfa5cc0ac
3 changed files with 42 additions and 58 deletions

View File

@ -4,7 +4,7 @@ import bisect
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from torch import Tensor from torch import Tensor
from torch.utils.data import Dataset from torch.utils.data import Dataset
from khaosz.data.mmap import MmapFileHander from khaosz.data.mmap import MmapFileHandler
from typing import Callable, List, Dict, Literal, Optional, Union from typing import Callable, List, Dict, Literal, Optional, Union
Seg = List[Tensor] Seg = List[Tensor]
@ -74,7 +74,7 @@ class BaseDataset(Dataset, ABC):
self.total_samples = None self.total_samples = None
def load(self, load_path: str): def load(self, load_path: str):
self.segments, self.total_samples = MmapFileHander.load(load_path) self.segments, self.total_samples = MmapFileHandler.load(load_path)
self.fetcher = MultiSegmentFetcher(self.segments) self.fetcher = MultiSegmentFetcher(self.segments)
def get_index(self, index: int) -> int: def get_index(self, index: int) -> int:

View File

@ -5,14 +5,14 @@ import torch
from torch import Tensor from torch import Tensor
from typing import List, Dict, Tuple from typing import List, Dict, Tuple
class MmapFileHander: class MmapFileHandler:
""" """
json metadata like this: json metadata like this:
``` ```
[ [
{"file_name": "file1.bin", "size": 1000, "dtype": "float32", "key": "key1"}, {"file_name": "file1.bin", "size": 1000, "key": "key1"},
{"file_name": "file2.bin", "size": 2000, "dtype": "float32", "key": "key2"} {"file_name": "file2.bin", "size": 2000, "key": "key2"}
... ...
] ]
``` ```
@ -20,29 +20,20 @@ class MmapFileHander:
``` ```
folder_path: folder_path:
- file_mapper.json - metadata.json
- file1.bin - file1.bin
- file2.bin - file2.bin
... ...
``` ```
""" """
META_DATA = "metadata.json"
DTYPE_MAP = {
"float32": torch.float32,
"float64": torch.float64,
"int32": torch.int32,
"int64": torch.int64,
"bool": torch.bool,
}
REVERSE_DTYPE_MAP = {v: k for k, v in DTYPE_MAP.items()}
@staticmethod @staticmethod
def load(root_path: str, shared: bool=True) -> Tuple[Dict[str, List[Tensor]], int]: def load(root_path: str, shared: bool=True) -> Tuple[Dict[str, List[Tensor]], int]:
metadata_list = [] metadata_list = []
mmap_shared_group: Dict[str, List[Tensor]] = {} tensor_group: Dict[str, List[Tensor]] = {}
file_mapper_path = os.path.join(root_path, "file_mapper.json") file_mapper_path = os.path.join(root_path, MmapFileHandler.META_DATA)
if not os.path.exists(file_mapper_path): if not os.path.exists(file_mapper_path):
raise FileNotFoundError(f"File mapper not found: {file_mapper_path}") raise FileNotFoundError(f"File mapper not found: {file_mapper_path}")
@ -50,25 +41,20 @@ class MmapFileHander:
metadata_list = json.load(f) metadata_list = json.load(f)
for metadata in metadata_list: for metadata in metadata_list:
file_path = os.path.join(root_path, metadata["file_name"]) file_key = metadata["key"]
if not os.path.exists(file_path): file_name = metadata["file_name"]
raise FileNotFoundError(f"Binary data file not found: {file_path}") file_path = os.path.join(root_path, file_name)
elm = torch.load(file_path, map_location="cpu", mmap=shared)
size = metadata["size"] if file_key not in tensor_group:
dtype = MmapFileHander.DTYPE_MAP[metadata["dtype"]] tensor_group[file_key] = []
segment_key = metadata["key"] tensor_group[file_key].append(elm)
mmap_tensor = torch.from_file(file_path, shared=shared, size=size, dtype=dtype)
if segment_key not in mmap_shared_group:
mmap_shared_group[segment_key] = []
mmap_shared_group[segment_key].append(mmap_tensor)
num_samples = sum(metadata["size"] for metadata in metadata_list) num_samples = sum(metadata["size"] for metadata in metadata_list)
num_keys = max(len(set(metadata['key'] for metadata in metadata_list)), 1) num_keys = max(len(set(metadata['key'] for metadata in metadata_list)), 1)
sample_per_key = num_samples // num_keys sample_per_key = num_samples // num_keys
return mmap_shared_group, sample_per_key return tensor_group, sample_per_key
@staticmethod @staticmethod
def save(save_path: str, mmap_shared_group: Dict[str, List[Tensor]]) -> None: def save(save_path: str, mmap_shared_group: Dict[str, List[Tensor]]) -> None:
@ -79,18 +65,18 @@ class MmapFileHander:
for idx, tensor in enumerate(segment_tensors): for idx, tensor in enumerate(segment_tensors):
try: try:
with open(os.path.join(save_path, f"{segment_key}_{idx}.bin"), "wb") as f: with open(os.path.join(save_path, f"{segment_key}_{idx}.pt"), "wb") as f:
f.write(tensor.cpu().numpy().tobytes()) torch.save(tensor.contiguous().cpu(), f)
except Exception as e: except Exception as e:
raise RuntimeError(f"Error saving tensor: {e}") raise RuntimeError(f"Error saving tensor: {e}")
metadata_list.append({ metadata_list.append({
"file_name": f"{segment_key}_{idx}.bin", "file_name": f"{segment_key}_{idx}.pt",
"size": tensor.numel(), "size": tensor.numel(),
"dtype": MmapFileHander.REVERSE_DTYPE_MAP[tensor.dtype],
"key": segment_key "key": segment_key
}) })
metadata_path = os.path.join(save_path, "file_mapper.json") metadata_path = os.path.join(save_path, MmapFileHandler.META_DATA)
with open(metadata_path, "w") as f: with open(metadata_path, "w") as f:
json.dump(metadata_list, f) json.dump(metadata_list, f)

View File

@ -17,23 +17,21 @@ def create_mmap_dataset(dir_path, data_dict, dataset_name):
for key, tensor in data_dict.items(): for key, tensor in data_dict.items():
# Convert tensor to numpy array and save as binary file # Convert tensor to numpy array and save as binary file
np_array = tensor.numpy() file_name = f"{key}.pt"
file_name = f"{key}.bin"
file_path = os.path.join(dataset_dir, file_name) file_path = os.path.join(dataset_dir, file_name)
# Save as binary file with open(file_path, "wb") as f:
np_array.tofile(file_path) torch.save(tensor, f)
# Add to file mapper # Add to file mapper
file_mapper.append({ file_mapper.append({
"file_name": file_name, "file_name": file_name,
"size": len(np_array), "size": tensor.numel(),
"dtype": str(np_array.dtype),
"key": key "key": key
}) })
# Save file mapper # Save file mapper
mapper_path = os.path.join(dataset_dir, "file_mapper.json") mapper_path = os.path.join(dataset_dir, "metadata.json")
with open(mapper_path, "w") as f: with open(mapper_path, "w") as f:
json.dump(file_mapper, f, indent=2) json.dump(file_mapper, f, indent=2)
@ -194,7 +192,7 @@ def test_multi_segment_fetcher(base_test_env):
dataset_path = create_mmap_dataset(test_dir, dummy_data, "multi_segment_test") dataset_path = create_mmap_dataset(test_dir, dummy_data, "multi_segment_test")
# Load the memory mapped files directly # Load the memory mapped files directly
multi_segments, _ = MmapFileHander.load(dataset_path) multi_segments, _ = MmapFileHandler.load(dataset_path)
# Create fetcher # Create fetcher
fetcher = MultiSegmentFetcher(multi_segments) fetcher = MultiSegmentFetcher(multi_segments)
@ -218,14 +216,14 @@ def test_multi_segment_fetcher(base_test_env):
def test_mmap_file_handler_direct(base_test_env): def test_mmap_file_handler_direct(base_test_env):
"""Test MmapFileHander directly without DatasetLoader""" """Test MmapFileHandler directly without DatasetLoader"""
test_dir = base_test_env["test_dir"] test_dir = base_test_env["test_dir"]
# Create test data with multiple segments # Create test data with multiple segments
seq_length1 = 100 seq_length1 = 100
seq_length2 = 200 seq_length2 = 200
# Create data in the format expected by MmapFileHander # Create data in the format expected by MmapFileHandler
dummy_data = { dummy_data = {
"sequence": [ "sequence": [
torch.randint(0, 1000, (seq_length1,), dtype=torch.int64), torch.randint(0, 1000, (seq_length1,), dtype=torch.int64),
@ -237,12 +235,12 @@ def test_mmap_file_handler_direct(base_test_env):
] ]
} }
# Save data using MmapFileHander # Save data using MmapFileHandler
dataset_dir = os.path.join(test_dir, "mmap_direct_test") dataset_dir = os.path.join(test_dir, "mmap_direct_test")
MmapFileHander.save(dataset_dir, dummy_data) MmapFileHandler.save(dataset_dir, dummy_data)
# Load data using MmapFileHander # Load data using MmapFileHandler
loaded_data, num_samples = MmapFileHander.load(dataset_dir) loaded_data, num_samples = MmapFileHandler.load(dataset_dir)
# Verify data structure # Verify data structure
assert set(loaded_data.keys()) == set(dummy_data.keys()) assert set(loaded_data.keys()) == set(dummy_data.keys())
@ -255,7 +253,7 @@ def test_mmap_file_handler_direct(base_test_env):
assert torch.equal(loaded_data[key][i], dummy_data[key][i]) assert torch.equal(loaded_data[key][i], dummy_data[key][i])
def test_mmap_file_handler_dtypes(base_test_env): def test_mmap_file_handler_dtypes(base_test_env):
"""Test MmapFileHander with different data types""" """Test MmapFileHandler with different data types"""
test_dir = base_test_env["test_dir"] test_dir = base_test_env["test_dir"]
# Create test data with different dtypes # Create test data with different dtypes
@ -269,10 +267,10 @@ def test_mmap_file_handler_dtypes(base_test_env):
# Save data # Save data
dataset_dir = os.path.join(test_dir, "dtype_test") dataset_dir = os.path.join(test_dir, "dtype_test")
MmapFileHander.save(dataset_dir, data) MmapFileHandler.save(dataset_dir, data)
# Load data # Load data
loaded_data, _ = MmapFileHander.load(dataset_dir) loaded_data, _ = MmapFileHandler.load(dataset_dir)
# Verify data types # Verify data types
for key in data: for key in data:
@ -280,14 +278,14 @@ def test_mmap_file_handler_dtypes(base_test_env):
assert torch.equal(loaded_data[key][0], data[key][0]) assert torch.equal(loaded_data[key][0], data[key][0])
def test_mmap_file_handler_error_handling(base_test_env): def test_mmap_file_handler_error_handling(base_test_env):
"""Test MmapFileHander error handling""" """Test MmapFileHandler error handling"""
test_dir = base_test_env["test_dir"] test_dir = base_test_env["test_dir"]
# Test loading without file_mapper.json # Test loading without file_mapper.json
empty_dir = os.path.join(test_dir, "empty_dir") empty_dir = os.path.join(test_dir, "empty_dir")
os.makedirs(empty_dir, exist_ok=True) os.makedirs(empty_dir, exist_ok=True)
with pytest.raises(FileNotFoundError): with pytest.raises(FileNotFoundError):
MmapFileHander.load(empty_dir) MmapFileHandler.load(empty_dir)
# Test loading with invalid file_mapper.json # Test loading with invalid file_mapper.json
invalid_dir = os.path.join(test_dir, "invalid_dir") invalid_dir = os.path.join(test_dir, "invalid_dir")
@ -299,4 +297,4 @@ def test_mmap_file_handler_error_handling(base_test_env):
# This should raise FileNotFoundError because no binary files exist # This should raise FileNotFoundError because no binary files exist
with pytest.raises(FileNotFoundError): with pytest.raises(FileNotFoundError):
MmapFileHander.load(invalid_dir) MmapFileHandler.load(invalid_dir)