refactor(data): 修改文件加载方案

This commit is contained in:
ViperEkura 2026-02-22 21:14:10 +08:00
parent 0ca4871e80
commit 582d4ae9a7
4 changed files with 67 additions and 246 deletions

View File

@ -1,18 +1,17 @@
import h5py
import torch
import bisect
from abc import ABC, abstractmethod
from torch import Tensor
from torch.utils.data import Dataset
from khaosz.data.mmap import MmapFileHandler
from khaosz.data.file import load_h5
from typing import Callable, List, Dict, Literal, Optional, Union
Seg = List[Tensor]
MultiSeg = Dict[str, Seg]
class BaseSegmentFetcher:
def __init__(self, segments: Seg):
def __init__(self, segments: List[Tensor]):
self.segments = segments
self.cum_lengths = []
total = 0
@ -37,20 +36,21 @@ class BaseSegmentFetcher:
prev_cum = self.cum_lengths[i - 1] if i > 0 else 0
start = max(begin_idx - prev_cum, 0)
end = min(end_idx - prev_cum, len(self.segments[i]))
result_segments.append(self.segments[i][start:end])
data = self.segments[i][start:end]
result_segments.append(data)
return torch.cat(result_segments, dim=0)
class MultiSegmentFetcher:
def __init__(self, muti_segments: MultiSeg):
def __init__(self, muti_segments: Dict):
self.muti_keys = list(muti_segments.keys())
self.muti_fetchers = {
key: BaseSegmentFetcher(segments)
for key, segments in muti_segments.items()
}
def key_fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]) -> Union[Tensor, Seg]:
def key_fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]) -> Dict:
fetch_dict = {}
keys = [keys] if isinstance(keys, str) else keys
@ -61,20 +61,20 @@ class MultiSegmentFetcher:
return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]]
def fetch_data(self, begin_idx: int, end_idx: int) -> Union[Tensor, Seg]:
def fetch_data(self, begin_idx: int, end_idx: int) -> Dict:
return self.key_fetch(begin_idx, end_idx, self.muti_keys)
class BaseDataset(Dataset, ABC):
def __init__(self, window_size: int, stride: int):
super().__init__()
self.segments: MultiSeg = {}
self.segments = {}
self.window_size = window_size
self.stride = stride
self.total_samples = None
def load(self, load_path: str):
self.segments, self.total_samples = MmapFileHandler.load(load_path)
self.segments, self.total_samples = load_h5(load_path)
self.fetcher = MultiSegmentFetcher(self.segments)
def get_index(self, index: int) -> int:

44
khaosz/data/file.py Normal file
View File

@ -0,0 +1,44 @@
import os
import h5py
import numpy as np
import torch
from torch import Tensor
from typing import Dict, List, Tuple
def save_h5(file_path: str, tensor_group: Dict[str, List[Tensor]]):
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with h5py.File(file_path, 'w') as f:
for key, tensors in tensor_group.items():
grp = f.create_group(key)
grp.attrs['num_tensors'] = len(tensors)
for idx, tensor in enumerate(tensors):
arr = tensor.cpu().numpy()
dset = grp.create_dataset(
f'data_{idx}',
data=arr,
compression='gzip',
compression_opts=4,
shuffle=True
)
dset.attrs['numel'] = tensor.numel()
def load_h5(file_path: str) -> Tuple[Dict[str, List[Tensor]], int]:
tensor_group: Dict[str, List[Tensor]] = {}
total_samples = 0
with h5py.File(file_path, 'r') as f:
for key in f.keys():
grp = f[key]
dsets = []
for dset_name in grp.keys():
dset = grp[dset_name]
dsets.append(torch.from_numpy(dset[:]).share_memory_())
total_samples += dset.attrs.get('numel', np.prod(dset.shape))
tensor_group[key] = dsets
num_keys = max(len(tensor_group), 1)
sample_per_key = total_samples // num_keys
return tensor_group, sample_per_key

View File

@ -1,82 +0,0 @@
import os
import json
import torch
from torch import Tensor
from typing import List, Dict, Tuple
class MmapFileHandler:
"""
json metadata like this:
```
[
{"file_name": "file1.pt", "size": 1000, "key": "key1"},
{"file_name": "file2.pt", "size": 2000, "key": "key2"}
...
]
```
files like:
```
folder_path:
- metadata.json
- file1.pt
- file2.pt
...
```
"""
META_DATA = "metadata.json"
@staticmethod
def load(root_path: str, shared: bool=True) -> Tuple[Dict[str, List[Tensor]], int]:
metadata_list = []
tensor_group: Dict[str, List[Tensor]] = {}
file_mapper_path = os.path.join(root_path, MmapFileHandler.META_DATA)
if not os.path.exists(file_mapper_path):
raise FileNotFoundError(f"File mapper not found: {file_mapper_path}")
with open(file_mapper_path, "r") as f:
metadata_list = json.load(f)
for metadata in metadata_list:
file_key = metadata["key"]
file_name = metadata["file_name"]
file_path = os.path.join(root_path, file_name)
elm = torch.load(file_path, map_location="cpu", mmap=shared)
if file_key not in tensor_group:
tensor_group[file_key] = []
tensor_group[file_key].append(elm)
num_samples = sum(metadata["size"] for metadata in metadata_list)
num_keys = max(len(set(metadata['key'] for metadata in metadata_list)), 1)
sample_per_key = num_samples // num_keys
return tensor_group, sample_per_key
@staticmethod
def save(save_path: str, mmap_shared_group: Dict[str, List[Tensor]]) -> None:
os.makedirs(save_path, exist_ok=True)
metadata_list = []
for segment_key, segment_tensors in mmap_shared_group.items():
for idx, tensor in enumerate(segment_tensors):
try:
with open(os.path.join(save_path, f"{segment_key}_{idx}.pt"), "wb") as f:
torch.save(tensor.contiguous().cpu(), f)
except Exception as e:
raise RuntimeError(f"Error saving tensor: {e}")
metadata_list.append({
"file_name": f"{segment_key}_{idx}.pt",
"size": tensor.numel(),
"key": segment_key
})
metadata_path = os.path.join(save_path, MmapFileHandler.META_DATA)
with open(metadata_path, "w") as f:
json.dump(metadata_list, f)

View File

@ -1,41 +1,22 @@
import os
import json
import pytest
import torch
import numpy as np
from khaosz.trainer import *
from khaosz.data.file import save_h5
from khaosz.data.dataset import *
def create_mmap_dataset(dir_path, data_dict, dataset_name):
"""Helper function to create memory-mapped dataset for testing"""
dataset_dir = os.path.join(dir_path, dataset_name)
os.makedirs(dataset_dir, exist_ok=True)
def create_h5_dataset(dir_path, data_dict, dataset_name):
"""Helper function to create HDF5 dataset for testing"""
dataset_path = os.path.join(dir_path, f"{dataset_name}.h5")
file_mapper = []
# Convert data_dict to the format expected by save_h5
# save_h5 expects a list of tensors for each key
tensor_group = {key: [tensor] for key, tensor in data_dict.items()}
for key, tensor in data_dict.items():
# Convert tensor to numpy array and save as binary file
file_name = f"{key}.pt"
file_path = os.path.join(dataset_dir, file_name)
save_h5(dataset_path, tensor_group)
with open(file_path, "wb") as f:
torch.save(tensor, f)
# Add to file mapper
file_mapper.append({
"file_name": file_name,
"size": tensor.numel(),
"key": key
})
# Save file mapper
mapper_path = os.path.join(dataset_dir, "metadata.json")
with open(mapper_path, "w") as f:
json.dump(file_mapper, f, indent=2)
return dataset_dir
return dataset_path
def test_dataset_loader_random_paths(base_test_env):
@ -50,7 +31,7 @@ def test_dataset_loader_random_paths(base_test_env):
dummy_data = {
"sequence": torch.randint(0, 1000, (seq_length,), dtype=torch.int64),
}
dataset_path = create_mmap_dataset(test_dir, dummy_data, f"test_data_{i}")
dataset_path = create_h5_dataset(test_dir, dummy_data, f"test_data_{i}")
# Test loading with multiple paths
loaded_dataset = DatasetLoader.load(
@ -84,7 +65,7 @@ def test_dpo_strategy_with_random_data(base_test_env):
"rejected_mask": torch.ones(seq_length, dtype=torch.bool)
}
dataset_path = create_mmap_dataset(test_dir, dummy_data, "dpo_data")
dataset_path = create_h5_dataset(test_dir, dummy_data, "dpo_data")
# Load DPO dataset
dpo_dataset = DatasetLoader.load(
@ -120,7 +101,7 @@ def test_sft_dataset_with_random_data(base_test_env):
"loss_mask": torch.ones(seq_length, dtype=torch.bool)
}
dataset_path = create_mmap_dataset(test_dir, dummy_data, "sft_data")
dataset_path = create_h5_dataset(test_dir, dummy_data, "sft_data")
# Load SFT dataset
sft_dataset = DatasetLoader.load(
@ -153,7 +134,7 @@ def test_dataset_with_custom_stride(base_test_env):
"sequence": torch.randint(0, 1000, (seq_length,), dtype=torch.int64),
}
dataset_path = create_mmap_dataset(test_dir, dummy_data, "stride_test_data")
dataset_path = create_h5_dataset(test_dir, dummy_data, "stride_test_data")
# Test with custom stride
custom_stride = 32
@ -176,125 +157,3 @@ def test_dataset_with_custom_stride(base_test_env):
)
assert len(dataset) > len(default_stride_dataset)
def test_multi_segment_fetcher(base_test_env):
"""Test MultiSegmentFetcher functionality directly"""
test_dir = base_test_env["test_dir"]
# Create test data with multiple segments
seq_length = 100
dummy_data = {
"sequence": torch.randint(0, 1000, (seq_length,), dtype=torch.int64),
"mask": torch.ones(seq_length, dtype=torch.bool)
}
dataset_path = create_mmap_dataset(test_dir, dummy_data, "multi_segment_test")
# Load the memory mapped files directly
multi_segments, _ = MmapFileHandler.load(dataset_path)
# Create fetcher
fetcher = MultiSegmentFetcher(multi_segments)
# Test fetching single key
sequence_data = fetcher.key_fetch(0, 10, "sequence")
assert sequence_data is not None
assert len(sequence_data) == 10
# Test fetching multiple keys
multi_data = fetcher.key_fetch(0, 10, ["sequence", "mask"])
assert "sequence" in multi_data
assert "mask" in multi_data
assert len(multi_data["sequence"]) == 10
assert len(multi_data["mask"]) == 10
# Test fetching all keys
all_data = fetcher.fetch_data(0, 10)
assert "sequence" in all_data
assert "mask" in all_data
def test_mmap_file_handler_direct(base_test_env):
"""Test MmapFileHandler directly without DatasetLoader"""
test_dir = base_test_env["test_dir"]
# Create test data with multiple segments
seq_length1 = 100
seq_length2 = 200
# Create data in the format expected by MmapFileHandler
dummy_data = {
"sequence": [
torch.randint(0, 1000, (seq_length1,), dtype=torch.int64),
torch.randint(0, 1000, (seq_length2,), dtype=torch.int64)
],
"mask": [
torch.ones(seq_length1, dtype=torch.bool),
torch.ones(seq_length2, dtype=torch.bool)
]
}
# Save data using MmapFileHandler
dataset_dir = os.path.join(test_dir, "mmap_direct_test")
MmapFileHandler.save(dataset_dir, dummy_data)
# Load data using MmapFileHandler
loaded_data, num_samples = MmapFileHandler.load(dataset_dir)
# Verify data structure
assert set(loaded_data.keys()) == set(dummy_data.keys())
assert num_samples == seq_length1 + seq_length2 # 300
# Verify data content
for key in dummy_data:
assert len(loaded_data[key]) == len(dummy_data[key])
for i in range(len(dummy_data[key])):
assert torch.equal(loaded_data[key][i], dummy_data[key][i])
def test_mmap_file_handler_dtypes(base_test_env):
"""Test MmapFileHandler with different data types"""
test_dir = base_test_env["test_dir"]
# Create test data with different dtypes
data = {
"float32": [torch.randn(100, dtype=torch.float32)],
"float64": [torch.randn(100, dtype=torch.float64)],
"int32": [torch.randint(0, 1000, (100,), dtype=torch.int32)],
"int64": [torch.randint(0, 1000, (100,), dtype=torch.int64)],
"bool": [torch.randint(0, 2, (100,), dtype=torch.bool)]
}
# Save data
dataset_dir = os.path.join(test_dir, "dtype_test")
MmapFileHandler.save(dataset_dir, data)
# Load data
loaded_data, _ = MmapFileHandler.load(dataset_dir)
# Verify data types
for key in data:
assert loaded_data[key][0].dtype == data[key][0].dtype
assert torch.equal(loaded_data[key][0], data[key][0])
def test_mmap_file_handler_error_handling(base_test_env):
"""Test MmapFileHandler error handling"""
test_dir = base_test_env["test_dir"]
# Test loading without file_mapper.json
empty_dir = os.path.join(test_dir, "empty_dir")
os.makedirs(empty_dir, exist_ok=True)
with pytest.raises(FileNotFoundError):
MmapFileHandler.load(empty_dir)
# Test loading with invalid file_mapper.json
invalid_dir = os.path.join(test_dir, "invalid_dir")
os.makedirs(invalid_dir, exist_ok=True)
# Create empty file_mapper.json
with open(os.path.join(invalid_dir, "file_mapper.json"), "w") as f:
json.dump([{"file_name": "file1.bin", "size": 1000, "dtype": "float32", "key": "key1"}], f)
# This should raise FileNotFoundError because no binary files exist
with pytest.raises(FileNotFoundError):
MmapFileHandler.load(invalid_dir)