fix(khaosz/trainer): 将保存检查点逻辑移至CheckpointCallback
This commit is contained in:
parent
648e4e177b
commit
89211c16f6
|
|
@ -1,3 +1,4 @@
|
||||||
|
import os
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from khaosz.core.parameter import Checkpoint
|
from khaosz.core.parameter import Checkpoint
|
||||||
from torch.nn.utils import clip_grad_norm_
|
from torch.nn.utils import clip_grad_norm_
|
||||||
|
|
@ -104,6 +105,13 @@ class CheckpointCallback(TrainerCallback):
|
||||||
self.checkpoint_interval = checkpoint_interval
|
self.checkpoint_interval = checkpoint_interval
|
||||||
self.last_ckpt_iter = 0
|
self.last_ckpt_iter = 0
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _save_checkpoint(trainer: 'Trainer'):
|
||||||
|
current_iter = len(trainer.checkpoint.loss_list)
|
||||||
|
save_path = os.path.join(trainer.train_config.checkpoint_dir, f"iter_{current_iter}")
|
||||||
|
trainer.checkpoint.optim_state = trainer.train_config.optimizer.state_dict()
|
||||||
|
trainer.checkpoint.save(save_path)
|
||||||
|
|
||||||
def on_train_begin(self, trainer: 'Trainer', **kwargs):
|
def on_train_begin(self, trainer: 'Trainer', **kwargs):
|
||||||
_ = trainer
|
_ = trainer
|
||||||
checkpoint = cast(Checkpoint, kwargs.get('checkpoint'))
|
checkpoint = cast(Checkpoint, kwargs.get('checkpoint'))
|
||||||
|
|
@ -112,14 +120,14 @@ class CheckpointCallback(TrainerCallback):
|
||||||
def on_batch_end(self, trainer: 'Trainer', **kwargs):
|
def on_batch_end(self, trainer: 'Trainer', **kwargs):
|
||||||
current_iter = kwargs.get('current_iter')
|
current_iter = kwargs.get('current_iter')
|
||||||
if current_iter - self.last_ckpt_iter >= self.checkpoint_interval:
|
if current_iter - self.last_ckpt_iter >= self.checkpoint_interval:
|
||||||
trainer._save_checkpoint()
|
CheckpointCallback._save_checkpoint(trainer)
|
||||||
self.last_ckpt_iter = current_iter
|
self.last_ckpt_iter = current_iter
|
||||||
|
|
||||||
def on_train_end(self, trainer: 'Trainer', **kwargs):
|
def on_train_end(self, trainer: 'Trainer', **kwargs):
|
||||||
checkpoint = cast(Checkpoint, kwargs.get('checkpoint'))
|
checkpoint = cast(Checkpoint, kwargs.get('checkpoint'))
|
||||||
current_iter = len(checkpoint.loss_list)
|
current_iter = len(checkpoint.loss_list)
|
||||||
if current_iter != self.last_ckpt_iter:
|
if current_iter != self.last_ckpt_iter:
|
||||||
trainer._save_checkpoint()
|
CheckpointCallback._save_checkpoint(trainer)
|
||||||
|
|
||||||
|
|
||||||
class GradientClippingCallback(TrainerCallback):
|
class GradientClippingCallback(TrainerCallback):
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
import os
|
|
||||||
import torch
|
import torch
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
from torch.utils.data import DataLoader, RandomSampler
|
from torch.utils.data import DataLoader, RandomSampler
|
||||||
|
|
@ -29,7 +28,6 @@ class Trainer:
|
||||||
)
|
)
|
||||||
self.train_config = train_config
|
self.train_config = train_config
|
||||||
self.schedule_config = schedule_config
|
self.schedule_config = schedule_config
|
||||||
|
|
||||||
self.callbacks = callbacks or self._get_default_callbacks()
|
self.callbacks = callbacks or self._get_default_callbacks()
|
||||||
|
|
||||||
def _get_default_callbacks(self) -> List[TrainerCallback]:
|
def _get_default_callbacks(self) -> List[TrainerCallback]:
|
||||||
|
|
@ -49,19 +47,13 @@ class Trainer:
|
||||||
batch_size=self.train_config.batch_size,
|
batch_size=self.train_config.batch_size,
|
||||||
sampler=sampler
|
sampler=sampler
|
||||||
)
|
)
|
||||||
|
|
||||||
def _save_checkpoint(self):
|
|
||||||
current_iter = len(self.checkpoint.loss_list)
|
|
||||||
save_path = os.path.join(self.train_config.checkpoint_dir, f"iter_{current_iter}")
|
|
||||||
self.checkpoint.optim_state = self.train_config.optimizer.state_dict()
|
|
||||||
self.checkpoint.save(save_path)
|
|
||||||
|
|
||||||
def _call_callbacks(self, method_name: str, **kwargs):
|
def _call_callbacks(self, method_name: str, **kwargs):
|
||||||
for callback in self.callbacks:
|
for callback in self.callbacks:
|
||||||
method = getattr(callback, method_name, None)
|
method = getattr(callback, method_name, None)
|
||||||
if method:
|
if method:
|
||||||
method(self, **kwargs)
|
method(self, **kwargs)
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
self,
|
self,
|
||||||
train_checkpoint: Optional[Checkpoint] = None
|
train_checkpoint: Optional[Checkpoint] = None
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue