refactor(core): 修改注意力掩码处理函数并重命名参数
This commit is contained in:
parent
053f4a4dad
commit
9fbc9481b5
|
|
@ -71,7 +71,7 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
|
||||||
|
|
||||||
return x_out.to(dtype)
|
return x_out.to(dtype)
|
||||||
|
|
||||||
def create_attention_mask(
|
def process_attention_mask(
|
||||||
seq_mask: Tensor,
|
seq_mask: Tensor,
|
||||||
start_pos: int = 0,
|
start_pos: int = 0,
|
||||||
seq_len: int = 0,
|
seq_len: int = 0,
|
||||||
|
|
@ -98,6 +98,9 @@ def create_attention_mask(
|
||||||
if seq_mask is None:
|
if seq_mask is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
if seq_mask.dim() > 2:
|
||||||
|
return seq_mask
|
||||||
|
|
||||||
batch_size = seq_mask.size(0)
|
batch_size = seq_mask.size(0)
|
||||||
seq_mask = seq_mask[:, :start_pos + seq_len].to(device=device, dtype=torch.bool)
|
seq_mask = seq_mask[:, :start_pos + seq_len].to(device=device, dtype=torch.bool)
|
||||||
# (bsz, start_pos + seq_len)
|
# (bsz, start_pos + seq_len)
|
||||||
|
|
@ -306,7 +309,7 @@ class Transformer(nn.Module):
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Tensor,
|
input_ids: Tensor,
|
||||||
seq_mask: Optional[Tensor]=None,
|
input_mask: Optional[Tensor]=None,
|
||||||
persistent_key_values: Optional[List[Tuple[Tensor, Tensor]]]=None,
|
persistent_key_values: Optional[List[Tuple[Tensor, Tensor]]]=None,
|
||||||
start_pos: int = 0
|
start_pos: int = 0
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
|
|
@ -318,8 +321,8 @@ class Transformer(nn.Module):
|
||||||
freqs_cis = self.freq_cis[start_pos:start_pos+seq_len]
|
freqs_cis = self.freq_cis[start_pos:start_pos+seq_len]
|
||||||
has_kvcache = persistent_key_values is not None
|
has_kvcache = persistent_key_values is not None
|
||||||
|
|
||||||
attn_mask = create_attention_mask(
|
attn_mask = process_attention_mask(
|
||||||
seq_mask,
|
input_mask,
|
||||||
start_pos=start_pos,
|
start_pos=start_pos,
|
||||||
seq_len=seq_len,
|
seq_len=seq_len,
|
||||||
is_causal=has_kvcache,
|
is_causal=has_kvcache,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue