fix: 增加旋转位置编码扩展

This commit is contained in:
ViperEkura 2026-04-06 13:29:39 +08:00
parent f2ffdf60d0
commit 64b78ecce3
1 changed files with 8 additions and 6 deletions

View File

@ -30,6 +30,7 @@ def get_rotary_emb(
dim: int,
max_len: int,
base: float = 10000,
device: Optional[torch.device] = None,
) -> Tuple[Tensor, Tensor]:
"""
Get the rotary embedding for the given dimension and maximum length.
@ -37,12 +38,13 @@ def get_rotary_emb(
dim (int): The dimension of the input.
max_len (int): The maximum length of the input.
base (float, optional): The base for the frequency. Defaults to 10000.
device (optional): The device to create tensors on. Defaults to None.
Returns:
Tensor: The rotary embedding tensor.
"""
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64) / dim)
t = torch.arange(0, max_len, dtype=torch.float64)
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim)
t = torch.arange(0, max_len, dtype=torch.float64, device=device)
freqs = torch.outer(t, theta)
return torch.cos(freqs).float(), torch.sin(freqs).float()
@ -83,10 +85,10 @@ class RotaryEmbedding(nn.Module):
self.max_len = max_len
self.base = base
self.max_len_cached = None
self._set_rotary_buffer(self.max_len)
self._set_rotary_buffer(self.max_len, None)
def _set_rotary_buffer(self, max_len: int):
cos_cached, sin_cached = get_rotary_emb(self.dim, max_len, self.base)
def _set_rotary_buffer(self, max_len: int, device: Optional[torch.device] = None):
cos_cached, sin_cached = get_rotary_emb(self.dim, max_len, self.base, device)
self.register_buffer("cos_cached", cos_cached, persistent=False)
self.register_buffer("sin_cached", sin_cached, persistent=False)
self.max_len_cached = max_len
@ -95,7 +97,7 @@ class RotaryEmbedding(nn.Module):
seq_len = x.size(1)
if self.max_len_cached < seq_len + start_pos:
self._set_rotary_buffer(seq_len + start_pos)
self._set_rotary_buffer(self.max_len_cached * 2, x.device)
cos = self.cos_cached[start_pos : start_pos + seq_len]
sin = self.sin_cached[start_pos : start_pos + seq_len]