refactor(core): 修改注意力掩码处理函数并重命名参数

This commit is contained in:
ViperEkura 2025-09-27 13:37:10 +08:00
parent 053f4a4dad
commit 9fbc9481b5
1 changed files with 7 additions and 4 deletions

View File

@ -71,7 +71,7 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
return x_out.to(dtype) return x_out.to(dtype)
def create_attention_mask( def process_attention_mask(
seq_mask: Tensor, seq_mask: Tensor,
start_pos: int = 0, start_pos: int = 0,
seq_len: int = 0, seq_len: int = 0,
@ -98,6 +98,9 @@ def create_attention_mask(
if seq_mask is None: if seq_mask is None:
return None return None
if seq_mask.dim() > 2:
return seq_mask
batch_size = seq_mask.size(0) batch_size = seq_mask.size(0)
seq_mask = seq_mask[:, :start_pos + seq_len].to(device=device, dtype=torch.bool) seq_mask = seq_mask[:, :start_pos + seq_len].to(device=device, dtype=torch.bool)
# (bsz, start_pos + seq_len) # (bsz, start_pos + seq_len)
@ -306,7 +309,7 @@ class Transformer(nn.Module):
def forward( def forward(
self, self,
input_ids: Tensor, input_ids: Tensor,
seq_mask: Optional[Tensor]=None, input_mask: Optional[Tensor]=None,
persistent_key_values: Optional[List[Tuple[Tensor, Tensor]]]=None, persistent_key_values: Optional[List[Tuple[Tensor, Tensor]]]=None,
start_pos: int = 0 start_pos: int = 0
) -> Tensor: ) -> Tensor:
@ -318,8 +321,8 @@ class Transformer(nn.Module):
freqs_cis = self.freq_cis[start_pos:start_pos+seq_len] freqs_cis = self.freq_cis[start_pos:start_pos+seq_len]
has_kvcache = persistent_key_values is not None has_kvcache = persistent_key_values is not None
attn_mask = create_attention_mask( attn_mask = process_attention_mask(
seq_mask, input_mask,
start_pos=start_pos, start_pos=start_pos,
seq_len=seq_len, seq_len=seq_len,
is_causal=has_kvcache, is_causal=has_kvcache,