fix(model): 修复加载状态字典时的键存在性检查
This commit is contained in:
parent
fb85aaf6a6
commit
5daf63a7a4
|
|
@ -83,8 +83,9 @@ class Transformer(nn.Module):
|
|||
# same tensor
|
||||
state_dict[lm_head_key] = state_dict[embed_key]
|
||||
else:
|
||||
# use clone to avoid sharing the same tensor
|
||||
state_dict[lm_head_key] = torch.clone(state_dict[embed_key])
|
||||
if lm_head_key not in state_dict and embed_key in state_dict:
|
||||
# use clone to avoid sharing the same tensor
|
||||
state_dict[lm_head_key] = torch.clone(state_dict[embed_key])
|
||||
|
||||
return super().load_state_dict(state_dict, strict, assign)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue