diff --git a/khaosz/config/train_config.py b/khaosz/config/train_config.py index cb4a8ac..44f20c4 100644 --- a/khaosz/config/train_config.py +++ b/khaosz/config/train_config.py @@ -105,6 +105,11 @@ class TrainConfig: default=None, metadata={"help": "Parallel function for training."} ) + state_dict_wrapper: Optional[Callable] = field( + default=None, + metadata={"help": "Parallel function for state dict saving."} + ) + optimizer_factory: Optional[Callable[[nn.Module], Optimizer]] = field( default=None, metadata={"help": "Optimizer factory for training."} diff --git a/khaosz/trainer/train_callback.py b/khaosz/trainer/train_callback.py index da9cd13..b1c0a9d 100644 --- a/khaosz/trainer/train_callback.py +++ b/khaosz/trainer/train_callback.py @@ -1,12 +1,13 @@ import os import json import time +import torch.nn as nn from pathlib import Path from tqdm import tqdm from torch.nn.utils import clip_grad_norm_ from torch.optim.lr_scheduler import LRScheduler -from typing import List, Literal, Optional, Protocol, TYPE_CHECKING +from typing import Callable, List, Optional, Protocol, TYPE_CHECKING from khaosz.parallel import only_on_rank from khaosz.trainer.metric_util import ( @@ -96,18 +97,20 @@ class CheckpointCallback(TrainCallback): self, save_dir: str, interval: int, - weight_only: bool = False + weight_only: bool = False, + state_dict_fn: Optional[Callable[[nn.Module], dict]] = None ): self.save_dir = save_dir self.interval = interval self.weight_only = weight_only - + self.state_dict_fn = state_dict_fn self.last_ckpt_iter = 0 def _save_checkpoint(self, context: 'TrainContext'): save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}") + state_dict = self.state_dict_fn(context.model) if self.state_dict_fn else context.optimizer.state_dict() context.checkpoint = Checkpoint( - optimizer_state_dict=context.optimizer.state_dict(), + optimizer_state_dict=state_dict, scheduler_state_dict=context.scheduler.state_dict() if context.scheduler else None, epoch=context.epoch, iteration=context.iteration diff --git a/khaosz/trainer/train_context.py b/khaosz/trainer/train_context.py index 91ffa10..4792e2a 100644 --- a/khaosz/trainer/train_context.py +++ b/khaosz/trainer/train_context.py @@ -28,6 +28,7 @@ class TrainContext: world_size: int = field(default=1) rank: int = field(default=0) + kwargs: dict = field(default_factory=dict) class TrainContextBuilder: