diff --git a/astrai/model/module.py b/astrai/model/module.py index 28c64f6..fd76247 100644 --- a/astrai/model/module.py +++ b/astrai/model/module.py @@ -30,6 +30,7 @@ def get_rotary_emb( dim: int, max_len: int, base: float = 10000, + device: Optional[torch.device] = None, ) -> Tuple[Tensor, Tensor]: """ Get the rotary embedding for the given dimension and maximum length. @@ -37,12 +38,13 @@ def get_rotary_emb( dim (int): The dimension of the input. max_len (int): The maximum length of the input. base (float, optional): The base for the frequency. Defaults to 10000. + device (optional): The device to create tensors on. Defaults to None. Returns: Tensor: The rotary embedding tensor. """ - theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64) / dim) - t = torch.arange(0, max_len, dtype=torch.float64) + theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim) + t = torch.arange(0, max_len, dtype=torch.float64, device=device) freqs = torch.outer(t, theta) return torch.cos(freqs).float(), torch.sin(freqs).float() @@ -83,10 +85,10 @@ class RotaryEmbedding(nn.Module): self.max_len = max_len self.base = base self.max_len_cached = None - self._set_rotary_buffer(self.max_len) + self._set_rotary_buffer(self.max_len, None) - def _set_rotary_buffer(self, max_len: int): - cos_cached, sin_cached = get_rotary_emb(self.dim, max_len, self.base) + def _set_rotary_buffer(self, max_len: int, device: Optional[torch.device] = None): + cos_cached, sin_cached = get_rotary_emb(self.dim, max_len, self.base, device) self.register_buffer("cos_cached", cos_cached, persistent=False) self.register_buffer("sin_cached", sin_cached, persistent=False) self.max_len_cached = max_len @@ -95,7 +97,7 @@ class RotaryEmbedding(nn.Module): seq_len = x.size(1) if self.max_len_cached < seq_len + start_pos: - self._set_rotary_buffer(seq_len + start_pos) + self._set_rotary_buffer(self.max_len_cached * 2, x.device) cos = self.cos_cached[start_pos : start_pos + seq_len] sin = self.sin_cached[start_pos : start_pos + seq_len]