From 5daf63a7a4204918c64be6effd3f8786eaaeee43 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Tue, 25 Nov 2025 21:03:10 +0800 Subject: [PATCH] =?UTF-8?q?fix(model):=20=E4=BF=AE=E5=A4=8D=E5=8A=A0?= =?UTF-8?q?=E8=BD=BD=E7=8A=B6=E6=80=81=E5=AD=97=E5=85=B8=E6=97=B6=E7=9A=84?= =?UTF-8?q?=E9=94=AE=E5=AD=98=E5=9C=A8=E6=80=A7=E6=A3=80=E6=9F=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/model/transformer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/khaosz/model/transformer.py b/khaosz/model/transformer.py index bff3634..9ff30a5 100644 --- a/khaosz/model/transformer.py +++ b/khaosz/model/transformer.py @@ -83,8 +83,9 @@ class Transformer(nn.Module): # same tensor state_dict[lm_head_key] = state_dict[embed_key] else: - # use clone to avoid sharing the same tensor - state_dict[lm_head_key] = torch.clone(state_dict[embed_key]) + if lm_head_key not in state_dict and embed_key in state_dict: + # use clone to avoid sharing the same tensor + state_dict[lm_head_key] = torch.clone(state_dict[embed_key]) return super().load_state_dict(state_dict, strict, assign)