From 64b78ecce315cab7c05171e12b8e804ff67d9f65 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Mon, 6 Apr 2026 13:29:39 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E5=A2=9E=E5=8A=A0=E6=97=8B=E8=BD=AC?= =?UTF-8?q?=E4=BD=8D=E7=BD=AE=E7=BC=96=E7=A0=81=E6=89=A9=E5=B1=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrai/model/module.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/astrai/model/module.py b/astrai/model/module.py index 28c64f6..fd76247 100644 --- a/astrai/model/module.py +++ b/astrai/model/module.py @@ -30,6 +30,7 @@ def get_rotary_emb( dim: int, max_len: int, base: float = 10000, + device: Optional[torch.device] = None, ) -> Tuple[Tensor, Tensor]: """ Get the rotary embedding for the given dimension and maximum length. @@ -37,12 +38,13 @@ def get_rotary_emb( dim (int): The dimension of the input. max_len (int): The maximum length of the input. base (float, optional): The base for the frequency. Defaults to 10000. + device (optional): The device to create tensors on. Defaults to None. Returns: Tensor: The rotary embedding tensor. """ - theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64) / dim) - t = torch.arange(0, max_len, dtype=torch.float64) + theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim) + t = torch.arange(0, max_len, dtype=torch.float64, device=device) freqs = torch.outer(t, theta) return torch.cos(freqs).float(), torch.sin(freqs).float() @@ -83,10 +85,10 @@ class RotaryEmbedding(nn.Module): self.max_len = max_len self.base = base self.max_len_cached = None - self._set_rotary_buffer(self.max_len) + self._set_rotary_buffer(self.max_len, None) - def _set_rotary_buffer(self, max_len: int): - cos_cached, sin_cached = get_rotary_emb(self.dim, max_len, self.base) + def _set_rotary_buffer(self, max_len: int, device: Optional[torch.device] = None): + cos_cached, sin_cached = get_rotary_emb(self.dim, max_len, self.base, device) self.register_buffer("cos_cached", cos_cached, persistent=False) self.register_buffer("sin_cached", sin_cached, persistent=False) self.max_len_cached = max_len @@ -95,7 +97,7 @@ class RotaryEmbedding(nn.Module): seq_len = x.size(1) if self.max_len_cached < seq_len + start_pos: - self._set_rotary_buffer(seq_len + start_pos) + self._set_rotary_buffer(self.max_len_cached * 2, x.device) cos = self.cos_cached[start_pos : start_pos + seq_len] sin = self.sin_cached[start_pos : start_pos + seq_len]