feat(trainer): 增加state_dict 存储设定
This commit is contained in:
parent
7a9b9d0659
commit
a5869d89ba
|
|
@ -105,6 +105,11 @@ class TrainConfig:
|
|||
default=None,
|
||||
metadata={"help": "Parallel function for training."}
|
||||
)
|
||||
state_dict_wrapper: Optional[Callable] = field(
|
||||
default=None,
|
||||
metadata={"help": "Parallel function for state dict saving."}
|
||||
)
|
||||
|
||||
optimizer_factory: Optional[Callable[[nn.Module], Optimizer]] = field(
|
||||
default=None,
|
||||
metadata={"help": "Optimizer factory for training."}
|
||||
|
|
|
|||
|
|
@ -1,12 +1,13 @@
|
|||
import os
|
||||
import json
|
||||
import time
|
||||
import torch.nn as nn
|
||||
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
from typing import List, Literal, Optional, Protocol, TYPE_CHECKING
|
||||
from typing import Callable, List, Optional, Protocol, TYPE_CHECKING
|
||||
|
||||
from khaosz.parallel import only_on_rank
|
||||
from khaosz.trainer.metric_util import (
|
||||
|
|
@ -96,18 +97,20 @@ class CheckpointCallback(TrainCallback):
|
|||
self,
|
||||
save_dir: str,
|
||||
interval: int,
|
||||
weight_only: bool = False
|
||||
weight_only: bool = False,
|
||||
state_dict_fn: Optional[Callable[[nn.Module], dict]] = None
|
||||
):
|
||||
self.save_dir = save_dir
|
||||
self.interval = interval
|
||||
self.weight_only = weight_only
|
||||
|
||||
self.state_dict_fn = state_dict_fn
|
||||
self.last_ckpt_iter = 0
|
||||
|
||||
def _save_checkpoint(self, context: 'TrainContext'):
|
||||
save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}")
|
||||
state_dict = self.state_dict_fn(context.model) if self.state_dict_fn else context.optimizer.state_dict()
|
||||
context.checkpoint = Checkpoint(
|
||||
optimizer_state_dict=context.optimizer.state_dict(),
|
||||
optimizer_state_dict=state_dict,
|
||||
scheduler_state_dict=context.scheduler.state_dict() if context.scheduler else None,
|
||||
epoch=context.epoch,
|
||||
iteration=context.iteration
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ class TrainContext:
|
|||
|
||||
world_size: int = field(default=1)
|
||||
rank: int = field(default=0)
|
||||
kwargs: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
class TrainContextBuilder:
|
||||
|
|
|
|||
Loading…
Reference in New Issue