diff --git a/khaosz/model/transformer.py b/khaosz/model/transformer.py index bff3634..9ff30a5 100644 --- a/khaosz/model/transformer.py +++ b/khaosz/model/transformer.py @@ -83,8 +83,9 @@ class Transformer(nn.Module): # same tensor state_dict[lm_head_key] = state_dict[embed_key] else: - # use clone to avoid sharing the same tensor - state_dict[lm_head_key] = torch.clone(state_dict[embed_key]) + if lm_head_key not in state_dict and embed_key in state_dict: + # use clone to avoid sharing the same tensor + state_dict[lm_head_key] = torch.clone(state_dict[embed_key]) return super().load_state_dict(state_dict, strict, assign)