diff --git a/khaosz/trainer/checkpoint.py b/khaosz/trainer/checkpoint.py index 10a9331..eaa086f 100644 --- a/khaosz/trainer/checkpoint.py +++ b/khaosz/trainer/checkpoint.py @@ -1,69 +1,104 @@ import os -import pickle as pkl +import json import matplotlib.pyplot as plt +from pathlib import Path +from typing import Dict, Optional, Any -from torch import Tensor -from typing import Dict, Optional +import torch.distributed as dist +from torch.distributed.checkpoint import save, load + + +def get_rank() -> int: + return dist.get_rank() if dist.is_initialized() else 0 class Checkpoint: def __init__( self, - optimizer_state: Dict[str, Tensor], - scheduler_state: Dict[str, Tensor], + optimizer_state_dict: Dict[str, Any], + scheduler_state_dict: Optional[Dict[str, Any]] = None, epoch: int = 0, iteration: int = 0, metrics: Optional[Dict[str, list]] = None, ): - self.optimizer_state = optimizer_state - self.scheduler_state = scheduler_state - self.epoch, self.iteration = epoch, iteration - self.metrics = metrics - - def save(self, save_dir: str, save_metric_plot=True) -> None: - os.makedirs(save_dir, exist_ok=True) + self.optimizer_state_dict = optimizer_state_dict + self.scheduler_state_dict = scheduler_state_dict + self.epoch = epoch + self.iteration = iteration + self.metrics = metrics or {} + + def save( + self, + save_dir: str, + save_metric_plot: bool = True, + ) -> None: - train_state = { - "epoch": self.epoch, - "iteration": self.iteration, - "metrics": self.metrics, - "optimizer_state": self.optimizer_state, - "scheduler_state": self.scheduler_state, + save_path = Path(save_dir) + save_path.mkdir(parents=True, exist_ok=True) + + rank = get_rank() + if rank == 0: + meta = { + "epoch": self.epoch, + "iteration": self.iteration, + "metrics": self.metrics, + } + with open(save_path / "meta.json", "w") as f: + json.dump(meta, f, indent=2) + + if save_metric_plot and self.metrics: + self._plot_metrics(str(save_path)) + + state_dict = { + "optimizer": self.optimizer_state_dict, + "scheduler": self.scheduler_state_dict } - - with open(os.path.join(save_dir, "train_state.pkl"), "wb") as f: - pkl.dump(train_state, f) - - if save_metric_plot and self.metrics: - self._plot_metrics(save_dir) - + + save(state_dict, checkpoint_id=str(save_path)) + @classmethod - def load(cls, save_dir: str) -> "Checkpoint": - checkpoint_path = os.path.join(save_dir, "train_state.pkl") - - if not os.path.exists(checkpoint_path): - raise FileNotFoundError(f"Checkpoint file {checkpoint_path} does not exist.") - - with open(checkpoint_path, "rb") as f: - train_state = pkl.load(f) - + def load( + cls, + save_dir: str, + ) -> "Checkpoint": + + save_path = str(Path(save_dir)) + rank = get_rank() + + meta = {} + if rank == 0: + with open(Path(save_dir) / "meta.json", "r") as f: + meta = json.load(f) + + if dist.is_initialized(): + meta_list = [meta] + dist.broadcast_object_list(meta_list, src=0) + meta = meta_list[0] + + state_dict = { + "optimizer": {}, + "scheduler": {} + } + load(state_dict, checkpoint_id=save_path, no_dist=True) + return cls( - optimizer_state=train_state["optimizer_state"], - scheduler_state=train_state["scheduler_state"], - epoch=train_state["epoch"], - iteration=train_state["iteration"], - metrics=train_state["metrics"] + optimizer_state_dict=state_dict["optimizer"], + scheduler_state_dict=state_dict["scheduler"], + epoch=meta["epoch"], + iteration=meta["iteration"], + metrics=meta.get("metrics", {}), ) - + def _plot_metrics(self, save_dir: str): - for metric_name, metric_value in self.metrics.items(): + for name, values in self.metrics.items(): + if not values: + continue plt.figure(figsize=(10, 6)) - plt.plot(metric_value, label=metric_name) - plt.xlabel('Step') - plt.ylabel('Value') + plt.plot(values, label=name) + plt.xlabel("Step") + plt.ylabel("Value") + plt.title(f"Training Metric: {name}") plt.legend() plt.grid(True, alpha=0.3) - - save_path = os.path.join(save_dir, f"{metric_name}.png") - plt.savefig(save_path, dpi=150, bbox_inches='tight') + plt.savefig(os.path.join(save_dir, f"{name}.png"), dpi=150, bbox_inches="tight") plt.close() \ No newline at end of file diff --git a/khaosz/trainer/train_callback.py b/khaosz/trainer/train_callback.py index 9e6b679..3477a06 100644 --- a/khaosz/trainer/train_callback.py +++ b/khaosz/trainer/train_callback.py @@ -6,7 +6,7 @@ from pathlib import Path from tqdm import tqdm from torch.nn.utils import clip_grad_norm_ from torch.optim.lr_scheduler import LRScheduler -from typing import List, Optional, Protocol, TYPE_CHECKING +from typing import List, Literal, Optional, Protocol, TYPE_CHECKING from khaosz.parallel import only_on_rank from khaosz.trainer.metric_util import ( @@ -104,15 +104,15 @@ class CheckpointCallback(TrainCallback): self.last_ckpt_iter = 0 - @only_on_rank(0) def _save_checkpoint(self, context: 'TrainContext'): save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}") context.checkpoint = Checkpoint( - context.optimizer.state_dict(), - context.scheduler.state_dict(), - context.epoch, - context.iteration + optimizer_state_dict=context.optimizer.state_dict(), + scheduler_state_dict=context.scheduler.state_dict() if context.scheduler else None, + epoch=context.epoch, + iteration=context.iteration ) + context.checkpoint.save(save_path) self.last_ckpt_iter = context.iteration diff --git a/khaosz/trainer/train_context.py b/khaosz/trainer/train_context.py index ee3cea5..7af3ad2 100644 --- a/khaosz/trainer/train_context.py +++ b/khaosz/trainer/train_context.py @@ -40,19 +40,32 @@ class TrainContextBuilder: world_size=get_world_size(), rank=get_rank(), ) - + + device = get_current_device() + self._context.model = self._context.model.to(device=device) + + if self.config.nprocs > 1: + + fn = self.config.parallel_wrapper + optimizer_fn = self.config.optimizer_factory + scheduler_fn = self.config.scheduler_factory + + self._context.model = fn(self._context.model) + self._context.optimizer = optimizer_fn(self._context.model.parameters()) + self._context.scheduler = scheduler_fn(self._context.optimizer) + def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self: if checkpoint is None: checkpoint = Checkpoint( - optimizer_state=self.config.optimizer.state_dict(), - scheduler_state=self.config.scheduler.state_dict(), + optimizer_state_dict=self.config.optimizer.state_dict(), + scheduler_state_dict=self.config.scheduler.state_dict() if self.config.scheduler is not None else None, ) else: # resume from the assigned checkpoint or assigned iteration self._context.epoch = max(checkpoint.epoch, self.config.start_epoch) self._context.iteration = max(checkpoint.iteration, self.config.start_batch) - self._context.optimizer.load_state_dict(checkpoint.optimizer_state) - self._context.scheduler.load_state_dict(checkpoint.scheduler_state) + self._context.optimizer.load_state_dict(checkpoint.optimizer_state_dict) + self._context.scheduler.load_state_dict(checkpoint.scheduler_state_dict) self._context.checkpoint = checkpoint return self @@ -88,21 +101,6 @@ class TrainContextBuilder: ) return self - def with_parallel(self) -> Self: - device = get_current_device() - self._context.model = self._context.model.to(device=device) - - if self.config.nprocs > 1: - - fn = self.config.parallel_wrapper - optimizer_fn = self.config.optimizer_factory - scheduler_fn = self.config.scheduler_factory - - self._context.model = fn(self._context.model) - self._context.optimizer = optimizer_fn(self._context.model.parameters()) - self._context.scheduler = scheduler_fn(self._context.optimizer) - - return self def build(self) -> TrainContext: return self._context \ No newline at end of file diff --git a/khaosz/trainer/trainer.py b/khaosz/trainer/trainer.py index 5bee86d..cb167b9 100644 --- a/khaosz/trainer/trainer.py +++ b/khaosz/trainer/trainer.py @@ -38,7 +38,6 @@ class Trainer: .with_checkpoint(checkpoint) .with_dataloader() .with_strategy() - .with_parallel() .build()) def _call_callbacks(self, method_name: str, context: TrainContext):