fix(khaosz/trainer): 将保存检查点逻辑移至CheckpointCallback

This commit is contained in:
ViperEkura 2025-09-29 13:38:46 +08:00
parent 648e4e177b
commit 89211c16f6
2 changed files with 11 additions and 11 deletions

View File

@ -1,3 +1,4 @@
import os
from tqdm import tqdm
from khaosz.core.parameter import Checkpoint
from torch.nn.utils import clip_grad_norm_
@ -104,6 +105,13 @@ class CheckpointCallback(TrainerCallback):
self.checkpoint_interval = checkpoint_interval
self.last_ckpt_iter = 0
@staticmethod
def _save_checkpoint(trainer: 'Trainer'):
current_iter = len(trainer.checkpoint.loss_list)
save_path = os.path.join(trainer.train_config.checkpoint_dir, f"iter_{current_iter}")
trainer.checkpoint.optim_state = trainer.train_config.optimizer.state_dict()
trainer.checkpoint.save(save_path)
def on_train_begin(self, trainer: 'Trainer', **kwargs):
_ = trainer
checkpoint = cast(Checkpoint, kwargs.get('checkpoint'))
@ -112,14 +120,14 @@ class CheckpointCallback(TrainerCallback):
def on_batch_end(self, trainer: 'Trainer', **kwargs):
current_iter = kwargs.get('current_iter')
if current_iter - self.last_ckpt_iter >= self.checkpoint_interval:
trainer._save_checkpoint()
CheckpointCallback._save_checkpoint(trainer)
self.last_ckpt_iter = current_iter
def on_train_end(self, trainer: 'Trainer', **kwargs):
checkpoint = cast(Checkpoint, kwargs.get('checkpoint'))
current_iter = len(checkpoint.loss_list)
if current_iter != self.last_ckpt_iter:
trainer._save_checkpoint()
CheckpointCallback._save_checkpoint(trainer)
class GradientClippingCallback(TrainerCallback):

View File

@ -1,4 +1,3 @@
import os
import torch
from typing import Optional, List
from torch.utils.data import DataLoader, RandomSampler
@ -29,7 +28,6 @@ class Trainer:
)
self.train_config = train_config
self.schedule_config = schedule_config
self.callbacks = callbacks or self._get_default_callbacks()
def _get_default_callbacks(self) -> List[TrainerCallback]:
@ -50,12 +48,6 @@ class Trainer:
sampler=sampler
)
def _save_checkpoint(self):
current_iter = len(self.checkpoint.loss_list)
save_path = os.path.join(self.train_config.checkpoint_dir, f"iter_{current_iter}")
self.checkpoint.optim_state = self.train_config.optimizer.state_dict()
self.checkpoint.save(save_path)
def _call_callbacks(self, method_name: str, **kwargs):
for callback in self.callbacks:
method = getattr(callback, method_name, None)