fix(transformer): 调整注意力掩码处理逻辑
This commit is contained in:
parent
816bc78894
commit
8206c7855e
|
|
@ -91,12 +91,12 @@ def process_attention_mask(
|
||||||
Tensor: The attention mask tensor.
|
Tensor: The attention mask tensor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if start_pos != 0 and seq_mask is None:
|
|
||||||
# for single prompt chat
|
|
||||||
seq_mask = torch.ones((1, seq_len), dtype=torch.bool, device=device)
|
|
||||||
|
|
||||||
if seq_mask is None:
|
if seq_mask is None:
|
||||||
return None
|
if start_pos != 0:
|
||||||
|
# for single prompt chat
|
||||||
|
seq_mask = torch.ones((1, seq_len), dtype=torch.bool, device=device)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
if seq_mask.dim() > 2:
|
if seq_mask.dim() > 2:
|
||||||
# shape (bsz, seq_len) or (bsz,n_heads, seq_len, seq_len + start_pos)
|
# shape (bsz, seq_len) or (bsz,n_heads, seq_len, seq_len + start_pos)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue