refactor(khaosz/trainer/data_util): 重构mask函数
This commit is contained in:
parent
78e5dbb3be
commit
5a356d66e1
|
|
@ -25,11 +25,25 @@ def load_pkl_files(paths: List[str]):
|
||||||
|
|
||||||
return segments, total_samples
|
return segments, total_samples
|
||||||
|
|
||||||
def build_loss_mask(input_ids: Tensor, bos_token_id: int, eos_token_id: int) -> Tensor:
|
def build_attention_mask(input_ids: Tensor, user_token_id: int, multi_turn: bool) -> Tensor:
|
||||||
token_markers = torch.zeros_like(input_ids, dtype=torch.int8)
|
seq_len = input_ids.size(0)
|
||||||
|
turn_id = input_ids.eq(user_token_id).cumsum(dim=-1)
|
||||||
|
|
||||||
is_bos_token = input_ids.eq(bos_token_id)
|
iq = turn_id.view(seq_len, 1)
|
||||||
is_eos_token = input_ids.eq(eos_token_id)
|
ik = turn_id.view(1, seq_len)
|
||||||
|
|
||||||
|
# fix the causual attention mask(iq >= ik condition)
|
||||||
|
seq_mask = (iq >= ik) if multi_turn else (iq == ik)
|
||||||
|
attention_mask = torch.tril(seq_mask)
|
||||||
|
|
||||||
|
# fix the shape (bsz, 1, seq_len, seq_len) unsqueeze for broadcast
|
||||||
|
return attention_mask.unsqueeze(0)
|
||||||
|
|
||||||
|
def build_loss_mask(target_ids: Tensor, bos_token_id: int, eos_token_id: int) -> Tensor:
|
||||||
|
token_markers = torch.zeros_like(target_ids, dtype=torch.int8)
|
||||||
|
|
||||||
|
is_bos_token = target_ids.eq(bos_token_id)
|
||||||
|
is_eos_token = target_ids.eq(eos_token_id)
|
||||||
|
|
||||||
token_markers[is_bos_token] = 1
|
token_markers[is_bos_token] = 1
|
||||||
token_markers[is_eos_token] = -1
|
token_markers[is_eos_token] = -1
|
||||||
|
|
@ -37,25 +51,9 @@ def build_loss_mask(input_ids: Tensor, bos_token_id: int, eos_token_id: int) ->
|
||||||
cumulative_markers = torch.cumsum(token_markers, dim=-1)
|
cumulative_markers = torch.cumsum(token_markers, dim=-1)
|
||||||
min_cumulative = cumulative_markers.min(dim=-1, keepdim=True).values
|
min_cumulative = cumulative_markers.min(dim=-1, keepdim=True).values
|
||||||
loss_mask = cumulative_markers - min_cumulative
|
loss_mask = cumulative_markers - min_cumulative
|
||||||
|
|
||||||
return loss_mask.to(dtype=torch.bool)
|
return loss_mask.to(dtype=torch.bool)
|
||||||
|
|
||||||
def build_attention_mask(input_ids: Tensor, user_token_id: int, multi_turn: bool) -> Tensor:
|
|
||||||
seq_len = input_ids.size(0)
|
|
||||||
is_user_token = input_ids.eq(user_token_id)
|
|
||||||
turn_id = is_user_token.cumsum(dim=-1)
|
|
||||||
|
|
||||||
iq = turn_id.view(seq_len, 1)
|
|
||||||
ik = turn_id.view(1, seq_len)
|
|
||||||
|
|
||||||
# fix the causual attention mask
|
|
||||||
seq_mask = (iq >= ik) if multi_turn else (iq == ik)
|
|
||||||
causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=input_ids.device)).bool()
|
|
||||||
attention_mask = seq_mask & causal_mask
|
|
||||||
|
|
||||||
# fix the shape (bsz, 1, seq_len, seq_len) unsqueeze for broadcast
|
|
||||||
return attention_mask.unsqueeze(0)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseSegmentFetcher:
|
class BaseSegmentFetcher:
|
||||||
def __init__(self, segments: List[Tensor]):
|
def __init__(self, segments: List[Tensor]):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue