From b0eff024461ab6b622ee1a6c58ac220f3fc32cd6 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Mon, 6 Apr 2026 20:27:01 +0800 Subject: [PATCH] =?UTF-8?q?chore:=20=E4=BF=AE=E6=94=B9RMSNorm=20=E5=AE=9E?= =?UTF-8?q?=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrai/model/module.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/astrai/model/module.py b/astrai/model/module.py index fd76247..d3696c8 100644 --- a/astrai/model/module.py +++ b/astrai/model/module.py @@ -123,8 +123,7 @@ class RMSNorm(nn.Module): self.norm_eps = norm_eps def forward(self, x: Tensor) -> Tensor: - rms = F.rms_norm(x.float(), self.normalized_shape, self.weight, self.norm_eps) - return rms.to(x.dtype) + return F.rms_norm(x, self.normalized_shape, self.weight, self.norm_eps) class MLP(nn.Module):