fix(transformer): 优化state_dict 处理逻辑, 优化attention_mask的处理方式
This commit is contained in:
parent
d25202a329
commit
6f3386f02c
|
|
@ -46,12 +46,7 @@ def process_attention_mask(
|
||||||
# (bsz, seq_len, start_pos + seq_len)
|
# (bsz, seq_len, start_pos + seq_len)
|
||||||
|
|
||||||
if is_causal:
|
if is_causal:
|
||||||
causal_mask = torch.tril(
|
expanded_mask = torch.tril(expanded_mask, diagonal=start_pos)
|
||||||
torch.ones((seq_len, start_pos + seq_len), dtype=torch.bool, device=device),
|
|
||||||
diagonal=start_pos
|
|
||||||
)
|
|
||||||
causal_mask = causal_mask.unsqueeze(0).expand(batch_size, seq_len, start_pos + seq_len)
|
|
||||||
expanded_mask = expanded_mask & causal_mask
|
|
||||||
|
|
||||||
attention_mask = torch.zeros_like(expanded_mask, dtype=dtype, device=device)
|
attention_mask = torch.zeros_like(expanded_mask, dtype=dtype, device=device)
|
||||||
attention_mask = attention_mask.masked_fill_(~expanded_mask, -torch.finfo(dtype).max / 2).unsqueeze(1)
|
attention_mask = attention_mask.masked_fill_(~expanded_mask, -torch.finfo(dtype).max / 2).unsqueeze(1)
|
||||||
|
|
@ -84,13 +79,12 @@ class Transformer(nn.Module):
|
||||||
lm_head_key = 'lm_head.weight'
|
lm_head_key = 'lm_head.weight'
|
||||||
embed_key = 'embed_tokens.weight'
|
embed_key = 'embed_tokens.weight'
|
||||||
|
|
||||||
if lm_head_key not in state_dict and embed_key in state_dict:
|
if self.config.tie_weight == True:
|
||||||
if self.config.tie_weight == True:
|
# same tensor
|
||||||
# same tensor
|
state_dict[lm_head_key] = state_dict[embed_key]
|
||||||
state_dict[lm_head_key] = state_dict[embed_key]
|
else:
|
||||||
else:
|
# use clone to avoid sharing the same tensor
|
||||||
# use clone to avoid sharing the same tensor
|
state_dict[lm_head_key] = torch.clone(state_dict[embed_key])
|
||||||
state_dict[lm_head_key] = torch.clone(state_dict[embed_key])
|
|
||||||
|
|
||||||
return super().load_state_dict(state_dict, strict, assign)
|
return super().load_state_dict(state_dict, strict, assign)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue