From 12793bc2d31d500ef6f53b04f436ce559a7728e0 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Tue, 7 Oct 2025 13:03:32 +0800 Subject: [PATCH] =?UTF-8?q?feat(khaosz/trainer):=20=E6=96=B0=E5=A2=9E?= =?UTF-8?q?=E6=A2=AF=E5=BA=A6=E7=BB=9F=E8=AE=A1=E5=B7=A5=E5=85=B7=E5=87=BD?= =?UTF-8?q?=E6=95=B0=E5=B9=B6=E9=87=8D=E6=9E=84=E8=AE=AD=E7=BB=83=E5=9B=9E?= =?UTF-8?q?=E8=B0=83=E6=9C=BA=E5=88=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/trainer/__init__.py | 4 +- khaosz/trainer/metric_util.py | 65 ++++++++++ khaosz/trainer/train_callback.py | 204 +++++++++++++++++++++++-------- khaosz/trainer/trainer.py | 1 + 4 files changed, 219 insertions(+), 55 deletions(-) create mode 100644 khaosz/trainer/metric_util.py diff --git a/khaosz/trainer/__init__.py b/khaosz/trainer/__init__.py index b4df66c..5c15fbd 100644 --- a/khaosz/trainer/__init__.py +++ b/khaosz/trainer/__init__.py @@ -12,7 +12,8 @@ from khaosz.trainer.train_callback import ( ProgressBarCallback, CheckpointCallback, TrainCallback, - SchedulerCallback + SchedulerCallback, + StepMonitorCallback ) __all__ = [ @@ -30,4 +31,5 @@ __all__ = [ "CheckpointCallback", "TrainCallback", "SchedulerCallback", + "StepMonitorCallback" ] \ No newline at end of file diff --git a/khaosz/trainer/metric_util.py b/khaosz/trainer/metric_util.py new file mode 100644 index 0000000..8710f11 --- /dev/null +++ b/khaosz/trainer/metric_util.py @@ -0,0 +1,65 @@ +import torch.nn as nn +from typing import Dict + +def grad_norm(model: nn.Module, norm_type: int = 2) -> Dict[str, float]: + """ Compute gradient norm for each parameter in the model. """ + norms = {} + for name, param in model.named_parameters(): + norms[name] = 0.0 + if param.grad: + norm = param.grad.data.norm(norm_type).item() + norms[name] = norm + return norms + +def grad_std(model: nn.Module) -> Dict[str, float]: + """ Compute standard deviation of gradients for each parameter. """ + stds = {} + for name, param in model.named_parameters(): + stds[name] = 0.0 + if param.grad: + std = param.grad.data.std().item() + stds[name] = std + return stds + +def grad_max(model: nn.Module) -> Dict[str, float]: + """ Find the maximum absolute gradient value for each parameter. """ + max_vals = {} + for name, param in model.named_parameters(): + max_vals[name] = -float('inf') + if param.grad: + max_val = param.grad.data.max().item() + max_vals[name] = max_val + + return max_vals + +def grad_min(model: nn.Module) -> Dict[str, float]: + """ Find the minimum absolute gradient value for each parameter. """ + min_vals = {} + for name, param in model.named_parameters(): + min_vals[name] = float('inf') + if param.grad: + min_val = param.grad.data.min().item() + min_vals[name] = min_val + + return min_vals + +def grad_mean(model: nn.Module) -> Dict[str, float]: + """ Compute mean of gradients for each parameter. """ + means = {} + for name, param in model.named_parameters(): + means[name] = 0.0 + if param.grad: + mean = param.grad.data.mean().item() + means[name] = mean + + return means + +def grad_nan_num(model: nn.Module) -> Dict[str, int]: + """ Count the number of NaNs in gradients for each parameter. """ + nan_nums = {} + for name, param in model.named_parameters(): + nan_nums[name] = 0 + if param.grad: + nan_num = param.grad.isnan().sum().item() + nan_nums[name] = nan_num + return nan_nums \ No newline at end of file diff --git a/khaosz/trainer/train_callback.py b/khaosz/trainer/train_callback.py index dcb868e..c83921f 100644 --- a/khaosz/trainer/train_callback.py +++ b/khaosz/trainer/train_callback.py @@ -1,9 +1,22 @@ import os +import json +import time + +from pathlib import Path from tqdm import tqdm from torch.nn.utils import clip_grad_norm_ from torch.optim.lr_scheduler import LambdaLR -from typing import Optional, Protocol, TYPE_CHECKING +from typing import List, Optional, Protocol, TYPE_CHECKING + from khaosz.trainer.strategy import ScheduleConfig, SchedulerFactory +from khaosz.trainer.metric_util import ( + grad_max, + grad_min, + grad_norm, + grad_mean, + grad_std, + grad_nan_num +) if TYPE_CHECKING: from khaosz.trainer.trainer import Trainer @@ -38,60 +51,9 @@ class TrainCallback(Protocol): def on_batch_end(self, trainer: 'Trainer', context: 'TrainContext'): """ Called at the end of each batch. """ - - -class ProgressBarCallback(TrainCallback): - """ - Progress bar callback for trainer. - """ - def __init__(self): - self.progress_bar: tqdm = None - def on_epoch_begin(self, trainer: 'Trainer', context: 'TrainContext'): - self.progress_bar = tqdm( - context.dataloader, - desc=f"Epoch {context.epoch+1}/{trainer.train_config.n_epoch}", - dynamic_ncols=True - ) - - def on_batch_end(self, trainer: 'Trainer', context: 'TrainContext'): - _ = trainer - self.progress_bar.set_postfix({ - "loss": f"{context.loss:.4f}", - "lr": f"{context.optimizer.param_groups[-1]['lr']:.2e}" - }) - self.progress_bar.update(1) - - def on_epoch_end(self, trainer: 'Trainer', context: 'TrainContext'): - _ = trainer, context - if self.progress_bar: - self.progress_bar.close() - - -class CheckpointCallback(TrainCallback): - """ - Checkpoint callback for trainer. - """ - def __init__(self, checkpoint_interval: int): - self.checkpoint_interval = checkpoint_interval - self.last_ckpt_iter = 0 - - def _save_checkpoint(self, trainer: 'Trainer', context: 'TrainContext'): - save_path = os.path.join(trainer.train_config.checkpoint_dir, f"iter_{context.current_iter}") - context.checkpoint.sampler_state = context.sampler.state_dict() - context.checkpoint.optimizer_state = context.optimizer.state_dict() - context.checkpoint.save(save_path) - self.last_ckpt_iter = context.current_iter - - def on_batch_end(self, trainer: 'Trainer', context: 'TrainContext'): - context.checkpoint.loss_list.append(context.loss) - - if context.current_iter - self.last_ckpt_iter >= self.checkpoint_interval: - self._save_checkpoint(trainer, context) - - def on_train_end(self, trainer: 'Trainer', context: 'TrainContext'): - if context.current_iter != self.last_ckpt_iter: - self._save_checkpoint(trainer, context) + def on_error(self, trainer: 'Trainer', context: 'TrainContext'): + """ Called when an error occurs during training. """ class GradientClippingCallback(TrainCallback): @@ -132,3 +94,137 @@ class SchedulerCallback(TrainCallback): _ = trainer, context if self.scheduler: self.scheduler.step() + + +class CheckpointCallback(TrainCallback): + """ + Checkpoint callback for trainer. + """ + def __init__(self, checkpoint_interval: int): + self.checkpoint_interval = checkpoint_interval + self.last_ckpt_iter = 0 + + def _save_checkpoint(self, trainer: 'Trainer', context: 'TrainContext'): + save_path = os.path.join(trainer.train_config.checkpoint_dir, f"iter_{context.current_iter}") + context.checkpoint.sampler_state = context.sampler.state_dict() + context.checkpoint.optimizer_state = context.optimizer.state_dict() + context.checkpoint.save(save_path) + self.last_ckpt_iter = context.current_iter + + def on_batch_end(self, trainer: 'Trainer', context: 'TrainContext'): + context.checkpoint.loss_list.append(context.loss) + + if context.current_iter - self.last_ckpt_iter >= self.checkpoint_interval: + self._save_checkpoint(trainer, context) + + def on_train_end(self, trainer: 'Trainer', context: 'TrainContext'): + if context.current_iter != self.last_ckpt_iter: + self._save_checkpoint(trainer, context) + + +class ProgressBarCallback(TrainCallback): + """ + Progress bar callback for trainer. + """ + def __init__(self): + self.progress_bar: tqdm = None + + def on_epoch_begin(self, trainer: 'Trainer', context: 'TrainContext'): + self.progress_bar = tqdm( + context.dataloader, + desc=f"Epoch {context.epoch+1}/{trainer.train_config.n_epoch}", + dynamic_ncols=True + ) + + def on_batch_end(self, trainer: 'Trainer', context: 'TrainContext'): + _ = trainer + self.progress_bar.set_postfix({ + "loss": f"{context.loss:.4f}", + "lr": f"{context.optimizer.param_groups[-1]['lr']:.2e}" + }) + self.progress_bar.update(1) + + def on_epoch_end(self, trainer: 'Trainer', context: 'TrainContext'): + _ = trainer, context + if self.progress_bar: + self.progress_bar.close() + + +class StepMonitorCallback(TrainCallback): + """ + Customizable logger callback for trainer. + + This callback provides flexible logging capabilities for training metrics, + supporting multiple log formats and custom log handlers. + """ + + def __init__( + self, + log_dir: Optional[str] = None, + log_interval: int = 100, + metrics: Optional[List[str]] = None + ): + """ + Args: + log_dir: Directory to save log files. If None, logs won't be saved to file. + log_interval: Log every N steps + metrics: List of metrics to log. Supported: ['loss', 'lr', 'grad_norm', 'grad_std', grad_max', 'grad_min', 'grad_mean', 'grad_nan_num'] + custom_handlers: List of custom log handler functions + json_log: Whether to save logs in JSON format + """ + + self.log_dir = Path(log_dir) if log_dir else Path(os.getcwd()) / "logs" + self.log_interval = log_interval + self.metrics = metrics or ['loss', 'lr'] + self.step_num = 0 + + self.log_dir.mkdir(parents=True, exist_ok=True) + + def _handle_info(self, trainer: 'Trainer', context: 'TrainContext'): + """ Logs training information to console and file. """ + + log_data = { + "timestamp": time.strftime('%Y-%m-%d %H:%M:%S'), + "epoch": context.epoch, + "iter": context.current_iter, + "metrics": self.metrics, + } + + for metric in self.metrics: + if metric == 'loss': + log_data[metric] = context.loss + elif metric == 'lr': + log_data[metric] = context.optimizer.param_groups[-1]['lr'] + elif metric == 'grad_norm': + log_data[metric] = grad_norm(trainer.parameter.model) + elif metric == 'grad_std': + log_data[metric] = grad_std(trainer.parameter.model) + elif metric == 'grad_max': + log_data[metric] = grad_max(trainer.parameter.model) + elif metric == 'grad_min': + log_data[metric] = grad_min(trainer.parameter.model) + elif metric == 'grad_mean': + log_data[metric] = grad_mean(trainer.parameter.model) + elif metric == 'grad_nan_num': + log_data[metric] = grad_nan_num(trainer.parameter.model) + else: + raise ValueError(f"Invalid metric: {metric}") + + return log_data + + def _handle_log(self, trainer: 'Trainer', context: 'TrainContext'): + """ Logs training information to console and file. """ + log_data = self._handle_info(trainer, context) + try: + log_file = self.log_dir / f"log_epoch_{context.epoch}_iter_{context.current_iter}.json" + with open(log_file, 'a') as f: + json.dump(log_data, f, indent=4) + except Exception: + raise + + def on_step_end(self, trainer: 'Trainer', context: 'TrainContext'): + if self.step_num % self.log_interval == 0: + self._handle_log(trainer, context) + + self.step_num += 1 + \ No newline at end of file diff --git a/khaosz/trainer/trainer.py b/khaosz/trainer/trainer.py index fefa70f..c287358 100644 --- a/khaosz/trainer/trainer.py +++ b/khaosz/trainer/trainer.py @@ -86,6 +86,7 @@ class Trainer: except Exception as e: logger.error(f"Training failed: {str(e)}", exc_info=True) + self._call_callbacks('on_error', context) raise finally: self._call_callbacks('on_train_end', context)