fix(khaosz/trainer/data_util.py): 修复 build_loss_mask 函数中使用错误的输入张量
This commit is contained in:
parent
cd4877e490
commit
e43a5b9b66
|
|
@ -39,12 +39,13 @@ def build_attention_mask(input_ids: Tensor, user_token_id: int, multi_turn: bool
|
||||||
# fix the shape (bsz, 1, seq_len, seq_len) unsqueeze for broadcast
|
# fix the shape (bsz, 1, seq_len, seq_len) unsqueeze for broadcast
|
||||||
return attention_mask.unsqueeze(0)
|
return attention_mask.unsqueeze(0)
|
||||||
|
|
||||||
def build_loss_mask(target_ids: Tensor, bos_token_id: int, eos_token_id: int) -> Tensor:
|
def build_loss_mask(input_ids: Tensor, bos_token_id: int, eos_token_id: int) -> Tensor:
|
||||||
token_markers = torch.zeros_like(target_ids, dtype=torch.int8)
|
token_markers = torch.zeros_like(input_ids, dtype=torch.int8)
|
||||||
|
|
||||||
is_bos_token = target_ids.eq(bos_token_id)
|
is_bos_token = input_ids.eq(bos_token_id)
|
||||||
is_eos_token = target_ids.eq(eos_token_id)
|
is_eos_token = input_ids.eq(eos_token_id)
|
||||||
|
|
||||||
|
# fix the eos_token_id bug(change target_ids to input_ids)
|
||||||
token_markers[is_bos_token] = 1
|
token_markers[is_bos_token] = 1
|
||||||
token_markers[is_eos_token] = -1
|
token_markers[is_eos_token] = -1
|
||||||
|
|
||||||
|
|
@ -193,7 +194,8 @@ class SftDataset(BaseDataset):
|
||||||
x = self._fetch_data(begin_idx, end_idx, "sequence").to(device=self.device, dtype=torch.long)
|
x = self._fetch_data(begin_idx, end_idx, "sequence").to(device=self.device, dtype=torch.long)
|
||||||
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(device=self.device, dtype=torch.long)
|
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(device=self.device, dtype=torch.long)
|
||||||
|
|
||||||
loss_mask = build_loss_mask(y, self.bos_token_id, self.eos_token_id)
|
# fix the eos_token_id bug(change target_ids to input_ids)
|
||||||
|
loss_mask = build_loss_mask(x, self.bos_token_id, self.eos_token_id)
|
||||||
attn_mask = build_attention_mask(x, self.user_token_id, self.multi_turn)
|
attn_mask = build_attention_mask(x, self.user_token_id, self.multi_turn)
|
||||||
|
|
||||||
return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask, "attn_mask": attn_mask}
|
return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask, "attn_mask": attn_mask}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue