fix: 增加旋转位置编码扩展
This commit is contained in:
parent
f2ffdf60d0
commit
64b78ecce3
|
|
@ -30,6 +30,7 @@ def get_rotary_emb(
|
||||||
dim: int,
|
dim: int,
|
||||||
max_len: int,
|
max_len: int,
|
||||||
base: float = 10000,
|
base: float = 10000,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
) -> Tuple[Tensor, Tensor]:
|
) -> Tuple[Tensor, Tensor]:
|
||||||
"""
|
"""
|
||||||
Get the rotary embedding for the given dimension and maximum length.
|
Get the rotary embedding for the given dimension and maximum length.
|
||||||
|
|
@ -37,12 +38,13 @@ def get_rotary_emb(
|
||||||
dim (int): The dimension of the input.
|
dim (int): The dimension of the input.
|
||||||
max_len (int): The maximum length of the input.
|
max_len (int): The maximum length of the input.
|
||||||
base (float, optional): The base for the frequency. Defaults to 10000.
|
base (float, optional): The base for the frequency. Defaults to 10000.
|
||||||
|
device (optional): The device to create tensors on. Defaults to None.
|
||||||
Returns:
|
Returns:
|
||||||
Tensor: The rotary embedding tensor.
|
Tensor: The rotary embedding tensor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64) / dim)
|
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim)
|
||||||
t = torch.arange(0, max_len, dtype=torch.float64)
|
t = torch.arange(0, max_len, dtype=torch.float64, device=device)
|
||||||
freqs = torch.outer(t, theta)
|
freqs = torch.outer(t, theta)
|
||||||
|
|
||||||
return torch.cos(freqs).float(), torch.sin(freqs).float()
|
return torch.cos(freqs).float(), torch.sin(freqs).float()
|
||||||
|
|
@ -83,10 +85,10 @@ class RotaryEmbedding(nn.Module):
|
||||||
self.max_len = max_len
|
self.max_len = max_len
|
||||||
self.base = base
|
self.base = base
|
||||||
self.max_len_cached = None
|
self.max_len_cached = None
|
||||||
self._set_rotary_buffer(self.max_len)
|
self._set_rotary_buffer(self.max_len, None)
|
||||||
|
|
||||||
def _set_rotary_buffer(self, max_len: int):
|
def _set_rotary_buffer(self, max_len: int, device: Optional[torch.device] = None):
|
||||||
cos_cached, sin_cached = get_rotary_emb(self.dim, max_len, self.base)
|
cos_cached, sin_cached = get_rotary_emb(self.dim, max_len, self.base, device)
|
||||||
self.register_buffer("cos_cached", cos_cached, persistent=False)
|
self.register_buffer("cos_cached", cos_cached, persistent=False)
|
||||||
self.register_buffer("sin_cached", sin_cached, persistent=False)
|
self.register_buffer("sin_cached", sin_cached, persistent=False)
|
||||||
self.max_len_cached = max_len
|
self.max_len_cached = max_len
|
||||||
|
|
@ -95,7 +97,7 @@ class RotaryEmbedding(nn.Module):
|
||||||
seq_len = x.size(1)
|
seq_len = x.size(1)
|
||||||
|
|
||||||
if self.max_len_cached < seq_len + start_pos:
|
if self.max_len_cached < seq_len + start_pos:
|
||||||
self._set_rotary_buffer(seq_len + start_pos)
|
self._set_rotary_buffer(self.max_len_cached * 2, x.device)
|
||||||
|
|
||||||
cos = self.cos_cached[start_pos : start_pos + seq_len]
|
cos = self.cos_cached[start_pos : start_pos + seq_len]
|
||||||
sin = self.sin_cached[start_pos : start_pos + seq_len]
|
sin = self.sin_cached[start_pos : start_pos + seq_len]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue