fix(khaosz/trainer): 修复数据获取中的索引范围错误和参数传递问题
This commit is contained in:
parent
68a15005cb
commit
8434c19923
|
|
@ -72,8 +72,9 @@ class BaseSegmentFetcher:
|
|||
if begin_idx >= end_idx:
|
||||
return torch.tensor([], dtype=torch.long)
|
||||
|
||||
seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx - 1)
|
||||
seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx - 1)
|
||||
# fix the range index bug
|
||||
seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx)
|
||||
seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx)
|
||||
|
||||
result_segments = []
|
||||
|
||||
|
|
@ -165,7 +166,6 @@ class SeqDataset(BaseDataset):
|
|||
return {"input_ids": x, "target_ids": y}
|
||||
|
||||
|
||||
|
||||
class SftDataset(BaseDataset):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -182,15 +182,15 @@ class SftDataset(BaseDataset):
|
|||
self.user_token_id = user_token_id
|
||||
self.multi_turn = multi_turn
|
||||
|
||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
||||
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
||||
return self.fetcher.key_fetch(begin_idx, end_idx, "sequence")
|
||||
|
||||
def __getitem__(self, index):
|
||||
begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1)
|
||||
end_idx = begin_idx + self.chunk_size
|
||||
|
||||
x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long)
|
||||
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(dtype=torch.long)
|
||||
x = self._fetch_data(begin_idx, end_idx).to(dtype=torch.long)
|
||||
y = self._fetch_data(begin_idx + 1, end_idx + 1).to(dtype=torch.long)
|
||||
|
||||
# fix the eos_token_id bug(change target_ids to input_ids)
|
||||
loss_mask = build_loss_mask(x, self.bos_token_id, self.eos_token_id)
|
||||
|
|
|
|||
Loading…
Reference in New Issue