feat(trainer): 增加state_dict 存储设定

This commit is contained in:
ViperEkura 2026-02-04 19:47:21 +08:00
parent 7a9b9d0659
commit a5869d89ba
3 changed files with 13 additions and 4 deletions

View File

@ -105,6 +105,11 @@ class TrainConfig:
default=None, default=None,
metadata={"help": "Parallel function for training."} metadata={"help": "Parallel function for training."}
) )
state_dict_wrapper: Optional[Callable] = field(
default=None,
metadata={"help": "Parallel function for state dict saving."}
)
optimizer_factory: Optional[Callable[[nn.Module], Optimizer]] = field( optimizer_factory: Optional[Callable[[nn.Module], Optimizer]] = field(
default=None, default=None,
metadata={"help": "Optimizer factory for training."} metadata={"help": "Optimizer factory for training."}

View File

@ -1,12 +1,13 @@
import os import os
import json import json
import time import time
import torch.nn as nn
from pathlib import Path from pathlib import Path
from tqdm import tqdm from tqdm import tqdm
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import LRScheduler from torch.optim.lr_scheduler import LRScheduler
from typing import List, Literal, Optional, Protocol, TYPE_CHECKING from typing import Callable, List, Optional, Protocol, TYPE_CHECKING
from khaosz.parallel import only_on_rank from khaosz.parallel import only_on_rank
from khaosz.trainer.metric_util import ( from khaosz.trainer.metric_util import (
@ -96,18 +97,20 @@ class CheckpointCallback(TrainCallback):
self, self,
save_dir: str, save_dir: str,
interval: int, interval: int,
weight_only: bool = False weight_only: bool = False,
state_dict_fn: Optional[Callable[[nn.Module], dict]] = None
): ):
self.save_dir = save_dir self.save_dir = save_dir
self.interval = interval self.interval = interval
self.weight_only = weight_only self.weight_only = weight_only
self.state_dict_fn = state_dict_fn
self.last_ckpt_iter = 0 self.last_ckpt_iter = 0
def _save_checkpoint(self, context: 'TrainContext'): def _save_checkpoint(self, context: 'TrainContext'):
save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}") save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}")
state_dict = self.state_dict_fn(context.model) if self.state_dict_fn else context.optimizer.state_dict()
context.checkpoint = Checkpoint( context.checkpoint = Checkpoint(
optimizer_state_dict=context.optimizer.state_dict(), optimizer_state_dict=state_dict,
scheduler_state_dict=context.scheduler.state_dict() if context.scheduler else None, scheduler_state_dict=context.scheduler.state_dict() if context.scheduler else None,
epoch=context.epoch, epoch=context.epoch,
iteration=context.iteration iteration=context.iteration

View File

@ -28,6 +28,7 @@ class TrainContext:
world_size: int = field(default=1) world_size: int = field(default=1)
rank: int = field(default=0) rank: int = field(default=0)
kwargs: dict = field(default_factory=dict)
class TrainContextBuilder: class TrainContextBuilder: