feat(model): 优化RMSNorm实现方式
This commit is contained in:
parent
e99ef9d6d8
commit
1c3a693d79
|
|
@ -118,16 +118,12 @@ class RMSNorm(nn.Module):
|
||||||
def __init__(self, n_dim, norm_eps):
|
def __init__(self, n_dim, norm_eps):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.weight = nn.Parameter(torch.ones(n_dim))
|
self.weight = nn.Parameter(torch.ones(n_dim))
|
||||||
|
self.normalized_shape = (n_dim, )
|
||||||
self.norm_eps = norm_eps
|
self.norm_eps = norm_eps
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
dtype = x.dtype
|
rms = F.rms_norm(x.float(), self.normalized_shape, self.weight, self.norm_eps)
|
||||||
x = x.float()
|
return rms.to(x.dtype)
|
||||||
mean_square = torch.mean(torch.pow(x, 2), dim=-1, keepdim=True)
|
|
||||||
norm = x * torch.rsqrt(mean_square + self.norm_eps)
|
|
||||||
norm = norm.to(dtype)
|
|
||||||
out = norm * self.weight
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class MLP(nn.Module):
|
class MLP(nn.Module):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue