fix(trainer): 修复检查点回调参数顺序和权重保存选项

This commit is contained in:
ViperEkura 2026-01-05 17:08:09 +08:00
parent eba99e1f5e
commit d21682f97a
4 changed files with 23 additions and 15 deletions

View File

@ -52,10 +52,11 @@ class BaseModelIO:
self.config.load(str(paths["config"])) self.config.load(str(paths["config"]))
self.tokenizer.load(str(paths["tokenizer"])) self.tokenizer.load(str(paths["tokenizer"]))
if self.model is None:
self.model = Transformer(self.config)
if paths["model"].exists(): if paths["model"].exists():
state_dict = st.load_file(str(paths["model"])) state_dict = st.load_file(str(paths["model"]))
if self.model is None:
self.model = Transformer(self.config)
self.model.load_state_dict(state_dict) self.model.load_state_dict(state_dict)
return self return self

View File

@ -2,16 +2,15 @@ import os
import pickle as pkl import pickle as pkl
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from torch.optim import Optimizer from torch import Tensor
from torch.optim.lr_scheduler import LRScheduler
from typing import Dict, Optional from typing import Dict, Optional
class Checkpoint: class Checkpoint:
def __init__( def __init__(
self, self,
optimizer_state: Optimizer, optimizer_state: Dict[str, Tensor],
scheduler_state: LRScheduler, scheduler_state: Dict[str, Tensor],
epoch: int = 0, epoch: int = 0,
iteration: int = 0, iteration: int = 0,
metrics: Optional[Dict[str, list]] = None, metrics: Optional[Dict[str, list]] = None,
@ -36,7 +35,7 @@ class Checkpoint:
pkl.dump(train_state, f) pkl.dump(train_state, f)
if save_metric_plot and self.metrics: if save_metric_plot and self.metrics:
self._plot_metrics() self._plot_metrics(save_dir)
@classmethod @classmethod
def load(cls, save_dir: str) -> "Checkpoint": def load(cls, save_dir: str) -> "Checkpoint":
@ -56,7 +55,7 @@ class Checkpoint:
metrics=train_state["metrics"] metrics=train_state["metrics"]
) )
def _plot_metrics(self): def _plot_metrics(self, save_dir: str):
for metric_name, metric_value in self.metrics.items(): for metric_name, metric_value in self.metrics.items():
plt.figure(figsize=(10, 6)) plt.figure(figsize=(10, 6))
plt.plot(metric_value, label=metric_name) plt.plot(metric_value, label=metric_name)
@ -65,5 +64,6 @@ class Checkpoint:
plt.legend() plt.legend()
plt.grid(True, alpha=0.3) plt.grid(True, alpha=0.3)
plt.savefig(f'{metric_name}.png', dpi=150, bbox_inches='tight') save_path = os.path.join(save_dir, f"{metric_name}.png")
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.close() plt.close()

View File

@ -72,8 +72,8 @@ class SchedulerCallback(TrainCallback):
""" """
Scheduler callback for trainer. Scheduler callback for trainer.
""" """
def __init__(self, scheduler: LRScheduler): def __init__(self):
self.scheduler: LRScheduler = scheduler self.scheduler: LRScheduler = None
def on_train_begin(self, context: 'TrainContext'): def on_train_begin(self, context: 'TrainContext'):
for group in context.optimizer.param_groups: for group in context.optimizer.param_groups:
@ -92,9 +92,16 @@ class CheckpointCallback(TrainCallback):
""" """
Checkpoint callback for trainer. Checkpoint callback for trainer.
""" """
def __init__(self, interval: int, save_dir: str): def __init__(
self.interval = interval self,
save_dir: str,
interval: int,
weight_only: bool = False
):
self.save_dir = save_dir self.save_dir = save_dir
self.interval = interval
self.weight_only = weight_only
self.last_ckpt_iter = 0 self.last_ckpt_iter = 0
@only_on_rank(0) @only_on_rank(0)

View File

@ -28,9 +28,9 @@ class Trainer:
train_config = self.train_config train_config = self.train_config
return [ return [
ProgressBarCallback(train_config.n_epoch), ProgressBarCallback(train_config.n_epoch),
CheckpointCallback(train_config.checkpoint_interval, train_config.checkpoint_dir), CheckpointCallback(train_config.checkpoint_dir, train_config.checkpoint_interval),
GradientClippingCallback(train_config.max_grad_norm), GradientClippingCallback(train_config.max_grad_norm),
SchedulerCallback(train_config.scheduler), SchedulerCallback(),
] ]
def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext: def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext: