From 5a356d66e1a44968960e47b855538122d638c251 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Tue, 30 Sep 2025 20:22:12 +0800 Subject: [PATCH] =?UTF-8?q?refactor(khaosz/trainer/data=5Futil):=20?= =?UTF-8?q?=E9=87=8D=E6=9E=84mask=E5=87=BD=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/trainer/data_util.py | 40 ++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 21 deletions(-) diff --git a/khaosz/trainer/data_util.py b/khaosz/trainer/data_util.py index ba012f6..5e50a86 100644 --- a/khaosz/trainer/data_util.py +++ b/khaosz/trainer/data_util.py @@ -25,11 +25,25 @@ def load_pkl_files(paths: List[str]): return segments, total_samples -def build_loss_mask(input_ids: Tensor, bos_token_id: int, eos_token_id: int) -> Tensor: - token_markers = torch.zeros_like(input_ids, dtype=torch.int8) +def build_attention_mask(input_ids: Tensor, user_token_id: int, multi_turn: bool) -> Tensor: + seq_len = input_ids.size(0) + turn_id = input_ids.eq(user_token_id).cumsum(dim=-1) - is_bos_token = input_ids.eq(bos_token_id) - is_eos_token = input_ids.eq(eos_token_id) + iq = turn_id.view(seq_len, 1) + ik = turn_id.view(1, seq_len) + + # fix the causual attention mask(iq >= ik condition) + seq_mask = (iq >= ik) if multi_turn else (iq == ik) + attention_mask = torch.tril(seq_mask) + + # fix the shape (bsz, 1, seq_len, seq_len) unsqueeze for broadcast + return attention_mask.unsqueeze(0) + +def build_loss_mask(target_ids: Tensor, bos_token_id: int, eos_token_id: int) -> Tensor: + token_markers = torch.zeros_like(target_ids, dtype=torch.int8) + + is_bos_token = target_ids.eq(bos_token_id) + is_eos_token = target_ids.eq(eos_token_id) token_markers[is_bos_token] = 1 token_markers[is_eos_token] = -1 @@ -37,25 +51,9 @@ def build_loss_mask(input_ids: Tensor, bos_token_id: int, eos_token_id: int) -> cumulative_markers = torch.cumsum(token_markers, dim=-1) min_cumulative = cumulative_markers.min(dim=-1, keepdim=True).values loss_mask = cumulative_markers - min_cumulative - + return loss_mask.to(dtype=torch.bool) -def build_attention_mask(input_ids: Tensor, user_token_id: int, multi_turn: bool) -> Tensor: - seq_len = input_ids.size(0) - is_user_token = input_ids.eq(user_token_id) - turn_id = is_user_token.cumsum(dim=-1) - - iq = turn_id.view(seq_len, 1) - ik = turn_id.view(1, seq_len) - - # fix the causual attention mask - seq_mask = (iq >= ik) if multi_turn else (iq == ik) - causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=input_ids.device)).bool() - attention_mask = seq_mask & causal_mask - - # fix the shape (bsz, 1, seq_len, seq_len) unsqueeze for broadcast - return attention_mask.unsqueeze(0) - class BaseSegmentFetcher: def __init__(self, segments: List[Tensor]):