fix: 修复 dataset 和 checkpoint 的 bug

This commit is contained in:
ViperEkura 2026-03-02 11:12:21 +08:00
parent 80e17418b4
commit 8a8d6369bc
5 changed files with 56 additions and 66 deletions

View File

@ -70,8 +70,7 @@ class Checkpoint:
state_dict = torch.load(f) state_dict = torch.load(f)
return cls( return cls(
optimizer_state_dict=state_dict["optimizer"], state_dict=state_dict,
scheduler_state_dict=state_dict["scheduler"],
epoch=meta["epoch"], epoch=meta["epoch"],
iteration=meta["iteration"], iteration=meta["iteration"],
metrics=meta.get("metrics", {}), metrics=meta.get("metrics", {}),

View File

@ -19,6 +19,9 @@ class BaseSegmentFetcher:
self.cum_lengths.append(total) self.cum_lengths.append(total)
self.total_length = total if segments else 0 self.total_length = total if segments else 0
def __len__(self) -> int:
return self.total_length
def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor: def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
if not (0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length): if not (0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length):
raise ValueError("begin_idx or end_idx out of bounds") raise ValueError("begin_idx or end_idx out of bounds")
@ -49,6 +52,10 @@ class MultiSegmentFetcher:
for key, segments in muti_segments.items() for key, segments in muti_segments.items()
} }
def __len__(self) -> int:
len_list = [len(seg) for seg in self.muti_fetchers.values()]
return min(len_list)
def key_fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]) -> Dict: def key_fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]) -> Dict:
fetch_dict = {} fetch_dict = {}
keys = [keys] if isinstance(keys, str) else keys keys = [keys] if isinstance(keys, str) else keys
@ -73,8 +80,9 @@ class BaseDataset(Dataset, ABC):
self.total_samples = None self.total_samples = None
def load(self, load_path: str): def load(self, load_path: str):
self.segments, self.total_samples = load_h5(load_path) self.segments = load_h5(load_path)
self.fetcher = MultiSegmentFetcher(self.segments) self.fetcher = MultiSegmentFetcher(self.segments)
self.total_samples = len(self.fetcher)
def get_index(self, index: int) -> int: def get_index(self, index: int) -> int:
assert self.total_samples > self.window_size assert self.total_samples > self.window_size

View File

@ -8,20 +8,17 @@ from torch import Tensor
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
def save_h5(file_path: str, tensor_group: Dict[str, List[Tensor]]): def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
os.makedirs(os.path.dirname(file_path), exist_ok=True) os.makedirs(file_path, exist_ok=True)
with h5py.File(file_path, 'w') as f: full_file_path = os.path.join(file_path, f"{file_name}.h5")
with h5py.File(full_file_path, 'w') as f:
for key, tensors in tensor_group.items(): for key, tensors in tensor_group.items():
grp = f.create_group(key) grp = f.create_group(key)
grp.attrs['num_tensors'] = len(tensors) grp.attrs['num_tensors'] = len(tensors)
for idx, tensor in enumerate(tensors): for idx, tensor in enumerate(tensors):
arr = tensor.cpu().numpy() arr = tensor.cpu().numpy()
dset = grp.create_dataset( grp.create_dataset(f'data_{idx}', data=arr)
f'data_{idx}',
data=arr
)
dset.attrs['numel'] = tensor.numel()
def load_h5(file_path: str) -> Tuple[Dict[str, List[Tensor]], int]: def load_h5(file_path: str) -> Tuple[Dict[str, List[Tensor]], int]:
tensor_group: Dict[str, List[Tensor]] = {} tensor_group: Dict[str, List[Tensor]] = {}
@ -38,10 +35,6 @@ def load_h5(file_path: str) -> Tuple[Dict[str, List[Tensor]], int]:
for dset_name in grp.keys(): for dset_name in grp.keys():
dset = grp[dset_name] dset = grp[dset_name]
dsets.append(torch.from_numpy(dset[:]).share_memory_()) dsets.append(torch.from_numpy(dset[:]).share_memory_())
total_samples += dset.attrs.get('numel', np.prod(dset.shape))
tensor_group[key] = dsets tensor_group[key] = dsets
num_keys = max(len(tensor_group), 1) return tensor_group
sample_per_key = total_samples // num_keys
return tensor_group, sample_per_key

View File

@ -1,12 +1,12 @@
import os
import torch import torch
import tempfile import tempfile
import torch.distributed as dist
from pathlib import Path from pathlib import Path
from torch.optim import AdamW from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR from torch.optim.lr_scheduler import CosineAnnealingLR
from khaosz.data.checkpoint import Checkpoint from khaosz.data.checkpoint import Checkpoint
from khaosz.parallel.setup import spawn_parallel_fn from khaosz.parallel.setup import get_rank, spawn_parallel_fn
def test_single_process(): def test_single_process():
model = torch.nn.Linear(10, 5) model = torch.nn.Linear(10, 5)
@ -26,8 +26,7 @@ def test_single_process():
scheduler.step() scheduler.step()
checkpoint = Checkpoint( checkpoint = Checkpoint(
optimizer_state_dict=optimizer.state_dict(), state_dict=model.state_dict(),
scheduler_state_dict=scheduler.state_dict(),
epoch=3, epoch=3,
iteration=30, iteration=30,
metrics={ metrics={
@ -45,21 +44,14 @@ def test_single_process():
assert loaded_checkpoint.iteration == 30 assert loaded_checkpoint.iteration == 30
assert loaded_checkpoint.metrics["loss"] == [0.5, 0.4, 0.3, 0.2, 0.1] assert loaded_checkpoint.metrics["loss"] == [0.5, 0.4, 0.3, 0.2, 0.1]
assert 'param_groups' in loaded_checkpoint.optimizer_state_dict
assert 'state' in loaded_checkpoint.optimizer_state_dict
png_files = list(Path(tmpdir).glob("*.png")) png_files = list(Path(tmpdir).glob("*.png"))
assert png_files assert png_files
def simple_training(): def simple_training():
rank = int(os.environ.get('LOCAL_RANK', 0))
# 简单的训练逻辑
model = torch.nn.Linear(10, 5) model = torch.nn.Linear(10, 5)
optimizer = AdamW(model.parameters(), lr=1e-3) optimizer = AdamW(model.parameters(), lr=1e-3)
scheduler = CosineAnnealingLR(optimizer, T_max=10) scheduler = CosineAnnealingLR(optimizer, T_max=10)
# 训练步骤
for epoch in range(2): for epoch in range(2):
for iteration in range(5): for iteration in range(5):
x = torch.randn(16, 10) x = torch.randn(16, 10)
@ -71,18 +63,29 @@ def simple_training():
scheduler.step() scheduler.step()
checkpoint = Checkpoint( checkpoint = Checkpoint(
optimizer_state_dict=optimizer.state_dict(), state_dict=model.state_dict(),
scheduler_state_dict=scheduler.state_dict(),
epoch=2, epoch=2,
iteration=10, iteration=10,
metrics={"loss": [0.3, 0.2, 0.1]} metrics={"loss": [0.3, 0.2, 0.1]}
) )
with tempfile.TemporaryDirectory() as tmpdir: rank = get_rank()
checkpoint.save(tmpdir)
loaded = Checkpoint.load(tmpdir) if rank == 0:
shared_dir = tempfile.mkdtemp()
checkpoint.save(shared_dir)
else:
shared_dir = None
if dist.is_initialized():
dir_list = [shared_dir]
dist.broadcast_object_list(dir_list, src=0)
shared_dir = dir_list[0]
loaded = Checkpoint.load(shared_dir)
assert loaded.epoch == 2 assert loaded.epoch == 2
print(f"Rank {rank}: Checkpoint test passed")
def test_multi_process(): def test_multi_process():
spawn_parallel_fn( spawn_parallel_fn(

View File

@ -1,4 +1,3 @@
import os
import torch import torch
import numpy as np import numpy as np
@ -6,18 +5,6 @@ from khaosz.data.file import save_h5
from khaosz.data.dataset import * from khaosz.data.dataset import *
def create_h5_dataset(dir_path, data_dict, dataset_name):
"""Helper function to create HDF5 dataset for testing"""
dataset_path = os.path.join(dir_path, f"{dataset_name}.h5")
# Convert data_dict to the format expected by save_h5
# save_h5 expects a list of tensors for each key
tensor_group = {key: [tensor] for key, tensor in data_dict.items()}
save_h5(dataset_path, tensor_group)
return dataset_path
def test_dataset_loader_random_paths(base_test_env): def test_dataset_loader_random_paths(base_test_env):
"""Test dataset loader with multiple random paths""" """Test dataset loader with multiple random paths"""
@ -27,23 +14,23 @@ def test_dataset_loader_random_paths(base_test_env):
num_files = np.random.randint(2, 5) num_files = np.random.randint(2, 5)
for i in range(num_files): for i in range(num_files):
seq_length = np.random.randint(100, 200) seq_length = np.random.randint(200, 400)
dummy_data = { dummy_data = {
"sequence": torch.randint(0, 1000, (seq_length,), dtype=torch.int64), "sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64) for _ in range(10)],
} }
dataset_path = create_h5_dataset(test_dir, dummy_data, f"test_data_{i}") save_h5(test_dir, f"data_{i}", dummy_data)
# Test loading with multiple paths # Test loading with multiple paths
loaded_dataset = DatasetLoader.load( loaded_dataset = DatasetLoader.load(
train_type="seq", train_type="seq",
load_path=dataset_path, load_path=test_dir,
window_size=64, window_size=64,
) )
assert loaded_dataset is not None assert loaded_dataset is not None
assert len(loaded_dataset) > 0 assert len(loaded_dataset) > 0
# Test that we can get items without errors # Test that we can get items without errors
for i in range(min(3, len(loaded_dataset))): for i in range(len(loaded_dataset)):
item = loaded_dataset[i] item = loaded_dataset[i]
assert "input_ids" in item assert "input_ids" in item
assert "target_ids" in item assert "target_ids" in item
@ -59,18 +46,18 @@ def test_dpo_strategy_with_random_data(base_test_env):
seq_length = np.random.randint(100, 200) seq_length = np.random.randint(100, 200)
dummy_data = { dummy_data = {
"chosen": torch.randint(0, 1000, (seq_length,), dtype=torch.int64), "chosen": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
"rejected": torch.randint(0, 1000, (seq_length,), dtype=torch.int64), "rejected": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
"chosen_mask": torch.ones(seq_length, dtype=torch.bool), "chosen_mask": [torch.ones(seq_length, dtype=torch.bool)],
"rejected_mask": torch.ones(seq_length, dtype=torch.bool) "rejected_mask": [torch.ones(seq_length, dtype=torch.bool)]
} }
dataset_path = create_h5_dataset(test_dir, dummy_data, "dpo_data") save_h5(test_dir, "dpo_data", dummy_data)
# Load DPO dataset # Load DPO dataset
dpo_dataset = DatasetLoader.load( dpo_dataset = DatasetLoader.load(
train_type="dpo", train_type="dpo",
load_path=dataset_path, load_path=test_dir,
window_size=64, window_size=64,
) )
@ -97,16 +84,16 @@ def test_sft_dataset_with_random_data(base_test_env):
seq_length = np.random.randint(100, 200) seq_length = np.random.randint(100, 200)
dummy_data = { dummy_data = {
"sequence": torch.randint(0, 1000, (seq_length,), dtype=torch.int64), "sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
"loss_mask": torch.ones(seq_length, dtype=torch.bool) "loss_mask": [torch.ones(seq_length, dtype=torch.bool)]
} }
dataset_path = create_h5_dataset(test_dir, dummy_data, "sft_data") save_h5(test_dir, "sft_data", dummy_data)
# Load SFT dataset # Load SFT dataset
sft_dataset = DatasetLoader.load( sft_dataset = DatasetLoader.load(
train_type="sft", train_type="sft",
load_path=dataset_path, load_path=test_dir,
window_size=64, window_size=64,
) )
@ -131,16 +118,16 @@ def test_dataset_with_custom_stride(base_test_env):
# Create test data # Create test data
seq_length = 200 seq_length = 200
dummy_data = { dummy_data = {
"sequence": torch.randint(0, 1000, (seq_length,), dtype=torch.int64), "sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
} }
dataset_path = create_h5_dataset(test_dir, dummy_data, "stride_test_data") save_h5(test_dir,"stride_test_data", dummy_data)
# Test with custom stride # Test with custom stride
custom_stride = 32 custom_stride = 32
dataset = DatasetLoader.load( dataset = DatasetLoader.load(
train_type="seq", train_type="seq",
load_path=dataset_path, load_path=test_dir,
window_size=64, window_size=64,
stride=custom_stride stride=custom_stride
) )
@ -152,7 +139,7 @@ def test_dataset_with_custom_stride(base_test_env):
# than with default stride (which equals window size) # than with default stride (which equals window size)
default_stride_dataset = DatasetLoader.load( default_stride_dataset = DatasetLoader.load(
train_type="seq", train_type="seq",
load_path=dataset_path, load_path=test_dir,
window_size=64, window_size=64,
) )