From 80e17418b49fcf15fa705ae88bb0b986519b31b9 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sun, 1 Mar 2026 15:47:07 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E4=B8=80=E4=BA=9B?= =?UTF-8?q?=E8=BF=90=E8=A1=8C=E6=97=B6=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 23 +++++++++---------- khaosz/data/checkpoint.py | 16 +++++--------- khaosz/data/dataset.py | 9 ++++---- khaosz/data/file.py | 29 +++++++++++++----------- khaosz/trainer/schedule.py | 4 ++-- khaosz/trainer/train_callback.py | 6 ++--- khaosz/trainer/train_context.py | 6 ++--- tools/train.py | 38 ++++++++++++++------------------ 8 files changed, 60 insertions(+), 71 deletions(-) diff --git a/.gitignore b/.gitignore index e13969b..c261f0c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,15 +1,12 @@ -# cache -__pycache__ -.pytest_cache +# Ignore everything +* -# params -params/* +# Allow directories to be traversed +!*/ -# vscode file -.vscode - -# build file -build -*.egg-info - -*.ipynb \ No newline at end of file +# Allow specific file types and root files +!*.py +!*.md +!*.png +!LICENSE +!pyproject.toml \ No newline at end of file diff --git a/khaosz/data/checkpoint.py b/khaosz/data/checkpoint.py index 538fa0c..0afaa22 100644 --- a/khaosz/data/checkpoint.py +++ b/khaosz/data/checkpoint.py @@ -12,14 +12,12 @@ from khaosz.parallel.setup import get_rank class Checkpoint: def __init__( self, - optimizer_state_dict: Dict[str, Any], - scheduler_state_dict: Optional[Dict[str, Any]] = None, + state_dict: Dict[str, Any], epoch: int = 0, iteration: int = 0, metrics: Optional[Dict[str, list]] = None, ): - self.optimizer_state_dict = optimizer_state_dict - self.scheduler_state_dict = scheduler_state_dict + self.state_dict = state_dict self.epoch = epoch self.iteration = iteration self.metrics = metrics or {} @@ -46,12 +44,8 @@ class Checkpoint: if save_metric_plot and self.metrics: self._plot_metrics(str(save_path)) - state_dict = { - "optimizer": self.optimizer_state_dict, - "scheduler": self.scheduler_state_dict - } - with open(save_path / f"state_dict_rank_{get_rank()}.pt", "wb") as f: - torch.save(state_dict, f) + with open(save_path / f"state_dict.pt", "wb") as f: + torch.save(self.state_dict, f) @classmethod def load( @@ -72,7 +66,7 @@ class Checkpoint: dist.broadcast_object_list(meta_list, src=0) meta = meta_list[0] - with open(save_path / f"state_dict_rank_{get_rank()}.pt", "rb") as f: + with open(save_path / f"state_dict.pt", "rb") as f: state_dict = torch.load(f) return cls( diff --git a/khaosz/data/dataset.py b/khaosz/data/dataset.py index dae7fe0..deb71fe 100644 --- a/khaosz/data/dataset.py +++ b/khaosz/data/dataset.py @@ -1,4 +1,3 @@ -import h5py import torch import bisect @@ -78,8 +77,10 @@ class BaseDataset(Dataset, ABC): self.fetcher = MultiSegmentFetcher(self.segments) def get_index(self, index: int) -> int: - begin_idx = min(index * self.stride, self.total_samples - self.window_size - 1) - end_idx = begin_idx + self.window_size + assert self.total_samples > self.window_size + + begin_idx = min(index * self.stride, self.total_samples - 1 - self.window_size) + end_idx = min(begin_idx + self.window_size, self.total_samples - 1) return begin_idx, end_idx @@ -91,7 +92,7 @@ class BaseDataset(Dataset, ABC): assert self.total_samples is not None if self.total_samples <= self.window_size: return 0 - return self.total_samples // self.stride + 1 + return (self.total_samples - 1 - self.window_size) // self.stride + 1 class SeqDataset(BaseDataset): diff --git a/khaosz/data/file.py b/khaosz/data/file.py index 38a1349..92ad9c5 100644 --- a/khaosz/data/file.py +++ b/khaosz/data/file.py @@ -2,6 +2,8 @@ import os import h5py import numpy as np import torch + +from pathlib import Path from torch import Tensor from typing import Dict, List, Tuple @@ -17,10 +19,7 @@ def save_h5(file_path: str, tensor_group: Dict[str, List[Tensor]]): arr = tensor.cpu().numpy() dset = grp.create_dataset( f'data_{idx}', - data=arr, - compression='gzip', - compression_opts=4, - shuffle=True + data=arr ) dset.attrs['numel'] = tensor.numel() @@ -28,15 +27,19 @@ def load_h5(file_path: str) -> Tuple[Dict[str, List[Tensor]], int]: tensor_group: Dict[str, List[Tensor]] = {} total_samples = 0 - with h5py.File(file_path, 'r') as f: - for key in f.keys(): - grp = f[key] - dsets = [] - for dset_name in grp.keys(): - dset = grp[dset_name] - dsets.append(torch.from_numpy(dset[:]).share_memory_()) - total_samples += dset.attrs.get('numel', np.prod(dset.shape)) - tensor_group[key] = dsets + root_path = Path(file_path) + h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5")) + + for h5_file in h5_files: + with h5py.File(h5_file, 'r') as f: + for key in f.keys(): + grp = f[key] + dsets = [] + for dset_name in grp.keys(): + dset = grp[dset_name] + dsets.append(torch.from_numpy(dset[:]).share_memory_()) + total_samples += dset.attrs.get('numel', np.prod(dset.shape)) + tensor_group[key] = dsets num_keys = max(len(tensor_group), 1) sample_per_key = total_samples // num_keys diff --git a/khaosz/trainer/schedule.py b/khaosz/trainer/schedule.py index 1b9c094..84135f0 100644 --- a/khaosz/trainer/schedule.py +++ b/khaosz/trainer/schedule.py @@ -151,8 +151,8 @@ class SchedulerFactory: """ @staticmethod - def load(optimizer, scedule_config: ScheduleConfig) -> BaseScheduler: - kwargs = scedule_config.get_kwargs() + def load(optimizer, schedule_config: ScheduleConfig) -> BaseScheduler: + kwargs = schedule_config.get_kwargs() schedule_type = kwargs.pop("schedule_type") if schedule_type == "cosine": diff --git a/khaosz/trainer/train_callback.py b/khaosz/trainer/train_callback.py index b1c0a9d..e2913de 100644 --- a/khaosz/trainer/train_callback.py +++ b/khaosz/trainer/train_callback.py @@ -108,10 +108,10 @@ class CheckpointCallback(TrainCallback): def _save_checkpoint(self, context: 'TrainContext'): save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}") - state_dict = self.state_dict_fn(context.model) if self.state_dict_fn else context.optimizer.state_dict() + state_dict = self.state_dict_fn(context.model) if self.state_dict_fn else context.model.state_dict() + context.checkpoint = Checkpoint( - optimizer_state_dict=state_dict, - scheduler_state_dict=context.scheduler.state_dict() if context.scheduler else None, + state_dict=state_dict, epoch=context.epoch, iteration=context.iteration ) diff --git a/khaosz/trainer/train_context.py b/khaosz/trainer/train_context.py index e1095f9..90af6c7 100644 --- a/khaosz/trainer/train_context.py +++ b/khaosz/trainer/train_context.py @@ -53,15 +53,13 @@ class TrainContextBuilder: def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self: if checkpoint is None: checkpoint = Checkpoint( - optimizer_state_dict=self._context.optimizer.state_dict(), - scheduler_state_dict=self._context.scheduler.state_dict(), + state_dict=self._context.model.state_dict(), ) else: # resume from the assigned checkpoint or assigned iteration self._context.epoch = max(checkpoint.epoch, self.config.start_epoch) self._context.iteration = max(checkpoint.iteration, self.config.start_batch) - self._context.optimizer.load_state_dict(checkpoint.optimizer_state_dict) - self._context.scheduler.load_state_dict(checkpoint.scheduler_state_dict) + self._context.model.load_state_dict(checkpoint.state_dict) self._context.checkpoint = checkpoint return self diff --git a/tools/train.py b/tools/train.py index dc65172..64dddb7 100644 --- a/tools/train.py +++ b/tools/train.py @@ -3,10 +3,8 @@ import argparse import torch import torch.nn as nn import torch.optim as optim -import torch.distributed.fsdp as fsdp +from torch.nn.parallel import DistributedDataParallel as DDP -from torch.distributed.fsdp.api import StateDictType -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from typing import List, Optional from functools import partial from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig @@ -59,19 +57,16 @@ def parse_args() -> argparse.Namespace: return args -def fsdp_wrap(model: nn.Module): - - fsdp_model = fsdp.FullyShardedDataParallel( +def ddp_wrap(model: nn.Module): + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + model = model.to(device=f"cuda:{local_rank}", dtype=torch.bfloat16) + ddp_model = DDP( model, - sharding_strategy=fsdp.ShardingStrategy.SHARD_GRAD_OP, - mixed_precision=fsdp.MixedPrecision( - param_dtype=torch.bfloat16, - reduce_dtype=torch.bfloat16, - buffer_dtype=torch.bfloat16, - ), - backward_prefetch=fsdp.BackwardPrefetch.BACKWARD_PRE + device_ids=[local_rank], + output_device=local_rank, + find_unused_parameters=False ) - return fsdp_model + return ddp_model def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer: return optim.AdamW(model.parameters(), **kwargs) @@ -79,12 +74,13 @@ def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer: def create_scheduler(optimizer: optim.Optimizer, **kwargs) -> optim.lr_scheduler.LRScheduler: return SchedulerFactory.load(optimizer, **kwargs) -def prepare_checkpoint(model: nn.Module, optimizer: optim.Optimizer) -> dict: - with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT): - model_state_dict = model.state_dict() - optim_state_dict = FSDP.optim_state_dict(model, optimizer) - - return model_state_dict, optim_state_dict +def prepare_checkpoint(model: nn.Module) -> dict: + if isinstance(model, torch.nn.parallel.DistributedDataParallel): + state_dict = model.module.state_dict() + else: + state_dict = model.state_dict() + return state_dict + def train( train_type: str, @@ -166,7 +162,7 @@ def train( num_workers=num_workers, pin_memory=pin_memory, nprocs=nprocs, - parallel_wrapper=fsdp_wrap, + parallel_wrapper=ddp_wrap, state_dict_fn=prepare_checkpoint, device_ids=device_ids, device_type=device_type,