From 7ccc4ab9acd712a266627830cbaa7e1c8cda52c5 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Wed, 5 Nov 2025 23:38:45 +0800 Subject: [PATCH] =?UTF-8?q?fix(model):=20=E4=BF=AE=E5=A4=8D=E5=8A=A0?= =?UTF-8?q?=E8=BD=BD=E7=8A=B6=E6=80=81=E5=AD=97=E5=85=B8=E6=97=B6=E7=9A=84?= =?UTF-8?q?=E6=9D=83=E9=87=8D=E5=85=B1=E4=BA=AB=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/model/transformer.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/khaosz/model/transformer.py b/khaosz/model/transformer.py index 06dacda..63904bc 100644 --- a/khaosz/model/transformer.py +++ b/khaosz/model/transformer.py @@ -78,12 +78,16 @@ class Transformer(nn.Module): self._init_parameters() def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False): - if self.config.tie_weight == True: - lm_head_key = 'lm_head.weight' - embed_key = 'embed_tokens.weight' - - if lm_head_key not in state_dict and embed_key in state_dict: + lm_head_key = 'lm_head.weight' + embed_key = 'embed_tokens.weight' + + if lm_head_key not in state_dict and embed_key in state_dict: + if self.config.tie_weight == True: + # same tensor state_dict[lm_head_key] = state_dict[embed_key] + else: + # use clone to avoid sharing the same tensor + state_dict[lm_head_key] = torch.clone(state_dict[embed_key]) return super().load_state_dict(state_dict, strict, assign)