refactor(data): 修改文件加载方案
This commit is contained in:
parent
0ca4871e80
commit
582d4ae9a7
|
|
@ -1,18 +1,17 @@
|
|||
import h5py
|
||||
import torch
|
||||
import bisect
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from torch import Tensor
|
||||
from torch.utils.data import Dataset
|
||||
from khaosz.data.mmap import MmapFileHandler
|
||||
from khaosz.data.file import load_h5
|
||||
from typing import Callable, List, Dict, Literal, Optional, Union
|
||||
|
||||
Seg = List[Tensor]
|
||||
MultiSeg = Dict[str, Seg]
|
||||
|
||||
|
||||
class BaseSegmentFetcher:
|
||||
def __init__(self, segments: Seg):
|
||||
def __init__(self, segments: List[Tensor]):
|
||||
self.segments = segments
|
||||
self.cum_lengths = []
|
||||
total = 0
|
||||
|
|
@ -37,20 +36,21 @@ class BaseSegmentFetcher:
|
|||
prev_cum = self.cum_lengths[i - 1] if i > 0 else 0
|
||||
start = max(begin_idx - prev_cum, 0)
|
||||
end = min(end_idx - prev_cum, len(self.segments[i]))
|
||||
result_segments.append(self.segments[i][start:end])
|
||||
data = self.segments[i][start:end]
|
||||
result_segments.append(data)
|
||||
|
||||
return torch.cat(result_segments, dim=0)
|
||||
|
||||
|
||||
class MultiSegmentFetcher:
|
||||
def __init__(self, muti_segments: MultiSeg):
|
||||
def __init__(self, muti_segments: Dict):
|
||||
self.muti_keys = list(muti_segments.keys())
|
||||
self.muti_fetchers = {
|
||||
key: BaseSegmentFetcher(segments)
|
||||
for key, segments in muti_segments.items()
|
||||
}
|
||||
|
||||
def key_fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]) -> Union[Tensor, Seg]:
|
||||
def key_fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]) -> Dict:
|
||||
fetch_dict = {}
|
||||
keys = [keys] if isinstance(keys, str) else keys
|
||||
|
||||
|
|
@ -61,20 +61,20 @@ class MultiSegmentFetcher:
|
|||
|
||||
return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]]
|
||||
|
||||
def fetch_data(self, begin_idx: int, end_idx: int) -> Union[Tensor, Seg]:
|
||||
def fetch_data(self, begin_idx: int, end_idx: int) -> Dict:
|
||||
return self.key_fetch(begin_idx, end_idx, self.muti_keys)
|
||||
|
||||
|
||||
class BaseDataset(Dataset, ABC):
|
||||
def __init__(self, window_size: int, stride: int):
|
||||
super().__init__()
|
||||
self.segments: MultiSeg = {}
|
||||
self.segments = {}
|
||||
self.window_size = window_size
|
||||
self.stride = stride
|
||||
self.total_samples = None
|
||||
|
||||
def load(self, load_path: str):
|
||||
self.segments, self.total_samples = MmapFileHandler.load(load_path)
|
||||
self.segments, self.total_samples = load_h5(load_path)
|
||||
self.fetcher = MultiSegmentFetcher(self.segments)
|
||||
|
||||
def get_index(self, index: int) -> int:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,44 @@
|
|||
import os
|
||||
import h5py
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
|
||||
def save_h5(file_path: str, tensor_group: Dict[str, List[Tensor]]):
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
with h5py.File(file_path, 'w') as f:
|
||||
for key, tensors in tensor_group.items():
|
||||
grp = f.create_group(key)
|
||||
grp.attrs['num_tensors'] = len(tensors)
|
||||
|
||||
for idx, tensor in enumerate(tensors):
|
||||
arr = tensor.cpu().numpy()
|
||||
dset = grp.create_dataset(
|
||||
f'data_{idx}',
|
||||
data=arr,
|
||||
compression='gzip',
|
||||
compression_opts=4,
|
||||
shuffle=True
|
||||
)
|
||||
dset.attrs['numel'] = tensor.numel()
|
||||
|
||||
def load_h5(file_path: str) -> Tuple[Dict[str, List[Tensor]], int]:
|
||||
tensor_group: Dict[str, List[Tensor]] = {}
|
||||
total_samples = 0
|
||||
|
||||
with h5py.File(file_path, 'r') as f:
|
||||
for key in f.keys():
|
||||
grp = f[key]
|
||||
dsets = []
|
||||
for dset_name in grp.keys():
|
||||
dset = grp[dset_name]
|
||||
dsets.append(torch.from_numpy(dset[:]).share_memory_())
|
||||
total_samples += dset.attrs.get('numel', np.prod(dset.shape))
|
||||
tensor_group[key] = dsets
|
||||
|
||||
num_keys = max(len(tensor_group), 1)
|
||||
sample_per_key = total_samples // num_keys
|
||||
|
||||
return tensor_group, sample_per_key
|
||||
|
|
@ -1,82 +0,0 @@
|
|||
import os
|
||||
import json
|
||||
import torch
|
||||
|
||||
from torch import Tensor
|
||||
from typing import List, Dict, Tuple
|
||||
|
||||
class MmapFileHandler:
|
||||
"""
|
||||
json metadata like this:
|
||||
|
||||
```
|
||||
[
|
||||
{"file_name": "file1.pt", "size": 1000, "key": "key1"},
|
||||
{"file_name": "file2.pt", "size": 2000, "key": "key2"}
|
||||
...
|
||||
]
|
||||
```
|
||||
files like:
|
||||
|
||||
```
|
||||
folder_path:
|
||||
- metadata.json
|
||||
- file1.pt
|
||||
- file2.pt
|
||||
...
|
||||
```
|
||||
"""
|
||||
META_DATA = "metadata.json"
|
||||
|
||||
@staticmethod
|
||||
def load(root_path: str, shared: bool=True) -> Tuple[Dict[str, List[Tensor]], int]:
|
||||
metadata_list = []
|
||||
tensor_group: Dict[str, List[Tensor]] = {}
|
||||
|
||||
file_mapper_path = os.path.join(root_path, MmapFileHandler.META_DATA)
|
||||
if not os.path.exists(file_mapper_path):
|
||||
raise FileNotFoundError(f"File mapper not found: {file_mapper_path}")
|
||||
|
||||
with open(file_mapper_path, "r") as f:
|
||||
metadata_list = json.load(f)
|
||||
|
||||
for metadata in metadata_list:
|
||||
file_key = metadata["key"]
|
||||
file_name = metadata["file_name"]
|
||||
file_path = os.path.join(root_path, file_name)
|
||||
elm = torch.load(file_path, map_location="cpu", mmap=shared)
|
||||
|
||||
if file_key not in tensor_group:
|
||||
tensor_group[file_key] = []
|
||||
tensor_group[file_key].append(elm)
|
||||
|
||||
num_samples = sum(metadata["size"] for metadata in metadata_list)
|
||||
num_keys = max(len(set(metadata['key'] for metadata in metadata_list)), 1)
|
||||
sample_per_key = num_samples // num_keys
|
||||
|
||||
return tensor_group, sample_per_key
|
||||
|
||||
@staticmethod
|
||||
def save(save_path: str, mmap_shared_group: Dict[str, List[Tensor]]) -> None:
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
|
||||
metadata_list = []
|
||||
for segment_key, segment_tensors in mmap_shared_group.items():
|
||||
for idx, tensor in enumerate(segment_tensors):
|
||||
|
||||
try:
|
||||
with open(os.path.join(save_path, f"{segment_key}_{idx}.pt"), "wb") as f:
|
||||
torch.save(tensor.contiguous().cpu(), f)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error saving tensor: {e}")
|
||||
|
||||
metadata_list.append({
|
||||
"file_name": f"{segment_key}_{idx}.pt",
|
||||
"size": tensor.numel(),
|
||||
"key": segment_key
|
||||
})
|
||||
|
||||
metadata_path = os.path.join(save_path, MmapFileHandler.META_DATA)
|
||||
|
||||
with open(metadata_path, "w") as f:
|
||||
json.dump(metadata_list, f)
|
||||
|
|
@ -1,41 +1,22 @@
|
|||
import os
|
||||
import json
|
||||
import pytest
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from khaosz.trainer import *
|
||||
from khaosz.data.file import save_h5
|
||||
from khaosz.data.dataset import *
|
||||
|
||||
|
||||
def create_mmap_dataset(dir_path, data_dict, dataset_name):
|
||||
"""Helper function to create memory-mapped dataset for testing"""
|
||||
dataset_dir = os.path.join(dir_path, dataset_name)
|
||||
os.makedirs(dataset_dir, exist_ok=True)
|
||||
def create_h5_dataset(dir_path, data_dict, dataset_name):
|
||||
"""Helper function to create HDF5 dataset for testing"""
|
||||
dataset_path = os.path.join(dir_path, f"{dataset_name}.h5")
|
||||
|
||||
file_mapper = []
|
||||
# Convert data_dict to the format expected by save_h5
|
||||
# save_h5 expects a list of tensors for each key
|
||||
tensor_group = {key: [tensor] for key, tensor in data_dict.items()}
|
||||
|
||||
for key, tensor in data_dict.items():
|
||||
# Convert tensor to numpy array and save as binary file
|
||||
file_name = f"{key}.pt"
|
||||
file_path = os.path.join(dataset_dir, file_name)
|
||||
|
||||
with open(file_path, "wb") as f:
|
||||
torch.save(tensor, f)
|
||||
|
||||
# Add to file mapper
|
||||
file_mapper.append({
|
||||
"file_name": file_name,
|
||||
"size": tensor.numel(),
|
||||
"key": key
|
||||
})
|
||||
save_h5(dataset_path, tensor_group)
|
||||
|
||||
# Save file mapper
|
||||
mapper_path = os.path.join(dataset_dir, "metadata.json")
|
||||
with open(mapper_path, "w") as f:
|
||||
json.dump(file_mapper, f, indent=2)
|
||||
|
||||
return dataset_dir
|
||||
return dataset_path
|
||||
|
||||
|
||||
def test_dataset_loader_random_paths(base_test_env):
|
||||
|
|
@ -50,7 +31,7 @@ def test_dataset_loader_random_paths(base_test_env):
|
|||
dummy_data = {
|
||||
"sequence": torch.randint(0, 1000, (seq_length,), dtype=torch.int64),
|
||||
}
|
||||
dataset_path = create_mmap_dataset(test_dir, dummy_data, f"test_data_{i}")
|
||||
dataset_path = create_h5_dataset(test_dir, dummy_data, f"test_data_{i}")
|
||||
|
||||
# Test loading with multiple paths
|
||||
loaded_dataset = DatasetLoader.load(
|
||||
|
|
@ -84,7 +65,7 @@ def test_dpo_strategy_with_random_data(base_test_env):
|
|||
"rejected_mask": torch.ones(seq_length, dtype=torch.bool)
|
||||
}
|
||||
|
||||
dataset_path = create_mmap_dataset(test_dir, dummy_data, "dpo_data")
|
||||
dataset_path = create_h5_dataset(test_dir, dummy_data, "dpo_data")
|
||||
|
||||
# Load DPO dataset
|
||||
dpo_dataset = DatasetLoader.load(
|
||||
|
|
@ -120,7 +101,7 @@ def test_sft_dataset_with_random_data(base_test_env):
|
|||
"loss_mask": torch.ones(seq_length, dtype=torch.bool)
|
||||
}
|
||||
|
||||
dataset_path = create_mmap_dataset(test_dir, dummy_data, "sft_data")
|
||||
dataset_path = create_h5_dataset(test_dir, dummy_data, "sft_data")
|
||||
|
||||
# Load SFT dataset
|
||||
sft_dataset = DatasetLoader.load(
|
||||
|
|
@ -153,7 +134,7 @@ def test_dataset_with_custom_stride(base_test_env):
|
|||
"sequence": torch.randint(0, 1000, (seq_length,), dtype=torch.int64),
|
||||
}
|
||||
|
||||
dataset_path = create_mmap_dataset(test_dir, dummy_data, "stride_test_data")
|
||||
dataset_path = create_h5_dataset(test_dir, dummy_data, "stride_test_data")
|
||||
|
||||
# Test with custom stride
|
||||
custom_stride = 32
|
||||
|
|
@ -176,125 +157,3 @@ def test_dataset_with_custom_stride(base_test_env):
|
|||
)
|
||||
|
||||
assert len(dataset) > len(default_stride_dataset)
|
||||
|
||||
|
||||
def test_multi_segment_fetcher(base_test_env):
|
||||
"""Test MultiSegmentFetcher functionality directly"""
|
||||
test_dir = base_test_env["test_dir"]
|
||||
|
||||
# Create test data with multiple segments
|
||||
seq_length = 100
|
||||
dummy_data = {
|
||||
"sequence": torch.randint(0, 1000, (seq_length,), dtype=torch.int64),
|
||||
"mask": torch.ones(seq_length, dtype=torch.bool)
|
||||
}
|
||||
|
||||
dataset_path = create_mmap_dataset(test_dir, dummy_data, "multi_segment_test")
|
||||
|
||||
# Load the memory mapped files directly
|
||||
multi_segments, _ = MmapFileHandler.load(dataset_path)
|
||||
|
||||
# Create fetcher
|
||||
fetcher = MultiSegmentFetcher(multi_segments)
|
||||
|
||||
# Test fetching single key
|
||||
sequence_data = fetcher.key_fetch(0, 10, "sequence")
|
||||
assert sequence_data is not None
|
||||
assert len(sequence_data) == 10
|
||||
|
||||
# Test fetching multiple keys
|
||||
multi_data = fetcher.key_fetch(0, 10, ["sequence", "mask"])
|
||||
assert "sequence" in multi_data
|
||||
assert "mask" in multi_data
|
||||
assert len(multi_data["sequence"]) == 10
|
||||
assert len(multi_data["mask"]) == 10
|
||||
|
||||
# Test fetching all keys
|
||||
all_data = fetcher.fetch_data(0, 10)
|
||||
assert "sequence" in all_data
|
||||
assert "mask" in all_data
|
||||
|
||||
|
||||
def test_mmap_file_handler_direct(base_test_env):
|
||||
"""Test MmapFileHandler directly without DatasetLoader"""
|
||||
test_dir = base_test_env["test_dir"]
|
||||
|
||||
# Create test data with multiple segments
|
||||
seq_length1 = 100
|
||||
seq_length2 = 200
|
||||
|
||||
# Create data in the format expected by MmapFileHandler
|
||||
dummy_data = {
|
||||
"sequence": [
|
||||
torch.randint(0, 1000, (seq_length1,), dtype=torch.int64),
|
||||
torch.randint(0, 1000, (seq_length2,), dtype=torch.int64)
|
||||
],
|
||||
"mask": [
|
||||
torch.ones(seq_length1, dtype=torch.bool),
|
||||
torch.ones(seq_length2, dtype=torch.bool)
|
||||
]
|
||||
}
|
||||
|
||||
# Save data using MmapFileHandler
|
||||
dataset_dir = os.path.join(test_dir, "mmap_direct_test")
|
||||
MmapFileHandler.save(dataset_dir, dummy_data)
|
||||
|
||||
# Load data using MmapFileHandler
|
||||
loaded_data, num_samples = MmapFileHandler.load(dataset_dir)
|
||||
|
||||
# Verify data structure
|
||||
assert set(loaded_data.keys()) == set(dummy_data.keys())
|
||||
assert num_samples == seq_length1 + seq_length2 # 300
|
||||
|
||||
# Verify data content
|
||||
for key in dummy_data:
|
||||
assert len(loaded_data[key]) == len(dummy_data[key])
|
||||
for i in range(len(dummy_data[key])):
|
||||
assert torch.equal(loaded_data[key][i], dummy_data[key][i])
|
||||
|
||||
def test_mmap_file_handler_dtypes(base_test_env):
|
||||
"""Test MmapFileHandler with different data types"""
|
||||
test_dir = base_test_env["test_dir"]
|
||||
|
||||
# Create test data with different dtypes
|
||||
data = {
|
||||
"float32": [torch.randn(100, dtype=torch.float32)],
|
||||
"float64": [torch.randn(100, dtype=torch.float64)],
|
||||
"int32": [torch.randint(0, 1000, (100,), dtype=torch.int32)],
|
||||
"int64": [torch.randint(0, 1000, (100,), dtype=torch.int64)],
|
||||
"bool": [torch.randint(0, 2, (100,), dtype=torch.bool)]
|
||||
}
|
||||
|
||||
# Save data
|
||||
dataset_dir = os.path.join(test_dir, "dtype_test")
|
||||
MmapFileHandler.save(dataset_dir, data)
|
||||
|
||||
# Load data
|
||||
loaded_data, _ = MmapFileHandler.load(dataset_dir)
|
||||
|
||||
# Verify data types
|
||||
for key in data:
|
||||
assert loaded_data[key][0].dtype == data[key][0].dtype
|
||||
assert torch.equal(loaded_data[key][0], data[key][0])
|
||||
|
||||
def test_mmap_file_handler_error_handling(base_test_env):
|
||||
"""Test MmapFileHandler error handling"""
|
||||
test_dir = base_test_env["test_dir"]
|
||||
|
||||
# Test loading without file_mapper.json
|
||||
empty_dir = os.path.join(test_dir, "empty_dir")
|
||||
os.makedirs(empty_dir, exist_ok=True)
|
||||
with pytest.raises(FileNotFoundError):
|
||||
MmapFileHandler.load(empty_dir)
|
||||
|
||||
# Test loading with invalid file_mapper.json
|
||||
invalid_dir = os.path.join(test_dir, "invalid_dir")
|
||||
os.makedirs(invalid_dir, exist_ok=True)
|
||||
|
||||
# Create empty file_mapper.json
|
||||
with open(os.path.join(invalid_dir, "file_mapper.json"), "w") as f:
|
||||
json.dump([{"file_name": "file1.bin", "size": 1000, "dtype": "float32", "key": "key1"}], f)
|
||||
|
||||
# This should raise FileNotFoundError because no binary files exist
|
||||
with pytest.raises(FileNotFoundError):
|
||||
MmapFileHandler.load(invalid_dir)
|
||||
|
|
|
|||
Loading…
Reference in New Issue