fix(khaosz/trainer/data_util.py): 修复 build_loss_mask 函数中使用错误的输入张量
This commit is contained in:
parent
cd4877e490
commit
e43a5b9b66
|
|
@ -39,12 +39,13 @@ def build_attention_mask(input_ids: Tensor, user_token_id: int, multi_turn: bool
|
|||
# fix the shape (bsz, 1, seq_len, seq_len) unsqueeze for broadcast
|
||||
return attention_mask.unsqueeze(0)
|
||||
|
||||
def build_loss_mask(target_ids: Tensor, bos_token_id: int, eos_token_id: int) -> Tensor:
|
||||
token_markers = torch.zeros_like(target_ids, dtype=torch.int8)
|
||||
def build_loss_mask(input_ids: Tensor, bos_token_id: int, eos_token_id: int) -> Tensor:
|
||||
token_markers = torch.zeros_like(input_ids, dtype=torch.int8)
|
||||
|
||||
is_bos_token = target_ids.eq(bos_token_id)
|
||||
is_eos_token = target_ids.eq(eos_token_id)
|
||||
is_bos_token = input_ids.eq(bos_token_id)
|
||||
is_eos_token = input_ids.eq(eos_token_id)
|
||||
|
||||
# fix the eos_token_id bug(change target_ids to input_ids)
|
||||
token_markers[is_bos_token] = 1
|
||||
token_markers[is_eos_token] = -1
|
||||
|
||||
|
|
@ -193,7 +194,8 @@ class SftDataset(BaseDataset):
|
|||
x = self._fetch_data(begin_idx, end_idx, "sequence").to(device=self.device, dtype=torch.long)
|
||||
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(device=self.device, dtype=torch.long)
|
||||
|
||||
loss_mask = build_loss_mask(y, self.bos_token_id, self.eos_token_id)
|
||||
# fix the eos_token_id bug(change target_ids to input_ids)
|
||||
loss_mask = build_loss_mask(x, self.bos_token_id, self.eos_token_id)
|
||||
attn_mask = build_attention_mask(x, self.user_token_id, self.multi_turn)
|
||||
|
||||
return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask, "attn_mask": attn_mask}
|
||||
|
|
|
|||
Loading…
Reference in New Issue