fix(trainer): 修复检查点回调参数顺序和权重保存选项
This commit is contained in:
parent
eba99e1f5e
commit
d21682f97a
|
|
@ -52,10 +52,11 @@ class BaseModelIO:
|
|||
self.config.load(str(paths["config"]))
|
||||
self.tokenizer.load(str(paths["tokenizer"]))
|
||||
|
||||
if self.model is None:
|
||||
self.model = Transformer(self.config)
|
||||
|
||||
if paths["model"].exists():
|
||||
state_dict = st.load_file(str(paths["model"]))
|
||||
if self.model is None:
|
||||
self.model = Transformer(self.config)
|
||||
self.model.load_state_dict(state_dict)
|
||||
|
||||
return self
|
||||
|
|
|
|||
|
|
@ -2,16 +2,15 @@ import os
|
|||
import pickle as pkl
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
from torch import Tensor
|
||||
from typing import Dict, Optional
|
||||
|
||||
|
||||
class Checkpoint:
|
||||
def __init__(
|
||||
self,
|
||||
optimizer_state: Optimizer,
|
||||
scheduler_state: LRScheduler,
|
||||
optimizer_state: Dict[str, Tensor],
|
||||
scheduler_state: Dict[str, Tensor],
|
||||
epoch: int = 0,
|
||||
iteration: int = 0,
|
||||
metrics: Optional[Dict[str, list]] = None,
|
||||
|
|
@ -36,7 +35,7 @@ class Checkpoint:
|
|||
pkl.dump(train_state, f)
|
||||
|
||||
if save_metric_plot and self.metrics:
|
||||
self._plot_metrics()
|
||||
self._plot_metrics(save_dir)
|
||||
|
||||
@classmethod
|
||||
def load(cls, save_dir: str) -> "Checkpoint":
|
||||
|
|
@ -56,7 +55,7 @@ class Checkpoint:
|
|||
metrics=train_state["metrics"]
|
||||
)
|
||||
|
||||
def _plot_metrics(self):
|
||||
def _plot_metrics(self, save_dir: str):
|
||||
for metric_name, metric_value in self.metrics.items():
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.plot(metric_value, label=metric_name)
|
||||
|
|
@ -65,5 +64,6 @@ class Checkpoint:
|
|||
plt.legend()
|
||||
plt.grid(True, alpha=0.3)
|
||||
|
||||
plt.savefig(f'{metric_name}.png', dpi=150, bbox_inches='tight')
|
||||
save_path = os.path.join(save_dir, f"{metric_name}.png")
|
||||
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
|
@ -72,8 +72,8 @@ class SchedulerCallback(TrainCallback):
|
|||
"""
|
||||
Scheduler callback for trainer.
|
||||
"""
|
||||
def __init__(self, scheduler: LRScheduler):
|
||||
self.scheduler: LRScheduler = scheduler
|
||||
def __init__(self):
|
||||
self.scheduler: LRScheduler = None
|
||||
|
||||
def on_train_begin(self, context: 'TrainContext'):
|
||||
for group in context.optimizer.param_groups:
|
||||
|
|
@ -92,9 +92,16 @@ class CheckpointCallback(TrainCallback):
|
|||
"""
|
||||
Checkpoint callback for trainer.
|
||||
"""
|
||||
def __init__(self, interval: int, save_dir: str):
|
||||
self.interval = interval
|
||||
def __init__(
|
||||
self,
|
||||
save_dir: str,
|
||||
interval: int,
|
||||
weight_only: bool = False
|
||||
):
|
||||
self.save_dir = save_dir
|
||||
self.interval = interval
|
||||
self.weight_only = weight_only
|
||||
|
||||
self.last_ckpt_iter = 0
|
||||
|
||||
@only_on_rank(0)
|
||||
|
|
|
|||
|
|
@ -28,9 +28,9 @@ class Trainer:
|
|||
train_config = self.train_config
|
||||
return [
|
||||
ProgressBarCallback(train_config.n_epoch),
|
||||
CheckpointCallback(train_config.checkpoint_interval, train_config.checkpoint_dir),
|
||||
CheckpointCallback(train_config.checkpoint_dir, train_config.checkpoint_interval),
|
||||
GradientClippingCallback(train_config.max_grad_norm),
|
||||
SchedulerCallback(train_config.scheduler),
|
||||
SchedulerCallback(),
|
||||
]
|
||||
|
||||
def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
|
||||
|
|
|
|||
Loading…
Reference in New Issue