chore: 修改RMSNorm 实现

This commit is contained in:
ViperEkura 2026-04-06 20:27:01 +08:00
parent 408f0cb513
commit b0eff02446
1 changed files with 1 additions and 2 deletions

View File

@ -123,8 +123,7 @@ class RMSNorm(nn.Module):
self.norm_eps = norm_eps self.norm_eps = norm_eps
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
rms = F.rms_norm(x.float(), self.normalized_shape, self.weight, self.norm_eps) return F.rms_norm(x, self.normalized_shape, self.weight, self.norm_eps)
return rms.to(x.dtype)
class MLP(nn.Module): class MLP(nn.Module):