fix: 修复 dataset 和 checkpoint 的 bug
This commit is contained in:
parent
80e17418b4
commit
8a8d6369bc
|
|
@ -70,8 +70,7 @@ class Checkpoint:
|
||||||
state_dict = torch.load(f)
|
state_dict = torch.load(f)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
optimizer_state_dict=state_dict["optimizer"],
|
state_dict=state_dict,
|
||||||
scheduler_state_dict=state_dict["scheduler"],
|
|
||||||
epoch=meta["epoch"],
|
epoch=meta["epoch"],
|
||||||
iteration=meta["iteration"],
|
iteration=meta["iteration"],
|
||||||
metrics=meta.get("metrics", {}),
|
metrics=meta.get("metrics", {}),
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,9 @@ class BaseSegmentFetcher:
|
||||||
total += len(seg)
|
total += len(seg)
|
||||||
self.cum_lengths.append(total)
|
self.cum_lengths.append(total)
|
||||||
self.total_length = total if segments else 0
|
self.total_length = total if segments else 0
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return self.total_length
|
||||||
|
|
||||||
def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
||||||
if not (0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length):
|
if not (0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length):
|
||||||
|
|
@ -48,6 +51,10 @@ class MultiSegmentFetcher:
|
||||||
key: BaseSegmentFetcher(segments)
|
key: BaseSegmentFetcher(segments)
|
||||||
for key, segments in muti_segments.items()
|
for key, segments in muti_segments.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
len_list = [len(seg) for seg in self.muti_fetchers.values()]
|
||||||
|
return min(len_list)
|
||||||
|
|
||||||
def key_fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]) -> Dict:
|
def key_fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]) -> Dict:
|
||||||
fetch_dict = {}
|
fetch_dict = {}
|
||||||
|
|
@ -73,8 +80,9 @@ class BaseDataset(Dataset, ABC):
|
||||||
self.total_samples = None
|
self.total_samples = None
|
||||||
|
|
||||||
def load(self, load_path: str):
|
def load(self, load_path: str):
|
||||||
self.segments, self.total_samples = load_h5(load_path)
|
self.segments = load_h5(load_path)
|
||||||
self.fetcher = MultiSegmentFetcher(self.segments)
|
self.fetcher = MultiSegmentFetcher(self.segments)
|
||||||
|
self.total_samples = len(self.fetcher)
|
||||||
|
|
||||||
def get_index(self, index: int) -> int:
|
def get_index(self, index: int) -> int:
|
||||||
assert self.total_samples > self.window_size
|
assert self.total_samples > self.window_size
|
||||||
|
|
|
||||||
|
|
@ -8,20 +8,17 @@ from torch import Tensor
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
|
||||||
def save_h5(file_path: str, tensor_group: Dict[str, List[Tensor]]):
|
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
|
||||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
os.makedirs(file_path, exist_ok=True)
|
||||||
with h5py.File(file_path, 'w') as f:
|
full_file_path = os.path.join(file_path, f"{file_name}.h5")
|
||||||
|
with h5py.File(full_file_path, 'w') as f:
|
||||||
for key, tensors in tensor_group.items():
|
for key, tensors in tensor_group.items():
|
||||||
grp = f.create_group(key)
|
grp = f.create_group(key)
|
||||||
grp.attrs['num_tensors'] = len(tensors)
|
grp.attrs['num_tensors'] = len(tensors)
|
||||||
|
|
||||||
for idx, tensor in enumerate(tensors):
|
for idx, tensor in enumerate(tensors):
|
||||||
arr = tensor.cpu().numpy()
|
arr = tensor.cpu().numpy()
|
||||||
dset = grp.create_dataset(
|
grp.create_dataset(f'data_{idx}', data=arr)
|
||||||
f'data_{idx}',
|
|
||||||
data=arr
|
|
||||||
)
|
|
||||||
dset.attrs['numel'] = tensor.numel()
|
|
||||||
|
|
||||||
def load_h5(file_path: str) -> Tuple[Dict[str, List[Tensor]], int]:
|
def load_h5(file_path: str) -> Tuple[Dict[str, List[Tensor]], int]:
|
||||||
tensor_group: Dict[str, List[Tensor]] = {}
|
tensor_group: Dict[str, List[Tensor]] = {}
|
||||||
|
|
@ -38,10 +35,6 @@ def load_h5(file_path: str) -> Tuple[Dict[str, List[Tensor]], int]:
|
||||||
for dset_name in grp.keys():
|
for dset_name in grp.keys():
|
||||||
dset = grp[dset_name]
|
dset = grp[dset_name]
|
||||||
dsets.append(torch.from_numpy(dset[:]).share_memory_())
|
dsets.append(torch.from_numpy(dset[:]).share_memory_())
|
||||||
total_samples += dset.attrs.get('numel', np.prod(dset.shape))
|
|
||||||
tensor_group[key] = dsets
|
tensor_group[key] = dsets
|
||||||
|
|
||||||
num_keys = max(len(tensor_group), 1)
|
return tensor_group
|
||||||
sample_per_key = total_samples // num_keys
|
|
||||||
|
|
||||||
return tensor_group, sample_per_key
|
|
||||||
|
|
@ -1,12 +1,12 @@
|
||||||
import os
|
|
||||||
import torch
|
import torch
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from torch.optim import AdamW
|
from torch.optim import AdamW
|
||||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||||
from khaosz.data.checkpoint import Checkpoint
|
from khaosz.data.checkpoint import Checkpoint
|
||||||
from khaosz.parallel.setup import spawn_parallel_fn
|
from khaosz.parallel.setup import get_rank, spawn_parallel_fn
|
||||||
|
|
||||||
def test_single_process():
|
def test_single_process():
|
||||||
model = torch.nn.Linear(10, 5)
|
model = torch.nn.Linear(10, 5)
|
||||||
|
|
@ -26,8 +26,7 @@ def test_single_process():
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
|
|
||||||
checkpoint = Checkpoint(
|
checkpoint = Checkpoint(
|
||||||
optimizer_state_dict=optimizer.state_dict(),
|
state_dict=model.state_dict(),
|
||||||
scheduler_state_dict=scheduler.state_dict(),
|
|
||||||
epoch=3,
|
epoch=3,
|
||||||
iteration=30,
|
iteration=30,
|
||||||
metrics={
|
metrics={
|
||||||
|
|
@ -45,21 +44,14 @@ def test_single_process():
|
||||||
assert loaded_checkpoint.iteration == 30
|
assert loaded_checkpoint.iteration == 30
|
||||||
assert loaded_checkpoint.metrics["loss"] == [0.5, 0.4, 0.3, 0.2, 0.1]
|
assert loaded_checkpoint.metrics["loss"] == [0.5, 0.4, 0.3, 0.2, 0.1]
|
||||||
|
|
||||||
assert 'param_groups' in loaded_checkpoint.optimizer_state_dict
|
|
||||||
assert 'state' in loaded_checkpoint.optimizer_state_dict
|
|
||||||
|
|
||||||
png_files = list(Path(tmpdir).glob("*.png"))
|
png_files = list(Path(tmpdir).glob("*.png"))
|
||||||
assert png_files
|
assert png_files
|
||||||
|
|
||||||
def simple_training():
|
def simple_training():
|
||||||
rank = int(os.environ.get('LOCAL_RANK', 0))
|
|
||||||
|
|
||||||
# 简单的训练逻辑
|
|
||||||
model = torch.nn.Linear(10, 5)
|
model = torch.nn.Linear(10, 5)
|
||||||
optimizer = AdamW(model.parameters(), lr=1e-3)
|
optimizer = AdamW(model.parameters(), lr=1e-3)
|
||||||
scheduler = CosineAnnealingLR(optimizer, T_max=10)
|
scheduler = CosineAnnealingLR(optimizer, T_max=10)
|
||||||
|
|
||||||
# 训练步骤
|
|
||||||
for epoch in range(2):
|
for epoch in range(2):
|
||||||
for iteration in range(5):
|
for iteration in range(5):
|
||||||
x = torch.randn(16, 10)
|
x = torch.randn(16, 10)
|
||||||
|
|
@ -71,18 +63,29 @@ def simple_training():
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
|
|
||||||
checkpoint = Checkpoint(
|
checkpoint = Checkpoint(
|
||||||
optimizer_state_dict=optimizer.state_dict(),
|
state_dict=model.state_dict(),
|
||||||
scheduler_state_dict=scheduler.state_dict(),
|
|
||||||
epoch=2,
|
epoch=2,
|
||||||
iteration=10,
|
iteration=10,
|
||||||
metrics={"loss": [0.3, 0.2, 0.1]}
|
metrics={"loss": [0.3, 0.2, 0.1]}
|
||||||
)
|
)
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
rank = get_rank()
|
||||||
checkpoint.save(tmpdir)
|
|
||||||
loaded = Checkpoint.load(tmpdir)
|
if rank == 0:
|
||||||
assert loaded.epoch == 2
|
shared_dir = tempfile.mkdtemp()
|
||||||
print(f"Rank {rank}: Checkpoint test passed")
|
checkpoint.save(shared_dir)
|
||||||
|
else:
|
||||||
|
shared_dir = None
|
||||||
|
|
||||||
|
|
||||||
|
if dist.is_initialized():
|
||||||
|
dir_list = [shared_dir]
|
||||||
|
dist.broadcast_object_list(dir_list, src=0)
|
||||||
|
shared_dir = dir_list[0]
|
||||||
|
|
||||||
|
|
||||||
|
loaded = Checkpoint.load(shared_dir)
|
||||||
|
assert loaded.epoch == 2
|
||||||
|
|
||||||
def test_multi_process():
|
def test_multi_process():
|
||||||
spawn_parallel_fn(
|
spawn_parallel_fn(
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
import os
|
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
@ -6,18 +5,6 @@ from khaosz.data.file import save_h5
|
||||||
from khaosz.data.dataset import *
|
from khaosz.data.dataset import *
|
||||||
|
|
||||||
|
|
||||||
def create_h5_dataset(dir_path, data_dict, dataset_name):
|
|
||||||
"""Helper function to create HDF5 dataset for testing"""
|
|
||||||
dataset_path = os.path.join(dir_path, f"{dataset_name}.h5")
|
|
||||||
|
|
||||||
# Convert data_dict to the format expected by save_h5
|
|
||||||
# save_h5 expects a list of tensors for each key
|
|
||||||
tensor_group = {key: [tensor] for key, tensor in data_dict.items()}
|
|
||||||
|
|
||||||
save_h5(dataset_path, tensor_group)
|
|
||||||
|
|
||||||
return dataset_path
|
|
||||||
|
|
||||||
|
|
||||||
def test_dataset_loader_random_paths(base_test_env):
|
def test_dataset_loader_random_paths(base_test_env):
|
||||||
"""Test dataset loader with multiple random paths"""
|
"""Test dataset loader with multiple random paths"""
|
||||||
|
|
@ -27,23 +14,23 @@ def test_dataset_loader_random_paths(base_test_env):
|
||||||
num_files = np.random.randint(2, 5)
|
num_files = np.random.randint(2, 5)
|
||||||
|
|
||||||
for i in range(num_files):
|
for i in range(num_files):
|
||||||
seq_length = np.random.randint(100, 200)
|
seq_length = np.random.randint(200, 400)
|
||||||
dummy_data = {
|
dummy_data = {
|
||||||
"sequence": torch.randint(0, 1000, (seq_length,), dtype=torch.int64),
|
"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64) for _ in range(10)],
|
||||||
}
|
}
|
||||||
dataset_path = create_h5_dataset(test_dir, dummy_data, f"test_data_{i}")
|
save_h5(test_dir, f"data_{i}", dummy_data)
|
||||||
|
|
||||||
# Test loading with multiple paths
|
# Test loading with multiple paths
|
||||||
loaded_dataset = DatasetLoader.load(
|
loaded_dataset = DatasetLoader.load(
|
||||||
train_type="seq",
|
train_type="seq",
|
||||||
load_path=dataset_path,
|
load_path=test_dir,
|
||||||
window_size=64,
|
window_size=64,
|
||||||
)
|
)
|
||||||
assert loaded_dataset is not None
|
assert loaded_dataset is not None
|
||||||
assert len(loaded_dataset) > 0
|
assert len(loaded_dataset) > 0
|
||||||
|
|
||||||
# Test that we can get items without errors
|
# Test that we can get items without errors
|
||||||
for i in range(min(3, len(loaded_dataset))):
|
for i in range(len(loaded_dataset)):
|
||||||
item = loaded_dataset[i]
|
item = loaded_dataset[i]
|
||||||
assert "input_ids" in item
|
assert "input_ids" in item
|
||||||
assert "target_ids" in item
|
assert "target_ids" in item
|
||||||
|
|
@ -59,18 +46,18 @@ def test_dpo_strategy_with_random_data(base_test_env):
|
||||||
seq_length = np.random.randint(100, 200)
|
seq_length = np.random.randint(100, 200)
|
||||||
|
|
||||||
dummy_data = {
|
dummy_data = {
|
||||||
"chosen": torch.randint(0, 1000, (seq_length,), dtype=torch.int64),
|
"chosen": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
|
||||||
"rejected": torch.randint(0, 1000, (seq_length,), dtype=torch.int64),
|
"rejected": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
|
||||||
"chosen_mask": torch.ones(seq_length, dtype=torch.bool),
|
"chosen_mask": [torch.ones(seq_length, dtype=torch.bool)],
|
||||||
"rejected_mask": torch.ones(seq_length, dtype=torch.bool)
|
"rejected_mask": [torch.ones(seq_length, dtype=torch.bool)]
|
||||||
}
|
}
|
||||||
|
|
||||||
dataset_path = create_h5_dataset(test_dir, dummy_data, "dpo_data")
|
save_h5(test_dir, "dpo_data", dummy_data)
|
||||||
|
|
||||||
# Load DPO dataset
|
# Load DPO dataset
|
||||||
dpo_dataset = DatasetLoader.load(
|
dpo_dataset = DatasetLoader.load(
|
||||||
train_type="dpo",
|
train_type="dpo",
|
||||||
load_path=dataset_path,
|
load_path=test_dir,
|
||||||
window_size=64,
|
window_size=64,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -97,16 +84,16 @@ def test_sft_dataset_with_random_data(base_test_env):
|
||||||
seq_length = np.random.randint(100, 200)
|
seq_length = np.random.randint(100, 200)
|
||||||
|
|
||||||
dummy_data = {
|
dummy_data = {
|
||||||
"sequence": torch.randint(0, 1000, (seq_length,), dtype=torch.int64),
|
"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
|
||||||
"loss_mask": torch.ones(seq_length, dtype=torch.bool)
|
"loss_mask": [torch.ones(seq_length, dtype=torch.bool)]
|
||||||
}
|
}
|
||||||
|
|
||||||
dataset_path = create_h5_dataset(test_dir, dummy_data, "sft_data")
|
save_h5(test_dir, "sft_data", dummy_data)
|
||||||
|
|
||||||
# Load SFT dataset
|
# Load SFT dataset
|
||||||
sft_dataset = DatasetLoader.load(
|
sft_dataset = DatasetLoader.load(
|
||||||
train_type="sft",
|
train_type="sft",
|
||||||
load_path=dataset_path,
|
load_path=test_dir,
|
||||||
window_size=64,
|
window_size=64,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -131,16 +118,16 @@ def test_dataset_with_custom_stride(base_test_env):
|
||||||
# Create test data
|
# Create test data
|
||||||
seq_length = 200
|
seq_length = 200
|
||||||
dummy_data = {
|
dummy_data = {
|
||||||
"sequence": torch.randint(0, 1000, (seq_length,), dtype=torch.int64),
|
"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
|
||||||
}
|
}
|
||||||
|
|
||||||
dataset_path = create_h5_dataset(test_dir, dummy_data, "stride_test_data")
|
save_h5(test_dir,"stride_test_data", dummy_data)
|
||||||
|
|
||||||
# Test with custom stride
|
# Test with custom stride
|
||||||
custom_stride = 32
|
custom_stride = 32
|
||||||
dataset = DatasetLoader.load(
|
dataset = DatasetLoader.load(
|
||||||
train_type="seq",
|
train_type="seq",
|
||||||
load_path=dataset_path,
|
load_path=test_dir,
|
||||||
window_size=64,
|
window_size=64,
|
||||||
stride=custom_stride
|
stride=custom_stride
|
||||||
)
|
)
|
||||||
|
|
@ -152,7 +139,7 @@ def test_dataset_with_custom_stride(base_test_env):
|
||||||
# than with default stride (which equals window size)
|
# than with default stride (which equals window size)
|
||||||
default_stride_dataset = DatasetLoader.load(
|
default_stride_dataset = DatasetLoader.load(
|
||||||
train_type="seq",
|
train_type="seq",
|
||||||
load_path=dataset_path,
|
load_path=test_dir,
|
||||||
window_size=64,
|
window_size=64,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue