refactor: 修改 StepMonitorCallback, 分离职责

This commit is contained in:
ViperEkura 2026-03-04 19:45:39 +08:00
parent b53e10aac4
commit 5713b55500
5 changed files with 48 additions and 57 deletions

View File

@ -1,11 +1,9 @@
import os
import json import json
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import matplotlib.pyplot as plt
from pathlib import Path from pathlib import Path
from typing import Dict, Optional, Any from typing import Dict, Any
from khaosz.parallel.setup import get_rank from khaosz.parallel.setup import get_rank
@ -15,17 +13,14 @@ class Checkpoint:
state_dict: Dict[str, Any], state_dict: Dict[str, Any],
epoch: int = 0, epoch: int = 0,
iteration: int = 0, iteration: int = 0,
metrics: Optional[Dict[str, list]] = None,
): ):
self.state_dict = state_dict self.state_dict = state_dict
self.epoch = epoch self.epoch = epoch
self.iteration = iteration self.iteration = iteration
self.metrics = metrics or {}
def save( def save(
self, self,
save_dir: str, save_dir: str,
save_metric_plot: bool = True,
) -> None: ) -> None:
save_path = Path(save_dir) save_path = Path(save_dir)
@ -36,14 +31,10 @@ class Checkpoint:
meta = { meta = {
"epoch": self.epoch, "epoch": self.epoch,
"iteration": self.iteration, "iteration": self.iteration,
"metrics": self.metrics,
} }
with open(save_path / "meta.json", "w") as f: with open(save_path / "meta.json", "w") as f:
json.dump(meta, f, indent=2) json.dump(meta, f, indent=2)
if save_metric_plot and self.metrics:
self._plot_metrics(str(save_path))
with open(save_path / f"state_dict.pt", "wb") as f: with open(save_path / f"state_dict.pt", "wb") as f:
torch.save(self.state_dict, f) torch.save(self.state_dict, f)
@ -73,19 +64,4 @@ class Checkpoint:
state_dict=state_dict, state_dict=state_dict,
epoch=meta["epoch"], epoch=meta["epoch"],
iteration=meta["iteration"], iteration=meta["iteration"],
metrics=meta.get("metrics", {}),
) )
def _plot_metrics(self, save_dir: str):
for name, values in self.metrics.items():
if not values:
continue
plt.figure(figsize=(10, 6))
plt.plot(values, label=name)
plt.xlabel("Step")
plt.ylabel("Value")
plt.title(f"Training Metric: {name}")
plt.legend()
plt.grid(True, alpha=0.3)
plt.savefig(os.path.join(save_dir, f"{name}.png"), dpi=150, bbox_inches="tight")
plt.close()

View File

@ -7,7 +7,7 @@ from pathlib import Path
from tqdm import tqdm from tqdm import tqdm
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import LRScheduler from torch.optim.lr_scheduler import LRScheduler
from typing import Callable, Optional, Protocol from typing import Callable, List, Optional, Protocol
from khaosz.parallel import only_on_rank from khaosz.parallel import only_on_rank
from khaosz.trainer.metric_util import ( from khaosz.trainer.metric_util import (
@ -104,6 +104,7 @@ class CheckpointCallback(TrainCallback):
self.state_dict_fn = state_dict_fn self.state_dict_fn = state_dict_fn
self.last_ckpt_iter = 0 self.last_ckpt_iter = 0
@only_on_rank(0)
def _save_checkpoint(self, context: TrainContext): def _save_checkpoint(self, context: TrainContext):
save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}") save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}")
state_dict = self.state_dict_fn(context.model) if self.state_dict_fn else context.model.state_dict() state_dict = self.state_dict_fn(context.model) if self.state_dict_fn else context.model.state_dict()
@ -161,14 +162,24 @@ class ProgressBarCallback(TrainCallback):
class StepMonitorCallback(TrainCallback): class StepMonitorCallback(TrainCallback):
def __init__(self, log_dir=None, log_interval=100, metrics=None): def __init__(
self,
log_dir:str,
save_interval:int,
log_interval:int=10,
metrics:List[str]=None
):
self.step_num = 0 self.step_num = 0
self.last_save_step = 0
self.save_interval = save_interval
self.log_interval = log_interval self.log_interval = log_interval
self.metrics = metrics or ['loss', 'lr'] self.metrics = metrics or ['loss', 'lr']
self.log_dir = Path(log_dir) if log_dir else Path.cwd() / "logs" self.log_dir = Path(log_dir) if log_dir else Path.cwd() / "logs"
self.log_dir.mkdir(parents=True, exist_ok=True) self.log_dir.mkdir(parents=True, exist_ok=True)
self.log_cache = []
self._metric_funcs = { self._metric_funcs = {
'loss': lambda ctx: ctx.loss, 'loss': lambda ctx: ctx.loss,
'lr': lambda ctx: ctx.optimizer.param_groups[-1]['lr'], 'lr': lambda ctx: ctx.optimizer.param_groups[-1]['lr'],
@ -189,13 +200,31 @@ class StepMonitorCallback(TrainCallback):
} }
@only_on_rank(0) @only_on_rank(0)
def _add_log(self, log_data):
self.log_cache.append(log_data)
@only_on_rank(0)
def _save_log(self, epoch, iter):
log_file = self.log_dir / f"epoch_{epoch}_iter_{iter}_metric.jsonl"
with open(log_file, 'w') as f:
for log in self.log_cache:
f.write(json.dumps(log) + '\n')
def on_step_end(self, context): def on_step_end(self, context):
self.step_num += 1 if self.step_num % self.log_interval == 0:
if self.step_num % self.log_interval != 0:
return
log_data = self._get_log_data(context) log_data = self._get_log_data(context)
self._add_log(log_data)
if self.step_num - self.last_save_step >= self.save_interval:
self._save_log(context.epoch, context.iteration)
self.last_save_step = self.step_num
self.step_num += 1
def on_train_end(self, context):
self._save_log(context.epoch, context.iteration)
def on_error(self, context):
self._save_log(context.epoch, context.iteration)
log_file = self.log_dir / f"epoch_{context.epoch}.jsonl"
with open(log_file, 'a') as f:
f.write(json.dumps(log_data) + '\n')

View File

@ -5,6 +5,7 @@ from khaosz.trainer.train_callback import (
TrainCallback, TrainCallback,
ProgressBarCallback, ProgressBarCallback,
CheckpointCallback, CheckpointCallback,
StepMonitorCallback,
GradientClippingCallback, GradientClippingCallback,
SchedulerCallback SchedulerCallback
) )
@ -30,6 +31,7 @@ class Trainer:
return [ return [
ProgressBarCallback(train_config.n_epoch), ProgressBarCallback(train_config.n_epoch),
CheckpointCallback(train_config.checkpoint_dir, train_config.checkpoint_interval), CheckpointCallback(train_config.checkpoint_dir, train_config.checkpoint_interval),
StepMonitorCallback(train_config.checkpoint_dir, train_config.checkpoint_interval),
GradientClippingCallback(train_config.max_grad_norm), GradientClippingCallback(train_config.max_grad_norm),
SchedulerCallback(), SchedulerCallback(),
] ]

View File

@ -4,19 +4,14 @@ import numpy as np
import tempfile import tempfile
import shutil import shutil
import torch import torch
import pytest import pytest
import matplotlib
from torch.utils.data import Dataset
from torch.utils.data import Dataset
from khaosz.config.model_config import ModelConfig from khaosz.config.model_config import ModelConfig
from khaosz.data.tokenizer import BpeTokenizer from khaosz.data.tokenizer import BpeTokenizer
from khaosz.model.transformer import Transformer from khaosz.model.transformer import Transformer
matplotlib.use("Agg")
class RandomDataset(Dataset): class RandomDataset(Dataset):
def __init__(self, length=None, max_length=64, vocab_size=1000): def __init__(self, length=None, max_length=64, vocab_size=1000):
self.length = length or int(np.random.randint(100, 200)) self.length = length or int(np.random.randint(100, 200))

View File

@ -2,7 +2,6 @@ import torch
import tempfile import tempfile
import torch.distributed as dist import torch.distributed as dist
from pathlib import Path
from torch.optim import AdamW from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR from torch.optim.lr_scheduler import CosineAnnealingLR
from khaosz.data.checkpoint import Checkpoint from khaosz.data.checkpoint import Checkpoint
@ -28,25 +27,16 @@ def test_single_process():
checkpoint = Checkpoint( checkpoint = Checkpoint(
state_dict=model.state_dict(), state_dict=model.state_dict(),
epoch=3, epoch=3,
iteration=30, iteration=30
metrics={
"loss": [0.5, 0.4, 0.3, 0.2, 0.1],
"accuracy": [0.6, 0.7, 0.8, 0.85, 0.9]
}
) )
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
checkpoint.save(tmpdir, save_metric_plot=True) checkpoint.save(tmpdir)
loaded_checkpoint = Checkpoint.load(tmpdir) loaded_checkpoint = Checkpoint.load(tmpdir)
assert loaded_checkpoint.epoch == 3 assert loaded_checkpoint.epoch == 3
assert loaded_checkpoint.iteration == 30 assert loaded_checkpoint.iteration == 30
assert loaded_checkpoint.metrics["loss"] == [0.5, 0.4, 0.3, 0.2, 0.1]
png_files = list(Path(tmpdir).glob("*.png"))
assert png_files
def simple_training(): def simple_training():
model = torch.nn.Linear(10, 5) model = torch.nn.Linear(10, 5)
optimizer = AdamW(model.parameters(), lr=1e-3) optimizer = AdamW(model.parameters(), lr=1e-3)
@ -66,7 +56,6 @@ def simple_training():
state_dict=model.state_dict(), state_dict=model.state_dict(),
epoch=2, epoch=2,
iteration=10, iteration=10,
metrics={"loss": [0.3, 0.2, 0.1]}
) )
rank = get_rank() rank = get_rank()