From 6f3386f02c420dc389b75aaa937bba945f832f49 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sun, 9 Nov 2025 16:25:17 +0800 Subject: [PATCH] =?UTF-8?q?fix(transformer):=20=E4=BC=98=E5=8C=96state=5Fd?= =?UTF-8?q?ict=20=E5=A4=84=E7=90=86=E9=80=BB=E8=BE=91,=20=E4=BC=98?= =?UTF-8?q?=E5=8C=96attention=5Fmask=E7=9A=84=E5=A4=84=E7=90=86=E6=96=B9?= =?UTF-8?q?=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/model/transformer.py | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/khaosz/model/transformer.py b/khaosz/model/transformer.py index f67edc9..bff3634 100644 --- a/khaosz/model/transformer.py +++ b/khaosz/model/transformer.py @@ -46,12 +46,7 @@ def process_attention_mask( # (bsz, seq_len, start_pos + seq_len) if is_causal: - causal_mask = torch.tril( - torch.ones((seq_len, start_pos + seq_len), dtype=torch.bool, device=device), - diagonal=start_pos - ) - causal_mask = causal_mask.unsqueeze(0).expand(batch_size, seq_len, start_pos + seq_len) - expanded_mask = expanded_mask & causal_mask + expanded_mask = torch.tril(expanded_mask, diagonal=start_pos) attention_mask = torch.zeros_like(expanded_mask, dtype=dtype, device=device) attention_mask = attention_mask.masked_fill_(~expanded_mask, -torch.finfo(dtype).max / 2).unsqueeze(1) @@ -84,13 +79,12 @@ class Transformer(nn.Module): lm_head_key = 'lm_head.weight' embed_key = 'embed_tokens.weight' - if lm_head_key not in state_dict and embed_key in state_dict: - if self.config.tie_weight == True: - # same tensor - state_dict[lm_head_key] = state_dict[embed_key] - else: - # use clone to avoid sharing the same tensor - state_dict[lm_head_key] = torch.clone(state_dict[embed_key]) + if self.config.tie_weight == True: + # same tensor + state_dict[lm_head_key] = state_dict[embed_key] + else: + # use clone to avoid sharing the same tensor + state_dict[lm_head_key] = torch.clone(state_dict[embed_key]) return super().load_state_dict(state_dict, strict, assign)