fix(model): 修复加载状态字典时的权重共享问题

This commit is contained in:
ViperEkura 2025-11-05 23:38:45 +08:00
parent 69d9374f51
commit 7ccc4ab9ac
1 changed files with 9 additions and 5 deletions

View File

@ -78,12 +78,16 @@ class Transformer(nn.Module):
self._init_parameters() self._init_parameters()
def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False): def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False):
if self.config.tie_weight == True: lm_head_key = 'lm_head.weight'
lm_head_key = 'lm_head.weight' embed_key = 'embed_tokens.weight'
embed_key = 'embed_tokens.weight'
if lm_head_key not in state_dict and embed_key in state_dict:
if lm_head_key not in state_dict and embed_key in state_dict: if self.config.tie_weight == True:
# same tensor
state_dict[lm_head_key] = state_dict[embed_key] state_dict[lm_head_key] = state_dict[embed_key]
else:
# use clone to avoid sharing the same tensor
state_dict[lm_head_key] = torch.clone(state_dict[embed_key])
return super().load_state_dict(state_dict, strict, assign) return super().load_state_dict(state_dict, strict, assign)