refactor(core): 修改注意力掩码处理函数并重命名参数
This commit is contained in:
parent
053f4a4dad
commit
9fbc9481b5
|
|
@ -71,7 +71,7 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
|
|||
|
||||
return x_out.to(dtype)
|
||||
|
||||
def create_attention_mask(
|
||||
def process_attention_mask(
|
||||
seq_mask: Tensor,
|
||||
start_pos: int = 0,
|
||||
seq_len: int = 0,
|
||||
|
|
@ -98,6 +98,9 @@ def create_attention_mask(
|
|||
if seq_mask is None:
|
||||
return None
|
||||
|
||||
if seq_mask.dim() > 2:
|
||||
return seq_mask
|
||||
|
||||
batch_size = seq_mask.size(0)
|
||||
seq_mask = seq_mask[:, :start_pos + seq_len].to(device=device, dtype=torch.bool)
|
||||
# (bsz, start_pos + seq_len)
|
||||
|
|
@ -306,7 +309,7 @@ class Transformer(nn.Module):
|
|||
def forward(
|
||||
self,
|
||||
input_ids: Tensor,
|
||||
seq_mask: Optional[Tensor]=None,
|
||||
input_mask: Optional[Tensor]=None,
|
||||
persistent_key_values: Optional[List[Tuple[Tensor, Tensor]]]=None,
|
||||
start_pos: int = 0
|
||||
) -> Tensor:
|
||||
|
|
@ -318,8 +321,8 @@ class Transformer(nn.Module):
|
|||
freqs_cis = self.freq_cis[start_pos:start_pos+seq_len]
|
||||
has_kvcache = persistent_key_values is not None
|
||||
|
||||
attn_mask = create_attention_mask(
|
||||
seq_mask,
|
||||
attn_mask = process_attention_mask(
|
||||
input_mask,
|
||||
start_pos=start_pos,
|
||||
seq_len=seq_len,
|
||||
is_causal=has_kvcache,
|
||||
|
|
|
|||
Loading…
Reference in New Issue