From 1c3a693d796498f2c12fb1a0b93ecc8036750380 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sat, 15 Nov 2025 13:54:04 +0800 Subject: [PATCH] =?UTF-8?q?feat(model):=20=E4=BC=98=E5=8C=96RMSNorm?= =?UTF-8?q?=E5=AE=9E=E7=8E=B0=E6=96=B9=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/model/module.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/khaosz/model/module.py b/khaosz/model/module.py index ec73cc9..748c53e 100644 --- a/khaosz/model/module.py +++ b/khaosz/model/module.py @@ -118,16 +118,12 @@ class RMSNorm(nn.Module): def __init__(self, n_dim, norm_eps): super().__init__() self.weight = nn.Parameter(torch.ones(n_dim)) + self.normalized_shape = (n_dim, ) self.norm_eps = norm_eps def forward(self, x: Tensor) -> Tensor: - dtype = x.dtype - x = x.float() - mean_square = torch.mean(torch.pow(x, 2), dim=-1, keepdim=True) - norm = x * torch.rsqrt(mean_square + self.norm_eps) - norm = norm.to(dtype) - out = norm * self.weight - return out + rms = F.rms_norm(x.float(), self.normalized_shape, self.weight, self.norm_eps) + return rms.to(x.dtype) class MLP(nn.Module):