feat(model): 优化RMSNorm实现方式

This commit is contained in:
ViperEkura 2025-11-15 13:54:04 +08:00
parent e99ef9d6d8
commit 1c3a693d79
1 changed files with 3 additions and 7 deletions

View File

@ -118,16 +118,12 @@ class RMSNorm(nn.Module):
def __init__(self, n_dim, norm_eps):
super().__init__()
self.weight = nn.Parameter(torch.ones(n_dim))
self.normalized_shape = (n_dim, )
self.norm_eps = norm_eps
def forward(self, x: Tensor) -> Tensor:
dtype = x.dtype
x = x.float()
mean_square = torch.mean(torch.pow(x, 2), dim=-1, keepdim=True)
norm = x * torch.rsqrt(mean_square + self.norm_eps)
norm = norm.to(dtype)
out = norm * self.weight
return out
rms = F.rms_norm(x.float(), self.normalized_shape, self.weight, self.norm_eps)
return rms.to(x.dtype)
class MLP(nn.Module):