fix(trainer): 修复检查点回调参数顺序和权重保存选项

This commit is contained in:
ViperEkura 2026-01-05 17:08:09 +08:00
parent eba99e1f5e
commit d21682f97a
4 changed files with 23 additions and 15 deletions

View File

@ -52,10 +52,11 @@ class BaseModelIO:
self.config.load(str(paths["config"]))
self.tokenizer.load(str(paths["tokenizer"]))
if self.model is None:
self.model = Transformer(self.config)
if paths["model"].exists():
state_dict = st.load_file(str(paths["model"]))
if self.model is None:
self.model = Transformer(self.config)
self.model.load_state_dict(state_dict)
return self

View File

@ -2,16 +2,15 @@ import os
import pickle as pkl
import matplotlib.pyplot as plt
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from torch import Tensor
from typing import Dict, Optional
class Checkpoint:
def __init__(
self,
optimizer_state: Optimizer,
scheduler_state: LRScheduler,
optimizer_state: Dict[str, Tensor],
scheduler_state: Dict[str, Tensor],
epoch: int = 0,
iteration: int = 0,
metrics: Optional[Dict[str, list]] = None,
@ -36,7 +35,7 @@ class Checkpoint:
pkl.dump(train_state, f)
if save_metric_plot and self.metrics:
self._plot_metrics()
self._plot_metrics(save_dir)
@classmethod
def load(cls, save_dir: str) -> "Checkpoint":
@ -56,7 +55,7 @@ class Checkpoint:
metrics=train_state["metrics"]
)
def _plot_metrics(self):
def _plot_metrics(self, save_dir: str):
for metric_name, metric_value in self.metrics.items():
plt.figure(figsize=(10, 6))
plt.plot(metric_value, label=metric_name)
@ -65,5 +64,6 @@ class Checkpoint:
plt.legend()
plt.grid(True, alpha=0.3)
plt.savefig(f'{metric_name}.png', dpi=150, bbox_inches='tight')
save_path = os.path.join(save_dir, f"{metric_name}.png")
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.close()

View File

@ -72,8 +72,8 @@ class SchedulerCallback(TrainCallback):
"""
Scheduler callback for trainer.
"""
def __init__(self, scheduler: LRScheduler):
self.scheduler: LRScheduler = scheduler
def __init__(self):
self.scheduler: LRScheduler = None
def on_train_begin(self, context: 'TrainContext'):
for group in context.optimizer.param_groups:
@ -92,9 +92,16 @@ class CheckpointCallback(TrainCallback):
"""
Checkpoint callback for trainer.
"""
def __init__(self, interval: int, save_dir: str):
self.interval = interval
def __init__(
self,
save_dir: str,
interval: int,
weight_only: bool = False
):
self.save_dir = save_dir
self.interval = interval
self.weight_only = weight_only
self.last_ckpt_iter = 0
@only_on_rank(0)

View File

@ -28,9 +28,9 @@ class Trainer:
train_config = self.train_config
return [
ProgressBarCallback(train_config.n_epoch),
CheckpointCallback(train_config.checkpoint_interval, train_config.checkpoint_dir),
CheckpointCallback(train_config.checkpoint_dir, train_config.checkpoint_interval),
GradientClippingCallback(train_config.max_grad_norm),
SchedulerCallback(train_config.scheduler),
SchedulerCallback(),
]
def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext: