fix(model): 修复加载状态字典时的权重共享问题
This commit is contained in:
parent
69d9374f51
commit
7ccc4ab9ac
|
|
@ -78,12 +78,16 @@ class Transformer(nn.Module):
|
||||||
self._init_parameters()
|
self._init_parameters()
|
||||||
|
|
||||||
def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False):
|
def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False):
|
||||||
if self.config.tie_weight == True:
|
lm_head_key = 'lm_head.weight'
|
||||||
lm_head_key = 'lm_head.weight'
|
embed_key = 'embed_tokens.weight'
|
||||||
embed_key = 'embed_tokens.weight'
|
|
||||||
|
|
||||||
if lm_head_key not in state_dict and embed_key in state_dict:
|
if lm_head_key not in state_dict and embed_key in state_dict:
|
||||||
|
if self.config.tie_weight == True:
|
||||||
|
# same tensor
|
||||||
state_dict[lm_head_key] = state_dict[embed_key]
|
state_dict[lm_head_key] = state_dict[embed_key]
|
||||||
|
else:
|
||||||
|
# use clone to avoid sharing the same tensor
|
||||||
|
state_dict[lm_head_key] = torch.clone(state_dict[embed_key])
|
||||||
|
|
||||||
return super().load_state_dict(state_dict, strict, assign)
|
return super().load_state_dict(state_dict, strict, assign)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue