refactor: 修改 StepMonitorCallback, 分离职责

This commit is contained in:
ViperEkura 2026-03-04 19:45:39 +08:00
parent b53e10aac4
commit 5713b55500
5 changed files with 48 additions and 57 deletions

View File

@ -1,11 +1,9 @@
import os
import json
import torch
import torch.distributed as dist
import matplotlib.pyplot as plt
from pathlib import Path
from typing import Dict, Optional, Any
from typing import Dict, Any
from khaosz.parallel.setup import get_rank
@ -15,17 +13,14 @@ class Checkpoint:
state_dict: Dict[str, Any],
epoch: int = 0,
iteration: int = 0,
metrics: Optional[Dict[str, list]] = None,
):
self.state_dict = state_dict
self.epoch = epoch
self.iteration = iteration
self.metrics = metrics or {}
def save(
self,
save_dir: str,
save_metric_plot: bool = True,
) -> None:
save_path = Path(save_dir)
@ -36,14 +31,10 @@ class Checkpoint:
meta = {
"epoch": self.epoch,
"iteration": self.iteration,
"metrics": self.metrics,
}
with open(save_path / "meta.json", "w") as f:
json.dump(meta, f, indent=2)
if save_metric_plot and self.metrics:
self._plot_metrics(str(save_path))
with open(save_path / f"state_dict.pt", "wb") as f:
torch.save(self.state_dict, f)
@ -73,19 +64,4 @@ class Checkpoint:
state_dict=state_dict,
epoch=meta["epoch"],
iteration=meta["iteration"],
metrics=meta.get("metrics", {}),
)
def _plot_metrics(self, save_dir: str):
for name, values in self.metrics.items():
if not values:
continue
plt.figure(figsize=(10, 6))
plt.plot(values, label=name)
plt.xlabel("Step")
plt.ylabel("Value")
plt.title(f"Training Metric: {name}")
plt.legend()
plt.grid(True, alpha=0.3)
plt.savefig(os.path.join(save_dir, f"{name}.png"), dpi=150, bbox_inches="tight")
plt.close()
)

View File

@ -7,7 +7,7 @@ from pathlib import Path
from tqdm import tqdm
from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import LRScheduler
from typing import Callable, Optional, Protocol
from typing import Callable, List, Optional, Protocol
from khaosz.parallel import only_on_rank
from khaosz.trainer.metric_util import (
@ -104,6 +104,7 @@ class CheckpointCallback(TrainCallback):
self.state_dict_fn = state_dict_fn
self.last_ckpt_iter = 0
@only_on_rank(0)
def _save_checkpoint(self, context: TrainContext):
save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}")
state_dict = self.state_dict_fn(context.model) if self.state_dict_fn else context.model.state_dict()
@ -161,14 +162,24 @@ class ProgressBarCallback(TrainCallback):
class StepMonitorCallback(TrainCallback):
def __init__(self, log_dir=None, log_interval=100, metrics=None):
def __init__(
self,
log_dir:str,
save_interval:int,
log_interval:int=10,
metrics:List[str]=None
):
self.step_num = 0
self.last_save_step = 0
self.save_interval = save_interval
self.log_interval = log_interval
self.metrics = metrics or ['loss', 'lr']
self.log_dir = Path(log_dir) if log_dir else Path.cwd() / "logs"
self.log_dir.mkdir(parents=True, exist_ok=True)
self.log_cache = []
self._metric_funcs = {
'loss': lambda ctx: ctx.loss,
'lr': lambda ctx: ctx.optimizer.param_groups[-1]['lr'],
@ -189,13 +200,31 @@ class StepMonitorCallback(TrainCallback):
}
@only_on_rank(0)
def _add_log(self, log_data):
self.log_cache.append(log_data)
@only_on_rank(0)
def _save_log(self, epoch, iter):
log_file = self.log_dir / f"epoch_{epoch}_iter_{iter}_metric.jsonl"
with open(log_file, 'w') as f:
for log in self.log_cache:
f.write(json.dumps(log) + '\n')
def on_step_end(self, context):
if self.step_num % self.log_interval == 0:
log_data = self._get_log_data(context)
self._add_log(log_data)
if self.step_num - self.last_save_step >= self.save_interval:
self._save_log(context.epoch, context.iteration)
self.last_save_step = self.step_num
self.step_num += 1
if self.step_num % self.log_interval != 0:
return
log_data = self._get_log_data(context)
log_file = self.log_dir / f"epoch_{context.epoch}.jsonl"
with open(log_file, 'a') as f:
f.write(json.dumps(log_data) + '\n')
def on_train_end(self, context):
self._save_log(context.epoch, context.iteration)
def on_error(self, context):
self._save_log(context.epoch, context.iteration)

View File

@ -4,7 +4,8 @@ from khaosz.config import TrainConfig
from khaosz.trainer.train_callback import (
TrainCallback,
ProgressBarCallback,
CheckpointCallback,
CheckpointCallback,
StepMonitorCallback,
GradientClippingCallback,
SchedulerCallback
)
@ -30,6 +31,7 @@ class Trainer:
return [
ProgressBarCallback(train_config.n_epoch),
CheckpointCallback(train_config.checkpoint_dir, train_config.checkpoint_interval),
StepMonitorCallback(train_config.checkpoint_dir, train_config.checkpoint_interval),
GradientClippingCallback(train_config.max_grad_norm),
SchedulerCallback(),
]

View File

@ -4,19 +4,14 @@ import numpy as np
import tempfile
import shutil
import torch
import pytest
import matplotlib
from torch.utils.data import Dataset
from torch.utils.data import Dataset
from khaosz.config.model_config import ModelConfig
from khaosz.data.tokenizer import BpeTokenizer
from khaosz.model.transformer import Transformer
matplotlib.use("Agg")
class RandomDataset(Dataset):
def __init__(self, length=None, max_length=64, vocab_size=1000):
self.length = length or int(np.random.randint(100, 200))

View File

@ -2,7 +2,6 @@ import torch
import tempfile
import torch.distributed as dist
from pathlib import Path
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from khaosz.data.checkpoint import Checkpoint
@ -28,25 +27,16 @@ def test_single_process():
checkpoint = Checkpoint(
state_dict=model.state_dict(),
epoch=3,
iteration=30,
metrics={
"loss": [0.5, 0.4, 0.3, 0.2, 0.1],
"accuracy": [0.6, 0.7, 0.8, 0.85, 0.9]
}
iteration=30
)
with tempfile.TemporaryDirectory() as tmpdir:
checkpoint.save(tmpdir, save_metric_plot=True)
checkpoint.save(tmpdir)
loaded_checkpoint = Checkpoint.load(tmpdir)
assert loaded_checkpoint.epoch == 3
assert loaded_checkpoint.iteration == 30
assert loaded_checkpoint.metrics["loss"] == [0.5, 0.4, 0.3, 0.2, 0.1]
png_files = list(Path(tmpdir).glob("*.png"))
assert png_files
def simple_training():
model = torch.nn.Linear(10, 5)
optimizer = AdamW(model.parameters(), lr=1e-3)
@ -66,7 +56,6 @@ def simple_training():
state_dict=model.state_dict(),
epoch=2,
iteration=10,
metrics={"loss": [0.3, 0.2, 0.1]}
)
rank = get_rank()