diff --git a/astrai/model/module.py b/astrai/model/module.py index fd76247..d3696c8 100644 --- a/astrai/model/module.py +++ b/astrai/model/module.py @@ -123,8 +123,7 @@ class RMSNorm(nn.Module): self.norm_eps = norm_eps def forward(self, x: Tensor) -> Tensor: - rms = F.rms_norm(x.float(), self.normalized_shape, self.weight, self.norm_eps) - return rms.to(x.dtype) + return F.rms_norm(x, self.normalized_shape, self.weight, self.norm_eps) class MLP(nn.Module):