diff --git a/khaosz/trainer/metric_util.py b/khaosz/trainer/metric_util.py index 8710f11..5d425f4 100644 --- a/khaosz/trainer/metric_util.py +++ b/khaosz/trainer/metric_util.py @@ -62,4 +62,28 @@ def grad_nan_num(model: nn.Module) -> Dict[str, int]: if param.grad: nan_num = param.grad.isnan().sum().item() nan_nums[name] = nan_num - return nan_nums \ No newline at end of file + return nan_nums + +def ctx_get_loss(ctx): + return ctx.loss + +def ctx_get_lr(ctx): + return ctx.optimizer.param_groups[-1]['lr'] + +def ctx_get_grad_norm(ctx): + return grad_norm(ctx.model) + +def ctx_get_grad_std(ctx): + return grad_std(ctx.model) + +def ctx_get_grad_max(ctx): + return grad_max(ctx.model) + +def ctx_get_grad_min(ctx): + return grad_min(ctx.model) + +def ctx_get_grad_mean(ctx): + return grad_mean(ctx.model) + +def ctx_get_grad_nan_num(ctx): + return grad_nan_num(ctx.model) \ No newline at end of file diff --git a/khaosz/trainer/train_callback.py b/khaosz/trainer/train_callback.py index b11c32d..6a161a8 100644 --- a/khaosz/trainer/train_callback.py +++ b/khaosz/trainer/train_callback.py @@ -11,12 +11,14 @@ from typing import Callable, List, Optional, Protocol from khaosz.parallel import only_on_rank from khaosz.trainer.metric_util import ( - grad_max, - grad_min, - grad_norm, - grad_mean, - grad_std, - grad_nan_num + ctx_get_loss, + ctx_get_lr, + ctx_get_grad_max, + ctx_get_grad_min, + ctx_get_grad_norm, + ctx_get_grad_mean, + ctx_get_grad_std, + ctx_get_grad_nan_num ) from khaosz.data.checkpoint import Checkpoint from khaosz.trainer.train_context import TrainContext @@ -181,40 +183,16 @@ class StepMonitorCallback(TrainCallback): self.log_cache = [] self._metric_funcs = { - 'loss': self._get_loss, - 'lr': self._get_lr, - 'grad_norm': self._get_grad_norm, - 'grad_std': self._get_grad_std, - 'grad_max': self._get_grad_max, - 'grad_min': self._get_grad_min, - 'grad_mean': self._get_grad_mean, - 'grad_nan_num': self._get_grad_nan_num + 'loss': ctx_get_loss, + 'lr': ctx_get_lr, + 'grad_norm': ctx_get_grad_norm, + 'grad_std': ctx_get_grad_std, + 'grad_max': ctx_get_grad_max, + 'grad_min': ctx_get_grad_min, + 'grad_mean': ctx_get_grad_mean, + 'grad_nan_num': ctx_get_grad_nan_num } - def _get_loss(self, ctx): - return ctx.loss - - def _get_lr(self, ctx): - return ctx.optimizer.param_groups[-1]['lr'] - - def _get_grad_norm(self, ctx): - return grad_norm(ctx.model) - - def _get_grad_std(self, ctx): - return grad_std(ctx.model) - - def _get_grad_max(self, ctx): - return grad_max(ctx.model) - - def _get_grad_min(self, ctx): - return grad_min(ctx.model) - - def _get_grad_mean(self, ctx): - return grad_mean(ctx.model) - - def _get_grad_nan_num(self, ctx): - return grad_nan_num(ctx.model) - def _get_log_data(self, context: TrainContext): return { "timestamp": time.strftime('%Y-%m-%d %H:%M:%S'),