diff --git a/khaosz/model/module.py b/khaosz/model/module.py index 154547b..ec73cc9 100644 --- a/khaosz/model/module.py +++ b/khaosz/model/module.py @@ -77,15 +77,26 @@ def apply_rotary_emb(x: torch.Tensor, rotary_emb: Tuple[Tensor, Tensor]) -> Tens class RotaryEmbedding(nn.Module): def __init__(self, dim: int, max_len: int, base: int=10000): super().__init__() - cos_emb, sin_emb = get_rotary_emb(dim, max_len, base) - self.register_buffer("cos_emb", cos_emb, persistent=False) - self.register_buffer("sin_emb", sin_emb, persistent=False) - self._rotary_buffers = {"cos_emb", "sin_emb"} + self.dim = dim + self.max_len = max_len + self.base = base + self.max_len_cached = None + self._set_rotary_buffer(self.max_len) + + def _set_rotary_buffer(self, max_len: int): + cos_cached, sin_cached = get_rotary_emb(self.dim, max_len, self.base) + self.register_buffer("cos_cached", cos_cached, persistent=False) + self.register_buffer("sin_cached", sin_cached, persistent=False) + self.max_len_cached = max_len def forward(self, x: Tensor, start_pos: int=0) -> Tuple[Tensor, Tensor]: seq_len = x.size(1) - cos = self.cos_emb[start_pos : start_pos + seq_len] - sin = self.sin_emb[start_pos : start_pos + seq_len] + + if self.max_len_cached < seq_len + start_pos: + self._set_rotary_buffer(seq_len) + + cos = self.cos_cached[start_pos : start_pos + seq_len] + sin = self.sin_cached[start_pos : start_pos + seq_len] return (cos, sin)