From 8206c7855e4da3b91a2e1eb9e1d74e528cc50020 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Mon, 29 Sep 2025 11:31:42 +0800 Subject: [PATCH] =?UTF-8?q?fix(transformer):=20=E8=B0=83=E6=95=B4=E6=B3=A8?= =?UTF-8?q?=E6=84=8F=E5=8A=9B=E6=8E=A9=E7=A0=81=E5=A4=84=E7=90=86=E9=80=BB?= =?UTF-8?q?=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/core/transformer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/khaosz/core/transformer.py b/khaosz/core/transformer.py index 3ec47f3..a324c63 100644 --- a/khaosz/core/transformer.py +++ b/khaosz/core/transformer.py @@ -91,12 +91,12 @@ def process_attention_mask( Tensor: The attention mask tensor. """ - if start_pos != 0 and seq_mask is None: - # for single prompt chat - seq_mask = torch.ones((1, seq_len), dtype=torch.bool, device=device) - if seq_mask is None: - return None + if start_pos != 0: + # for single prompt chat + seq_mask = torch.ones((1, seq_len), dtype=torch.bool, device=device) + else: + return None if seq_mask.dim() > 2: # shape (bsz, seq_len) or (bsz,n_heads, seq_len, seq_len + start_pos)