diff --git a/khaosz/core/transformer.py b/khaosz/core/transformer.py index 3ec47f3..a324c63 100644 --- a/khaosz/core/transformer.py +++ b/khaosz/core/transformer.py @@ -91,12 +91,12 @@ def process_attention_mask( Tensor: The attention mask tensor. """ - if start_pos != 0 and seq_mask is None: - # for single prompt chat - seq_mask = torch.ones((1, seq_len), dtype=torch.bool, device=device) - if seq_mask is None: - return None + if start_pos != 0: + # for single prompt chat + seq_mask = torch.ones((1, seq_len), dtype=torch.bool, device=device) + else: + return None if seq_mask.dim() > 2: # shape (bsz, seq_len) or (bsz,n_heads, seq_len, seq_len + start_pos)