refactor(data): 修改文件加载方案
This commit is contained in:
parent
0ca4871e80
commit
582d4ae9a7
|
|
@ -1,18 +1,17 @@
|
||||||
|
import h5py
|
||||||
import torch
|
import torch
|
||||||
import bisect
|
import bisect
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from khaosz.data.mmap import MmapFileHandler
|
from khaosz.data.file import load_h5
|
||||||
from typing import Callable, List, Dict, Literal, Optional, Union
|
from typing import Callable, List, Dict, Literal, Optional, Union
|
||||||
|
|
||||||
Seg = List[Tensor]
|
|
||||||
MultiSeg = Dict[str, Seg]
|
|
||||||
|
|
||||||
|
|
||||||
class BaseSegmentFetcher:
|
class BaseSegmentFetcher:
|
||||||
def __init__(self, segments: Seg):
|
def __init__(self, segments: List[Tensor]):
|
||||||
self.segments = segments
|
self.segments = segments
|
||||||
self.cum_lengths = []
|
self.cum_lengths = []
|
||||||
total = 0
|
total = 0
|
||||||
|
|
@ -37,20 +36,21 @@ class BaseSegmentFetcher:
|
||||||
prev_cum = self.cum_lengths[i - 1] if i > 0 else 0
|
prev_cum = self.cum_lengths[i - 1] if i > 0 else 0
|
||||||
start = max(begin_idx - prev_cum, 0)
|
start = max(begin_idx - prev_cum, 0)
|
||||||
end = min(end_idx - prev_cum, len(self.segments[i]))
|
end = min(end_idx - prev_cum, len(self.segments[i]))
|
||||||
result_segments.append(self.segments[i][start:end])
|
data = self.segments[i][start:end]
|
||||||
|
result_segments.append(data)
|
||||||
|
|
||||||
return torch.cat(result_segments, dim=0)
|
return torch.cat(result_segments, dim=0)
|
||||||
|
|
||||||
|
|
||||||
class MultiSegmentFetcher:
|
class MultiSegmentFetcher:
|
||||||
def __init__(self, muti_segments: MultiSeg):
|
def __init__(self, muti_segments: Dict):
|
||||||
self.muti_keys = list(muti_segments.keys())
|
self.muti_keys = list(muti_segments.keys())
|
||||||
self.muti_fetchers = {
|
self.muti_fetchers = {
|
||||||
key: BaseSegmentFetcher(segments)
|
key: BaseSegmentFetcher(segments)
|
||||||
for key, segments in muti_segments.items()
|
for key, segments in muti_segments.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
def key_fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]) -> Union[Tensor, Seg]:
|
def key_fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]) -> Dict:
|
||||||
fetch_dict = {}
|
fetch_dict = {}
|
||||||
keys = [keys] if isinstance(keys, str) else keys
|
keys = [keys] if isinstance(keys, str) else keys
|
||||||
|
|
||||||
|
|
@ -61,20 +61,20 @@ class MultiSegmentFetcher:
|
||||||
|
|
||||||
return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]]
|
return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]]
|
||||||
|
|
||||||
def fetch_data(self, begin_idx: int, end_idx: int) -> Union[Tensor, Seg]:
|
def fetch_data(self, begin_idx: int, end_idx: int) -> Dict:
|
||||||
return self.key_fetch(begin_idx, end_idx, self.muti_keys)
|
return self.key_fetch(begin_idx, end_idx, self.muti_keys)
|
||||||
|
|
||||||
|
|
||||||
class BaseDataset(Dataset, ABC):
|
class BaseDataset(Dataset, ABC):
|
||||||
def __init__(self, window_size: int, stride: int):
|
def __init__(self, window_size: int, stride: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.segments: MultiSeg = {}
|
self.segments = {}
|
||||||
self.window_size = window_size
|
self.window_size = window_size
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
self.total_samples = None
|
self.total_samples = None
|
||||||
|
|
||||||
def load(self, load_path: str):
|
def load(self, load_path: str):
|
||||||
self.segments, self.total_samples = MmapFileHandler.load(load_path)
|
self.segments, self.total_samples = load_h5(load_path)
|
||||||
self.fetcher = MultiSegmentFetcher(self.segments)
|
self.fetcher = MultiSegmentFetcher(self.segments)
|
||||||
|
|
||||||
def get_index(self, index: int) -> int:
|
def get_index(self, index: int) -> int:
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,44 @@
|
||||||
|
import os
|
||||||
|
import h5py
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
def save_h5(file_path: str, tensor_group: Dict[str, List[Tensor]]):
|
||||||
|
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||||
|
with h5py.File(file_path, 'w') as f:
|
||||||
|
for key, tensors in tensor_group.items():
|
||||||
|
grp = f.create_group(key)
|
||||||
|
grp.attrs['num_tensors'] = len(tensors)
|
||||||
|
|
||||||
|
for idx, tensor in enumerate(tensors):
|
||||||
|
arr = tensor.cpu().numpy()
|
||||||
|
dset = grp.create_dataset(
|
||||||
|
f'data_{idx}',
|
||||||
|
data=arr,
|
||||||
|
compression='gzip',
|
||||||
|
compression_opts=4,
|
||||||
|
shuffle=True
|
||||||
|
)
|
||||||
|
dset.attrs['numel'] = tensor.numel()
|
||||||
|
|
||||||
|
def load_h5(file_path: str) -> Tuple[Dict[str, List[Tensor]], int]:
|
||||||
|
tensor_group: Dict[str, List[Tensor]] = {}
|
||||||
|
total_samples = 0
|
||||||
|
|
||||||
|
with h5py.File(file_path, 'r') as f:
|
||||||
|
for key in f.keys():
|
||||||
|
grp = f[key]
|
||||||
|
dsets = []
|
||||||
|
for dset_name in grp.keys():
|
||||||
|
dset = grp[dset_name]
|
||||||
|
dsets.append(torch.from_numpy(dset[:]).share_memory_())
|
||||||
|
total_samples += dset.attrs.get('numel', np.prod(dset.shape))
|
||||||
|
tensor_group[key] = dsets
|
||||||
|
|
||||||
|
num_keys = max(len(tensor_group), 1)
|
||||||
|
sample_per_key = total_samples // num_keys
|
||||||
|
|
||||||
|
return tensor_group, sample_per_key
|
||||||
|
|
@ -1,82 +0,0 @@
|
||||||
import os
|
|
||||||
import json
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from torch import Tensor
|
|
||||||
from typing import List, Dict, Tuple
|
|
||||||
|
|
||||||
class MmapFileHandler:
|
|
||||||
"""
|
|
||||||
json metadata like this:
|
|
||||||
|
|
||||||
```
|
|
||||||
[
|
|
||||||
{"file_name": "file1.pt", "size": 1000, "key": "key1"},
|
|
||||||
{"file_name": "file2.pt", "size": 2000, "key": "key2"}
|
|
||||||
...
|
|
||||||
]
|
|
||||||
```
|
|
||||||
files like:
|
|
||||||
|
|
||||||
```
|
|
||||||
folder_path:
|
|
||||||
- metadata.json
|
|
||||||
- file1.pt
|
|
||||||
- file2.pt
|
|
||||||
...
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
META_DATA = "metadata.json"
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def load(root_path: str, shared: bool=True) -> Tuple[Dict[str, List[Tensor]], int]:
|
|
||||||
metadata_list = []
|
|
||||||
tensor_group: Dict[str, List[Tensor]] = {}
|
|
||||||
|
|
||||||
file_mapper_path = os.path.join(root_path, MmapFileHandler.META_DATA)
|
|
||||||
if not os.path.exists(file_mapper_path):
|
|
||||||
raise FileNotFoundError(f"File mapper not found: {file_mapper_path}")
|
|
||||||
|
|
||||||
with open(file_mapper_path, "r") as f:
|
|
||||||
metadata_list = json.load(f)
|
|
||||||
|
|
||||||
for metadata in metadata_list:
|
|
||||||
file_key = metadata["key"]
|
|
||||||
file_name = metadata["file_name"]
|
|
||||||
file_path = os.path.join(root_path, file_name)
|
|
||||||
elm = torch.load(file_path, map_location="cpu", mmap=shared)
|
|
||||||
|
|
||||||
if file_key not in tensor_group:
|
|
||||||
tensor_group[file_key] = []
|
|
||||||
tensor_group[file_key].append(elm)
|
|
||||||
|
|
||||||
num_samples = sum(metadata["size"] for metadata in metadata_list)
|
|
||||||
num_keys = max(len(set(metadata['key'] for metadata in metadata_list)), 1)
|
|
||||||
sample_per_key = num_samples // num_keys
|
|
||||||
|
|
||||||
return tensor_group, sample_per_key
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def save(save_path: str, mmap_shared_group: Dict[str, List[Tensor]]) -> None:
|
|
||||||
os.makedirs(save_path, exist_ok=True)
|
|
||||||
|
|
||||||
metadata_list = []
|
|
||||||
for segment_key, segment_tensors in mmap_shared_group.items():
|
|
||||||
for idx, tensor in enumerate(segment_tensors):
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(os.path.join(save_path, f"{segment_key}_{idx}.pt"), "wb") as f:
|
|
||||||
torch.save(tensor.contiguous().cpu(), f)
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"Error saving tensor: {e}")
|
|
||||||
|
|
||||||
metadata_list.append({
|
|
||||||
"file_name": f"{segment_key}_{idx}.pt",
|
|
||||||
"size": tensor.numel(),
|
|
||||||
"key": segment_key
|
|
||||||
})
|
|
||||||
|
|
||||||
metadata_path = os.path.join(save_path, MmapFileHandler.META_DATA)
|
|
||||||
|
|
||||||
with open(metadata_path, "w") as f:
|
|
||||||
json.dump(metadata_list, f)
|
|
||||||
|
|
@ -1,41 +1,22 @@
|
||||||
import os
|
import os
|
||||||
import json
|
|
||||||
import pytest
|
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from khaosz.trainer import *
|
from khaosz.data.file import save_h5
|
||||||
from khaosz.data.dataset import *
|
from khaosz.data.dataset import *
|
||||||
|
|
||||||
|
|
||||||
def create_mmap_dataset(dir_path, data_dict, dataset_name):
|
def create_h5_dataset(dir_path, data_dict, dataset_name):
|
||||||
"""Helper function to create memory-mapped dataset for testing"""
|
"""Helper function to create HDF5 dataset for testing"""
|
||||||
dataset_dir = os.path.join(dir_path, dataset_name)
|
dataset_path = os.path.join(dir_path, f"{dataset_name}.h5")
|
||||||
os.makedirs(dataset_dir, exist_ok=True)
|
|
||||||
|
|
||||||
file_mapper = []
|
# Convert data_dict to the format expected by save_h5
|
||||||
|
# save_h5 expects a list of tensors for each key
|
||||||
|
tensor_group = {key: [tensor] for key, tensor in data_dict.items()}
|
||||||
|
|
||||||
for key, tensor in data_dict.items():
|
save_h5(dataset_path, tensor_group)
|
||||||
# Convert tensor to numpy array and save as binary file
|
|
||||||
file_name = f"{key}.pt"
|
|
||||||
file_path = os.path.join(dataset_dir, file_name)
|
|
||||||
|
|
||||||
with open(file_path, "wb") as f:
|
|
||||||
torch.save(tensor, f)
|
|
||||||
|
|
||||||
# Add to file mapper
|
|
||||||
file_mapper.append({
|
|
||||||
"file_name": file_name,
|
|
||||||
"size": tensor.numel(),
|
|
||||||
"key": key
|
|
||||||
})
|
|
||||||
|
|
||||||
# Save file mapper
|
return dataset_path
|
||||||
mapper_path = os.path.join(dataset_dir, "metadata.json")
|
|
||||||
with open(mapper_path, "w") as f:
|
|
||||||
json.dump(file_mapper, f, indent=2)
|
|
||||||
|
|
||||||
return dataset_dir
|
|
||||||
|
|
||||||
|
|
||||||
def test_dataset_loader_random_paths(base_test_env):
|
def test_dataset_loader_random_paths(base_test_env):
|
||||||
|
|
@ -50,7 +31,7 @@ def test_dataset_loader_random_paths(base_test_env):
|
||||||
dummy_data = {
|
dummy_data = {
|
||||||
"sequence": torch.randint(0, 1000, (seq_length,), dtype=torch.int64),
|
"sequence": torch.randint(0, 1000, (seq_length,), dtype=torch.int64),
|
||||||
}
|
}
|
||||||
dataset_path = create_mmap_dataset(test_dir, dummy_data, f"test_data_{i}")
|
dataset_path = create_h5_dataset(test_dir, dummy_data, f"test_data_{i}")
|
||||||
|
|
||||||
# Test loading with multiple paths
|
# Test loading with multiple paths
|
||||||
loaded_dataset = DatasetLoader.load(
|
loaded_dataset = DatasetLoader.load(
|
||||||
|
|
@ -84,7 +65,7 @@ def test_dpo_strategy_with_random_data(base_test_env):
|
||||||
"rejected_mask": torch.ones(seq_length, dtype=torch.bool)
|
"rejected_mask": torch.ones(seq_length, dtype=torch.bool)
|
||||||
}
|
}
|
||||||
|
|
||||||
dataset_path = create_mmap_dataset(test_dir, dummy_data, "dpo_data")
|
dataset_path = create_h5_dataset(test_dir, dummy_data, "dpo_data")
|
||||||
|
|
||||||
# Load DPO dataset
|
# Load DPO dataset
|
||||||
dpo_dataset = DatasetLoader.load(
|
dpo_dataset = DatasetLoader.load(
|
||||||
|
|
@ -120,7 +101,7 @@ def test_sft_dataset_with_random_data(base_test_env):
|
||||||
"loss_mask": torch.ones(seq_length, dtype=torch.bool)
|
"loss_mask": torch.ones(seq_length, dtype=torch.bool)
|
||||||
}
|
}
|
||||||
|
|
||||||
dataset_path = create_mmap_dataset(test_dir, dummy_data, "sft_data")
|
dataset_path = create_h5_dataset(test_dir, dummy_data, "sft_data")
|
||||||
|
|
||||||
# Load SFT dataset
|
# Load SFT dataset
|
||||||
sft_dataset = DatasetLoader.load(
|
sft_dataset = DatasetLoader.load(
|
||||||
|
|
@ -153,7 +134,7 @@ def test_dataset_with_custom_stride(base_test_env):
|
||||||
"sequence": torch.randint(0, 1000, (seq_length,), dtype=torch.int64),
|
"sequence": torch.randint(0, 1000, (seq_length,), dtype=torch.int64),
|
||||||
}
|
}
|
||||||
|
|
||||||
dataset_path = create_mmap_dataset(test_dir, dummy_data, "stride_test_data")
|
dataset_path = create_h5_dataset(test_dir, dummy_data, "stride_test_data")
|
||||||
|
|
||||||
# Test with custom stride
|
# Test with custom stride
|
||||||
custom_stride = 32
|
custom_stride = 32
|
||||||
|
|
@ -176,125 +157,3 @@ def test_dataset_with_custom_stride(base_test_env):
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(dataset) > len(default_stride_dataset)
|
assert len(dataset) > len(default_stride_dataset)
|
||||||
|
|
||||||
|
|
||||||
def test_multi_segment_fetcher(base_test_env):
|
|
||||||
"""Test MultiSegmentFetcher functionality directly"""
|
|
||||||
test_dir = base_test_env["test_dir"]
|
|
||||||
|
|
||||||
# Create test data with multiple segments
|
|
||||||
seq_length = 100
|
|
||||||
dummy_data = {
|
|
||||||
"sequence": torch.randint(0, 1000, (seq_length,), dtype=torch.int64),
|
|
||||||
"mask": torch.ones(seq_length, dtype=torch.bool)
|
|
||||||
}
|
|
||||||
|
|
||||||
dataset_path = create_mmap_dataset(test_dir, dummy_data, "multi_segment_test")
|
|
||||||
|
|
||||||
# Load the memory mapped files directly
|
|
||||||
multi_segments, _ = MmapFileHandler.load(dataset_path)
|
|
||||||
|
|
||||||
# Create fetcher
|
|
||||||
fetcher = MultiSegmentFetcher(multi_segments)
|
|
||||||
|
|
||||||
# Test fetching single key
|
|
||||||
sequence_data = fetcher.key_fetch(0, 10, "sequence")
|
|
||||||
assert sequence_data is not None
|
|
||||||
assert len(sequence_data) == 10
|
|
||||||
|
|
||||||
# Test fetching multiple keys
|
|
||||||
multi_data = fetcher.key_fetch(0, 10, ["sequence", "mask"])
|
|
||||||
assert "sequence" in multi_data
|
|
||||||
assert "mask" in multi_data
|
|
||||||
assert len(multi_data["sequence"]) == 10
|
|
||||||
assert len(multi_data["mask"]) == 10
|
|
||||||
|
|
||||||
# Test fetching all keys
|
|
||||||
all_data = fetcher.fetch_data(0, 10)
|
|
||||||
assert "sequence" in all_data
|
|
||||||
assert "mask" in all_data
|
|
||||||
|
|
||||||
|
|
||||||
def test_mmap_file_handler_direct(base_test_env):
|
|
||||||
"""Test MmapFileHandler directly without DatasetLoader"""
|
|
||||||
test_dir = base_test_env["test_dir"]
|
|
||||||
|
|
||||||
# Create test data with multiple segments
|
|
||||||
seq_length1 = 100
|
|
||||||
seq_length2 = 200
|
|
||||||
|
|
||||||
# Create data in the format expected by MmapFileHandler
|
|
||||||
dummy_data = {
|
|
||||||
"sequence": [
|
|
||||||
torch.randint(0, 1000, (seq_length1,), dtype=torch.int64),
|
|
||||||
torch.randint(0, 1000, (seq_length2,), dtype=torch.int64)
|
|
||||||
],
|
|
||||||
"mask": [
|
|
||||||
torch.ones(seq_length1, dtype=torch.bool),
|
|
||||||
torch.ones(seq_length2, dtype=torch.bool)
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
# Save data using MmapFileHandler
|
|
||||||
dataset_dir = os.path.join(test_dir, "mmap_direct_test")
|
|
||||||
MmapFileHandler.save(dataset_dir, dummy_data)
|
|
||||||
|
|
||||||
# Load data using MmapFileHandler
|
|
||||||
loaded_data, num_samples = MmapFileHandler.load(dataset_dir)
|
|
||||||
|
|
||||||
# Verify data structure
|
|
||||||
assert set(loaded_data.keys()) == set(dummy_data.keys())
|
|
||||||
assert num_samples == seq_length1 + seq_length2 # 300
|
|
||||||
|
|
||||||
# Verify data content
|
|
||||||
for key in dummy_data:
|
|
||||||
assert len(loaded_data[key]) == len(dummy_data[key])
|
|
||||||
for i in range(len(dummy_data[key])):
|
|
||||||
assert torch.equal(loaded_data[key][i], dummy_data[key][i])
|
|
||||||
|
|
||||||
def test_mmap_file_handler_dtypes(base_test_env):
|
|
||||||
"""Test MmapFileHandler with different data types"""
|
|
||||||
test_dir = base_test_env["test_dir"]
|
|
||||||
|
|
||||||
# Create test data with different dtypes
|
|
||||||
data = {
|
|
||||||
"float32": [torch.randn(100, dtype=torch.float32)],
|
|
||||||
"float64": [torch.randn(100, dtype=torch.float64)],
|
|
||||||
"int32": [torch.randint(0, 1000, (100,), dtype=torch.int32)],
|
|
||||||
"int64": [torch.randint(0, 1000, (100,), dtype=torch.int64)],
|
|
||||||
"bool": [torch.randint(0, 2, (100,), dtype=torch.bool)]
|
|
||||||
}
|
|
||||||
|
|
||||||
# Save data
|
|
||||||
dataset_dir = os.path.join(test_dir, "dtype_test")
|
|
||||||
MmapFileHandler.save(dataset_dir, data)
|
|
||||||
|
|
||||||
# Load data
|
|
||||||
loaded_data, _ = MmapFileHandler.load(dataset_dir)
|
|
||||||
|
|
||||||
# Verify data types
|
|
||||||
for key in data:
|
|
||||||
assert loaded_data[key][0].dtype == data[key][0].dtype
|
|
||||||
assert torch.equal(loaded_data[key][0], data[key][0])
|
|
||||||
|
|
||||||
def test_mmap_file_handler_error_handling(base_test_env):
|
|
||||||
"""Test MmapFileHandler error handling"""
|
|
||||||
test_dir = base_test_env["test_dir"]
|
|
||||||
|
|
||||||
# Test loading without file_mapper.json
|
|
||||||
empty_dir = os.path.join(test_dir, "empty_dir")
|
|
||||||
os.makedirs(empty_dir, exist_ok=True)
|
|
||||||
with pytest.raises(FileNotFoundError):
|
|
||||||
MmapFileHandler.load(empty_dir)
|
|
||||||
|
|
||||||
# Test loading with invalid file_mapper.json
|
|
||||||
invalid_dir = os.path.join(test_dir, "invalid_dir")
|
|
||||||
os.makedirs(invalid_dir, exist_ok=True)
|
|
||||||
|
|
||||||
# Create empty file_mapper.json
|
|
||||||
with open(os.path.join(invalid_dir, "file_mapper.json"), "w") as f:
|
|
||||||
json.dump([{"file_name": "file1.bin", "size": 1000, "dtype": "float32", "key": "key1"}], f)
|
|
||||||
|
|
||||||
# This should raise FileNotFoundError because no binary files exist
|
|
||||||
with pytest.raises(FileNotFoundError):
|
|
||||||
MmapFileHandler.load(invalid_dir)
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue