diff --git a/khaosz/trainer/data_util.py b/khaosz/trainer/data_util.py index 5e50a86..bd9449d 100644 --- a/khaosz/trainer/data_util.py +++ b/khaosz/trainer/data_util.py @@ -39,12 +39,13 @@ def build_attention_mask(input_ids: Tensor, user_token_id: int, multi_turn: bool # fix the shape (bsz, 1, seq_len, seq_len) unsqueeze for broadcast return attention_mask.unsqueeze(0) -def build_loss_mask(target_ids: Tensor, bos_token_id: int, eos_token_id: int) -> Tensor: - token_markers = torch.zeros_like(target_ids, dtype=torch.int8) +def build_loss_mask(input_ids: Tensor, bos_token_id: int, eos_token_id: int) -> Tensor: + token_markers = torch.zeros_like(input_ids, dtype=torch.int8) - is_bos_token = target_ids.eq(bos_token_id) - is_eos_token = target_ids.eq(eos_token_id) + is_bos_token = input_ids.eq(bos_token_id) + is_eos_token = input_ids.eq(eos_token_id) + # fix the eos_token_id bug(change target_ids to input_ids) token_markers[is_bos_token] = 1 token_markers[is_eos_token] = -1 @@ -193,7 +194,8 @@ class SftDataset(BaseDataset): x = self._fetch_data(begin_idx, end_idx, "sequence").to(device=self.device, dtype=torch.long) y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(device=self.device, dtype=torch.long) - loss_mask = build_loss_mask(y, self.bos_token_id, self.eos_token_id) + # fix the eos_token_id bug(change target_ids to input_ids) + loss_mask = build_loss_mask(x, self.bos_token_id, self.eos_token_id) attn_mask = build_attention_mask(x, self.user_token_id, self.multi_turn) return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask, "attn_mask": attn_mask}