refactor(core): 修改注意力掩码处理函数并重命名参数

This commit is contained in:
ViperEkura 2025-09-27 13:37:10 +08:00
parent 053f4a4dad
commit 9fbc9481b5
1 changed files with 7 additions and 4 deletions

View File

@ -71,7 +71,7 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
return x_out.to(dtype)
def create_attention_mask(
def process_attention_mask(
seq_mask: Tensor,
start_pos: int = 0,
seq_len: int = 0,
@ -98,6 +98,9 @@ def create_attention_mask(
if seq_mask is None:
return None
if seq_mask.dim() > 2:
return seq_mask
batch_size = seq_mask.size(0)
seq_mask = seq_mask[:, :start_pos + seq_len].to(device=device, dtype=torch.bool)
# (bsz, start_pos + seq_len)
@ -306,7 +309,7 @@ class Transformer(nn.Module):
def forward(
self,
input_ids: Tensor,
seq_mask: Optional[Tensor]=None,
input_mask: Optional[Tensor]=None,
persistent_key_values: Optional[List[Tuple[Tensor, Tensor]]]=None,
start_pos: int = 0
) -> Tensor:
@ -318,8 +321,8 @@ class Transformer(nn.Module):
freqs_cis = self.freq_cis[start_pos:start_pos+seq_len]
has_kvcache = persistent_key_values is not None
attn_mask = create_attention_mask(
seq_mask,
attn_mask = process_attention_mask(
input_mask,
start_pos=start_pos,
seq_len=seq_len,
is_causal=has_kvcache,