fix(trainer): 修复检查点回调参数顺序和权重保存选项
This commit is contained in:
parent
eba99e1f5e
commit
d21682f97a
|
|
@ -52,10 +52,11 @@ class BaseModelIO:
|
||||||
self.config.load(str(paths["config"]))
|
self.config.load(str(paths["config"]))
|
||||||
self.tokenizer.load(str(paths["tokenizer"]))
|
self.tokenizer.load(str(paths["tokenizer"]))
|
||||||
|
|
||||||
|
if self.model is None:
|
||||||
|
self.model = Transformer(self.config)
|
||||||
|
|
||||||
if paths["model"].exists():
|
if paths["model"].exists():
|
||||||
state_dict = st.load_file(str(paths["model"]))
|
state_dict = st.load_file(str(paths["model"]))
|
||||||
if self.model is None:
|
|
||||||
self.model = Transformer(self.config)
|
|
||||||
self.model.load_state_dict(state_dict)
|
self.model.load_state_dict(state_dict)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
|
||||||
|
|
@ -2,16 +2,15 @@ import os
|
||||||
import pickle as pkl
|
import pickle as pkl
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
from torch.optim import Optimizer
|
from torch import Tensor
|
||||||
from torch.optim.lr_scheduler import LRScheduler
|
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
|
||||||
class Checkpoint:
|
class Checkpoint:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
optimizer_state: Optimizer,
|
optimizer_state: Dict[str, Tensor],
|
||||||
scheduler_state: LRScheduler,
|
scheduler_state: Dict[str, Tensor],
|
||||||
epoch: int = 0,
|
epoch: int = 0,
|
||||||
iteration: int = 0,
|
iteration: int = 0,
|
||||||
metrics: Optional[Dict[str, list]] = None,
|
metrics: Optional[Dict[str, list]] = None,
|
||||||
|
|
@ -36,7 +35,7 @@ class Checkpoint:
|
||||||
pkl.dump(train_state, f)
|
pkl.dump(train_state, f)
|
||||||
|
|
||||||
if save_metric_plot and self.metrics:
|
if save_metric_plot and self.metrics:
|
||||||
self._plot_metrics()
|
self._plot_metrics(save_dir)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, save_dir: str) -> "Checkpoint":
|
def load(cls, save_dir: str) -> "Checkpoint":
|
||||||
|
|
@ -56,7 +55,7 @@ class Checkpoint:
|
||||||
metrics=train_state["metrics"]
|
metrics=train_state["metrics"]
|
||||||
)
|
)
|
||||||
|
|
||||||
def _plot_metrics(self):
|
def _plot_metrics(self, save_dir: str):
|
||||||
for metric_name, metric_value in self.metrics.items():
|
for metric_name, metric_value in self.metrics.items():
|
||||||
plt.figure(figsize=(10, 6))
|
plt.figure(figsize=(10, 6))
|
||||||
plt.plot(metric_value, label=metric_name)
|
plt.plot(metric_value, label=metric_name)
|
||||||
|
|
@ -65,5 +64,6 @@ class Checkpoint:
|
||||||
plt.legend()
|
plt.legend()
|
||||||
plt.grid(True, alpha=0.3)
|
plt.grid(True, alpha=0.3)
|
||||||
|
|
||||||
plt.savefig(f'{metric_name}.png', dpi=150, bbox_inches='tight')
|
save_path = os.path.join(save_dir, f"{metric_name}.png")
|
||||||
|
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
||||||
plt.close()
|
plt.close()
|
||||||
|
|
@ -72,8 +72,8 @@ class SchedulerCallback(TrainCallback):
|
||||||
"""
|
"""
|
||||||
Scheduler callback for trainer.
|
Scheduler callback for trainer.
|
||||||
"""
|
"""
|
||||||
def __init__(self, scheduler: LRScheduler):
|
def __init__(self):
|
||||||
self.scheduler: LRScheduler = scheduler
|
self.scheduler: LRScheduler = None
|
||||||
|
|
||||||
def on_train_begin(self, context: 'TrainContext'):
|
def on_train_begin(self, context: 'TrainContext'):
|
||||||
for group in context.optimizer.param_groups:
|
for group in context.optimizer.param_groups:
|
||||||
|
|
@ -92,9 +92,16 @@ class CheckpointCallback(TrainCallback):
|
||||||
"""
|
"""
|
||||||
Checkpoint callback for trainer.
|
Checkpoint callback for trainer.
|
||||||
"""
|
"""
|
||||||
def __init__(self, interval: int, save_dir: str):
|
def __init__(
|
||||||
self.interval = interval
|
self,
|
||||||
|
save_dir: str,
|
||||||
|
interval: int,
|
||||||
|
weight_only: bool = False
|
||||||
|
):
|
||||||
self.save_dir = save_dir
|
self.save_dir = save_dir
|
||||||
|
self.interval = interval
|
||||||
|
self.weight_only = weight_only
|
||||||
|
|
||||||
self.last_ckpt_iter = 0
|
self.last_ckpt_iter = 0
|
||||||
|
|
||||||
@only_on_rank(0)
|
@only_on_rank(0)
|
||||||
|
|
|
||||||
|
|
@ -28,9 +28,9 @@ class Trainer:
|
||||||
train_config = self.train_config
|
train_config = self.train_config
|
||||||
return [
|
return [
|
||||||
ProgressBarCallback(train_config.n_epoch),
|
ProgressBarCallback(train_config.n_epoch),
|
||||||
CheckpointCallback(train_config.checkpoint_interval, train_config.checkpoint_dir),
|
CheckpointCallback(train_config.checkpoint_dir, train_config.checkpoint_interval),
|
||||||
GradientClippingCallback(train_config.max_grad_norm),
|
GradientClippingCallback(train_config.max_grad_norm),
|
||||||
SchedulerCallback(train_config.scheduler),
|
SchedulerCallback(),
|
||||||
]
|
]
|
||||||
|
|
||||||
def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
|
def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue