feat(trainer): 增加state_dict 存储设定

This commit is contained in:
ViperEkura 2026-02-04 19:47:21 +08:00
parent 7a9b9d0659
commit a5869d89ba
3 changed files with 13 additions and 4 deletions

View File

@ -105,6 +105,11 @@ class TrainConfig:
default=None,
metadata={"help": "Parallel function for training."}
)
state_dict_wrapper: Optional[Callable] = field(
default=None,
metadata={"help": "Parallel function for state dict saving."}
)
optimizer_factory: Optional[Callable[[nn.Module], Optimizer]] = field(
default=None,
metadata={"help": "Optimizer factory for training."}

View File

@ -1,12 +1,13 @@
import os
import json
import time
import torch.nn as nn
from pathlib import Path
from tqdm import tqdm
from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import LRScheduler
from typing import List, Literal, Optional, Protocol, TYPE_CHECKING
from typing import Callable, List, Optional, Protocol, TYPE_CHECKING
from khaosz.parallel import only_on_rank
from khaosz.trainer.metric_util import (
@ -96,18 +97,20 @@ class CheckpointCallback(TrainCallback):
self,
save_dir: str,
interval: int,
weight_only: bool = False
weight_only: bool = False,
state_dict_fn: Optional[Callable[[nn.Module], dict]] = None
):
self.save_dir = save_dir
self.interval = interval
self.weight_only = weight_only
self.state_dict_fn = state_dict_fn
self.last_ckpt_iter = 0
def _save_checkpoint(self, context: 'TrainContext'):
save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}")
state_dict = self.state_dict_fn(context.model) if self.state_dict_fn else context.optimizer.state_dict()
context.checkpoint = Checkpoint(
optimizer_state_dict=context.optimizer.state_dict(),
optimizer_state_dict=state_dict,
scheduler_state_dict=context.scheduler.state_dict() if context.scheduler else None,
epoch=context.epoch,
iteration=context.iteration

View File

@ -28,6 +28,7 @@ class TrainContext:
world_size: int = field(default=1)
rank: int = field(default=0)
kwargs: dict = field(default_factory=dict)
class TrainContextBuilder: