From 582d4ae9a7a841a0596842b3394b7ded28cb351b Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sun, 22 Feb 2026 21:14:10 +0800 Subject: [PATCH] =?UTF-8?q?refactor(data):=20=E4=BF=AE=E6=94=B9=E6=96=87?= =?UTF-8?q?=E4=BB=B6=E5=8A=A0=E8=BD=BD=E6=96=B9=E6=A1=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/data/dataset.py | 20 ++--- khaosz/data/file.py | 44 ++++++++++ khaosz/data/mmap.py | 82 ------------------ tests/data/test_dataset.py | 167 +++---------------------------------- 4 files changed, 67 insertions(+), 246 deletions(-) create mode 100644 khaosz/data/file.py delete mode 100644 khaosz/data/mmap.py diff --git a/khaosz/data/dataset.py b/khaosz/data/dataset.py index ad253ac..dae7fe0 100644 --- a/khaosz/data/dataset.py +++ b/khaosz/data/dataset.py @@ -1,18 +1,17 @@ +import h5py import torch import bisect from abc import ABC, abstractmethod from torch import Tensor from torch.utils.data import Dataset -from khaosz.data.mmap import MmapFileHandler +from khaosz.data.file import load_h5 from typing import Callable, List, Dict, Literal, Optional, Union -Seg = List[Tensor] -MultiSeg = Dict[str, Seg] class BaseSegmentFetcher: - def __init__(self, segments: Seg): + def __init__(self, segments: List[Tensor]): self.segments = segments self.cum_lengths = [] total = 0 @@ -37,20 +36,21 @@ class BaseSegmentFetcher: prev_cum = self.cum_lengths[i - 1] if i > 0 else 0 start = max(begin_idx - prev_cum, 0) end = min(end_idx - prev_cum, len(self.segments[i])) - result_segments.append(self.segments[i][start:end]) + data = self.segments[i][start:end] + result_segments.append(data) return torch.cat(result_segments, dim=0) class MultiSegmentFetcher: - def __init__(self, muti_segments: MultiSeg): + def __init__(self, muti_segments: Dict): self.muti_keys = list(muti_segments.keys()) self.muti_fetchers = { key: BaseSegmentFetcher(segments) for key, segments in muti_segments.items() } - def key_fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]) -> Union[Tensor, Seg]: + def key_fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]) -> Dict: fetch_dict = {} keys = [keys] if isinstance(keys, str) else keys @@ -61,20 +61,20 @@ class MultiSegmentFetcher: return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]] - def fetch_data(self, begin_idx: int, end_idx: int) -> Union[Tensor, Seg]: + def fetch_data(self, begin_idx: int, end_idx: int) -> Dict: return self.key_fetch(begin_idx, end_idx, self.muti_keys) class BaseDataset(Dataset, ABC): def __init__(self, window_size: int, stride: int): super().__init__() - self.segments: MultiSeg = {} + self.segments = {} self.window_size = window_size self.stride = stride self.total_samples = None def load(self, load_path: str): - self.segments, self.total_samples = MmapFileHandler.load(load_path) + self.segments, self.total_samples = load_h5(load_path) self.fetcher = MultiSegmentFetcher(self.segments) def get_index(self, index: int) -> int: diff --git a/khaosz/data/file.py b/khaosz/data/file.py new file mode 100644 index 0000000..38a1349 --- /dev/null +++ b/khaosz/data/file.py @@ -0,0 +1,44 @@ +import os +import h5py +import numpy as np +import torch +from torch import Tensor +from typing import Dict, List, Tuple + + +def save_h5(file_path: str, tensor_group: Dict[str, List[Tensor]]): + os.makedirs(os.path.dirname(file_path), exist_ok=True) + with h5py.File(file_path, 'w') as f: + for key, tensors in tensor_group.items(): + grp = f.create_group(key) + grp.attrs['num_tensors'] = len(tensors) + + for idx, tensor in enumerate(tensors): + arr = tensor.cpu().numpy() + dset = grp.create_dataset( + f'data_{idx}', + data=arr, + compression='gzip', + compression_opts=4, + shuffle=True + ) + dset.attrs['numel'] = tensor.numel() + +def load_h5(file_path: str) -> Tuple[Dict[str, List[Tensor]], int]: + tensor_group: Dict[str, List[Tensor]] = {} + total_samples = 0 + + with h5py.File(file_path, 'r') as f: + for key in f.keys(): + grp = f[key] + dsets = [] + for dset_name in grp.keys(): + dset = grp[dset_name] + dsets.append(torch.from_numpy(dset[:]).share_memory_()) + total_samples += dset.attrs.get('numel', np.prod(dset.shape)) + tensor_group[key] = dsets + + num_keys = max(len(tensor_group), 1) + sample_per_key = total_samples // num_keys + + return tensor_group, sample_per_key \ No newline at end of file diff --git a/khaosz/data/mmap.py b/khaosz/data/mmap.py deleted file mode 100644 index 8db521d..0000000 --- a/khaosz/data/mmap.py +++ /dev/null @@ -1,82 +0,0 @@ -import os -import json -import torch - -from torch import Tensor -from typing import List, Dict, Tuple - -class MmapFileHandler: - """ - json metadata like this: - - ``` - [ - {"file_name": "file1.pt", "size": 1000, "key": "key1"}, - {"file_name": "file2.pt", "size": 2000, "key": "key2"} - ... - ] - ``` - files like: - - ``` - folder_path: - - metadata.json - - file1.pt - - file2.pt - ... - ``` - """ - META_DATA = "metadata.json" - - @staticmethod - def load(root_path: str, shared: bool=True) -> Tuple[Dict[str, List[Tensor]], int]: - metadata_list = [] - tensor_group: Dict[str, List[Tensor]] = {} - - file_mapper_path = os.path.join(root_path, MmapFileHandler.META_DATA) - if not os.path.exists(file_mapper_path): - raise FileNotFoundError(f"File mapper not found: {file_mapper_path}") - - with open(file_mapper_path, "r") as f: - metadata_list = json.load(f) - - for metadata in metadata_list: - file_key = metadata["key"] - file_name = metadata["file_name"] - file_path = os.path.join(root_path, file_name) - elm = torch.load(file_path, map_location="cpu", mmap=shared) - - if file_key not in tensor_group: - tensor_group[file_key] = [] - tensor_group[file_key].append(elm) - - num_samples = sum(metadata["size"] for metadata in metadata_list) - num_keys = max(len(set(metadata['key'] for metadata in metadata_list)), 1) - sample_per_key = num_samples // num_keys - - return tensor_group, sample_per_key - - @staticmethod - def save(save_path: str, mmap_shared_group: Dict[str, List[Tensor]]) -> None: - os.makedirs(save_path, exist_ok=True) - - metadata_list = [] - for segment_key, segment_tensors in mmap_shared_group.items(): - for idx, tensor in enumerate(segment_tensors): - - try: - with open(os.path.join(save_path, f"{segment_key}_{idx}.pt"), "wb") as f: - torch.save(tensor.contiguous().cpu(), f) - except Exception as e: - raise RuntimeError(f"Error saving tensor: {e}") - - metadata_list.append({ - "file_name": f"{segment_key}_{idx}.pt", - "size": tensor.numel(), - "key": segment_key - }) - - metadata_path = os.path.join(save_path, MmapFileHandler.META_DATA) - - with open(metadata_path, "w") as f: - json.dump(metadata_list, f) \ No newline at end of file diff --git a/tests/data/test_dataset.py b/tests/data/test_dataset.py index 419c3ad..b3f9243 100644 --- a/tests/data/test_dataset.py +++ b/tests/data/test_dataset.py @@ -1,41 +1,22 @@ import os -import json -import pytest import torch import numpy as np -from khaosz.trainer import * +from khaosz.data.file import save_h5 from khaosz.data.dataset import * -def create_mmap_dataset(dir_path, data_dict, dataset_name): - """Helper function to create memory-mapped dataset for testing""" - dataset_dir = os.path.join(dir_path, dataset_name) - os.makedirs(dataset_dir, exist_ok=True) +def create_h5_dataset(dir_path, data_dict, dataset_name): + """Helper function to create HDF5 dataset for testing""" + dataset_path = os.path.join(dir_path, f"{dataset_name}.h5") - file_mapper = [] + # Convert data_dict to the format expected by save_h5 + # save_h5 expects a list of tensors for each key + tensor_group = {key: [tensor] for key, tensor in data_dict.items()} - for key, tensor in data_dict.items(): - # Convert tensor to numpy array and save as binary file - file_name = f"{key}.pt" - file_path = os.path.join(dataset_dir, file_name) - - with open(file_path, "wb") as f: - torch.save(tensor, f) - - # Add to file mapper - file_mapper.append({ - "file_name": file_name, - "size": tensor.numel(), - "key": key - }) + save_h5(dataset_path, tensor_group) - # Save file mapper - mapper_path = os.path.join(dataset_dir, "metadata.json") - with open(mapper_path, "w") as f: - json.dump(file_mapper, f, indent=2) - - return dataset_dir + return dataset_path def test_dataset_loader_random_paths(base_test_env): @@ -50,7 +31,7 @@ def test_dataset_loader_random_paths(base_test_env): dummy_data = { "sequence": torch.randint(0, 1000, (seq_length,), dtype=torch.int64), } - dataset_path = create_mmap_dataset(test_dir, dummy_data, f"test_data_{i}") + dataset_path = create_h5_dataset(test_dir, dummy_data, f"test_data_{i}") # Test loading with multiple paths loaded_dataset = DatasetLoader.load( @@ -84,7 +65,7 @@ def test_dpo_strategy_with_random_data(base_test_env): "rejected_mask": torch.ones(seq_length, dtype=torch.bool) } - dataset_path = create_mmap_dataset(test_dir, dummy_data, "dpo_data") + dataset_path = create_h5_dataset(test_dir, dummy_data, "dpo_data") # Load DPO dataset dpo_dataset = DatasetLoader.load( @@ -120,7 +101,7 @@ def test_sft_dataset_with_random_data(base_test_env): "loss_mask": torch.ones(seq_length, dtype=torch.bool) } - dataset_path = create_mmap_dataset(test_dir, dummy_data, "sft_data") + dataset_path = create_h5_dataset(test_dir, dummy_data, "sft_data") # Load SFT dataset sft_dataset = DatasetLoader.load( @@ -153,7 +134,7 @@ def test_dataset_with_custom_stride(base_test_env): "sequence": torch.randint(0, 1000, (seq_length,), dtype=torch.int64), } - dataset_path = create_mmap_dataset(test_dir, dummy_data, "stride_test_data") + dataset_path = create_h5_dataset(test_dir, dummy_data, "stride_test_data") # Test with custom stride custom_stride = 32 @@ -176,125 +157,3 @@ def test_dataset_with_custom_stride(base_test_env): ) assert len(dataset) > len(default_stride_dataset) - - -def test_multi_segment_fetcher(base_test_env): - """Test MultiSegmentFetcher functionality directly""" - test_dir = base_test_env["test_dir"] - - # Create test data with multiple segments - seq_length = 100 - dummy_data = { - "sequence": torch.randint(0, 1000, (seq_length,), dtype=torch.int64), - "mask": torch.ones(seq_length, dtype=torch.bool) - } - - dataset_path = create_mmap_dataset(test_dir, dummy_data, "multi_segment_test") - - # Load the memory mapped files directly - multi_segments, _ = MmapFileHandler.load(dataset_path) - - # Create fetcher - fetcher = MultiSegmentFetcher(multi_segments) - - # Test fetching single key - sequence_data = fetcher.key_fetch(0, 10, "sequence") - assert sequence_data is not None - assert len(sequence_data) == 10 - - # Test fetching multiple keys - multi_data = fetcher.key_fetch(0, 10, ["sequence", "mask"]) - assert "sequence" in multi_data - assert "mask" in multi_data - assert len(multi_data["sequence"]) == 10 - assert len(multi_data["mask"]) == 10 - - # Test fetching all keys - all_data = fetcher.fetch_data(0, 10) - assert "sequence" in all_data - assert "mask" in all_data - - -def test_mmap_file_handler_direct(base_test_env): - """Test MmapFileHandler directly without DatasetLoader""" - test_dir = base_test_env["test_dir"] - - # Create test data with multiple segments - seq_length1 = 100 - seq_length2 = 200 - - # Create data in the format expected by MmapFileHandler - dummy_data = { - "sequence": [ - torch.randint(0, 1000, (seq_length1,), dtype=torch.int64), - torch.randint(0, 1000, (seq_length2,), dtype=torch.int64) - ], - "mask": [ - torch.ones(seq_length1, dtype=torch.bool), - torch.ones(seq_length2, dtype=torch.bool) - ] - } - - # Save data using MmapFileHandler - dataset_dir = os.path.join(test_dir, "mmap_direct_test") - MmapFileHandler.save(dataset_dir, dummy_data) - - # Load data using MmapFileHandler - loaded_data, num_samples = MmapFileHandler.load(dataset_dir) - - # Verify data structure - assert set(loaded_data.keys()) == set(dummy_data.keys()) - assert num_samples == seq_length1 + seq_length2 # 300 - - # Verify data content - for key in dummy_data: - assert len(loaded_data[key]) == len(dummy_data[key]) - for i in range(len(dummy_data[key])): - assert torch.equal(loaded_data[key][i], dummy_data[key][i]) - -def test_mmap_file_handler_dtypes(base_test_env): - """Test MmapFileHandler with different data types""" - test_dir = base_test_env["test_dir"] - - # Create test data with different dtypes - data = { - "float32": [torch.randn(100, dtype=torch.float32)], - "float64": [torch.randn(100, dtype=torch.float64)], - "int32": [torch.randint(0, 1000, (100,), dtype=torch.int32)], - "int64": [torch.randint(0, 1000, (100,), dtype=torch.int64)], - "bool": [torch.randint(0, 2, (100,), dtype=torch.bool)] - } - - # Save data - dataset_dir = os.path.join(test_dir, "dtype_test") - MmapFileHandler.save(dataset_dir, data) - - # Load data - loaded_data, _ = MmapFileHandler.load(dataset_dir) - - # Verify data types - for key in data: - assert loaded_data[key][0].dtype == data[key][0].dtype - assert torch.equal(loaded_data[key][0], data[key][0]) - -def test_mmap_file_handler_error_handling(base_test_env): - """Test MmapFileHandler error handling""" - test_dir = base_test_env["test_dir"] - - # Test loading without file_mapper.json - empty_dir = os.path.join(test_dir, "empty_dir") - os.makedirs(empty_dir, exist_ok=True) - with pytest.raises(FileNotFoundError): - MmapFileHandler.load(empty_dir) - - # Test loading with invalid file_mapper.json - invalid_dir = os.path.join(test_dir, "invalid_dir") - os.makedirs(invalid_dir, exist_ok=True) - - # Create empty file_mapper.json - with open(os.path.join(invalid_dir, "file_mapper.json"), "w") as f: - json.dump([{"file_name": "file1.bin", "size": 1000, "dtype": "float32", "key": "key1"}], f) - - # This should raise FileNotFoundError because no binary files exist - with pytest.raises(FileNotFoundError): - MmapFileHandler.load(invalid_dir)