diff --git a/khaosz/model/transformer.py b/khaosz/model/transformer.py index 63904bc..e6af718 100644 --- a/khaosz/model/transformer.py +++ b/khaosz/model/transformer.py @@ -18,10 +18,10 @@ def process_attention_mask( Create attention mask for GQA Args: seq_mask (Tensor): A tensor indicating whether each position is valid or not. + input_tensor (Tensor): The input tensor. start_pos (int): The starting position of the sequence. seq_len (int): The length of the sequence. is_causal (bool): Whether the attention is causal or not. - device (torch.device): The device to use. Returns: Tensor: The attention mask tensor. """