refactor: 修改metric_util.py
This commit is contained in:
parent
2331713fde
commit
96744ac2d2
|
|
@ -63,3 +63,27 @@ def grad_nan_num(model: nn.Module) -> Dict[str, int]:
|
||||||
nan_num = param.grad.isnan().sum().item()
|
nan_num = param.grad.isnan().sum().item()
|
||||||
nan_nums[name] = nan_num
|
nan_nums[name] = nan_num
|
||||||
return nan_nums
|
return nan_nums
|
||||||
|
|
||||||
|
def ctx_get_loss(ctx):
|
||||||
|
return ctx.loss
|
||||||
|
|
||||||
|
def ctx_get_lr(ctx):
|
||||||
|
return ctx.optimizer.param_groups[-1]['lr']
|
||||||
|
|
||||||
|
def ctx_get_grad_norm(ctx):
|
||||||
|
return grad_norm(ctx.model)
|
||||||
|
|
||||||
|
def ctx_get_grad_std(ctx):
|
||||||
|
return grad_std(ctx.model)
|
||||||
|
|
||||||
|
def ctx_get_grad_max(ctx):
|
||||||
|
return grad_max(ctx.model)
|
||||||
|
|
||||||
|
def ctx_get_grad_min(ctx):
|
||||||
|
return grad_min(ctx.model)
|
||||||
|
|
||||||
|
def ctx_get_grad_mean(ctx):
|
||||||
|
return grad_mean(ctx.model)
|
||||||
|
|
||||||
|
def ctx_get_grad_nan_num(ctx):
|
||||||
|
return grad_nan_num(ctx.model)
|
||||||
|
|
@ -11,12 +11,14 @@ from typing import Callable, List, Optional, Protocol
|
||||||
|
|
||||||
from khaosz.parallel import only_on_rank
|
from khaosz.parallel import only_on_rank
|
||||||
from khaosz.trainer.metric_util import (
|
from khaosz.trainer.metric_util import (
|
||||||
grad_max,
|
ctx_get_loss,
|
||||||
grad_min,
|
ctx_get_lr,
|
||||||
grad_norm,
|
ctx_get_grad_max,
|
||||||
grad_mean,
|
ctx_get_grad_min,
|
||||||
grad_std,
|
ctx_get_grad_norm,
|
||||||
grad_nan_num
|
ctx_get_grad_mean,
|
||||||
|
ctx_get_grad_std,
|
||||||
|
ctx_get_grad_nan_num
|
||||||
)
|
)
|
||||||
from khaosz.data.checkpoint import Checkpoint
|
from khaosz.data.checkpoint import Checkpoint
|
||||||
from khaosz.trainer.train_context import TrainContext
|
from khaosz.trainer.train_context import TrainContext
|
||||||
|
|
@ -181,40 +183,16 @@ class StepMonitorCallback(TrainCallback):
|
||||||
self.log_cache = []
|
self.log_cache = []
|
||||||
|
|
||||||
self._metric_funcs = {
|
self._metric_funcs = {
|
||||||
'loss': self._get_loss,
|
'loss': ctx_get_loss,
|
||||||
'lr': self._get_lr,
|
'lr': ctx_get_lr,
|
||||||
'grad_norm': self._get_grad_norm,
|
'grad_norm': ctx_get_grad_norm,
|
||||||
'grad_std': self._get_grad_std,
|
'grad_std': ctx_get_grad_std,
|
||||||
'grad_max': self._get_grad_max,
|
'grad_max': ctx_get_grad_max,
|
||||||
'grad_min': self._get_grad_min,
|
'grad_min': ctx_get_grad_min,
|
||||||
'grad_mean': self._get_grad_mean,
|
'grad_mean': ctx_get_grad_mean,
|
||||||
'grad_nan_num': self._get_grad_nan_num
|
'grad_nan_num': ctx_get_grad_nan_num
|
||||||
}
|
}
|
||||||
|
|
||||||
def _get_loss(self, ctx):
|
|
||||||
return ctx.loss
|
|
||||||
|
|
||||||
def _get_lr(self, ctx):
|
|
||||||
return ctx.optimizer.param_groups[-1]['lr']
|
|
||||||
|
|
||||||
def _get_grad_norm(self, ctx):
|
|
||||||
return grad_norm(ctx.model)
|
|
||||||
|
|
||||||
def _get_grad_std(self, ctx):
|
|
||||||
return grad_std(ctx.model)
|
|
||||||
|
|
||||||
def _get_grad_max(self, ctx):
|
|
||||||
return grad_max(ctx.model)
|
|
||||||
|
|
||||||
def _get_grad_min(self, ctx):
|
|
||||||
return grad_min(ctx.model)
|
|
||||||
|
|
||||||
def _get_grad_mean(self, ctx):
|
|
||||||
return grad_mean(ctx.model)
|
|
||||||
|
|
||||||
def _get_grad_nan_num(self, ctx):
|
|
||||||
return grad_nan_num(ctx.model)
|
|
||||||
|
|
||||||
def _get_log_data(self, context: TrainContext):
|
def _get_log_data(self, context: TrainContext):
|
||||||
return {
|
return {
|
||||||
"timestamp": time.strftime('%Y-%m-%d %H:%M:%S'),
|
"timestamp": time.strftime('%Y-%m-%d %H:%M:%S'),
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue