From e43a5b9b669a2e4eace58c4d2e3e1b65df42d190 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Thu, 2 Oct 2025 11:55:51 +0800 Subject: [PATCH] =?UTF-8?q?fix(khaosz/trainer/data=5Futil.py):=20=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D=20build=5Floss=5Fmask=20=E5=87=BD=E6=95=B0=E4=B8=AD?= =?UTF-8?q?=E4=BD=BF=E7=94=A8=E9=94=99=E8=AF=AF=E7=9A=84=E8=BE=93=E5=85=A5?= =?UTF-8?q?=E5=BC=A0=E9=87=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/trainer/data_util.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/khaosz/trainer/data_util.py b/khaosz/trainer/data_util.py index 5e50a86..bd9449d 100644 --- a/khaosz/trainer/data_util.py +++ b/khaosz/trainer/data_util.py @@ -39,12 +39,13 @@ def build_attention_mask(input_ids: Tensor, user_token_id: int, multi_turn: bool # fix the shape (bsz, 1, seq_len, seq_len) unsqueeze for broadcast return attention_mask.unsqueeze(0) -def build_loss_mask(target_ids: Tensor, bos_token_id: int, eos_token_id: int) -> Tensor: - token_markers = torch.zeros_like(target_ids, dtype=torch.int8) +def build_loss_mask(input_ids: Tensor, bos_token_id: int, eos_token_id: int) -> Tensor: + token_markers = torch.zeros_like(input_ids, dtype=torch.int8) - is_bos_token = target_ids.eq(bos_token_id) - is_eos_token = target_ids.eq(eos_token_id) + is_bos_token = input_ids.eq(bos_token_id) + is_eos_token = input_ids.eq(eos_token_id) + # fix the eos_token_id bug(change target_ids to input_ids) token_markers[is_bos_token] = 1 token_markers[is_eos_token] = -1 @@ -193,7 +194,8 @@ class SftDataset(BaseDataset): x = self._fetch_data(begin_idx, end_idx, "sequence").to(device=self.device, dtype=torch.long) y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(device=self.device, dtype=torch.long) - loss_mask = build_loss_mask(y, self.bos_token_id, self.eos_token_id) + # fix the eos_token_id bug(change target_ids to input_ids) + loss_mask = build_loss_mask(x, self.bos_token_id, self.eos_token_id) attn_mask = build_attention_mask(x, self.user_token_id, self.multi_turn) return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask, "attn_mask": attn_mask}