feat(model): 实现旋转位置编码缓存动态扩展
This commit is contained in:
parent
254ec934be
commit
d25202a329
|
|
@ -77,15 +77,26 @@ def apply_rotary_emb(x: torch.Tensor, rotary_emb: Tuple[Tensor, Tensor]) -> Tens
|
|||
class RotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim: int, max_len: int, base: int=10000):
|
||||
super().__init__()
|
||||
cos_emb, sin_emb = get_rotary_emb(dim, max_len, base)
|
||||
self.register_buffer("cos_emb", cos_emb, persistent=False)
|
||||
self.register_buffer("sin_emb", sin_emb, persistent=False)
|
||||
self._rotary_buffers = {"cos_emb", "sin_emb"}
|
||||
self.dim = dim
|
||||
self.max_len = max_len
|
||||
self.base = base
|
||||
self.max_len_cached = None
|
||||
self._set_rotary_buffer(self.max_len)
|
||||
|
||||
def _set_rotary_buffer(self, max_len: int):
|
||||
cos_cached, sin_cached = get_rotary_emb(self.dim, max_len, self.base)
|
||||
self.register_buffer("cos_cached", cos_cached, persistent=False)
|
||||
self.register_buffer("sin_cached", sin_cached, persistent=False)
|
||||
self.max_len_cached = max_len
|
||||
|
||||
def forward(self, x: Tensor, start_pos: int=0) -> Tuple[Tensor, Tensor]:
|
||||
seq_len = x.size(1)
|
||||
cos = self.cos_emb[start_pos : start_pos + seq_len]
|
||||
sin = self.sin_emb[start_pos : start_pos + seq_len]
|
||||
|
||||
if self.max_len_cached < seq_len + start_pos:
|
||||
self._set_rotary_buffer(seq_len)
|
||||
|
||||
cos = self.cos_cached[start_pos : start_pos + seq_len]
|
||||
sin = self.sin_cached[start_pos : start_pos + seq_len]
|
||||
|
||||
return (cos, sin)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue