From 5713b555005aee889dbbd18d2aa7256319669fc1 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Wed, 4 Mar 2026 19:45:39 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E4=BF=AE=E6=94=B9=20StepMonitorCal?= =?UTF-8?q?lback,=20=20=E5=88=86=E7=A6=BB=E8=81=8C=E8=B4=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/data/checkpoint.py | 28 ++---------------- khaosz/trainer/train_callback.py | 51 +++++++++++++++++++++++++------- khaosz/trainer/trainer.py | 4 ++- tests/conftest.py | 7 +---- tests/data/test_checkpoint.py | 15 ++-------- 5 files changed, 48 insertions(+), 57 deletions(-) diff --git a/khaosz/data/checkpoint.py b/khaosz/data/checkpoint.py index 042789b..d18ee2f 100644 --- a/khaosz/data/checkpoint.py +++ b/khaosz/data/checkpoint.py @@ -1,11 +1,9 @@ -import os import json import torch import torch.distributed as dist -import matplotlib.pyplot as plt from pathlib import Path -from typing import Dict, Optional, Any +from typing import Dict, Any from khaosz.parallel.setup import get_rank @@ -15,17 +13,14 @@ class Checkpoint: state_dict: Dict[str, Any], epoch: int = 0, iteration: int = 0, - metrics: Optional[Dict[str, list]] = None, ): self.state_dict = state_dict self.epoch = epoch self.iteration = iteration - self.metrics = metrics or {} def save( self, save_dir: str, - save_metric_plot: bool = True, ) -> None: save_path = Path(save_dir) @@ -36,14 +31,10 @@ class Checkpoint: meta = { "epoch": self.epoch, "iteration": self.iteration, - "metrics": self.metrics, } with open(save_path / "meta.json", "w") as f: json.dump(meta, f, indent=2) - if save_metric_plot and self.metrics: - self._plot_metrics(str(save_path)) - with open(save_path / f"state_dict.pt", "wb") as f: torch.save(self.state_dict, f) @@ -73,19 +64,4 @@ class Checkpoint: state_dict=state_dict, epoch=meta["epoch"], iteration=meta["iteration"], - metrics=meta.get("metrics", {}), - ) - - def _plot_metrics(self, save_dir: str): - for name, values in self.metrics.items(): - if not values: - continue - plt.figure(figsize=(10, 6)) - plt.plot(values, label=name) - plt.xlabel("Step") - plt.ylabel("Value") - plt.title(f"Training Metric: {name}") - plt.legend() - plt.grid(True, alpha=0.3) - plt.savefig(os.path.join(save_dir, f"{name}.png"), dpi=150, bbox_inches="tight") - plt.close() \ No newline at end of file + ) \ No newline at end of file diff --git a/khaosz/trainer/train_callback.py b/khaosz/trainer/train_callback.py index 65da86d..d23dc1b 100644 --- a/khaosz/trainer/train_callback.py +++ b/khaosz/trainer/train_callback.py @@ -7,7 +7,7 @@ from pathlib import Path from tqdm import tqdm from torch.nn.utils import clip_grad_norm_ from torch.optim.lr_scheduler import LRScheduler -from typing import Callable, Optional, Protocol +from typing import Callable, List, Optional, Protocol from khaosz.parallel import only_on_rank from khaosz.trainer.metric_util import ( @@ -104,6 +104,7 @@ class CheckpointCallback(TrainCallback): self.state_dict_fn = state_dict_fn self.last_ckpt_iter = 0 + @only_on_rank(0) def _save_checkpoint(self, context: TrainContext): save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}") state_dict = self.state_dict_fn(context.model) if self.state_dict_fn else context.model.state_dict() @@ -161,14 +162,24 @@ class ProgressBarCallback(TrainCallback): class StepMonitorCallback(TrainCallback): - def __init__(self, log_dir=None, log_interval=100, metrics=None): - + def __init__( + self, + log_dir:str, + save_interval:int, + log_interval:int=10, + metrics:List[str]=None + ): self.step_num = 0 + self.last_save_step = 0 + self.save_interval = save_interval self.log_interval = log_interval self.metrics = metrics or ['loss', 'lr'] + self.log_dir = Path(log_dir) if log_dir else Path.cwd() / "logs" self.log_dir.mkdir(parents=True, exist_ok=True) + self.log_cache = [] + self._metric_funcs = { 'loss': lambda ctx: ctx.loss, 'lr': lambda ctx: ctx.optimizer.param_groups[-1]['lr'], @@ -189,13 +200,31 @@ class StepMonitorCallback(TrainCallback): } @only_on_rank(0) + def _add_log(self, log_data): + self.log_cache.append(log_data) + + @only_on_rank(0) + def _save_log(self, epoch, iter): + log_file = self.log_dir / f"epoch_{epoch}_iter_{iter}_metric.jsonl" + + with open(log_file, 'w') as f: + for log in self.log_cache: + f.write(json.dumps(log) + '\n') + def on_step_end(self, context): + if self.step_num % self.log_interval == 0: + log_data = self._get_log_data(context) + self._add_log(log_data) + + if self.step_num - self.last_save_step >= self.save_interval: + self._save_log(context.epoch, context.iteration) + self.last_save_step = self.step_num + self.step_num += 1 - if self.step_num % self.log_interval != 0: - return - - log_data = self._get_log_data(context) - - log_file = self.log_dir / f"epoch_{context.epoch}.jsonl" - with open(log_file, 'a') as f: - f.write(json.dumps(log_data) + '\n') \ No newline at end of file + + def on_train_end(self, context): + self._save_log(context.epoch, context.iteration) + + def on_error(self, context): + self._save_log(context.epoch, context.iteration) + \ No newline at end of file diff --git a/khaosz/trainer/trainer.py b/khaosz/trainer/trainer.py index 18739bd..b997956 100644 --- a/khaosz/trainer/trainer.py +++ b/khaosz/trainer/trainer.py @@ -4,7 +4,8 @@ from khaosz.config import TrainConfig from khaosz.trainer.train_callback import ( TrainCallback, ProgressBarCallback, - CheckpointCallback, + CheckpointCallback, + StepMonitorCallback, GradientClippingCallback, SchedulerCallback ) @@ -30,6 +31,7 @@ class Trainer: return [ ProgressBarCallback(train_config.n_epoch), CheckpointCallback(train_config.checkpoint_dir, train_config.checkpoint_interval), + StepMonitorCallback(train_config.checkpoint_dir, train_config.checkpoint_interval), GradientClippingCallback(train_config.max_grad_norm), SchedulerCallback(), ] diff --git a/tests/conftest.py b/tests/conftest.py index cb81fa6..032d3f7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,19 +4,14 @@ import numpy as np import tempfile import shutil import torch - import pytest -import matplotlib -from torch.utils.data import Dataset +from torch.utils.data import Dataset from khaosz.config.model_config import ModelConfig from khaosz.data.tokenizer import BpeTokenizer from khaosz.model.transformer import Transformer -matplotlib.use("Agg") - - class RandomDataset(Dataset): def __init__(self, length=None, max_length=64, vocab_size=1000): self.length = length or int(np.random.randint(100, 200)) diff --git a/tests/data/test_checkpoint.py b/tests/data/test_checkpoint.py index 05ae8e5..2d264c5 100644 --- a/tests/data/test_checkpoint.py +++ b/tests/data/test_checkpoint.py @@ -2,7 +2,6 @@ import torch import tempfile import torch.distributed as dist -from pathlib import Path from torch.optim import AdamW from torch.optim.lr_scheduler import CosineAnnealingLR from khaosz.data.checkpoint import Checkpoint @@ -28,25 +27,16 @@ def test_single_process(): checkpoint = Checkpoint( state_dict=model.state_dict(), epoch=3, - iteration=30, - metrics={ - "loss": [0.5, 0.4, 0.3, 0.2, 0.1], - "accuracy": [0.6, 0.7, 0.8, 0.85, 0.9] - } + iteration=30 ) with tempfile.TemporaryDirectory() as tmpdir: - checkpoint.save(tmpdir, save_metric_plot=True) + checkpoint.save(tmpdir) loaded_checkpoint = Checkpoint.load(tmpdir) assert loaded_checkpoint.epoch == 3 assert loaded_checkpoint.iteration == 30 - assert loaded_checkpoint.metrics["loss"] == [0.5, 0.4, 0.3, 0.2, 0.1] - - png_files = list(Path(tmpdir).glob("*.png")) - assert png_files - def simple_training(): model = torch.nn.Linear(10, 5) optimizer = AdamW(model.parameters(), lr=1e-3) @@ -66,7 +56,6 @@ def simple_training(): state_dict=model.state_dict(), epoch=2, iteration=10, - metrics={"loss": [0.3, 0.2, 0.1]} ) rank = get_rank()