chore: 修改RMSNorm 实现
This commit is contained in:
parent
408f0cb513
commit
b0eff02446
|
|
@ -123,8 +123,7 @@ class RMSNorm(nn.Module):
|
||||||
self.norm_eps = norm_eps
|
self.norm_eps = norm_eps
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
rms = F.rms_norm(x.float(), self.normalized_shape, self.weight, self.norm_eps)
|
return F.rms_norm(x, self.normalized_shape, self.weight, self.norm_eps)
|
||||||
return rms.to(x.dtype)
|
|
||||||
|
|
||||||
|
|
||||||
class MLP(nn.Module):
|
class MLP(nn.Module):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue