fix: 修复一些运行时问题

This commit is contained in:
ViperEkura 2026-03-01 15:47:07 +08:00
parent 6089a12cef
commit 80e17418b4
8 changed files with 60 additions and 71 deletions

23
.gitignore vendored
View File

@ -1,15 +1,12 @@
# cache # Ignore everything
__pycache__ *
.pytest_cache
# params # Allow directories to be traversed
params/* !*/
# vscode file # Allow specific file types and root files
.vscode !*.py
!*.md
# build file !*.png
build !LICENSE
*.egg-info !pyproject.toml
*.ipynb

View File

@ -12,14 +12,12 @@ from khaosz.parallel.setup import get_rank
class Checkpoint: class Checkpoint:
def __init__( def __init__(
self, self,
optimizer_state_dict: Dict[str, Any], state_dict: Dict[str, Any],
scheduler_state_dict: Optional[Dict[str, Any]] = None,
epoch: int = 0, epoch: int = 0,
iteration: int = 0, iteration: int = 0,
metrics: Optional[Dict[str, list]] = None, metrics: Optional[Dict[str, list]] = None,
): ):
self.optimizer_state_dict = optimizer_state_dict self.state_dict = state_dict
self.scheduler_state_dict = scheduler_state_dict
self.epoch = epoch self.epoch = epoch
self.iteration = iteration self.iteration = iteration
self.metrics = metrics or {} self.metrics = metrics or {}
@ -46,12 +44,8 @@ class Checkpoint:
if save_metric_plot and self.metrics: if save_metric_plot and self.metrics:
self._plot_metrics(str(save_path)) self._plot_metrics(str(save_path))
state_dict = { with open(save_path / f"state_dict.pt", "wb") as f:
"optimizer": self.optimizer_state_dict, torch.save(self.state_dict, f)
"scheduler": self.scheduler_state_dict
}
with open(save_path / f"state_dict_rank_{get_rank()}.pt", "wb") as f:
torch.save(state_dict, f)
@classmethod @classmethod
def load( def load(
@ -72,7 +66,7 @@ class Checkpoint:
dist.broadcast_object_list(meta_list, src=0) dist.broadcast_object_list(meta_list, src=0)
meta = meta_list[0] meta = meta_list[0]
with open(save_path / f"state_dict_rank_{get_rank()}.pt", "rb") as f: with open(save_path / f"state_dict.pt", "rb") as f:
state_dict = torch.load(f) state_dict = torch.load(f)
return cls( return cls(

View File

@ -1,4 +1,3 @@
import h5py
import torch import torch
import bisect import bisect
@ -78,8 +77,10 @@ class BaseDataset(Dataset, ABC):
self.fetcher = MultiSegmentFetcher(self.segments) self.fetcher = MultiSegmentFetcher(self.segments)
def get_index(self, index: int) -> int: def get_index(self, index: int) -> int:
begin_idx = min(index * self.stride, self.total_samples - self.window_size - 1) assert self.total_samples > self.window_size
end_idx = begin_idx + self.window_size
begin_idx = min(index * self.stride, self.total_samples - 1 - self.window_size)
end_idx = min(begin_idx + self.window_size, self.total_samples - 1)
return begin_idx, end_idx return begin_idx, end_idx
@ -91,7 +92,7 @@ class BaseDataset(Dataset, ABC):
assert self.total_samples is not None assert self.total_samples is not None
if self.total_samples <= self.window_size: if self.total_samples <= self.window_size:
return 0 return 0
return self.total_samples // self.stride + 1 return (self.total_samples - 1 - self.window_size) // self.stride + 1
class SeqDataset(BaseDataset): class SeqDataset(BaseDataset):

View File

@ -2,6 +2,8 @@ import os
import h5py import h5py
import numpy as np import numpy as np
import torch import torch
from pathlib import Path
from torch import Tensor from torch import Tensor
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
@ -17,10 +19,7 @@ def save_h5(file_path: str, tensor_group: Dict[str, List[Tensor]]):
arr = tensor.cpu().numpy() arr = tensor.cpu().numpy()
dset = grp.create_dataset( dset = grp.create_dataset(
f'data_{idx}', f'data_{idx}',
data=arr, data=arr
compression='gzip',
compression_opts=4,
shuffle=True
) )
dset.attrs['numel'] = tensor.numel() dset.attrs['numel'] = tensor.numel()
@ -28,7 +27,11 @@ def load_h5(file_path: str) -> Tuple[Dict[str, List[Tensor]], int]:
tensor_group: Dict[str, List[Tensor]] = {} tensor_group: Dict[str, List[Tensor]] = {}
total_samples = 0 total_samples = 0
with h5py.File(file_path, 'r') as f: root_path = Path(file_path)
h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5"))
for h5_file in h5_files:
with h5py.File(h5_file, 'r') as f:
for key in f.keys(): for key in f.keys():
grp = f[key] grp = f[key]
dsets = [] dsets = []

View File

@ -151,8 +151,8 @@ class SchedulerFactory:
""" """
@staticmethod @staticmethod
def load(optimizer, scedule_config: ScheduleConfig) -> BaseScheduler: def load(optimizer, schedule_config: ScheduleConfig) -> BaseScheduler:
kwargs = scedule_config.get_kwargs() kwargs = schedule_config.get_kwargs()
schedule_type = kwargs.pop("schedule_type") schedule_type = kwargs.pop("schedule_type")
if schedule_type == "cosine": if schedule_type == "cosine":

View File

@ -108,10 +108,10 @@ class CheckpointCallback(TrainCallback):
def _save_checkpoint(self, context: 'TrainContext'): def _save_checkpoint(self, context: 'TrainContext'):
save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}") save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}")
state_dict = self.state_dict_fn(context.model) if self.state_dict_fn else context.optimizer.state_dict() state_dict = self.state_dict_fn(context.model) if self.state_dict_fn else context.model.state_dict()
context.checkpoint = Checkpoint( context.checkpoint = Checkpoint(
optimizer_state_dict=state_dict, state_dict=state_dict,
scheduler_state_dict=context.scheduler.state_dict() if context.scheduler else None,
epoch=context.epoch, epoch=context.epoch,
iteration=context.iteration iteration=context.iteration
) )

View File

@ -53,15 +53,13 @@ class TrainContextBuilder:
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self: def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
if checkpoint is None: if checkpoint is None:
checkpoint = Checkpoint( checkpoint = Checkpoint(
optimizer_state_dict=self._context.optimizer.state_dict(), state_dict=self._context.model.state_dict(),
scheduler_state_dict=self._context.scheduler.state_dict(),
) )
else: else:
# resume from the assigned checkpoint or assigned iteration # resume from the assigned checkpoint or assigned iteration
self._context.epoch = max(checkpoint.epoch, self.config.start_epoch) self._context.epoch = max(checkpoint.epoch, self.config.start_epoch)
self._context.iteration = max(checkpoint.iteration, self.config.start_batch) self._context.iteration = max(checkpoint.iteration, self.config.start_batch)
self._context.optimizer.load_state_dict(checkpoint.optimizer_state_dict) self._context.model.load_state_dict(checkpoint.state_dict)
self._context.scheduler.load_state_dict(checkpoint.scheduler_state_dict)
self._context.checkpoint = checkpoint self._context.checkpoint = checkpoint
return self return self

View File

@ -3,10 +3,8 @@ import argparse
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
import torch.distributed.fsdp as fsdp from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp.api import StateDictType
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from typing import List, Optional from typing import List, Optional
from functools import partial from functools import partial
from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig
@ -59,19 +57,16 @@ def parse_args() -> argparse.Namespace:
return args return args
def fsdp_wrap(model: nn.Module): def ddp_wrap(model: nn.Module):
local_rank = int(os.environ.get("LOCAL_RANK", 0))
fsdp_model = fsdp.FullyShardedDataParallel( model = model.to(device=f"cuda:{local_rank}", dtype=torch.bfloat16)
ddp_model = DDP(
model, model,
sharding_strategy=fsdp.ShardingStrategy.SHARD_GRAD_OP, device_ids=[local_rank],
mixed_precision=fsdp.MixedPrecision( output_device=local_rank,
param_dtype=torch.bfloat16, find_unused_parameters=False
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
),
backward_prefetch=fsdp.BackwardPrefetch.BACKWARD_PRE
) )
return fsdp_model return ddp_model
def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer: def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer:
return optim.AdamW(model.parameters(), **kwargs) return optim.AdamW(model.parameters(), **kwargs)
@ -79,12 +74,13 @@ def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer:
def create_scheduler(optimizer: optim.Optimizer, **kwargs) -> optim.lr_scheduler.LRScheduler: def create_scheduler(optimizer: optim.Optimizer, **kwargs) -> optim.lr_scheduler.LRScheduler:
return SchedulerFactory.load(optimizer, **kwargs) return SchedulerFactory.load(optimizer, **kwargs)
def prepare_checkpoint(model: nn.Module, optimizer: optim.Optimizer) -> dict: def prepare_checkpoint(model: nn.Module) -> dict:
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT): if isinstance(model, torch.nn.parallel.DistributedDataParallel):
model_state_dict = model.state_dict() state_dict = model.module.state_dict()
optim_state_dict = FSDP.optim_state_dict(model, optimizer) else:
state_dict = model.state_dict()
return state_dict
return model_state_dict, optim_state_dict
def train( def train(
train_type: str, train_type: str,
@ -166,7 +162,7 @@ def train(
num_workers=num_workers, num_workers=num_workers,
pin_memory=pin_memory, pin_memory=pin_memory,
nprocs=nprocs, nprocs=nprocs,
parallel_wrapper=fsdp_wrap, parallel_wrapper=ddp_wrap,
state_dict_fn=prepare_checkpoint, state_dict_fn=prepare_checkpoint,
device_ids=device_ids, device_ids=device_ids,
device_type=device_type, device_type=device_type,