From c7d044882281570edbda279e41ee36b637e2487e Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Wed, 4 Mar 2026 20:38:07 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8DStepMonitorCallback?= =?UTF-8?q?=E5=BA=8F=E5=88=97=E5=8C=96=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/trainer/train_callback.py | 40 +++++++++++++++++++++++++------- 1 file changed, 32 insertions(+), 8 deletions(-) diff --git a/khaosz/trainer/train_callback.py b/khaosz/trainer/train_callback.py index d23dc1b..b11c32d 100644 --- a/khaosz/trainer/train_callback.py +++ b/khaosz/trainer/train_callback.py @@ -181,16 +181,40 @@ class StepMonitorCallback(TrainCallback): self.log_cache = [] self._metric_funcs = { - 'loss': lambda ctx: ctx.loss, - 'lr': lambda ctx: ctx.optimizer.param_groups[-1]['lr'], - 'grad_norm': lambda ctx: grad_norm(ctx.model), - 'grad_std': lambda ctx: grad_std(ctx.model), - 'grad_max': lambda ctx: grad_max(ctx.model), - 'grad_min': lambda ctx: grad_min(ctx.model), - 'grad_mean': lambda ctx: grad_mean(ctx.model), - 'grad_nan_num': lambda ctx: grad_nan_num(ctx.model) + 'loss': self._get_loss, + 'lr': self._get_lr, + 'grad_norm': self._get_grad_norm, + 'grad_std': self._get_grad_std, + 'grad_max': self._get_grad_max, + 'grad_min': self._get_grad_min, + 'grad_mean': self._get_grad_mean, + 'grad_nan_num': self._get_grad_nan_num } + + def _get_loss(self, ctx): + return ctx.loss + def _get_lr(self, ctx): + return ctx.optimizer.param_groups[-1]['lr'] + + def _get_grad_norm(self, ctx): + return grad_norm(ctx.model) + + def _get_grad_std(self, ctx): + return grad_std(ctx.model) + + def _get_grad_max(self, ctx): + return grad_max(ctx.model) + + def _get_grad_min(self, ctx): + return grad_min(ctx.model) + + def _get_grad_mean(self, ctx): + return grad_mean(ctx.model) + + def _get_grad_nan_num(self, ctx): + return grad_nan_num(ctx.model) + def _get_log_data(self, context: TrainContext): return { "timestamp": time.strftime('%Y-%m-%d %H:%M:%S'),