diff --git a/khaosz/model/transformer.py b/khaosz/model/transformer.py index 06dacda..63904bc 100644 --- a/khaosz/model/transformer.py +++ b/khaosz/model/transformer.py @@ -78,12 +78,16 @@ class Transformer(nn.Module): self._init_parameters() def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False): - if self.config.tie_weight == True: - lm_head_key = 'lm_head.weight' - embed_key = 'embed_tokens.weight' - - if lm_head_key not in state_dict and embed_key in state_dict: + lm_head_key = 'lm_head.weight' + embed_key = 'embed_tokens.weight' + + if lm_head_key not in state_dict and embed_key in state_dict: + if self.config.tie_weight == True: + # same tensor state_dict[lm_head_key] = state_dict[embed_key] + else: + # use clone to avoid sharing the same tensor + state_dict[lm_head_key] = torch.clone(state_dict[embed_key]) return super().load_state_dict(state_dict, strict, assign)