From d25202a329d1c6ff1c32193153eed8f24dabce58 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sun, 9 Nov 2025 14:35:29 +0800 Subject: [PATCH] =?UTF-8?q?feat(model):=20=E5=AE=9E=E7=8E=B0=E6=97=8B?= =?UTF-8?q?=E8=BD=AC=E4=BD=8D=E7=BD=AE=E7=BC=96=E7=A0=81=E7=BC=93=E5=AD=98?= =?UTF-8?q?=E5=8A=A8=E6=80=81=E6=89=A9=E5=B1=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/model/module.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/khaosz/model/module.py b/khaosz/model/module.py index 154547b..ec73cc9 100644 --- a/khaosz/model/module.py +++ b/khaosz/model/module.py @@ -77,15 +77,26 @@ def apply_rotary_emb(x: torch.Tensor, rotary_emb: Tuple[Tensor, Tensor]) -> Tens class RotaryEmbedding(nn.Module): def __init__(self, dim: int, max_len: int, base: int=10000): super().__init__() - cos_emb, sin_emb = get_rotary_emb(dim, max_len, base) - self.register_buffer("cos_emb", cos_emb, persistent=False) - self.register_buffer("sin_emb", sin_emb, persistent=False) - self._rotary_buffers = {"cos_emb", "sin_emb"} + self.dim = dim + self.max_len = max_len + self.base = base + self.max_len_cached = None + self._set_rotary_buffer(self.max_len) + + def _set_rotary_buffer(self, max_len: int): + cos_cached, sin_cached = get_rotary_emb(self.dim, max_len, self.base) + self.register_buffer("cos_cached", cos_cached, persistent=False) + self.register_buffer("sin_cached", sin_cached, persistent=False) + self.max_len_cached = max_len def forward(self, x: Tensor, start_pos: int=0) -> Tuple[Tensor, Tensor]: seq_len = x.size(1) - cos = self.cos_emb[start_pos : start_pos + seq_len] - sin = self.sin_emb[start_pos : start_pos + seq_len] + + if self.max_len_cached < seq_len + start_pos: + self._set_rotary_buffer(seq_len) + + cos = self.cos_cached[start_pos : start_pos + seq_len] + sin = self.sin_cached[start_pos : start_pos + seq_len] return (cos, sin)