fix: 修复StepMonitorCallback序列化问题

This commit is contained in:
ViperEkura 2026-03-04 20:38:07 +08:00
parent 1d43a1785e
commit c7d0448822
1 changed files with 32 additions and 8 deletions

View File

@ -181,16 +181,40 @@ class StepMonitorCallback(TrainCallback):
self.log_cache = [] self.log_cache = []
self._metric_funcs = { self._metric_funcs = {
'loss': lambda ctx: ctx.loss, 'loss': self._get_loss,
'lr': lambda ctx: ctx.optimizer.param_groups[-1]['lr'], 'lr': self._get_lr,
'grad_norm': lambda ctx: grad_norm(ctx.model), 'grad_norm': self._get_grad_norm,
'grad_std': lambda ctx: grad_std(ctx.model), 'grad_std': self._get_grad_std,
'grad_max': lambda ctx: grad_max(ctx.model), 'grad_max': self._get_grad_max,
'grad_min': lambda ctx: grad_min(ctx.model), 'grad_min': self._get_grad_min,
'grad_mean': lambda ctx: grad_mean(ctx.model), 'grad_mean': self._get_grad_mean,
'grad_nan_num': lambda ctx: grad_nan_num(ctx.model) 'grad_nan_num': self._get_grad_nan_num
} }
def _get_loss(self, ctx):
return ctx.loss
def _get_lr(self, ctx):
return ctx.optimizer.param_groups[-1]['lr']
def _get_grad_norm(self, ctx):
return grad_norm(ctx.model)
def _get_grad_std(self, ctx):
return grad_std(ctx.model)
def _get_grad_max(self, ctx):
return grad_max(ctx.model)
def _get_grad_min(self, ctx):
return grad_min(ctx.model)
def _get_grad_mean(self, ctx):
return grad_mean(ctx.model)
def _get_grad_nan_num(self, ctx):
return grad_nan_num(ctx.model)
def _get_log_data(self, context: TrainContext): def _get_log_data(self, context: TrainContext):
return { return {
"timestamp": time.strftime('%Y-%m-%d %H:%M:%S'), "timestamp": time.strftime('%Y-%m-%d %H:%M:%S'),