fix: 修复一些运行时问题

This commit is contained in:
ViperEkura 2026-03-01 15:47:07 +08:00
parent 6089a12cef
commit 80e17418b4
8 changed files with 60 additions and 71 deletions

23
.gitignore vendored
View File

@ -1,15 +1,12 @@
# cache
__pycache__
.pytest_cache
# Ignore everything
*
# params
params/*
# Allow directories to be traversed
!*/
# vscode file
.vscode
# build file
build
*.egg-info
*.ipynb
# Allow specific file types and root files
!*.py
!*.md
!*.png
!LICENSE
!pyproject.toml

View File

@ -12,14 +12,12 @@ from khaosz.parallel.setup import get_rank
class Checkpoint:
def __init__(
self,
optimizer_state_dict: Dict[str, Any],
scheduler_state_dict: Optional[Dict[str, Any]] = None,
state_dict: Dict[str, Any],
epoch: int = 0,
iteration: int = 0,
metrics: Optional[Dict[str, list]] = None,
):
self.optimizer_state_dict = optimizer_state_dict
self.scheduler_state_dict = scheduler_state_dict
self.state_dict = state_dict
self.epoch = epoch
self.iteration = iteration
self.metrics = metrics or {}
@ -46,12 +44,8 @@ class Checkpoint:
if save_metric_plot and self.metrics:
self._plot_metrics(str(save_path))
state_dict = {
"optimizer": self.optimizer_state_dict,
"scheduler": self.scheduler_state_dict
}
with open(save_path / f"state_dict_rank_{get_rank()}.pt", "wb") as f:
torch.save(state_dict, f)
with open(save_path / f"state_dict.pt", "wb") as f:
torch.save(self.state_dict, f)
@classmethod
def load(
@ -72,7 +66,7 @@ class Checkpoint:
dist.broadcast_object_list(meta_list, src=0)
meta = meta_list[0]
with open(save_path / f"state_dict_rank_{get_rank()}.pt", "rb") as f:
with open(save_path / f"state_dict.pt", "rb") as f:
state_dict = torch.load(f)
return cls(

View File

@ -1,4 +1,3 @@
import h5py
import torch
import bisect
@ -78,8 +77,10 @@ class BaseDataset(Dataset, ABC):
self.fetcher = MultiSegmentFetcher(self.segments)
def get_index(self, index: int) -> int:
begin_idx = min(index * self.stride, self.total_samples - self.window_size - 1)
end_idx = begin_idx + self.window_size
assert self.total_samples > self.window_size
begin_idx = min(index * self.stride, self.total_samples - 1 - self.window_size)
end_idx = min(begin_idx + self.window_size, self.total_samples - 1)
return begin_idx, end_idx
@ -91,7 +92,7 @@ class BaseDataset(Dataset, ABC):
assert self.total_samples is not None
if self.total_samples <= self.window_size:
return 0
return self.total_samples // self.stride + 1
return (self.total_samples - 1 - self.window_size) // self.stride + 1
class SeqDataset(BaseDataset):

View File

@ -2,6 +2,8 @@ import os
import h5py
import numpy as np
import torch
from pathlib import Path
from torch import Tensor
from typing import Dict, List, Tuple
@ -17,10 +19,7 @@ def save_h5(file_path: str, tensor_group: Dict[str, List[Tensor]]):
arr = tensor.cpu().numpy()
dset = grp.create_dataset(
f'data_{idx}',
data=arr,
compression='gzip',
compression_opts=4,
shuffle=True
data=arr
)
dset.attrs['numel'] = tensor.numel()
@ -28,15 +27,19 @@ def load_h5(file_path: str) -> Tuple[Dict[str, List[Tensor]], int]:
tensor_group: Dict[str, List[Tensor]] = {}
total_samples = 0
with h5py.File(file_path, 'r') as f:
for key in f.keys():
grp = f[key]
dsets = []
for dset_name in grp.keys():
dset = grp[dset_name]
dsets.append(torch.from_numpy(dset[:]).share_memory_())
total_samples += dset.attrs.get('numel', np.prod(dset.shape))
tensor_group[key] = dsets
root_path = Path(file_path)
h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5"))
for h5_file in h5_files:
with h5py.File(h5_file, 'r') as f:
for key in f.keys():
grp = f[key]
dsets = []
for dset_name in grp.keys():
dset = grp[dset_name]
dsets.append(torch.from_numpy(dset[:]).share_memory_())
total_samples += dset.attrs.get('numel', np.prod(dset.shape))
tensor_group[key] = dsets
num_keys = max(len(tensor_group), 1)
sample_per_key = total_samples // num_keys

View File

@ -151,8 +151,8 @@ class SchedulerFactory:
"""
@staticmethod
def load(optimizer, scedule_config: ScheduleConfig) -> BaseScheduler:
kwargs = scedule_config.get_kwargs()
def load(optimizer, schedule_config: ScheduleConfig) -> BaseScheduler:
kwargs = schedule_config.get_kwargs()
schedule_type = kwargs.pop("schedule_type")
if schedule_type == "cosine":

View File

@ -108,10 +108,10 @@ class CheckpointCallback(TrainCallback):
def _save_checkpoint(self, context: 'TrainContext'):
save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}")
state_dict = self.state_dict_fn(context.model) if self.state_dict_fn else context.optimizer.state_dict()
state_dict = self.state_dict_fn(context.model) if self.state_dict_fn else context.model.state_dict()
context.checkpoint = Checkpoint(
optimizer_state_dict=state_dict,
scheduler_state_dict=context.scheduler.state_dict() if context.scheduler else None,
state_dict=state_dict,
epoch=context.epoch,
iteration=context.iteration
)

View File

@ -53,15 +53,13 @@ class TrainContextBuilder:
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
if checkpoint is None:
checkpoint = Checkpoint(
optimizer_state_dict=self._context.optimizer.state_dict(),
scheduler_state_dict=self._context.scheduler.state_dict(),
state_dict=self._context.model.state_dict(),
)
else:
# resume from the assigned checkpoint or assigned iteration
self._context.epoch = max(checkpoint.epoch, self.config.start_epoch)
self._context.iteration = max(checkpoint.iteration, self.config.start_batch)
self._context.optimizer.load_state_dict(checkpoint.optimizer_state_dict)
self._context.scheduler.load_state_dict(checkpoint.scheduler_state_dict)
self._context.model.load_state_dict(checkpoint.state_dict)
self._context.checkpoint = checkpoint
return self

View File

@ -3,10 +3,8 @@ import argparse
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed.fsdp as fsdp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp.api import StateDictType
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from typing import List, Optional
from functools import partial
from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig
@ -59,19 +57,16 @@ def parse_args() -> argparse.Namespace:
return args
def fsdp_wrap(model: nn.Module):
fsdp_model = fsdp.FullyShardedDataParallel(
def ddp_wrap(model: nn.Module):
local_rank = int(os.environ.get("LOCAL_RANK", 0))
model = model.to(device=f"cuda:{local_rank}", dtype=torch.bfloat16)
ddp_model = DDP(
model,
sharding_strategy=fsdp.ShardingStrategy.SHARD_GRAD_OP,
mixed_precision=fsdp.MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
),
backward_prefetch=fsdp.BackwardPrefetch.BACKWARD_PRE
device_ids=[local_rank],
output_device=local_rank,
find_unused_parameters=False
)
return fsdp_model
return ddp_model
def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer:
return optim.AdamW(model.parameters(), **kwargs)
@ -79,12 +74,13 @@ def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer:
def create_scheduler(optimizer: optim.Optimizer, **kwargs) -> optim.lr_scheduler.LRScheduler:
return SchedulerFactory.load(optimizer, **kwargs)
def prepare_checkpoint(model: nn.Module, optimizer: optim.Optimizer) -> dict:
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):
model_state_dict = model.state_dict()
optim_state_dict = FSDP.optim_state_dict(model, optimizer)
return model_state_dict, optim_state_dict
def prepare_checkpoint(model: nn.Module) -> dict:
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
return state_dict
def train(
train_type: str,
@ -166,7 +162,7 @@ def train(
num_workers=num_workers,
pin_memory=pin_memory,
nprocs=nprocs,
parallel_wrapper=fsdp_wrap,
parallel_wrapper=ddp_wrap,
state_dict_fn=prepare_checkpoint,
device_ids=device_ids,
device_type=device_type,