feat(model): 优化RMSNorm实现方式
This commit is contained in:
parent
e99ef9d6d8
commit
1c3a693d79
|
|
@ -118,16 +118,12 @@ class RMSNorm(nn.Module):
|
|||
def __init__(self, n_dim, norm_eps):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(n_dim))
|
||||
self.normalized_shape = (n_dim, )
|
||||
self.norm_eps = norm_eps
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
dtype = x.dtype
|
||||
x = x.float()
|
||||
mean_square = torch.mean(torch.pow(x, 2), dim=-1, keepdim=True)
|
||||
norm = x * torch.rsqrt(mean_square + self.norm_eps)
|
||||
norm = norm.to(dtype)
|
||||
out = norm * self.weight
|
||||
return out
|
||||
rms = F.rms_norm(x.float(), self.normalized_shape, self.weight, self.norm_eps)
|
||||
return rms.to(x.dtype)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
|
|
|
|||
Loading…
Reference in New Issue