fix(model): 修复加载状态字典时的键存在性检查

This commit is contained in:
ViperEkura 2025-11-25 21:03:10 +08:00
parent fb85aaf6a6
commit 5daf63a7a4
1 changed files with 3 additions and 2 deletions

View File

@ -83,8 +83,9 @@ class Transformer(nn.Module):
# same tensor # same tensor
state_dict[lm_head_key] = state_dict[embed_key] state_dict[lm_head_key] = state_dict[embed_key]
else: else:
# use clone to avoid sharing the same tensor if lm_head_key not in state_dict and embed_key in state_dict:
state_dict[lm_head_key] = torch.clone(state_dict[embed_key]) # use clone to avoid sharing the same tensor
state_dict[lm_head_key] = torch.clone(state_dict[embed_key])
return super().load_state_dict(state_dict, strict, assign) return super().load_state_dict(state_dict, strict, assign)