From d21682f97a0a2d2386082f4c60ebb44a0920e96c Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Mon, 5 Jan 2026 17:08:09 +0800 Subject: [PATCH] =?UTF-8?q?fix(trainer):=20=E4=BF=AE=E5=A4=8D=E6=A3=80?= =?UTF-8?q?=E6=9F=A5=E7=82=B9=E5=9B=9E=E8=B0=83=E5=8F=82=E6=95=B0=E9=A1=BA?= =?UTF-8?q?=E5=BA=8F=E5=92=8C=E6=9D=83=E9=87=8D=E4=BF=9D=E5=AD=98=E9=80=89?= =?UTF-8?q?=E9=A1=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/config/param_config.py | 5 +++-- khaosz/trainer/checkpoint.py | 14 +++++++------- khaosz/trainer/train_callback.py | 15 +++++++++++---- khaosz/trainer/trainer.py | 4 ++-- 4 files changed, 23 insertions(+), 15 deletions(-) diff --git a/khaosz/config/param_config.py b/khaosz/config/param_config.py index 47120e5..4d29813 100644 --- a/khaosz/config/param_config.py +++ b/khaosz/config/param_config.py @@ -52,10 +52,11 @@ class BaseModelIO: self.config.load(str(paths["config"])) self.tokenizer.load(str(paths["tokenizer"])) + if self.model is None: + self.model = Transformer(self.config) + if paths["model"].exists(): state_dict = st.load_file(str(paths["model"])) - if self.model is None: - self.model = Transformer(self.config) self.model.load_state_dict(state_dict) return self diff --git a/khaosz/trainer/checkpoint.py b/khaosz/trainer/checkpoint.py index a4ff101..10a9331 100644 --- a/khaosz/trainer/checkpoint.py +++ b/khaosz/trainer/checkpoint.py @@ -2,16 +2,15 @@ import os import pickle as pkl import matplotlib.pyplot as plt -from torch.optim import Optimizer -from torch.optim.lr_scheduler import LRScheduler +from torch import Tensor from typing import Dict, Optional class Checkpoint: def __init__( self, - optimizer_state: Optimizer, - scheduler_state: LRScheduler, + optimizer_state: Dict[str, Tensor], + scheduler_state: Dict[str, Tensor], epoch: int = 0, iteration: int = 0, metrics: Optional[Dict[str, list]] = None, @@ -36,7 +35,7 @@ class Checkpoint: pkl.dump(train_state, f) if save_metric_plot and self.metrics: - self._plot_metrics() + self._plot_metrics(save_dir) @classmethod def load(cls, save_dir: str) -> "Checkpoint": @@ -56,7 +55,7 @@ class Checkpoint: metrics=train_state["metrics"] ) - def _plot_metrics(self): + def _plot_metrics(self, save_dir: str): for metric_name, metric_value in self.metrics.items(): plt.figure(figsize=(10, 6)) plt.plot(metric_value, label=metric_name) @@ -65,5 +64,6 @@ class Checkpoint: plt.legend() plt.grid(True, alpha=0.3) - plt.savefig(f'{metric_name}.png', dpi=150, bbox_inches='tight') + save_path = os.path.join(save_dir, f"{metric_name}.png") + plt.savefig(save_path, dpi=150, bbox_inches='tight') plt.close() \ No newline at end of file diff --git a/khaosz/trainer/train_callback.py b/khaosz/trainer/train_callback.py index a44f0e2..9e6b679 100644 --- a/khaosz/trainer/train_callback.py +++ b/khaosz/trainer/train_callback.py @@ -72,8 +72,8 @@ class SchedulerCallback(TrainCallback): """ Scheduler callback for trainer. """ - def __init__(self, scheduler: LRScheduler): - self.scheduler: LRScheduler = scheduler + def __init__(self): + self.scheduler: LRScheduler = None def on_train_begin(self, context: 'TrainContext'): for group in context.optimizer.param_groups: @@ -92,9 +92,16 @@ class CheckpointCallback(TrainCallback): """ Checkpoint callback for trainer. """ - def __init__(self, interval: int, save_dir: str): - self.interval = interval + def __init__( + self, + save_dir: str, + interval: int, + weight_only: bool = False + ): self.save_dir = save_dir + self.interval = interval + self.weight_only = weight_only + self.last_ckpt_iter = 0 @only_on_rank(0) diff --git a/khaosz/trainer/trainer.py b/khaosz/trainer/trainer.py index dec2d0c..5bee86d 100644 --- a/khaosz/trainer/trainer.py +++ b/khaosz/trainer/trainer.py @@ -28,9 +28,9 @@ class Trainer: train_config = self.train_config return [ ProgressBarCallback(train_config.n_epoch), - CheckpointCallback(train_config.checkpoint_interval, train_config.checkpoint_dir), + CheckpointCallback(train_config.checkpoint_dir, train_config.checkpoint_interval), GradientClippingCallback(train_config.max_grad_norm), - SchedulerCallback(train_config.scheduler), + SchedulerCallback(), ] def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext: