fix(transformer): 调整注意力掩码处理逻辑
This commit is contained in:
parent
816bc78894
commit
8206c7855e
|
|
@ -91,12 +91,12 @@ def process_attention_mask(
|
|||
Tensor: The attention mask tensor.
|
||||
"""
|
||||
|
||||
if start_pos != 0 and seq_mask is None:
|
||||
# for single prompt chat
|
||||
seq_mask = torch.ones((1, seq_len), dtype=torch.bool, device=device)
|
||||
|
||||
if seq_mask is None:
|
||||
return None
|
||||
if start_pos != 0:
|
||||
# for single prompt chat
|
||||
seq_mask = torch.ones((1, seq_len), dtype=torch.bool, device=device)
|
||||
else:
|
||||
return None
|
||||
|
||||
if seq_mask.dim() > 2:
|
||||
# shape (bsz, seq_len) or (bsz,n_heads, seq_len, seq_len + start_pos)
|
||||
|
|
|
|||
Loading…
Reference in New Issue