diff --git a/khaosz/config/param_config.py b/khaosz/config/param_config.py index 47120e5..4d29813 100644 --- a/khaosz/config/param_config.py +++ b/khaosz/config/param_config.py @@ -52,10 +52,11 @@ class BaseModelIO: self.config.load(str(paths["config"])) self.tokenizer.load(str(paths["tokenizer"])) + if self.model is None: + self.model = Transformer(self.config) + if paths["model"].exists(): state_dict = st.load_file(str(paths["model"])) - if self.model is None: - self.model = Transformer(self.config) self.model.load_state_dict(state_dict) return self diff --git a/khaosz/trainer/checkpoint.py b/khaosz/trainer/checkpoint.py index a4ff101..10a9331 100644 --- a/khaosz/trainer/checkpoint.py +++ b/khaosz/trainer/checkpoint.py @@ -2,16 +2,15 @@ import os import pickle as pkl import matplotlib.pyplot as plt -from torch.optim import Optimizer -from torch.optim.lr_scheduler import LRScheduler +from torch import Tensor from typing import Dict, Optional class Checkpoint: def __init__( self, - optimizer_state: Optimizer, - scheduler_state: LRScheduler, + optimizer_state: Dict[str, Tensor], + scheduler_state: Dict[str, Tensor], epoch: int = 0, iteration: int = 0, metrics: Optional[Dict[str, list]] = None, @@ -36,7 +35,7 @@ class Checkpoint: pkl.dump(train_state, f) if save_metric_plot and self.metrics: - self._plot_metrics() + self._plot_metrics(save_dir) @classmethod def load(cls, save_dir: str) -> "Checkpoint": @@ -56,7 +55,7 @@ class Checkpoint: metrics=train_state["metrics"] ) - def _plot_metrics(self): + def _plot_metrics(self, save_dir: str): for metric_name, metric_value in self.metrics.items(): plt.figure(figsize=(10, 6)) plt.plot(metric_value, label=metric_name) @@ -65,5 +64,6 @@ class Checkpoint: plt.legend() plt.grid(True, alpha=0.3) - plt.savefig(f'{metric_name}.png', dpi=150, bbox_inches='tight') + save_path = os.path.join(save_dir, f"{metric_name}.png") + plt.savefig(save_path, dpi=150, bbox_inches='tight') plt.close() \ No newline at end of file diff --git a/khaosz/trainer/train_callback.py b/khaosz/trainer/train_callback.py index a44f0e2..9e6b679 100644 --- a/khaosz/trainer/train_callback.py +++ b/khaosz/trainer/train_callback.py @@ -72,8 +72,8 @@ class SchedulerCallback(TrainCallback): """ Scheduler callback for trainer. """ - def __init__(self, scheduler: LRScheduler): - self.scheduler: LRScheduler = scheduler + def __init__(self): + self.scheduler: LRScheduler = None def on_train_begin(self, context: 'TrainContext'): for group in context.optimizer.param_groups: @@ -92,9 +92,16 @@ class CheckpointCallback(TrainCallback): """ Checkpoint callback for trainer. """ - def __init__(self, interval: int, save_dir: str): - self.interval = interval + def __init__( + self, + save_dir: str, + interval: int, + weight_only: bool = False + ): self.save_dir = save_dir + self.interval = interval + self.weight_only = weight_only + self.last_ckpt_iter = 0 @only_on_rank(0) diff --git a/khaosz/trainer/trainer.py b/khaosz/trainer/trainer.py index dec2d0c..5bee86d 100644 --- a/khaosz/trainer/trainer.py +++ b/khaosz/trainer/trainer.py @@ -28,9 +28,9 @@ class Trainer: train_config = self.train_config return [ ProgressBarCallback(train_config.n_epoch), - CheckpointCallback(train_config.checkpoint_interval, train_config.checkpoint_dir), + CheckpointCallback(train_config.checkpoint_dir, train_config.checkpoint_interval), GradientClippingCallback(train_config.max_grad_norm), - SchedulerCallback(train_config.scheduler), + SchedulerCallback(), ] def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext: