feat(model): 实现旋转位置编码缓存动态扩展

This commit is contained in:
ViperEkura 2025-11-09 14:35:29 +08:00
parent 254ec934be
commit d25202a329
1 changed files with 17 additions and 6 deletions

View File

@ -77,15 +77,26 @@ def apply_rotary_emb(x: torch.Tensor, rotary_emb: Tuple[Tensor, Tensor]) -> Tens
class RotaryEmbedding(nn.Module):
def __init__(self, dim: int, max_len: int, base: int=10000):
super().__init__()
cos_emb, sin_emb = get_rotary_emb(dim, max_len, base)
self.register_buffer("cos_emb", cos_emb, persistent=False)
self.register_buffer("sin_emb", sin_emb, persistent=False)
self._rotary_buffers = {"cos_emb", "sin_emb"}
self.dim = dim
self.max_len = max_len
self.base = base
self.max_len_cached = None
self._set_rotary_buffer(self.max_len)
def _set_rotary_buffer(self, max_len: int):
cos_cached, sin_cached = get_rotary_emb(self.dim, max_len, self.base)
self.register_buffer("cos_cached", cos_cached, persistent=False)
self.register_buffer("sin_cached", sin_cached, persistent=False)
self.max_len_cached = max_len
def forward(self, x: Tensor, start_pos: int=0) -> Tuple[Tensor, Tensor]:
seq_len = x.size(1)
cos = self.cos_emb[start_pos : start_pos + seq_len]
sin = self.sin_emb[start_pos : start_pos + seq_len]
if self.max_len_cached < seq_len + start_pos:
self._set_rotary_buffer(seq_len)
cos = self.cos_cached[start_pos : start_pos + seq_len]
sin = self.sin_cached[start_pos : start_pos + seq_len]
return (cos, sin)