fix: 修复 dataset 和 checkpoint 的 bug
This commit is contained in:
parent
80e17418b4
commit
8a8d6369bc
|
|
@ -70,8 +70,7 @@ class Checkpoint:
|
|||
state_dict = torch.load(f)
|
||||
|
||||
return cls(
|
||||
optimizer_state_dict=state_dict["optimizer"],
|
||||
scheduler_state_dict=state_dict["scheduler"],
|
||||
state_dict=state_dict,
|
||||
epoch=meta["epoch"],
|
||||
iteration=meta["iteration"],
|
||||
metrics=meta.get("metrics", {}),
|
||||
|
|
|
|||
|
|
@ -19,6 +19,9 @@ class BaseSegmentFetcher:
|
|||
self.cum_lengths.append(total)
|
||||
self.total_length = total if segments else 0
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.total_length
|
||||
|
||||
def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
||||
if not (0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length):
|
||||
raise ValueError("begin_idx or end_idx out of bounds")
|
||||
|
|
@ -49,6 +52,10 @@ class MultiSegmentFetcher:
|
|||
for key, segments in muti_segments.items()
|
||||
}
|
||||
|
||||
def __len__(self) -> int:
|
||||
len_list = [len(seg) for seg in self.muti_fetchers.values()]
|
||||
return min(len_list)
|
||||
|
||||
def key_fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]) -> Dict:
|
||||
fetch_dict = {}
|
||||
keys = [keys] if isinstance(keys, str) else keys
|
||||
|
|
@ -73,8 +80,9 @@ class BaseDataset(Dataset, ABC):
|
|||
self.total_samples = None
|
||||
|
||||
def load(self, load_path: str):
|
||||
self.segments, self.total_samples = load_h5(load_path)
|
||||
self.segments = load_h5(load_path)
|
||||
self.fetcher = MultiSegmentFetcher(self.segments)
|
||||
self.total_samples = len(self.fetcher)
|
||||
|
||||
def get_index(self, index: int) -> int:
|
||||
assert self.total_samples > self.window_size
|
||||
|
|
|
|||
|
|
@ -8,20 +8,17 @@ from torch import Tensor
|
|||
from typing import Dict, List, Tuple
|
||||
|
||||
|
||||
def save_h5(file_path: str, tensor_group: Dict[str, List[Tensor]]):
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
with h5py.File(file_path, 'w') as f:
|
||||
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
|
||||
os.makedirs(file_path, exist_ok=True)
|
||||
full_file_path = os.path.join(file_path, f"{file_name}.h5")
|
||||
with h5py.File(full_file_path, 'w') as f:
|
||||
for key, tensors in tensor_group.items():
|
||||
grp = f.create_group(key)
|
||||
grp.attrs['num_tensors'] = len(tensors)
|
||||
|
||||
for idx, tensor in enumerate(tensors):
|
||||
arr = tensor.cpu().numpy()
|
||||
dset = grp.create_dataset(
|
||||
f'data_{idx}',
|
||||
data=arr
|
||||
)
|
||||
dset.attrs['numel'] = tensor.numel()
|
||||
grp.create_dataset(f'data_{idx}', data=arr)
|
||||
|
||||
def load_h5(file_path: str) -> Tuple[Dict[str, List[Tensor]], int]:
|
||||
tensor_group: Dict[str, List[Tensor]] = {}
|
||||
|
|
@ -38,10 +35,6 @@ def load_h5(file_path: str) -> Tuple[Dict[str, List[Tensor]], int]:
|
|||
for dset_name in grp.keys():
|
||||
dset = grp[dset_name]
|
||||
dsets.append(torch.from_numpy(dset[:]).share_memory_())
|
||||
total_samples += dset.attrs.get('numel', np.prod(dset.shape))
|
||||
tensor_group[key] = dsets
|
||||
|
||||
num_keys = max(len(tensor_group), 1)
|
||||
sample_per_key = total_samples // num_keys
|
||||
|
||||
return tensor_group, sample_per_key
|
||||
return tensor_group
|
||||
|
|
@ -1,12 +1,12 @@
|
|||
import os
|
||||
import torch
|
||||
import tempfile
|
||||
import torch.distributed as dist
|
||||
|
||||
from pathlib import Path
|
||||
from torch.optim import AdamW
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
from khaosz.data.checkpoint import Checkpoint
|
||||
from khaosz.parallel.setup import spawn_parallel_fn
|
||||
from khaosz.parallel.setup import get_rank, spawn_parallel_fn
|
||||
|
||||
def test_single_process():
|
||||
model = torch.nn.Linear(10, 5)
|
||||
|
|
@ -26,8 +26,7 @@ def test_single_process():
|
|||
scheduler.step()
|
||||
|
||||
checkpoint = Checkpoint(
|
||||
optimizer_state_dict=optimizer.state_dict(),
|
||||
scheduler_state_dict=scheduler.state_dict(),
|
||||
state_dict=model.state_dict(),
|
||||
epoch=3,
|
||||
iteration=30,
|
||||
metrics={
|
||||
|
|
@ -45,21 +44,14 @@ def test_single_process():
|
|||
assert loaded_checkpoint.iteration == 30
|
||||
assert loaded_checkpoint.metrics["loss"] == [0.5, 0.4, 0.3, 0.2, 0.1]
|
||||
|
||||
assert 'param_groups' in loaded_checkpoint.optimizer_state_dict
|
||||
assert 'state' in loaded_checkpoint.optimizer_state_dict
|
||||
|
||||
png_files = list(Path(tmpdir).glob("*.png"))
|
||||
assert png_files
|
||||
|
||||
def simple_training():
|
||||
rank = int(os.environ.get('LOCAL_RANK', 0))
|
||||
|
||||
# 简单的训练逻辑
|
||||
model = torch.nn.Linear(10, 5)
|
||||
optimizer = AdamW(model.parameters(), lr=1e-3)
|
||||
scheduler = CosineAnnealingLR(optimizer, T_max=10)
|
||||
|
||||
# 训练步骤
|
||||
for epoch in range(2):
|
||||
for iteration in range(5):
|
||||
x = torch.randn(16, 10)
|
||||
|
|
@ -71,18 +63,29 @@ def simple_training():
|
|||
scheduler.step()
|
||||
|
||||
checkpoint = Checkpoint(
|
||||
optimizer_state_dict=optimizer.state_dict(),
|
||||
scheduler_state_dict=scheduler.state_dict(),
|
||||
state_dict=model.state_dict(),
|
||||
epoch=2,
|
||||
iteration=10,
|
||||
metrics={"loss": [0.3, 0.2, 0.1]}
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
checkpoint.save(tmpdir)
|
||||
loaded = Checkpoint.load(tmpdir)
|
||||
assert loaded.epoch == 2
|
||||
print(f"Rank {rank}: Checkpoint test passed")
|
||||
rank = get_rank()
|
||||
|
||||
if rank == 0:
|
||||
shared_dir = tempfile.mkdtemp()
|
||||
checkpoint.save(shared_dir)
|
||||
else:
|
||||
shared_dir = None
|
||||
|
||||
|
||||
if dist.is_initialized():
|
||||
dir_list = [shared_dir]
|
||||
dist.broadcast_object_list(dir_list, src=0)
|
||||
shared_dir = dir_list[0]
|
||||
|
||||
|
||||
loaded = Checkpoint.load(shared_dir)
|
||||
assert loaded.epoch == 2
|
||||
|
||||
def test_multi_process():
|
||||
spawn_parallel_fn(
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
|
@ -6,18 +5,6 @@ from khaosz.data.file import save_h5
|
|||
from khaosz.data.dataset import *
|
||||
|
||||
|
||||
def create_h5_dataset(dir_path, data_dict, dataset_name):
|
||||
"""Helper function to create HDF5 dataset for testing"""
|
||||
dataset_path = os.path.join(dir_path, f"{dataset_name}.h5")
|
||||
|
||||
# Convert data_dict to the format expected by save_h5
|
||||
# save_h5 expects a list of tensors for each key
|
||||
tensor_group = {key: [tensor] for key, tensor in data_dict.items()}
|
||||
|
||||
save_h5(dataset_path, tensor_group)
|
||||
|
||||
return dataset_path
|
||||
|
||||
|
||||
def test_dataset_loader_random_paths(base_test_env):
|
||||
"""Test dataset loader with multiple random paths"""
|
||||
|
|
@ -27,23 +14,23 @@ def test_dataset_loader_random_paths(base_test_env):
|
|||
num_files = np.random.randint(2, 5)
|
||||
|
||||
for i in range(num_files):
|
||||
seq_length = np.random.randint(100, 200)
|
||||
seq_length = np.random.randint(200, 400)
|
||||
dummy_data = {
|
||||
"sequence": torch.randint(0, 1000, (seq_length,), dtype=torch.int64),
|
||||
"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64) for _ in range(10)],
|
||||
}
|
||||
dataset_path = create_h5_dataset(test_dir, dummy_data, f"test_data_{i}")
|
||||
save_h5(test_dir, f"data_{i}", dummy_data)
|
||||
|
||||
# Test loading with multiple paths
|
||||
loaded_dataset = DatasetLoader.load(
|
||||
train_type="seq",
|
||||
load_path=dataset_path,
|
||||
load_path=test_dir,
|
||||
window_size=64,
|
||||
)
|
||||
assert loaded_dataset is not None
|
||||
assert len(loaded_dataset) > 0
|
||||
|
||||
# Test that we can get items without errors
|
||||
for i in range(min(3, len(loaded_dataset))):
|
||||
for i in range(len(loaded_dataset)):
|
||||
item = loaded_dataset[i]
|
||||
assert "input_ids" in item
|
||||
assert "target_ids" in item
|
||||
|
|
@ -59,18 +46,18 @@ def test_dpo_strategy_with_random_data(base_test_env):
|
|||
seq_length = np.random.randint(100, 200)
|
||||
|
||||
dummy_data = {
|
||||
"chosen": torch.randint(0, 1000, (seq_length,), dtype=torch.int64),
|
||||
"rejected": torch.randint(0, 1000, (seq_length,), dtype=torch.int64),
|
||||
"chosen_mask": torch.ones(seq_length, dtype=torch.bool),
|
||||
"rejected_mask": torch.ones(seq_length, dtype=torch.bool)
|
||||
"chosen": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
|
||||
"rejected": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
|
||||
"chosen_mask": [torch.ones(seq_length, dtype=torch.bool)],
|
||||
"rejected_mask": [torch.ones(seq_length, dtype=torch.bool)]
|
||||
}
|
||||
|
||||
dataset_path = create_h5_dataset(test_dir, dummy_data, "dpo_data")
|
||||
save_h5(test_dir, "dpo_data", dummy_data)
|
||||
|
||||
# Load DPO dataset
|
||||
dpo_dataset = DatasetLoader.load(
|
||||
train_type="dpo",
|
||||
load_path=dataset_path,
|
||||
load_path=test_dir,
|
||||
window_size=64,
|
||||
)
|
||||
|
||||
|
|
@ -97,16 +84,16 @@ def test_sft_dataset_with_random_data(base_test_env):
|
|||
seq_length = np.random.randint(100, 200)
|
||||
|
||||
dummy_data = {
|
||||
"sequence": torch.randint(0, 1000, (seq_length,), dtype=torch.int64),
|
||||
"loss_mask": torch.ones(seq_length, dtype=torch.bool)
|
||||
"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
|
||||
"loss_mask": [torch.ones(seq_length, dtype=torch.bool)]
|
||||
}
|
||||
|
||||
dataset_path = create_h5_dataset(test_dir, dummy_data, "sft_data")
|
||||
save_h5(test_dir, "sft_data", dummy_data)
|
||||
|
||||
# Load SFT dataset
|
||||
sft_dataset = DatasetLoader.load(
|
||||
train_type="sft",
|
||||
load_path=dataset_path,
|
||||
load_path=test_dir,
|
||||
window_size=64,
|
||||
)
|
||||
|
||||
|
|
@ -131,16 +118,16 @@ def test_dataset_with_custom_stride(base_test_env):
|
|||
# Create test data
|
||||
seq_length = 200
|
||||
dummy_data = {
|
||||
"sequence": torch.randint(0, 1000, (seq_length,), dtype=torch.int64),
|
||||
"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
|
||||
}
|
||||
|
||||
dataset_path = create_h5_dataset(test_dir, dummy_data, "stride_test_data")
|
||||
save_h5(test_dir,"stride_test_data", dummy_data)
|
||||
|
||||
# Test with custom stride
|
||||
custom_stride = 32
|
||||
dataset = DatasetLoader.load(
|
||||
train_type="seq",
|
||||
load_path=dataset_path,
|
||||
load_path=test_dir,
|
||||
window_size=64,
|
||||
stride=custom_stride
|
||||
)
|
||||
|
|
@ -152,7 +139,7 @@ def test_dataset_with_custom_stride(base_test_env):
|
|||
# than with default stride (which equals window size)
|
||||
default_stride_dataset = DatasetLoader.load(
|
||||
train_type="seq",
|
||||
load_path=dataset_path,
|
||||
load_path=test_dir,
|
||||
window_size=64,
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue