feat(trainer): 重构检查点系统支持分布式训练

This commit is contained in:
ViperEkura 2026-01-08 15:01:19 +08:00
parent d21682f97a
commit 3d8047fa1b
4 changed files with 106 additions and 74 deletions

View File

@ -1,69 +1,104 @@
import os
import pickle as pkl
import json
import matplotlib.pyplot as plt
from pathlib import Path
from typing import Dict, Optional, Any
from torch import Tensor
from typing import Dict, Optional
import torch.distributed as dist
from torch.distributed.checkpoint import save, load
def get_rank() -> int:
return dist.get_rank() if dist.is_initialized() else 0
class Checkpoint:
def __init__(
self,
optimizer_state: Dict[str, Tensor],
scheduler_state: Dict[str, Tensor],
optimizer_state_dict: Dict[str, Any],
scheduler_state_dict: Optional[Dict[str, Any]] = None,
epoch: int = 0,
iteration: int = 0,
metrics: Optional[Dict[str, list]] = None,
):
self.optimizer_state = optimizer_state
self.scheduler_state = scheduler_state
self.epoch, self.iteration = epoch, iteration
self.metrics = metrics
def save(self, save_dir: str, save_metric_plot=True) -> None:
os.makedirs(save_dir, exist_ok=True)
self.optimizer_state_dict = optimizer_state_dict
self.scheduler_state_dict = scheduler_state_dict
self.epoch = epoch
self.iteration = iteration
self.metrics = metrics or {}
def save(
self,
save_dir: str,
save_metric_plot: bool = True,
) -> None:
train_state = {
"epoch": self.epoch,
"iteration": self.iteration,
"metrics": self.metrics,
"optimizer_state": self.optimizer_state,
"scheduler_state": self.scheduler_state,
save_path = Path(save_dir)
save_path.mkdir(parents=True, exist_ok=True)
rank = get_rank()
if rank == 0:
meta = {
"epoch": self.epoch,
"iteration": self.iteration,
"metrics": self.metrics,
}
with open(save_path / "meta.json", "w") as f:
json.dump(meta, f, indent=2)
if save_metric_plot and self.metrics:
self._plot_metrics(str(save_path))
state_dict = {
"optimizer": self.optimizer_state_dict,
"scheduler": self.scheduler_state_dict
}
with open(os.path.join(save_dir, "train_state.pkl"), "wb") as f:
pkl.dump(train_state, f)
if save_metric_plot and self.metrics:
self._plot_metrics(save_dir)
save(state_dict, checkpoint_id=str(save_path))
@classmethod
def load(cls, save_dir: str) -> "Checkpoint":
checkpoint_path = os.path.join(save_dir, "train_state.pkl")
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"Checkpoint file {checkpoint_path} does not exist.")
with open(checkpoint_path, "rb") as f:
train_state = pkl.load(f)
def load(
cls,
save_dir: str,
) -> "Checkpoint":
save_path = str(Path(save_dir))
rank = get_rank()
meta = {}
if rank == 0:
with open(Path(save_dir) / "meta.json", "r") as f:
meta = json.load(f)
if dist.is_initialized():
meta_list = [meta]
dist.broadcast_object_list(meta_list, src=0)
meta = meta_list[0]
state_dict = {
"optimizer": {},
"scheduler": {}
}
load(state_dict, checkpoint_id=save_path, no_dist=True)
return cls(
optimizer_state=train_state["optimizer_state"],
scheduler_state=train_state["scheduler_state"],
epoch=train_state["epoch"],
iteration=train_state["iteration"],
metrics=train_state["metrics"]
optimizer_state_dict=state_dict["optimizer"],
scheduler_state_dict=state_dict["scheduler"],
epoch=meta["epoch"],
iteration=meta["iteration"],
metrics=meta.get("metrics", {}),
)
def _plot_metrics(self, save_dir: str):
for metric_name, metric_value in self.metrics.items():
for name, values in self.metrics.items():
if not values:
continue
plt.figure(figsize=(10, 6))
plt.plot(metric_value, label=metric_name)
plt.xlabel('Step')
plt.ylabel('Value')
plt.plot(values, label=name)
plt.xlabel("Step")
plt.ylabel("Value")
plt.title(f"Training Metric: {name}")
plt.legend()
plt.grid(True, alpha=0.3)
save_path = os.path.join(save_dir, f"{metric_name}.png")
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.savefig(os.path.join(save_dir, f"{name}.png"), dpi=150, bbox_inches="tight")
plt.close()

View File

@ -6,7 +6,7 @@ from pathlib import Path
from tqdm import tqdm
from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import LRScheduler
from typing import List, Optional, Protocol, TYPE_CHECKING
from typing import List, Literal, Optional, Protocol, TYPE_CHECKING
from khaosz.parallel import only_on_rank
from khaosz.trainer.metric_util import (
@ -104,15 +104,15 @@ class CheckpointCallback(TrainCallback):
self.last_ckpt_iter = 0
@only_on_rank(0)
def _save_checkpoint(self, context: 'TrainContext'):
save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}")
context.checkpoint = Checkpoint(
context.optimizer.state_dict(),
context.scheduler.state_dict(),
context.epoch,
context.iteration
optimizer_state_dict=context.optimizer.state_dict(),
scheduler_state_dict=context.scheduler.state_dict() if context.scheduler else None,
epoch=context.epoch,
iteration=context.iteration
)
context.checkpoint.save(save_path)
self.last_ckpt_iter = context.iteration

View File

@ -40,19 +40,32 @@ class TrainContextBuilder:
world_size=get_world_size(),
rank=get_rank(),
)
device = get_current_device()
self._context.model = self._context.model.to(device=device)
if self.config.nprocs > 1:
fn = self.config.parallel_wrapper
optimizer_fn = self.config.optimizer_factory
scheduler_fn = self.config.scheduler_factory
self._context.model = fn(self._context.model)
self._context.optimizer = optimizer_fn(self._context.model.parameters())
self._context.scheduler = scheduler_fn(self._context.optimizer)
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
if checkpoint is None:
checkpoint = Checkpoint(
optimizer_state=self.config.optimizer.state_dict(),
scheduler_state=self.config.scheduler.state_dict(),
optimizer_state_dict=self.config.optimizer.state_dict(),
scheduler_state_dict=self.config.scheduler.state_dict() if self.config.scheduler is not None else None,
)
else:
# resume from the assigned checkpoint or assigned iteration
self._context.epoch = max(checkpoint.epoch, self.config.start_epoch)
self._context.iteration = max(checkpoint.iteration, self.config.start_batch)
self._context.optimizer.load_state_dict(checkpoint.optimizer_state)
self._context.scheduler.load_state_dict(checkpoint.scheduler_state)
self._context.optimizer.load_state_dict(checkpoint.optimizer_state_dict)
self._context.scheduler.load_state_dict(checkpoint.scheduler_state_dict)
self._context.checkpoint = checkpoint
return self
@ -88,21 +101,6 @@ class TrainContextBuilder:
)
return self
def with_parallel(self) -> Self:
device = get_current_device()
self._context.model = self._context.model.to(device=device)
if self.config.nprocs > 1:
fn = self.config.parallel_wrapper
optimizer_fn = self.config.optimizer_factory
scheduler_fn = self.config.scheduler_factory
self._context.model = fn(self._context.model)
self._context.optimizer = optimizer_fn(self._context.model.parameters())
self._context.scheduler = scheduler_fn(self._context.optimizer)
return self
def build(self) -> TrainContext:
return self._context

View File

@ -38,7 +38,6 @@ class Trainer:
.with_checkpoint(checkpoint)
.with_dataloader()
.with_strategy()
.with_parallel()
.build())
def _call_callbacks(self, method_name: str, context: TrainContext):