docs(transformer): 更新process_attention_mask函数文档
This commit is contained in:
parent
7ccc4ab9ac
commit
805773c7fe
|
|
@ -18,10 +18,10 @@ def process_attention_mask(
|
||||||
Create attention mask for GQA
|
Create attention mask for GQA
|
||||||
Args:
|
Args:
|
||||||
seq_mask (Tensor): A tensor indicating whether each position is valid or not.
|
seq_mask (Tensor): A tensor indicating whether each position is valid or not.
|
||||||
|
input_tensor (Tensor): The input tensor.
|
||||||
start_pos (int): The starting position of the sequence.
|
start_pos (int): The starting position of the sequence.
|
||||||
seq_len (int): The length of the sequence.
|
seq_len (int): The length of the sequence.
|
||||||
is_causal (bool): Whether the attention is causal or not.
|
is_causal (bool): Whether the attention is causal or not.
|
||||||
device (torch.device): The device to use.
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor: The attention mask tensor.
|
Tensor: The attention mask tensor.
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue