diff --git a/khaosz/trainer/checkpoint.py b/khaosz/data/checkpoint.py similarity index 86% rename from khaosz/trainer/checkpoint.py rename to khaosz/data/checkpoint.py index eaa086f..538fa0c 100644 --- a/khaosz/trainer/checkpoint.py +++ b/khaosz/data/checkpoint.py @@ -1,15 +1,12 @@ import os import json +import torch +import torch.distributed as dist import matplotlib.pyplot as plt + from pathlib import Path from typing import Dict, Optional, Any - -import torch.distributed as dist -from torch.distributed.checkpoint import save, load - - -def get_rank() -> int: - return dist.get_rank() if dist.is_initialized() else 0 +from khaosz.parallel.setup import get_rank class Checkpoint: @@ -53,8 +50,8 @@ class Checkpoint: "optimizer": self.optimizer_state_dict, "scheduler": self.scheduler_state_dict } - - save(state_dict, checkpoint_id=str(save_path)) + with open(save_path / f"state_dict_rank_{get_rank()}.pt", "wb") as f: + torch.save(state_dict, f) @classmethod def load( @@ -62,9 +59,9 @@ class Checkpoint: save_dir: str, ) -> "Checkpoint": - save_path = str(Path(save_dir)) rank = get_rank() - + save_path = Path(save_dir) + meta = {} if rank == 0: with open(Path(save_dir) / "meta.json", "r") as f: @@ -75,11 +72,8 @@ class Checkpoint: dist.broadcast_object_list(meta_list, src=0) meta = meta_list[0] - state_dict = { - "optimizer": {}, - "scheduler": {} - } - load(state_dict, checkpoint_id=save_path, no_dist=True) + with open(save_path / f"state_dict_rank_{get_rank()}.pt", "rb") as f: + state_dict = torch.load(f) return cls( optimizer_state_dict=state_dict["optimizer"], diff --git a/khaosz/trainer/__init__.py b/khaosz/trainer/__init__.py index 2e92aa4..d856750 100644 --- a/khaosz/trainer/__init__.py +++ b/khaosz/trainer/__init__.py @@ -1,5 +1,4 @@ from khaosz.trainer.trainer import Trainer -from khaosz.trainer.checkpoint import Checkpoint from khaosz.trainer.strategy import StrategyFactory from khaosz.trainer.schedule import SchedulerFactory @@ -16,9 +15,6 @@ __all__ = [ # trainer "Trainer", - # checkpoint - "Checkpoint", - # factory "StrategyFactory", "SchedulerFactory", diff --git a/khaosz/trainer/train_callback.py b/khaosz/trainer/train_callback.py index 3477a06..da9cd13 100644 --- a/khaosz/trainer/train_callback.py +++ b/khaosz/trainer/train_callback.py @@ -17,7 +17,7 @@ from khaosz.trainer.metric_util import ( grad_std, grad_nan_num ) -from khaosz.trainer.checkpoint import Checkpoint +from khaosz.data.checkpoint import Checkpoint if TYPE_CHECKING: from khaosz.trainer.train_context import TrainContext diff --git a/khaosz/trainer/train_context.py b/khaosz/trainer/train_context.py index 7af3ad2..91ffa10 100644 --- a/khaosz/trainer/train_context.py +++ b/khaosz/trainer/train_context.py @@ -4,7 +4,7 @@ from torch.optim.lr_scheduler import LRScheduler from torch.utils.data import DataLoader from khaosz.data import ResumableDistributedSampler -from khaosz.trainer.checkpoint import Checkpoint +from khaosz.data.checkpoint import Checkpoint from khaosz.trainer.strategy import StrategyFactory, BaseStrategy from khaosz.config.train_config import TrainConfig from khaosz.parallel.setup import get_current_device, get_world_size, get_rank diff --git a/khaosz/trainer/trainer.py b/khaosz/trainer/trainer.py index cb167b9..678d2a0 100644 --- a/khaosz/trainer/trainer.py +++ b/khaosz/trainer/trainer.py @@ -9,7 +9,7 @@ from khaosz.trainer.train_callback import ( SchedulerCallback ) from khaosz.trainer.train_context import TrainContext, TrainContextBuilder -from khaosz.trainer.checkpoint import Checkpoint +from khaosz.data.checkpoint import Checkpoint from khaosz.parallel.setup import spawn_parallel_fn logger = logging.getLogger(__name__) diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py new file mode 100644 index 0000000..79633df --- /dev/null +++ b/tests/test_checkpoint.py @@ -0,0 +1,52 @@ +from pathlib import Path +import tempfile +import torch +from torch.optim import AdamW +from torch.optim.lr_scheduler import CosineAnnealingLR +from khaosz.data.checkpoint import Checkpoint + +def test_single_process(): + model = torch.nn.Linear(10, 5) + optimizer = AdamW(model.parameters(), lr=1e-3) + scheduler = CosineAnnealingLR(optimizer, T_max=10) + + for epoch in range(3): + for iteration in range(10): + + x = torch.randn(32, 10) + y = torch.randn(32, 5) + loss = model(x).mean() + loss.backward() + optimizer.step() + optimizer.zero_grad() + + scheduler.step() + + checkpoint = Checkpoint( + optimizer_state_dict=optimizer.state_dict(), + scheduler_state_dict=scheduler.state_dict(), + epoch=3, + iteration=30, + metrics={ + "loss": [0.5, 0.4, 0.3, 0.2, 0.1], + "accuracy": [0.6, 0.7, 0.8, 0.85, 0.9] + } + ) + + with tempfile.TemporaryDirectory() as tmpdir: + checkpoint.save(tmpdir, save_metric_plot=True) + + loaded_checkpoint = Checkpoint.load(tmpdir) + + assert loaded_checkpoint.epoch == 3 + assert loaded_checkpoint.iteration == 30 + assert loaded_checkpoint.metrics["loss"] == [0.5, 0.4, 0.3, 0.2, 0.1] + + assert 'param_groups' in loaded_checkpoint.optimizer_state_dict + assert 'state' in loaded_checkpoint.optimizer_state_dict + + png_files = list(Path(tmpdir).glob("*.png")) + assert png_files + +def test_multi_process(): + pass \ No newline at end of file diff --git a/tests/test_early_stopping.py b/tests/test_early_stopping.py index a7db7f0..2ead17d 100644 --- a/tests/test_early_stopping.py +++ b/tests/test_early_stopping.py @@ -3,7 +3,7 @@ import torch import numpy as np from khaosz.config import * from khaosz.trainer import * - +from khaosz.data.checkpoint import Checkpoint def test_early_stopping_simulation(base_test_env, early_stopping_dataset): """Simulate early stopping behavior"""