diff --git a/khaosz/model/module.py b/khaosz/model/module.py index ec73cc9..748c53e 100644 --- a/khaosz/model/module.py +++ b/khaosz/model/module.py @@ -118,16 +118,12 @@ class RMSNorm(nn.Module): def __init__(self, n_dim, norm_eps): super().__init__() self.weight = nn.Parameter(torch.ones(n_dim)) + self.normalized_shape = (n_dim, ) self.norm_eps = norm_eps def forward(self, x: Tensor) -> Tensor: - dtype = x.dtype - x = x.float() - mean_square = torch.mean(torch.pow(x, 2), dim=-1, keepdim=True) - norm = x * torch.rsqrt(mean_square + self.norm_eps) - norm = norm.to(dtype) - out = norm * self.weight - return out + rms = F.rms_norm(x.float(), self.normalized_shape, self.weight, self.norm_eps) + return rms.to(x.dtype) class MLP(nn.Module):