fix: 修复一些运行时问题
This commit is contained in:
parent
6089a12cef
commit
80e17418b4
|
|
@ -1,15 +1,12 @@
|
||||||
# cache
|
# Ignore everything
|
||||||
__pycache__
|
*
|
||||||
.pytest_cache
|
|
||||||
|
|
||||||
# params
|
# Allow directories to be traversed
|
||||||
params/*
|
!*/
|
||||||
|
|
||||||
# vscode file
|
# Allow specific file types and root files
|
||||||
.vscode
|
!*.py
|
||||||
|
!*.md
|
||||||
# build file
|
!*.png
|
||||||
build
|
!LICENSE
|
||||||
*.egg-info
|
!pyproject.toml
|
||||||
|
|
||||||
*.ipynb
|
|
||||||
|
|
@ -12,14 +12,12 @@ from khaosz.parallel.setup import get_rank
|
||||||
class Checkpoint:
|
class Checkpoint:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
optimizer_state_dict: Dict[str, Any],
|
state_dict: Dict[str, Any],
|
||||||
scheduler_state_dict: Optional[Dict[str, Any]] = None,
|
|
||||||
epoch: int = 0,
|
epoch: int = 0,
|
||||||
iteration: int = 0,
|
iteration: int = 0,
|
||||||
metrics: Optional[Dict[str, list]] = None,
|
metrics: Optional[Dict[str, list]] = None,
|
||||||
):
|
):
|
||||||
self.optimizer_state_dict = optimizer_state_dict
|
self.state_dict = state_dict
|
||||||
self.scheduler_state_dict = scheduler_state_dict
|
|
||||||
self.epoch = epoch
|
self.epoch = epoch
|
||||||
self.iteration = iteration
|
self.iteration = iteration
|
||||||
self.metrics = metrics or {}
|
self.metrics = metrics or {}
|
||||||
|
|
@ -46,12 +44,8 @@ class Checkpoint:
|
||||||
if save_metric_plot and self.metrics:
|
if save_metric_plot and self.metrics:
|
||||||
self._plot_metrics(str(save_path))
|
self._plot_metrics(str(save_path))
|
||||||
|
|
||||||
state_dict = {
|
with open(save_path / f"state_dict.pt", "wb") as f:
|
||||||
"optimizer": self.optimizer_state_dict,
|
torch.save(self.state_dict, f)
|
||||||
"scheduler": self.scheduler_state_dict
|
|
||||||
}
|
|
||||||
with open(save_path / f"state_dict_rank_{get_rank()}.pt", "wb") as f:
|
|
||||||
torch.save(state_dict, f)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(
|
def load(
|
||||||
|
|
@ -72,7 +66,7 @@ class Checkpoint:
|
||||||
dist.broadcast_object_list(meta_list, src=0)
|
dist.broadcast_object_list(meta_list, src=0)
|
||||||
meta = meta_list[0]
|
meta = meta_list[0]
|
||||||
|
|
||||||
with open(save_path / f"state_dict_rank_{get_rank()}.pt", "rb") as f:
|
with open(save_path / f"state_dict.pt", "rb") as f:
|
||||||
state_dict = torch.load(f)
|
state_dict = torch.load(f)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
import h5py
|
|
||||||
import torch
|
import torch
|
||||||
import bisect
|
import bisect
|
||||||
|
|
||||||
|
|
@ -78,8 +77,10 @@ class BaseDataset(Dataset, ABC):
|
||||||
self.fetcher = MultiSegmentFetcher(self.segments)
|
self.fetcher = MultiSegmentFetcher(self.segments)
|
||||||
|
|
||||||
def get_index(self, index: int) -> int:
|
def get_index(self, index: int) -> int:
|
||||||
begin_idx = min(index * self.stride, self.total_samples - self.window_size - 1)
|
assert self.total_samples > self.window_size
|
||||||
end_idx = begin_idx + self.window_size
|
|
||||||
|
begin_idx = min(index * self.stride, self.total_samples - 1 - self.window_size)
|
||||||
|
end_idx = min(begin_idx + self.window_size, self.total_samples - 1)
|
||||||
|
|
||||||
return begin_idx, end_idx
|
return begin_idx, end_idx
|
||||||
|
|
||||||
|
|
@ -91,7 +92,7 @@ class BaseDataset(Dataset, ABC):
|
||||||
assert self.total_samples is not None
|
assert self.total_samples is not None
|
||||||
if self.total_samples <= self.window_size:
|
if self.total_samples <= self.window_size:
|
||||||
return 0
|
return 0
|
||||||
return self.total_samples // self.stride + 1
|
return (self.total_samples - 1 - self.window_size) // self.stride + 1
|
||||||
|
|
||||||
|
|
||||||
class SeqDataset(BaseDataset):
|
class SeqDataset(BaseDataset):
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,8 @@ import os
|
||||||
import h5py
|
import h5py
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
|
@ -17,10 +19,7 @@ def save_h5(file_path: str, tensor_group: Dict[str, List[Tensor]]):
|
||||||
arr = tensor.cpu().numpy()
|
arr = tensor.cpu().numpy()
|
||||||
dset = grp.create_dataset(
|
dset = grp.create_dataset(
|
||||||
f'data_{idx}',
|
f'data_{idx}',
|
||||||
data=arr,
|
data=arr
|
||||||
compression='gzip',
|
|
||||||
compression_opts=4,
|
|
||||||
shuffle=True
|
|
||||||
)
|
)
|
||||||
dset.attrs['numel'] = tensor.numel()
|
dset.attrs['numel'] = tensor.numel()
|
||||||
|
|
||||||
|
|
@ -28,15 +27,19 @@ def load_h5(file_path: str) -> Tuple[Dict[str, List[Tensor]], int]:
|
||||||
tensor_group: Dict[str, List[Tensor]] = {}
|
tensor_group: Dict[str, List[Tensor]] = {}
|
||||||
total_samples = 0
|
total_samples = 0
|
||||||
|
|
||||||
with h5py.File(file_path, 'r') as f:
|
root_path = Path(file_path)
|
||||||
for key in f.keys():
|
h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5"))
|
||||||
grp = f[key]
|
|
||||||
dsets = []
|
for h5_file in h5_files:
|
||||||
for dset_name in grp.keys():
|
with h5py.File(h5_file, 'r') as f:
|
||||||
dset = grp[dset_name]
|
for key in f.keys():
|
||||||
dsets.append(torch.from_numpy(dset[:]).share_memory_())
|
grp = f[key]
|
||||||
total_samples += dset.attrs.get('numel', np.prod(dset.shape))
|
dsets = []
|
||||||
tensor_group[key] = dsets
|
for dset_name in grp.keys():
|
||||||
|
dset = grp[dset_name]
|
||||||
|
dsets.append(torch.from_numpy(dset[:]).share_memory_())
|
||||||
|
total_samples += dset.attrs.get('numel', np.prod(dset.shape))
|
||||||
|
tensor_group[key] = dsets
|
||||||
|
|
||||||
num_keys = max(len(tensor_group), 1)
|
num_keys = max(len(tensor_group), 1)
|
||||||
sample_per_key = total_samples // num_keys
|
sample_per_key = total_samples // num_keys
|
||||||
|
|
|
||||||
|
|
@ -151,8 +151,8 @@ class SchedulerFactory:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load(optimizer, scedule_config: ScheduleConfig) -> BaseScheduler:
|
def load(optimizer, schedule_config: ScheduleConfig) -> BaseScheduler:
|
||||||
kwargs = scedule_config.get_kwargs()
|
kwargs = schedule_config.get_kwargs()
|
||||||
schedule_type = kwargs.pop("schedule_type")
|
schedule_type = kwargs.pop("schedule_type")
|
||||||
|
|
||||||
if schedule_type == "cosine":
|
if schedule_type == "cosine":
|
||||||
|
|
|
||||||
|
|
@ -108,10 +108,10 @@ class CheckpointCallback(TrainCallback):
|
||||||
|
|
||||||
def _save_checkpoint(self, context: 'TrainContext'):
|
def _save_checkpoint(self, context: 'TrainContext'):
|
||||||
save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}")
|
save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}")
|
||||||
state_dict = self.state_dict_fn(context.model) if self.state_dict_fn else context.optimizer.state_dict()
|
state_dict = self.state_dict_fn(context.model) if self.state_dict_fn else context.model.state_dict()
|
||||||
|
|
||||||
context.checkpoint = Checkpoint(
|
context.checkpoint = Checkpoint(
|
||||||
optimizer_state_dict=state_dict,
|
state_dict=state_dict,
|
||||||
scheduler_state_dict=context.scheduler.state_dict() if context.scheduler else None,
|
|
||||||
epoch=context.epoch,
|
epoch=context.epoch,
|
||||||
iteration=context.iteration
|
iteration=context.iteration
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -53,15 +53,13 @@ class TrainContextBuilder:
|
||||||
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
|
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
|
||||||
if checkpoint is None:
|
if checkpoint is None:
|
||||||
checkpoint = Checkpoint(
|
checkpoint = Checkpoint(
|
||||||
optimizer_state_dict=self._context.optimizer.state_dict(),
|
state_dict=self._context.model.state_dict(),
|
||||||
scheduler_state_dict=self._context.scheduler.state_dict(),
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# resume from the assigned checkpoint or assigned iteration
|
# resume from the assigned checkpoint or assigned iteration
|
||||||
self._context.epoch = max(checkpoint.epoch, self.config.start_epoch)
|
self._context.epoch = max(checkpoint.epoch, self.config.start_epoch)
|
||||||
self._context.iteration = max(checkpoint.iteration, self.config.start_batch)
|
self._context.iteration = max(checkpoint.iteration, self.config.start_batch)
|
||||||
self._context.optimizer.load_state_dict(checkpoint.optimizer_state_dict)
|
self._context.model.load_state_dict(checkpoint.state_dict)
|
||||||
self._context.scheduler.load_state_dict(checkpoint.scheduler_state_dict)
|
|
||||||
|
|
||||||
self._context.checkpoint = checkpoint
|
self._context.checkpoint = checkpoint
|
||||||
return self
|
return self
|
||||||
|
|
|
||||||
|
|
@ -3,10 +3,8 @@ import argparse
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
import torch.distributed.fsdp as fsdp
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
|
||||||
from torch.distributed.fsdp.api import StateDictType
|
|
||||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig
|
from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig
|
||||||
|
|
@ -59,19 +57,16 @@ def parse_args() -> argparse.Namespace:
|
||||||
|
|
||||||
return args
|
return args
|
||||||
|
|
||||||
def fsdp_wrap(model: nn.Module):
|
def ddp_wrap(model: nn.Module):
|
||||||
|
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||||
fsdp_model = fsdp.FullyShardedDataParallel(
|
model = model.to(device=f"cuda:{local_rank}", dtype=torch.bfloat16)
|
||||||
|
ddp_model = DDP(
|
||||||
model,
|
model,
|
||||||
sharding_strategy=fsdp.ShardingStrategy.SHARD_GRAD_OP,
|
device_ids=[local_rank],
|
||||||
mixed_precision=fsdp.MixedPrecision(
|
output_device=local_rank,
|
||||||
param_dtype=torch.bfloat16,
|
find_unused_parameters=False
|
||||||
reduce_dtype=torch.bfloat16,
|
|
||||||
buffer_dtype=torch.bfloat16,
|
|
||||||
),
|
|
||||||
backward_prefetch=fsdp.BackwardPrefetch.BACKWARD_PRE
|
|
||||||
)
|
)
|
||||||
return fsdp_model
|
return ddp_model
|
||||||
|
|
||||||
def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer:
|
def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer:
|
||||||
return optim.AdamW(model.parameters(), **kwargs)
|
return optim.AdamW(model.parameters(), **kwargs)
|
||||||
|
|
@ -79,12 +74,13 @@ def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer:
|
||||||
def create_scheduler(optimizer: optim.Optimizer, **kwargs) -> optim.lr_scheduler.LRScheduler:
|
def create_scheduler(optimizer: optim.Optimizer, **kwargs) -> optim.lr_scheduler.LRScheduler:
|
||||||
return SchedulerFactory.load(optimizer, **kwargs)
|
return SchedulerFactory.load(optimizer, **kwargs)
|
||||||
|
|
||||||
def prepare_checkpoint(model: nn.Module, optimizer: optim.Optimizer) -> dict:
|
def prepare_checkpoint(model: nn.Module) -> dict:
|
||||||
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):
|
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
||||||
model_state_dict = model.state_dict()
|
state_dict = model.module.state_dict()
|
||||||
optim_state_dict = FSDP.optim_state_dict(model, optimizer)
|
else:
|
||||||
|
state_dict = model.state_dict()
|
||||||
return model_state_dict, optim_state_dict
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
train_type: str,
|
train_type: str,
|
||||||
|
|
@ -166,7 +162,7 @@ def train(
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
nprocs=nprocs,
|
nprocs=nprocs,
|
||||||
parallel_wrapper=fsdp_wrap,
|
parallel_wrapper=ddp_wrap,
|
||||||
state_dict_fn=prepare_checkpoint,
|
state_dict_fn=prepare_checkpoint,
|
||||||
device_ids=device_ids,
|
device_ids=device_ids,
|
||||||
device_type=device_type,
|
device_type=device_type,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue