fix: 增加旋转位置编码扩展
This commit is contained in:
parent
f2ffdf60d0
commit
64b78ecce3
|
|
@ -30,6 +30,7 @@ def get_rotary_emb(
|
|||
dim: int,
|
||||
max_len: int,
|
||||
base: float = 10000,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Get the rotary embedding for the given dimension and maximum length.
|
||||
|
|
@ -37,12 +38,13 @@ def get_rotary_emb(
|
|||
dim (int): The dimension of the input.
|
||||
max_len (int): The maximum length of the input.
|
||||
base (float, optional): The base for the frequency. Defaults to 10000.
|
||||
device (optional): The device to create tensors on. Defaults to None.
|
||||
Returns:
|
||||
Tensor: The rotary embedding tensor.
|
||||
"""
|
||||
|
||||
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64) / dim)
|
||||
t = torch.arange(0, max_len, dtype=torch.float64)
|
||||
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim)
|
||||
t = torch.arange(0, max_len, dtype=torch.float64, device=device)
|
||||
freqs = torch.outer(t, theta)
|
||||
|
||||
return torch.cos(freqs).float(), torch.sin(freqs).float()
|
||||
|
|
@ -83,10 +85,10 @@ class RotaryEmbedding(nn.Module):
|
|||
self.max_len = max_len
|
||||
self.base = base
|
||||
self.max_len_cached = None
|
||||
self._set_rotary_buffer(self.max_len)
|
||||
self._set_rotary_buffer(self.max_len, None)
|
||||
|
||||
def _set_rotary_buffer(self, max_len: int):
|
||||
cos_cached, sin_cached = get_rotary_emb(self.dim, max_len, self.base)
|
||||
def _set_rotary_buffer(self, max_len: int, device: Optional[torch.device] = None):
|
||||
cos_cached, sin_cached = get_rotary_emb(self.dim, max_len, self.base, device)
|
||||
self.register_buffer("cos_cached", cos_cached, persistent=False)
|
||||
self.register_buffer("sin_cached", sin_cached, persistent=False)
|
||||
self.max_len_cached = max_len
|
||||
|
|
@ -95,7 +97,7 @@ class RotaryEmbedding(nn.Module):
|
|||
seq_len = x.size(1)
|
||||
|
||||
if self.max_len_cached < seq_len + start_pos:
|
||||
self._set_rotary_buffer(seq_len + start_pos)
|
||||
self._set_rotary_buffer(self.max_len_cached * 2, x.device)
|
||||
|
||||
cos = self.cos_cached[start_pos : start_pos + seq_len]
|
||||
sin = self.sin_cached[start_pos : start_pos + seq_len]
|
||||
|
|
|
|||
Loading…
Reference in New Issue