fix: 修复一些运行时问题
This commit is contained in:
parent
6089a12cef
commit
80e17418b4
|
|
@ -1,15 +1,12 @@
|
|||
# cache
|
||||
__pycache__
|
||||
.pytest_cache
|
||||
# Ignore everything
|
||||
*
|
||||
|
||||
# params
|
||||
params/*
|
||||
# Allow directories to be traversed
|
||||
!*/
|
||||
|
||||
# vscode file
|
||||
.vscode
|
||||
|
||||
# build file
|
||||
build
|
||||
*.egg-info
|
||||
|
||||
*.ipynb
|
||||
# Allow specific file types and root files
|
||||
!*.py
|
||||
!*.md
|
||||
!*.png
|
||||
!LICENSE
|
||||
!pyproject.toml
|
||||
|
|
@ -12,14 +12,12 @@ from khaosz.parallel.setup import get_rank
|
|||
class Checkpoint:
|
||||
def __init__(
|
||||
self,
|
||||
optimizer_state_dict: Dict[str, Any],
|
||||
scheduler_state_dict: Optional[Dict[str, Any]] = None,
|
||||
state_dict: Dict[str, Any],
|
||||
epoch: int = 0,
|
||||
iteration: int = 0,
|
||||
metrics: Optional[Dict[str, list]] = None,
|
||||
):
|
||||
self.optimizer_state_dict = optimizer_state_dict
|
||||
self.scheduler_state_dict = scheduler_state_dict
|
||||
self.state_dict = state_dict
|
||||
self.epoch = epoch
|
||||
self.iteration = iteration
|
||||
self.metrics = metrics or {}
|
||||
|
|
@ -46,12 +44,8 @@ class Checkpoint:
|
|||
if save_metric_plot and self.metrics:
|
||||
self._plot_metrics(str(save_path))
|
||||
|
||||
state_dict = {
|
||||
"optimizer": self.optimizer_state_dict,
|
||||
"scheduler": self.scheduler_state_dict
|
||||
}
|
||||
with open(save_path / f"state_dict_rank_{get_rank()}.pt", "wb") as f:
|
||||
torch.save(state_dict, f)
|
||||
with open(save_path / f"state_dict.pt", "wb") as f:
|
||||
torch.save(self.state_dict, f)
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
|
|
@ -72,7 +66,7 @@ class Checkpoint:
|
|||
dist.broadcast_object_list(meta_list, src=0)
|
||||
meta = meta_list[0]
|
||||
|
||||
with open(save_path / f"state_dict_rank_{get_rank()}.pt", "rb") as f:
|
||||
with open(save_path / f"state_dict.pt", "rb") as f:
|
||||
state_dict = torch.load(f)
|
||||
|
||||
return cls(
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import h5py
|
||||
import torch
|
||||
import bisect
|
||||
|
||||
|
|
@ -78,8 +77,10 @@ class BaseDataset(Dataset, ABC):
|
|||
self.fetcher = MultiSegmentFetcher(self.segments)
|
||||
|
||||
def get_index(self, index: int) -> int:
|
||||
begin_idx = min(index * self.stride, self.total_samples - self.window_size - 1)
|
||||
end_idx = begin_idx + self.window_size
|
||||
assert self.total_samples > self.window_size
|
||||
|
||||
begin_idx = min(index * self.stride, self.total_samples - 1 - self.window_size)
|
||||
end_idx = min(begin_idx + self.window_size, self.total_samples - 1)
|
||||
|
||||
return begin_idx, end_idx
|
||||
|
||||
|
|
@ -91,7 +92,7 @@ class BaseDataset(Dataset, ABC):
|
|||
assert self.total_samples is not None
|
||||
if self.total_samples <= self.window_size:
|
||||
return 0
|
||||
return self.total_samples // self.stride + 1
|
||||
return (self.total_samples - 1 - self.window_size) // self.stride + 1
|
||||
|
||||
|
||||
class SeqDataset(BaseDataset):
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@ import os
|
|||
import h5py
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from pathlib import Path
|
||||
from torch import Tensor
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
|
|
@ -17,10 +19,7 @@ def save_h5(file_path: str, tensor_group: Dict[str, List[Tensor]]):
|
|||
arr = tensor.cpu().numpy()
|
||||
dset = grp.create_dataset(
|
||||
f'data_{idx}',
|
||||
data=arr,
|
||||
compression='gzip',
|
||||
compression_opts=4,
|
||||
shuffle=True
|
||||
data=arr
|
||||
)
|
||||
dset.attrs['numel'] = tensor.numel()
|
||||
|
||||
|
|
@ -28,15 +27,19 @@ def load_h5(file_path: str) -> Tuple[Dict[str, List[Tensor]], int]:
|
|||
tensor_group: Dict[str, List[Tensor]] = {}
|
||||
total_samples = 0
|
||||
|
||||
with h5py.File(file_path, 'r') as f:
|
||||
for key in f.keys():
|
||||
grp = f[key]
|
||||
dsets = []
|
||||
for dset_name in grp.keys():
|
||||
dset = grp[dset_name]
|
||||
dsets.append(torch.from_numpy(dset[:]).share_memory_())
|
||||
total_samples += dset.attrs.get('numel', np.prod(dset.shape))
|
||||
tensor_group[key] = dsets
|
||||
root_path = Path(file_path)
|
||||
h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5"))
|
||||
|
||||
for h5_file in h5_files:
|
||||
with h5py.File(h5_file, 'r') as f:
|
||||
for key in f.keys():
|
||||
grp = f[key]
|
||||
dsets = []
|
||||
for dset_name in grp.keys():
|
||||
dset = grp[dset_name]
|
||||
dsets.append(torch.from_numpy(dset[:]).share_memory_())
|
||||
total_samples += dset.attrs.get('numel', np.prod(dset.shape))
|
||||
tensor_group[key] = dsets
|
||||
|
||||
num_keys = max(len(tensor_group), 1)
|
||||
sample_per_key = total_samples // num_keys
|
||||
|
|
|
|||
|
|
@ -151,8 +151,8 @@ class SchedulerFactory:
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def load(optimizer, scedule_config: ScheduleConfig) -> BaseScheduler:
|
||||
kwargs = scedule_config.get_kwargs()
|
||||
def load(optimizer, schedule_config: ScheduleConfig) -> BaseScheduler:
|
||||
kwargs = schedule_config.get_kwargs()
|
||||
schedule_type = kwargs.pop("schedule_type")
|
||||
|
||||
if schedule_type == "cosine":
|
||||
|
|
|
|||
|
|
@ -108,10 +108,10 @@ class CheckpointCallback(TrainCallback):
|
|||
|
||||
def _save_checkpoint(self, context: 'TrainContext'):
|
||||
save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}")
|
||||
state_dict = self.state_dict_fn(context.model) if self.state_dict_fn else context.optimizer.state_dict()
|
||||
state_dict = self.state_dict_fn(context.model) if self.state_dict_fn else context.model.state_dict()
|
||||
|
||||
context.checkpoint = Checkpoint(
|
||||
optimizer_state_dict=state_dict,
|
||||
scheduler_state_dict=context.scheduler.state_dict() if context.scheduler else None,
|
||||
state_dict=state_dict,
|
||||
epoch=context.epoch,
|
||||
iteration=context.iteration
|
||||
)
|
||||
|
|
|
|||
|
|
@ -53,15 +53,13 @@ class TrainContextBuilder:
|
|||
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
|
||||
if checkpoint is None:
|
||||
checkpoint = Checkpoint(
|
||||
optimizer_state_dict=self._context.optimizer.state_dict(),
|
||||
scheduler_state_dict=self._context.scheduler.state_dict(),
|
||||
state_dict=self._context.model.state_dict(),
|
||||
)
|
||||
else:
|
||||
# resume from the assigned checkpoint or assigned iteration
|
||||
self._context.epoch = max(checkpoint.epoch, self.config.start_epoch)
|
||||
self._context.iteration = max(checkpoint.iteration, self.config.start_batch)
|
||||
self._context.optimizer.load_state_dict(checkpoint.optimizer_state_dict)
|
||||
self._context.scheduler.load_state_dict(checkpoint.scheduler_state_dict)
|
||||
self._context.model.load_state_dict(checkpoint.state_dict)
|
||||
|
||||
self._context.checkpoint = checkpoint
|
||||
return self
|
||||
|
|
|
|||
|
|
@ -3,10 +3,8 @@ import argparse
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torch.distributed.fsdp as fsdp
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
from torch.distributed.fsdp.api import StateDictType
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from typing import List, Optional
|
||||
from functools import partial
|
||||
from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig
|
||||
|
|
@ -59,19 +57,16 @@ def parse_args() -> argparse.Namespace:
|
|||
|
||||
return args
|
||||
|
||||
def fsdp_wrap(model: nn.Module):
|
||||
|
||||
fsdp_model = fsdp.FullyShardedDataParallel(
|
||||
def ddp_wrap(model: nn.Module):
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
model = model.to(device=f"cuda:{local_rank}", dtype=torch.bfloat16)
|
||||
ddp_model = DDP(
|
||||
model,
|
||||
sharding_strategy=fsdp.ShardingStrategy.SHARD_GRAD_OP,
|
||||
mixed_precision=fsdp.MixedPrecision(
|
||||
param_dtype=torch.bfloat16,
|
||||
reduce_dtype=torch.bfloat16,
|
||||
buffer_dtype=torch.bfloat16,
|
||||
),
|
||||
backward_prefetch=fsdp.BackwardPrefetch.BACKWARD_PRE
|
||||
device_ids=[local_rank],
|
||||
output_device=local_rank,
|
||||
find_unused_parameters=False
|
||||
)
|
||||
return fsdp_model
|
||||
return ddp_model
|
||||
|
||||
def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer:
|
||||
return optim.AdamW(model.parameters(), **kwargs)
|
||||
|
|
@ -79,12 +74,13 @@ def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer:
|
|||
def create_scheduler(optimizer: optim.Optimizer, **kwargs) -> optim.lr_scheduler.LRScheduler:
|
||||
return SchedulerFactory.load(optimizer, **kwargs)
|
||||
|
||||
def prepare_checkpoint(model: nn.Module, optimizer: optim.Optimizer) -> dict:
|
||||
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):
|
||||
model_state_dict = model.state_dict()
|
||||
optim_state_dict = FSDP.optim_state_dict(model, optimizer)
|
||||
def prepare_checkpoint(model: nn.Module) -> dict:
|
||||
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
||||
state_dict = model.module.state_dict()
|
||||
else:
|
||||
state_dict = model.state_dict()
|
||||
return state_dict
|
||||
|
||||
return model_state_dict, optim_state_dict
|
||||
|
||||
def train(
|
||||
train_type: str,
|
||||
|
|
@ -166,7 +162,7 @@ def train(
|
|||
num_workers=num_workers,
|
||||
pin_memory=pin_memory,
|
||||
nprocs=nprocs,
|
||||
parallel_wrapper=fsdp_wrap,
|
||||
parallel_wrapper=ddp_wrap,
|
||||
state_dict_fn=prepare_checkpoint,
|
||||
device_ids=device_ids,
|
||||
device_type=device_type,
|
||||
|
|
|
|||
Loading…
Reference in New Issue