From 8a8d6369bc136f1ea4d2bfe8eeacd06d15b2069a Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Mon, 2 Mar 2026 11:12:21 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=20dataset=20=20?= =?UTF-8?q?=E5=92=8C=20checkpoint=20=E7=9A=84=20bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/data/checkpoint.py | 3 +-- khaosz/data/dataset.py | 10 ++++++- khaosz/data/file.py | 19 +++++-------- tests/data/test_checkpoint.py | 39 ++++++++++++++------------- tests/data/test_dataset.py | 51 +++++++++++++---------------------- 5 files changed, 56 insertions(+), 66 deletions(-) diff --git a/khaosz/data/checkpoint.py b/khaosz/data/checkpoint.py index 0afaa22..042789b 100644 --- a/khaosz/data/checkpoint.py +++ b/khaosz/data/checkpoint.py @@ -70,8 +70,7 @@ class Checkpoint: state_dict = torch.load(f) return cls( - optimizer_state_dict=state_dict["optimizer"], - scheduler_state_dict=state_dict["scheduler"], + state_dict=state_dict, epoch=meta["epoch"], iteration=meta["iteration"], metrics=meta.get("metrics", {}), diff --git a/khaosz/data/dataset.py b/khaosz/data/dataset.py index deb71fe..1255b8c 100644 --- a/khaosz/data/dataset.py +++ b/khaosz/data/dataset.py @@ -18,6 +18,9 @@ class BaseSegmentFetcher: total += len(seg) self.cum_lengths.append(total) self.total_length = total if segments else 0 + + def __len__(self) -> int: + return self.total_length def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor: if not (0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length): @@ -48,6 +51,10 @@ class MultiSegmentFetcher: key: BaseSegmentFetcher(segments) for key, segments in muti_segments.items() } + + def __len__(self) -> int: + len_list = [len(seg) for seg in self.muti_fetchers.values()] + return min(len_list) def key_fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]) -> Dict: fetch_dict = {} @@ -73,8 +80,9 @@ class BaseDataset(Dataset, ABC): self.total_samples = None def load(self, load_path: str): - self.segments, self.total_samples = load_h5(load_path) + self.segments = load_h5(load_path) self.fetcher = MultiSegmentFetcher(self.segments) + self.total_samples = len(self.fetcher) def get_index(self, index: int) -> int: assert self.total_samples > self.window_size diff --git a/khaosz/data/file.py b/khaosz/data/file.py index 92ad9c5..23bad89 100644 --- a/khaosz/data/file.py +++ b/khaosz/data/file.py @@ -8,20 +8,17 @@ from torch import Tensor from typing import Dict, List, Tuple -def save_h5(file_path: str, tensor_group: Dict[str, List[Tensor]]): - os.makedirs(os.path.dirname(file_path), exist_ok=True) - with h5py.File(file_path, 'w') as f: +def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]): + os.makedirs(file_path, exist_ok=True) + full_file_path = os.path.join(file_path, f"{file_name}.h5") + with h5py.File(full_file_path, 'w') as f: for key, tensors in tensor_group.items(): grp = f.create_group(key) grp.attrs['num_tensors'] = len(tensors) for idx, tensor in enumerate(tensors): arr = tensor.cpu().numpy() - dset = grp.create_dataset( - f'data_{idx}', - data=arr - ) - dset.attrs['numel'] = tensor.numel() + grp.create_dataset(f'data_{idx}', data=arr) def load_h5(file_path: str) -> Tuple[Dict[str, List[Tensor]], int]: tensor_group: Dict[str, List[Tensor]] = {} @@ -38,10 +35,6 @@ def load_h5(file_path: str) -> Tuple[Dict[str, List[Tensor]], int]: for dset_name in grp.keys(): dset = grp[dset_name] dsets.append(torch.from_numpy(dset[:]).share_memory_()) - total_samples += dset.attrs.get('numel', np.prod(dset.shape)) tensor_group[key] = dsets - num_keys = max(len(tensor_group), 1) - sample_per_key = total_samples // num_keys - - return tensor_group, sample_per_key \ No newline at end of file + return tensor_group \ No newline at end of file diff --git a/tests/data/test_checkpoint.py b/tests/data/test_checkpoint.py index ec1dc46..05ae8e5 100644 --- a/tests/data/test_checkpoint.py +++ b/tests/data/test_checkpoint.py @@ -1,12 +1,12 @@ -import os import torch import tempfile +import torch.distributed as dist from pathlib import Path from torch.optim import AdamW from torch.optim.lr_scheduler import CosineAnnealingLR from khaosz.data.checkpoint import Checkpoint -from khaosz.parallel.setup import spawn_parallel_fn +from khaosz.parallel.setup import get_rank, spawn_parallel_fn def test_single_process(): model = torch.nn.Linear(10, 5) @@ -26,8 +26,7 @@ def test_single_process(): scheduler.step() checkpoint = Checkpoint( - optimizer_state_dict=optimizer.state_dict(), - scheduler_state_dict=scheduler.state_dict(), + state_dict=model.state_dict(), epoch=3, iteration=30, metrics={ @@ -45,21 +44,14 @@ def test_single_process(): assert loaded_checkpoint.iteration == 30 assert loaded_checkpoint.metrics["loss"] == [0.5, 0.4, 0.3, 0.2, 0.1] - assert 'param_groups' in loaded_checkpoint.optimizer_state_dict - assert 'state' in loaded_checkpoint.optimizer_state_dict - png_files = list(Path(tmpdir).glob("*.png")) assert png_files def simple_training(): - rank = int(os.environ.get('LOCAL_RANK', 0)) - - # 简单的训练逻辑 model = torch.nn.Linear(10, 5) optimizer = AdamW(model.parameters(), lr=1e-3) scheduler = CosineAnnealingLR(optimizer, T_max=10) - # 训练步骤 for epoch in range(2): for iteration in range(5): x = torch.randn(16, 10) @@ -71,18 +63,29 @@ def simple_training(): scheduler.step() checkpoint = Checkpoint( - optimizer_state_dict=optimizer.state_dict(), - scheduler_state_dict=scheduler.state_dict(), + state_dict=model.state_dict(), epoch=2, iteration=10, metrics={"loss": [0.3, 0.2, 0.1]} ) - with tempfile.TemporaryDirectory() as tmpdir: - checkpoint.save(tmpdir) - loaded = Checkpoint.load(tmpdir) - assert loaded.epoch == 2 - print(f"Rank {rank}: Checkpoint test passed") + rank = get_rank() + + if rank == 0: + shared_dir = tempfile.mkdtemp() + checkpoint.save(shared_dir) + else: + shared_dir = None + + + if dist.is_initialized(): + dir_list = [shared_dir] + dist.broadcast_object_list(dir_list, src=0) + shared_dir = dir_list[0] + + + loaded = Checkpoint.load(shared_dir) + assert loaded.epoch == 2 def test_multi_process(): spawn_parallel_fn( diff --git a/tests/data/test_dataset.py b/tests/data/test_dataset.py index b3f9243..d12aeb9 100644 --- a/tests/data/test_dataset.py +++ b/tests/data/test_dataset.py @@ -1,4 +1,3 @@ -import os import torch import numpy as np @@ -6,18 +5,6 @@ from khaosz.data.file import save_h5 from khaosz.data.dataset import * -def create_h5_dataset(dir_path, data_dict, dataset_name): - """Helper function to create HDF5 dataset for testing""" - dataset_path = os.path.join(dir_path, f"{dataset_name}.h5") - - # Convert data_dict to the format expected by save_h5 - # save_h5 expects a list of tensors for each key - tensor_group = {key: [tensor] for key, tensor in data_dict.items()} - - save_h5(dataset_path, tensor_group) - - return dataset_path - def test_dataset_loader_random_paths(base_test_env): """Test dataset loader with multiple random paths""" @@ -27,23 +14,23 @@ def test_dataset_loader_random_paths(base_test_env): num_files = np.random.randint(2, 5) for i in range(num_files): - seq_length = np.random.randint(100, 200) + seq_length = np.random.randint(200, 400) dummy_data = { - "sequence": torch.randint(0, 1000, (seq_length,), dtype=torch.int64), + "sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64) for _ in range(10)], } - dataset_path = create_h5_dataset(test_dir, dummy_data, f"test_data_{i}") + save_h5(test_dir, f"data_{i}", dummy_data) # Test loading with multiple paths loaded_dataset = DatasetLoader.load( train_type="seq", - load_path=dataset_path, + load_path=test_dir, window_size=64, ) assert loaded_dataset is not None assert len(loaded_dataset) > 0 # Test that we can get items without errors - for i in range(min(3, len(loaded_dataset))): + for i in range(len(loaded_dataset)): item = loaded_dataset[i] assert "input_ids" in item assert "target_ids" in item @@ -59,18 +46,18 @@ def test_dpo_strategy_with_random_data(base_test_env): seq_length = np.random.randint(100, 200) dummy_data = { - "chosen": torch.randint(0, 1000, (seq_length,), dtype=torch.int64), - "rejected": torch.randint(0, 1000, (seq_length,), dtype=torch.int64), - "chosen_mask": torch.ones(seq_length, dtype=torch.bool), - "rejected_mask": torch.ones(seq_length, dtype=torch.bool) + "chosen": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)], + "rejected": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)], + "chosen_mask": [torch.ones(seq_length, dtype=torch.bool)], + "rejected_mask": [torch.ones(seq_length, dtype=torch.bool)] } - dataset_path = create_h5_dataset(test_dir, dummy_data, "dpo_data") + save_h5(test_dir, "dpo_data", dummy_data) # Load DPO dataset dpo_dataset = DatasetLoader.load( train_type="dpo", - load_path=dataset_path, + load_path=test_dir, window_size=64, ) @@ -97,16 +84,16 @@ def test_sft_dataset_with_random_data(base_test_env): seq_length = np.random.randint(100, 200) dummy_data = { - "sequence": torch.randint(0, 1000, (seq_length,), dtype=torch.int64), - "loss_mask": torch.ones(seq_length, dtype=torch.bool) + "sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)], + "loss_mask": [torch.ones(seq_length, dtype=torch.bool)] } - dataset_path = create_h5_dataset(test_dir, dummy_data, "sft_data") + save_h5(test_dir, "sft_data", dummy_data) # Load SFT dataset sft_dataset = DatasetLoader.load( train_type="sft", - load_path=dataset_path, + load_path=test_dir, window_size=64, ) @@ -131,16 +118,16 @@ def test_dataset_with_custom_stride(base_test_env): # Create test data seq_length = 200 dummy_data = { - "sequence": torch.randint(0, 1000, (seq_length,), dtype=torch.int64), + "sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)], } - dataset_path = create_h5_dataset(test_dir, dummy_data, "stride_test_data") + save_h5(test_dir,"stride_test_data", dummy_data) # Test with custom stride custom_stride = 32 dataset = DatasetLoader.load( train_type="seq", - load_path=dataset_path, + load_path=test_dir, window_size=64, stride=custom_stride ) @@ -152,7 +139,7 @@ def test_dataset_with_custom_stride(base_test_env): # than with default stride (which equals window size) default_stride_dataset = DatasetLoader.load( train_type="seq", - load_path=dataset_path, + load_path=test_dir, window_size=64, )