From a5869d89baf6234ff7f62b77353ef3d51bbd85ce Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Wed, 4 Feb 2026 19:47:21 +0800 Subject: [PATCH] =?UTF-8?q?feat(trainer):=20=E5=A2=9E=E5=8A=A0state=5Fdict?= =?UTF-8?q?=20=E5=AD=98=E5=82=A8=E8=AE=BE=E5=AE=9A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/config/train_config.py | 5 +++++ khaosz/trainer/train_callback.py | 11 +++++++---- khaosz/trainer/train_context.py | 1 + 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/khaosz/config/train_config.py b/khaosz/config/train_config.py index cb4a8ac..44f20c4 100644 --- a/khaosz/config/train_config.py +++ b/khaosz/config/train_config.py @@ -105,6 +105,11 @@ class TrainConfig: default=None, metadata={"help": "Parallel function for training."} ) + state_dict_wrapper: Optional[Callable] = field( + default=None, + metadata={"help": "Parallel function for state dict saving."} + ) + optimizer_factory: Optional[Callable[[nn.Module], Optimizer]] = field( default=None, metadata={"help": "Optimizer factory for training."} diff --git a/khaosz/trainer/train_callback.py b/khaosz/trainer/train_callback.py index da9cd13..b1c0a9d 100644 --- a/khaosz/trainer/train_callback.py +++ b/khaosz/trainer/train_callback.py @@ -1,12 +1,13 @@ import os import json import time +import torch.nn as nn from pathlib import Path from tqdm import tqdm from torch.nn.utils import clip_grad_norm_ from torch.optim.lr_scheduler import LRScheduler -from typing import List, Literal, Optional, Protocol, TYPE_CHECKING +from typing import Callable, List, Optional, Protocol, TYPE_CHECKING from khaosz.parallel import only_on_rank from khaosz.trainer.metric_util import ( @@ -96,18 +97,20 @@ class CheckpointCallback(TrainCallback): self, save_dir: str, interval: int, - weight_only: bool = False + weight_only: bool = False, + state_dict_fn: Optional[Callable[[nn.Module], dict]] = None ): self.save_dir = save_dir self.interval = interval self.weight_only = weight_only - + self.state_dict_fn = state_dict_fn self.last_ckpt_iter = 0 def _save_checkpoint(self, context: 'TrainContext'): save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}") + state_dict = self.state_dict_fn(context.model) if self.state_dict_fn else context.optimizer.state_dict() context.checkpoint = Checkpoint( - optimizer_state_dict=context.optimizer.state_dict(), + optimizer_state_dict=state_dict, scheduler_state_dict=context.scheduler.state_dict() if context.scheduler else None, epoch=context.epoch, iteration=context.iteration diff --git a/khaosz/trainer/train_context.py b/khaosz/trainer/train_context.py index 91ffa10..4792e2a 100644 --- a/khaosz/trainer/train_context.py +++ b/khaosz/trainer/train_context.py @@ -28,6 +28,7 @@ class TrainContext: world_size: int = field(default=1) rank: int = field(default=0) + kwargs: dict = field(default_factory=dict) class TrainContextBuilder: