fix(khaosz/trainer/data_util.py): 修复 build_loss_mask 函数中使用错误的输入张量

This commit is contained in:
ViperEkura 2025-10-02 11:55:51 +08:00
parent cd4877e490
commit e43a5b9b66
1 changed files with 7 additions and 5 deletions

View File

@ -39,12 +39,13 @@ def build_attention_mask(input_ids: Tensor, user_token_id: int, multi_turn: bool
# fix the shape (bsz, 1, seq_len, seq_len) unsqueeze for broadcast
return attention_mask.unsqueeze(0)
def build_loss_mask(target_ids: Tensor, bos_token_id: int, eos_token_id: int) -> Tensor:
token_markers = torch.zeros_like(target_ids, dtype=torch.int8)
def build_loss_mask(input_ids: Tensor, bos_token_id: int, eos_token_id: int) -> Tensor:
token_markers = torch.zeros_like(input_ids, dtype=torch.int8)
is_bos_token = target_ids.eq(bos_token_id)
is_eos_token = target_ids.eq(eos_token_id)
is_bos_token = input_ids.eq(bos_token_id)
is_eos_token = input_ids.eq(eos_token_id)
# fix the eos_token_id bug(change target_ids to input_ids)
token_markers[is_bos_token] = 1
token_markers[is_eos_token] = -1
@ -193,7 +194,8 @@ class SftDataset(BaseDataset):
x = self._fetch_data(begin_idx, end_idx, "sequence").to(device=self.device, dtype=torch.long)
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(device=self.device, dtype=torch.long)
loss_mask = build_loss_mask(y, self.bos_token_id, self.eos_token_id)
# fix the eos_token_id bug(change target_ids to input_ids)
loss_mask = build_loss_mask(x, self.bos_token_id, self.eos_token_id)
attn_mask = build_attention_mask(x, self.user_token_id, self.multi_turn)
return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask, "attn_mask": attn_mask}