feat(trainer): 增加state_dict 存储设定
This commit is contained in:
parent
7a9b9d0659
commit
a5869d89ba
|
|
@ -105,6 +105,11 @@ class TrainConfig:
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Parallel function for training."}
|
metadata={"help": "Parallel function for training."}
|
||||||
)
|
)
|
||||||
|
state_dict_wrapper: Optional[Callable] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Parallel function for state dict saving."}
|
||||||
|
)
|
||||||
|
|
||||||
optimizer_factory: Optional[Callable[[nn.Module], Optimizer]] = field(
|
optimizer_factory: Optional[Callable[[nn.Module], Optimizer]] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Optimizer factory for training."}
|
metadata={"help": "Optimizer factory for training."}
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,13 @@
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from torch.nn.utils import clip_grad_norm_
|
from torch.nn.utils import clip_grad_norm_
|
||||||
from torch.optim.lr_scheduler import LRScheduler
|
from torch.optim.lr_scheduler import LRScheduler
|
||||||
from typing import List, Literal, Optional, Protocol, TYPE_CHECKING
|
from typing import Callable, List, Optional, Protocol, TYPE_CHECKING
|
||||||
|
|
||||||
from khaosz.parallel import only_on_rank
|
from khaosz.parallel import only_on_rank
|
||||||
from khaosz.trainer.metric_util import (
|
from khaosz.trainer.metric_util import (
|
||||||
|
|
@ -96,18 +97,20 @@ class CheckpointCallback(TrainCallback):
|
||||||
self,
|
self,
|
||||||
save_dir: str,
|
save_dir: str,
|
||||||
interval: int,
|
interval: int,
|
||||||
weight_only: bool = False
|
weight_only: bool = False,
|
||||||
|
state_dict_fn: Optional[Callable[[nn.Module], dict]] = None
|
||||||
):
|
):
|
||||||
self.save_dir = save_dir
|
self.save_dir = save_dir
|
||||||
self.interval = interval
|
self.interval = interval
|
||||||
self.weight_only = weight_only
|
self.weight_only = weight_only
|
||||||
|
self.state_dict_fn = state_dict_fn
|
||||||
self.last_ckpt_iter = 0
|
self.last_ckpt_iter = 0
|
||||||
|
|
||||||
def _save_checkpoint(self, context: 'TrainContext'):
|
def _save_checkpoint(self, context: 'TrainContext'):
|
||||||
save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}")
|
save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}")
|
||||||
|
state_dict = self.state_dict_fn(context.model) if self.state_dict_fn else context.optimizer.state_dict()
|
||||||
context.checkpoint = Checkpoint(
|
context.checkpoint = Checkpoint(
|
||||||
optimizer_state_dict=context.optimizer.state_dict(),
|
optimizer_state_dict=state_dict,
|
||||||
scheduler_state_dict=context.scheduler.state_dict() if context.scheduler else None,
|
scheduler_state_dict=context.scheduler.state_dict() if context.scheduler else None,
|
||||||
epoch=context.epoch,
|
epoch=context.epoch,
|
||||||
iteration=context.iteration
|
iteration=context.iteration
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,7 @@ class TrainContext:
|
||||||
|
|
||||||
world_size: int = field(default=1)
|
world_size: int = field(default=1)
|
||||||
rank: int = field(default=0)
|
rank: int = field(default=0)
|
||||||
|
kwargs: dict = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class TrainContextBuilder:
|
class TrainContextBuilder:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue