fix: 修复 dataset 和 checkpoint 的 bug

This commit is contained in:
ViperEkura 2026-03-02 11:12:21 +08:00
parent 80e17418b4
commit 8a8d6369bc
5 changed files with 56 additions and 66 deletions

View File

@ -70,8 +70,7 @@ class Checkpoint:
state_dict = torch.load(f)
return cls(
optimizer_state_dict=state_dict["optimizer"],
scheduler_state_dict=state_dict["scheduler"],
state_dict=state_dict,
epoch=meta["epoch"],
iteration=meta["iteration"],
metrics=meta.get("metrics", {}),

View File

@ -19,6 +19,9 @@ class BaseSegmentFetcher:
self.cum_lengths.append(total)
self.total_length = total if segments else 0
def __len__(self) -> int:
return self.total_length
def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
if not (0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length):
raise ValueError("begin_idx or end_idx out of bounds")
@ -49,6 +52,10 @@ class MultiSegmentFetcher:
for key, segments in muti_segments.items()
}
def __len__(self) -> int:
len_list = [len(seg) for seg in self.muti_fetchers.values()]
return min(len_list)
def key_fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]) -> Dict:
fetch_dict = {}
keys = [keys] if isinstance(keys, str) else keys
@ -73,8 +80,9 @@ class BaseDataset(Dataset, ABC):
self.total_samples = None
def load(self, load_path: str):
self.segments, self.total_samples = load_h5(load_path)
self.segments = load_h5(load_path)
self.fetcher = MultiSegmentFetcher(self.segments)
self.total_samples = len(self.fetcher)
def get_index(self, index: int) -> int:
assert self.total_samples > self.window_size

View File

@ -8,20 +8,17 @@ from torch import Tensor
from typing import Dict, List, Tuple
def save_h5(file_path: str, tensor_group: Dict[str, List[Tensor]]):
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with h5py.File(file_path, 'w') as f:
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
os.makedirs(file_path, exist_ok=True)
full_file_path = os.path.join(file_path, f"{file_name}.h5")
with h5py.File(full_file_path, 'w') as f:
for key, tensors in tensor_group.items():
grp = f.create_group(key)
grp.attrs['num_tensors'] = len(tensors)
for idx, tensor in enumerate(tensors):
arr = tensor.cpu().numpy()
dset = grp.create_dataset(
f'data_{idx}',
data=arr
)
dset.attrs['numel'] = tensor.numel()
grp.create_dataset(f'data_{idx}', data=arr)
def load_h5(file_path: str) -> Tuple[Dict[str, List[Tensor]], int]:
tensor_group: Dict[str, List[Tensor]] = {}
@ -38,10 +35,6 @@ def load_h5(file_path: str) -> Tuple[Dict[str, List[Tensor]], int]:
for dset_name in grp.keys():
dset = grp[dset_name]
dsets.append(torch.from_numpy(dset[:]).share_memory_())
total_samples += dset.attrs.get('numel', np.prod(dset.shape))
tensor_group[key] = dsets
num_keys = max(len(tensor_group), 1)
sample_per_key = total_samples // num_keys
return tensor_group, sample_per_key
return tensor_group

View File

@ -1,12 +1,12 @@
import os
import torch
import tempfile
import torch.distributed as dist
from pathlib import Path
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from khaosz.data.checkpoint import Checkpoint
from khaosz.parallel.setup import spawn_parallel_fn
from khaosz.parallel.setup import get_rank, spawn_parallel_fn
def test_single_process():
model = torch.nn.Linear(10, 5)
@ -26,8 +26,7 @@ def test_single_process():
scheduler.step()
checkpoint = Checkpoint(
optimizer_state_dict=optimizer.state_dict(),
scheduler_state_dict=scheduler.state_dict(),
state_dict=model.state_dict(),
epoch=3,
iteration=30,
metrics={
@ -45,21 +44,14 @@ def test_single_process():
assert loaded_checkpoint.iteration == 30
assert loaded_checkpoint.metrics["loss"] == [0.5, 0.4, 0.3, 0.2, 0.1]
assert 'param_groups' in loaded_checkpoint.optimizer_state_dict
assert 'state' in loaded_checkpoint.optimizer_state_dict
png_files = list(Path(tmpdir).glob("*.png"))
assert png_files
def simple_training():
rank = int(os.environ.get('LOCAL_RANK', 0))
# 简单的训练逻辑
model = torch.nn.Linear(10, 5)
optimizer = AdamW(model.parameters(), lr=1e-3)
scheduler = CosineAnnealingLR(optimizer, T_max=10)
# 训练步骤
for epoch in range(2):
for iteration in range(5):
x = torch.randn(16, 10)
@ -71,18 +63,29 @@ def simple_training():
scheduler.step()
checkpoint = Checkpoint(
optimizer_state_dict=optimizer.state_dict(),
scheduler_state_dict=scheduler.state_dict(),
state_dict=model.state_dict(),
epoch=2,
iteration=10,
metrics={"loss": [0.3, 0.2, 0.1]}
)
with tempfile.TemporaryDirectory() as tmpdir:
checkpoint.save(tmpdir)
loaded = Checkpoint.load(tmpdir)
assert loaded.epoch == 2
print(f"Rank {rank}: Checkpoint test passed")
rank = get_rank()
if rank == 0:
shared_dir = tempfile.mkdtemp()
checkpoint.save(shared_dir)
else:
shared_dir = None
if dist.is_initialized():
dir_list = [shared_dir]
dist.broadcast_object_list(dir_list, src=0)
shared_dir = dir_list[0]
loaded = Checkpoint.load(shared_dir)
assert loaded.epoch == 2
def test_multi_process():
spawn_parallel_fn(

View File

@ -1,4 +1,3 @@
import os
import torch
import numpy as np
@ -6,18 +5,6 @@ from khaosz.data.file import save_h5
from khaosz.data.dataset import *
def create_h5_dataset(dir_path, data_dict, dataset_name):
"""Helper function to create HDF5 dataset for testing"""
dataset_path = os.path.join(dir_path, f"{dataset_name}.h5")
# Convert data_dict to the format expected by save_h5
# save_h5 expects a list of tensors for each key
tensor_group = {key: [tensor] for key, tensor in data_dict.items()}
save_h5(dataset_path, tensor_group)
return dataset_path
def test_dataset_loader_random_paths(base_test_env):
"""Test dataset loader with multiple random paths"""
@ -27,23 +14,23 @@ def test_dataset_loader_random_paths(base_test_env):
num_files = np.random.randint(2, 5)
for i in range(num_files):
seq_length = np.random.randint(100, 200)
seq_length = np.random.randint(200, 400)
dummy_data = {
"sequence": torch.randint(0, 1000, (seq_length,), dtype=torch.int64),
"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64) for _ in range(10)],
}
dataset_path = create_h5_dataset(test_dir, dummy_data, f"test_data_{i}")
save_h5(test_dir, f"data_{i}", dummy_data)
# Test loading with multiple paths
loaded_dataset = DatasetLoader.load(
train_type="seq",
load_path=dataset_path,
load_path=test_dir,
window_size=64,
)
assert loaded_dataset is not None
assert len(loaded_dataset) > 0
# Test that we can get items without errors
for i in range(min(3, len(loaded_dataset))):
for i in range(len(loaded_dataset)):
item = loaded_dataset[i]
assert "input_ids" in item
assert "target_ids" in item
@ -59,18 +46,18 @@ def test_dpo_strategy_with_random_data(base_test_env):
seq_length = np.random.randint(100, 200)
dummy_data = {
"chosen": torch.randint(0, 1000, (seq_length,), dtype=torch.int64),
"rejected": torch.randint(0, 1000, (seq_length,), dtype=torch.int64),
"chosen_mask": torch.ones(seq_length, dtype=torch.bool),
"rejected_mask": torch.ones(seq_length, dtype=torch.bool)
"chosen": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
"rejected": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
"chosen_mask": [torch.ones(seq_length, dtype=torch.bool)],
"rejected_mask": [torch.ones(seq_length, dtype=torch.bool)]
}
dataset_path = create_h5_dataset(test_dir, dummy_data, "dpo_data")
save_h5(test_dir, "dpo_data", dummy_data)
# Load DPO dataset
dpo_dataset = DatasetLoader.load(
train_type="dpo",
load_path=dataset_path,
load_path=test_dir,
window_size=64,
)
@ -97,16 +84,16 @@ def test_sft_dataset_with_random_data(base_test_env):
seq_length = np.random.randint(100, 200)
dummy_data = {
"sequence": torch.randint(0, 1000, (seq_length,), dtype=torch.int64),
"loss_mask": torch.ones(seq_length, dtype=torch.bool)
"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
"loss_mask": [torch.ones(seq_length, dtype=torch.bool)]
}
dataset_path = create_h5_dataset(test_dir, dummy_data, "sft_data")
save_h5(test_dir, "sft_data", dummy_data)
# Load SFT dataset
sft_dataset = DatasetLoader.load(
train_type="sft",
load_path=dataset_path,
load_path=test_dir,
window_size=64,
)
@ -131,16 +118,16 @@ def test_dataset_with_custom_stride(base_test_env):
# Create test data
seq_length = 200
dummy_data = {
"sequence": torch.randint(0, 1000, (seq_length,), dtype=torch.int64),
"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
}
dataset_path = create_h5_dataset(test_dir, dummy_data, "stride_test_data")
save_h5(test_dir,"stride_test_data", dummy_data)
# Test with custom stride
custom_stride = 32
dataset = DatasetLoader.load(
train_type="seq",
load_path=dataset_path,
load_path=test_dir,
window_size=64,
stride=custom_stride
)
@ -152,7 +139,7 @@ def test_dataset_with_custom_stride(base_test_env):
# than with default stride (which equals window size)
default_stride_dataset = DatasetLoader.load(
train_type="seq",
load_path=dataset_path,
load_path=test_dir,
window_size=64,
)