From 9fbc9481b501f5121317f72965acee2f82b94367 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sat, 27 Sep 2025 13:37:10 +0800 Subject: [PATCH] =?UTF-8?q?refactor(core):=20=E4=BF=AE=E6=94=B9=E6=B3=A8?= =?UTF-8?q?=E6=84=8F=E5=8A=9B=E6=8E=A9=E7=A0=81=E5=A4=84=E7=90=86=E5=87=BD?= =?UTF-8?q?=E6=95=B0=E5=B9=B6=E9=87=8D=E5=91=BD=E5=90=8D=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/core/transformer.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/khaosz/core/transformer.py b/khaosz/core/transformer.py index b5a2d66..ae2d9c3 100644 --- a/khaosz/core/transformer.py +++ b/khaosz/core/transformer.py @@ -71,7 +71,7 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: return x_out.to(dtype) -def create_attention_mask( +def process_attention_mask( seq_mask: Tensor, start_pos: int = 0, seq_len: int = 0, @@ -98,6 +98,9 @@ def create_attention_mask( if seq_mask is None: return None + if seq_mask.dim() > 2: + return seq_mask + batch_size = seq_mask.size(0) seq_mask = seq_mask[:, :start_pos + seq_len].to(device=device, dtype=torch.bool) # (bsz, start_pos + seq_len) @@ -306,7 +309,7 @@ class Transformer(nn.Module): def forward( self, input_ids: Tensor, - seq_mask: Optional[Tensor]=None, + input_mask: Optional[Tensor]=None, persistent_key_values: Optional[List[Tuple[Tensor, Tensor]]]=None, start_pos: int = 0 ) -> Tensor: @@ -318,8 +321,8 @@ class Transformer(nn.Module): freqs_cis = self.freq_cis[start_pos:start_pos+seq_len] has_kvcache = persistent_key_values is not None - attn_mask = create_attention_mask( - seq_mask, + attn_mask = process_attention_mask( + input_mask, start_pos=start_pos, seq_len=seq_len, is_causal=has_kvcache,