refactor(khaosz/trainer/data_util): 重构mask函数

This commit is contained in:
ViperEkura 2025-09-30 20:22:12 +08:00
parent 78e5dbb3be
commit 5a356d66e1
1 changed files with 19 additions and 21 deletions

View File

@ -25,11 +25,25 @@ def load_pkl_files(paths: List[str]):
return segments, total_samples
def build_loss_mask(input_ids: Tensor, bos_token_id: int, eos_token_id: int) -> Tensor:
token_markers = torch.zeros_like(input_ids, dtype=torch.int8)
def build_attention_mask(input_ids: Tensor, user_token_id: int, multi_turn: bool) -> Tensor:
seq_len = input_ids.size(0)
turn_id = input_ids.eq(user_token_id).cumsum(dim=-1)
is_bos_token = input_ids.eq(bos_token_id)
is_eos_token = input_ids.eq(eos_token_id)
iq = turn_id.view(seq_len, 1)
ik = turn_id.view(1, seq_len)
# fix the causual attention mask(iq >= ik condition)
seq_mask = (iq >= ik) if multi_turn else (iq == ik)
attention_mask = torch.tril(seq_mask)
# fix the shape (bsz, 1, seq_len, seq_len) unsqueeze for broadcast
return attention_mask.unsqueeze(0)
def build_loss_mask(target_ids: Tensor, bos_token_id: int, eos_token_id: int) -> Tensor:
token_markers = torch.zeros_like(target_ids, dtype=torch.int8)
is_bos_token = target_ids.eq(bos_token_id)
is_eos_token = target_ids.eq(eos_token_id)
token_markers[is_bos_token] = 1
token_markers[is_eos_token] = -1
@ -37,25 +51,9 @@ def build_loss_mask(input_ids: Tensor, bos_token_id: int, eos_token_id: int) ->
cumulative_markers = torch.cumsum(token_markers, dim=-1)
min_cumulative = cumulative_markers.min(dim=-1, keepdim=True).values
loss_mask = cumulative_markers - min_cumulative
return loss_mask.to(dtype=torch.bool)
def build_attention_mask(input_ids: Tensor, user_token_id: int, multi_turn: bool) -> Tensor:
seq_len = input_ids.size(0)
is_user_token = input_ids.eq(user_token_id)
turn_id = is_user_token.cumsum(dim=-1)
iq = turn_id.view(seq_len, 1)
ik = turn_id.view(1, seq_len)
# fix the causual attention mask
seq_mask = (iq >= ik) if multi_turn else (iq == ik)
causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=input_ids.device)).bool()
attention_mask = seq_mask & causal_mask
# fix the shape (bsz, 1, seq_len, seq_len) unsqueeze for broadcast
return attention_mask.unsqueeze(0)
class BaseSegmentFetcher:
def __init__(self, segments: List[Tensor]):