refactor(data): 修改文件加载方案

This commit is contained in:
ViperEkura 2026-02-22 21:14:10 +08:00
parent 0ca4871e80
commit 582d4ae9a7
4 changed files with 67 additions and 246 deletions

View File

@ -1,18 +1,17 @@
import h5py
import torch import torch
import bisect import bisect
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from torch import Tensor from torch import Tensor
from torch.utils.data import Dataset from torch.utils.data import Dataset
from khaosz.data.mmap import MmapFileHandler from khaosz.data.file import load_h5
from typing import Callable, List, Dict, Literal, Optional, Union from typing import Callable, List, Dict, Literal, Optional, Union
Seg = List[Tensor]
MultiSeg = Dict[str, Seg]
class BaseSegmentFetcher: class BaseSegmentFetcher:
def __init__(self, segments: Seg): def __init__(self, segments: List[Tensor]):
self.segments = segments self.segments = segments
self.cum_lengths = [] self.cum_lengths = []
total = 0 total = 0
@ -37,20 +36,21 @@ class BaseSegmentFetcher:
prev_cum = self.cum_lengths[i - 1] if i > 0 else 0 prev_cum = self.cum_lengths[i - 1] if i > 0 else 0
start = max(begin_idx - prev_cum, 0) start = max(begin_idx - prev_cum, 0)
end = min(end_idx - prev_cum, len(self.segments[i])) end = min(end_idx - prev_cum, len(self.segments[i]))
result_segments.append(self.segments[i][start:end]) data = self.segments[i][start:end]
result_segments.append(data)
return torch.cat(result_segments, dim=0) return torch.cat(result_segments, dim=0)
class MultiSegmentFetcher: class MultiSegmentFetcher:
def __init__(self, muti_segments: MultiSeg): def __init__(self, muti_segments: Dict):
self.muti_keys = list(muti_segments.keys()) self.muti_keys = list(muti_segments.keys())
self.muti_fetchers = { self.muti_fetchers = {
key: BaseSegmentFetcher(segments) key: BaseSegmentFetcher(segments)
for key, segments in muti_segments.items() for key, segments in muti_segments.items()
} }
def key_fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]) -> Union[Tensor, Seg]: def key_fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]) -> Dict:
fetch_dict = {} fetch_dict = {}
keys = [keys] if isinstance(keys, str) else keys keys = [keys] if isinstance(keys, str) else keys
@ -61,20 +61,20 @@ class MultiSegmentFetcher:
return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]] return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]]
def fetch_data(self, begin_idx: int, end_idx: int) -> Union[Tensor, Seg]: def fetch_data(self, begin_idx: int, end_idx: int) -> Dict:
return self.key_fetch(begin_idx, end_idx, self.muti_keys) return self.key_fetch(begin_idx, end_idx, self.muti_keys)
class BaseDataset(Dataset, ABC): class BaseDataset(Dataset, ABC):
def __init__(self, window_size: int, stride: int): def __init__(self, window_size: int, stride: int):
super().__init__() super().__init__()
self.segments: MultiSeg = {} self.segments = {}
self.window_size = window_size self.window_size = window_size
self.stride = stride self.stride = stride
self.total_samples = None self.total_samples = None
def load(self, load_path: str): def load(self, load_path: str):
self.segments, self.total_samples = MmapFileHandler.load(load_path) self.segments, self.total_samples = load_h5(load_path)
self.fetcher = MultiSegmentFetcher(self.segments) self.fetcher = MultiSegmentFetcher(self.segments)
def get_index(self, index: int) -> int: def get_index(self, index: int) -> int:

44
khaosz/data/file.py Normal file
View File

@ -0,0 +1,44 @@
import os
import h5py
import numpy as np
import torch
from torch import Tensor
from typing import Dict, List, Tuple
def save_h5(file_path: str, tensor_group: Dict[str, List[Tensor]]):
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with h5py.File(file_path, 'w') as f:
for key, tensors in tensor_group.items():
grp = f.create_group(key)
grp.attrs['num_tensors'] = len(tensors)
for idx, tensor in enumerate(tensors):
arr = tensor.cpu().numpy()
dset = grp.create_dataset(
f'data_{idx}',
data=arr,
compression='gzip',
compression_opts=4,
shuffle=True
)
dset.attrs['numel'] = tensor.numel()
def load_h5(file_path: str) -> Tuple[Dict[str, List[Tensor]], int]:
tensor_group: Dict[str, List[Tensor]] = {}
total_samples = 0
with h5py.File(file_path, 'r') as f:
for key in f.keys():
grp = f[key]
dsets = []
for dset_name in grp.keys():
dset = grp[dset_name]
dsets.append(torch.from_numpy(dset[:]).share_memory_())
total_samples += dset.attrs.get('numel', np.prod(dset.shape))
tensor_group[key] = dsets
num_keys = max(len(tensor_group), 1)
sample_per_key = total_samples // num_keys
return tensor_group, sample_per_key

View File

@ -1,82 +0,0 @@
import os
import json
import torch
from torch import Tensor
from typing import List, Dict, Tuple
class MmapFileHandler:
"""
json metadata like this:
```
[
{"file_name": "file1.pt", "size": 1000, "key": "key1"},
{"file_name": "file2.pt", "size": 2000, "key": "key2"}
...
]
```
files like:
```
folder_path:
- metadata.json
- file1.pt
- file2.pt
...
```
"""
META_DATA = "metadata.json"
@staticmethod
def load(root_path: str, shared: bool=True) -> Tuple[Dict[str, List[Tensor]], int]:
metadata_list = []
tensor_group: Dict[str, List[Tensor]] = {}
file_mapper_path = os.path.join(root_path, MmapFileHandler.META_DATA)
if not os.path.exists(file_mapper_path):
raise FileNotFoundError(f"File mapper not found: {file_mapper_path}")
with open(file_mapper_path, "r") as f:
metadata_list = json.load(f)
for metadata in metadata_list:
file_key = metadata["key"]
file_name = metadata["file_name"]
file_path = os.path.join(root_path, file_name)
elm = torch.load(file_path, map_location="cpu", mmap=shared)
if file_key not in tensor_group:
tensor_group[file_key] = []
tensor_group[file_key].append(elm)
num_samples = sum(metadata["size"] for metadata in metadata_list)
num_keys = max(len(set(metadata['key'] for metadata in metadata_list)), 1)
sample_per_key = num_samples // num_keys
return tensor_group, sample_per_key
@staticmethod
def save(save_path: str, mmap_shared_group: Dict[str, List[Tensor]]) -> None:
os.makedirs(save_path, exist_ok=True)
metadata_list = []
for segment_key, segment_tensors in mmap_shared_group.items():
for idx, tensor in enumerate(segment_tensors):
try:
with open(os.path.join(save_path, f"{segment_key}_{idx}.pt"), "wb") as f:
torch.save(tensor.contiguous().cpu(), f)
except Exception as e:
raise RuntimeError(f"Error saving tensor: {e}")
metadata_list.append({
"file_name": f"{segment_key}_{idx}.pt",
"size": tensor.numel(),
"key": segment_key
})
metadata_path = os.path.join(save_path, MmapFileHandler.META_DATA)
with open(metadata_path, "w") as f:
json.dump(metadata_list, f)

View File

@ -1,41 +1,22 @@
import os import os
import json
import pytest
import torch import torch
import numpy as np import numpy as np
from khaosz.trainer import * from khaosz.data.file import save_h5
from khaosz.data.dataset import * from khaosz.data.dataset import *
def create_mmap_dataset(dir_path, data_dict, dataset_name): def create_h5_dataset(dir_path, data_dict, dataset_name):
"""Helper function to create memory-mapped dataset for testing""" """Helper function to create HDF5 dataset for testing"""
dataset_dir = os.path.join(dir_path, dataset_name) dataset_path = os.path.join(dir_path, f"{dataset_name}.h5")
os.makedirs(dataset_dir, exist_ok=True)
file_mapper = [] # Convert data_dict to the format expected by save_h5
# save_h5 expects a list of tensors for each key
tensor_group = {key: [tensor] for key, tensor in data_dict.items()}
for key, tensor in data_dict.items(): save_h5(dataset_path, tensor_group)
# Convert tensor to numpy array and save as binary file
file_name = f"{key}.pt"
file_path = os.path.join(dataset_dir, file_name)
with open(file_path, "wb") as f:
torch.save(tensor, f)
# Add to file mapper
file_mapper.append({
"file_name": file_name,
"size": tensor.numel(),
"key": key
})
# Save file mapper return dataset_path
mapper_path = os.path.join(dataset_dir, "metadata.json")
with open(mapper_path, "w") as f:
json.dump(file_mapper, f, indent=2)
return dataset_dir
def test_dataset_loader_random_paths(base_test_env): def test_dataset_loader_random_paths(base_test_env):
@ -50,7 +31,7 @@ def test_dataset_loader_random_paths(base_test_env):
dummy_data = { dummy_data = {
"sequence": torch.randint(0, 1000, (seq_length,), dtype=torch.int64), "sequence": torch.randint(0, 1000, (seq_length,), dtype=torch.int64),
} }
dataset_path = create_mmap_dataset(test_dir, dummy_data, f"test_data_{i}") dataset_path = create_h5_dataset(test_dir, dummy_data, f"test_data_{i}")
# Test loading with multiple paths # Test loading with multiple paths
loaded_dataset = DatasetLoader.load( loaded_dataset = DatasetLoader.load(
@ -84,7 +65,7 @@ def test_dpo_strategy_with_random_data(base_test_env):
"rejected_mask": torch.ones(seq_length, dtype=torch.bool) "rejected_mask": torch.ones(seq_length, dtype=torch.bool)
} }
dataset_path = create_mmap_dataset(test_dir, dummy_data, "dpo_data") dataset_path = create_h5_dataset(test_dir, dummy_data, "dpo_data")
# Load DPO dataset # Load DPO dataset
dpo_dataset = DatasetLoader.load( dpo_dataset = DatasetLoader.load(
@ -120,7 +101,7 @@ def test_sft_dataset_with_random_data(base_test_env):
"loss_mask": torch.ones(seq_length, dtype=torch.bool) "loss_mask": torch.ones(seq_length, dtype=torch.bool)
} }
dataset_path = create_mmap_dataset(test_dir, dummy_data, "sft_data") dataset_path = create_h5_dataset(test_dir, dummy_data, "sft_data")
# Load SFT dataset # Load SFT dataset
sft_dataset = DatasetLoader.load( sft_dataset = DatasetLoader.load(
@ -153,7 +134,7 @@ def test_dataset_with_custom_stride(base_test_env):
"sequence": torch.randint(0, 1000, (seq_length,), dtype=torch.int64), "sequence": torch.randint(0, 1000, (seq_length,), dtype=torch.int64),
} }
dataset_path = create_mmap_dataset(test_dir, dummy_data, "stride_test_data") dataset_path = create_h5_dataset(test_dir, dummy_data, "stride_test_data")
# Test with custom stride # Test with custom stride
custom_stride = 32 custom_stride = 32
@ -176,125 +157,3 @@ def test_dataset_with_custom_stride(base_test_env):
) )
assert len(dataset) > len(default_stride_dataset) assert len(dataset) > len(default_stride_dataset)
def test_multi_segment_fetcher(base_test_env):
"""Test MultiSegmentFetcher functionality directly"""
test_dir = base_test_env["test_dir"]
# Create test data with multiple segments
seq_length = 100
dummy_data = {
"sequence": torch.randint(0, 1000, (seq_length,), dtype=torch.int64),
"mask": torch.ones(seq_length, dtype=torch.bool)
}
dataset_path = create_mmap_dataset(test_dir, dummy_data, "multi_segment_test")
# Load the memory mapped files directly
multi_segments, _ = MmapFileHandler.load(dataset_path)
# Create fetcher
fetcher = MultiSegmentFetcher(multi_segments)
# Test fetching single key
sequence_data = fetcher.key_fetch(0, 10, "sequence")
assert sequence_data is not None
assert len(sequence_data) == 10
# Test fetching multiple keys
multi_data = fetcher.key_fetch(0, 10, ["sequence", "mask"])
assert "sequence" in multi_data
assert "mask" in multi_data
assert len(multi_data["sequence"]) == 10
assert len(multi_data["mask"]) == 10
# Test fetching all keys
all_data = fetcher.fetch_data(0, 10)
assert "sequence" in all_data
assert "mask" in all_data
def test_mmap_file_handler_direct(base_test_env):
"""Test MmapFileHandler directly without DatasetLoader"""
test_dir = base_test_env["test_dir"]
# Create test data with multiple segments
seq_length1 = 100
seq_length2 = 200
# Create data in the format expected by MmapFileHandler
dummy_data = {
"sequence": [
torch.randint(0, 1000, (seq_length1,), dtype=torch.int64),
torch.randint(0, 1000, (seq_length2,), dtype=torch.int64)
],
"mask": [
torch.ones(seq_length1, dtype=torch.bool),
torch.ones(seq_length2, dtype=torch.bool)
]
}
# Save data using MmapFileHandler
dataset_dir = os.path.join(test_dir, "mmap_direct_test")
MmapFileHandler.save(dataset_dir, dummy_data)
# Load data using MmapFileHandler
loaded_data, num_samples = MmapFileHandler.load(dataset_dir)
# Verify data structure
assert set(loaded_data.keys()) == set(dummy_data.keys())
assert num_samples == seq_length1 + seq_length2 # 300
# Verify data content
for key in dummy_data:
assert len(loaded_data[key]) == len(dummy_data[key])
for i in range(len(dummy_data[key])):
assert torch.equal(loaded_data[key][i], dummy_data[key][i])
def test_mmap_file_handler_dtypes(base_test_env):
"""Test MmapFileHandler with different data types"""
test_dir = base_test_env["test_dir"]
# Create test data with different dtypes
data = {
"float32": [torch.randn(100, dtype=torch.float32)],
"float64": [torch.randn(100, dtype=torch.float64)],
"int32": [torch.randint(0, 1000, (100,), dtype=torch.int32)],
"int64": [torch.randint(0, 1000, (100,), dtype=torch.int64)],
"bool": [torch.randint(0, 2, (100,), dtype=torch.bool)]
}
# Save data
dataset_dir = os.path.join(test_dir, "dtype_test")
MmapFileHandler.save(dataset_dir, data)
# Load data
loaded_data, _ = MmapFileHandler.load(dataset_dir)
# Verify data types
for key in data:
assert loaded_data[key][0].dtype == data[key][0].dtype
assert torch.equal(loaded_data[key][0], data[key][0])
def test_mmap_file_handler_error_handling(base_test_env):
"""Test MmapFileHandler error handling"""
test_dir = base_test_env["test_dir"]
# Test loading without file_mapper.json
empty_dir = os.path.join(test_dir, "empty_dir")
os.makedirs(empty_dir, exist_ok=True)
with pytest.raises(FileNotFoundError):
MmapFileHandler.load(empty_dir)
# Test loading with invalid file_mapper.json
invalid_dir = os.path.join(test_dir, "invalid_dir")
os.makedirs(invalid_dir, exist_ok=True)
# Create empty file_mapper.json
with open(os.path.join(invalid_dir, "file_mapper.json"), "w") as f:
json.dump([{"file_name": "file1.bin", "size": 1000, "dtype": "float32", "key": "key1"}], f)
# This should raise FileNotFoundError because no binary files exist
with pytest.raises(FileNotFoundError):
MmapFileHandler.load(invalid_dir)