diff --git a/khaosz/core/transformer.py b/khaosz/core/transformer.py index b5a2d66..ae2d9c3 100644 --- a/khaosz/core/transformer.py +++ b/khaosz/core/transformer.py @@ -71,7 +71,7 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: return x_out.to(dtype) -def create_attention_mask( +def process_attention_mask( seq_mask: Tensor, start_pos: int = 0, seq_len: int = 0, @@ -98,6 +98,9 @@ def create_attention_mask( if seq_mask is None: return None + if seq_mask.dim() > 2: + return seq_mask + batch_size = seq_mask.size(0) seq_mask = seq_mask[:, :start_pos + seq_len].to(device=device, dtype=torch.bool) # (bsz, start_pos + seq_len) @@ -306,7 +309,7 @@ class Transformer(nn.Module): def forward( self, input_ids: Tensor, - seq_mask: Optional[Tensor]=None, + input_mask: Optional[Tensor]=None, persistent_key_values: Optional[List[Tuple[Tensor, Tensor]]]=None, start_pos: int = 0 ) -> Tensor: @@ -318,8 +321,8 @@ class Transformer(nn.Module): freqs_cis = self.freq_cis[start_pos:start_pos+seq_len] has_kvcache = persistent_key_values is not None - attn_mask = create_attention_mask( - seq_mask, + attn_mask = process_attention_mask( + input_mask, start_pos=start_pos, seq_len=seq_len, is_causal=has_kvcache,