104 lines
2.6 KiB
Python
104 lines
2.6 KiB
Python
import torch.nn as nn
|
|
from typing import Dict
|
|
|
|
|
|
def grad_norm(model: nn.Module, norm_type: int = 2) -> Dict[str, float]:
|
|
"""Compute gradient norm for each parameter in the model."""
|
|
norms = {}
|
|
for name, param in model.named_parameters():
|
|
norms[name] = 0.0
|
|
if param.grad:
|
|
norm = param.grad.data.norm(norm_type).item()
|
|
norms[name] = norm
|
|
return norms
|
|
|
|
|
|
def grad_std(model: nn.Module) -> Dict[str, float]:
|
|
"""Compute standard deviation of gradients for each parameter."""
|
|
stds = {}
|
|
for name, param in model.named_parameters():
|
|
stds[name] = 0.0
|
|
if param.grad:
|
|
std = param.grad.data.std().item()
|
|
stds[name] = std
|
|
return stds
|
|
|
|
|
|
def grad_max(model: nn.Module) -> Dict[str, float]:
|
|
"""Find the maximum absolute gradient value for each parameter."""
|
|
max_vals = {}
|
|
for name, param in model.named_parameters():
|
|
max_vals[name] = -float("inf")
|
|
if param.grad:
|
|
max_val = param.grad.data.max().item()
|
|
max_vals[name] = max_val
|
|
|
|
return max_vals
|
|
|
|
|
|
def grad_min(model: nn.Module) -> Dict[str, float]:
|
|
"""Find the minimum absolute gradient value for each parameter."""
|
|
min_vals = {}
|
|
for name, param in model.named_parameters():
|
|
min_vals[name] = float("inf")
|
|
if param.grad:
|
|
min_val = param.grad.data.min().item()
|
|
min_vals[name] = min_val
|
|
|
|
return min_vals
|
|
|
|
|
|
def grad_mean(model: nn.Module) -> Dict[str, float]:
|
|
"""Compute mean of gradients for each parameter."""
|
|
means = {}
|
|
for name, param in model.named_parameters():
|
|
means[name] = 0.0
|
|
if param.grad:
|
|
mean = param.grad.data.mean().item()
|
|
means[name] = mean
|
|
|
|
return means
|
|
|
|
|
|
def grad_nan_num(model: nn.Module) -> Dict[str, int]:
|
|
"""Count the number of NaNs in gradients for each parameter."""
|
|
nan_nums = {}
|
|
for name, param in model.named_parameters():
|
|
nan_nums[name] = 0
|
|
if param.grad:
|
|
nan_num = param.grad.isnan().sum().item()
|
|
nan_nums[name] = nan_num
|
|
return nan_nums
|
|
|
|
|
|
def ctx_get_loss(ctx):
|
|
return ctx.loss
|
|
|
|
|
|
def ctx_get_lr(ctx):
|
|
return ctx.optimizer.param_groups[-1]["lr"]
|
|
|
|
|
|
def ctx_get_grad_norm(ctx):
|
|
return grad_norm(ctx.model)
|
|
|
|
|
|
def ctx_get_grad_std(ctx):
|
|
return grad_std(ctx.model)
|
|
|
|
|
|
def ctx_get_grad_max(ctx):
|
|
return grad_max(ctx.model)
|
|
|
|
|
|
def ctx_get_grad_min(ctx):
|
|
return grad_min(ctx.model)
|
|
|
|
|
|
def ctx_get_grad_mean(ctx):
|
|
return grad_mean(ctx.model)
|
|
|
|
|
|
def ctx_get_grad_nan_num(ctx):
|
|
return grad_nan_num(ctx.model)
|