diff --git a/khaosz/trainer/data_util.py b/khaosz/trainer/data_util.py index 5c229de..7b6570e 100644 --- a/khaosz/trainer/data_util.py +++ b/khaosz/trainer/data_util.py @@ -72,8 +72,9 @@ class BaseSegmentFetcher: if begin_idx >= end_idx: return torch.tensor([], dtype=torch.long) - seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx - 1) - seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx - 1) + # fix the range index bug + seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx) + seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx) result_segments = [] @@ -165,7 +166,6 @@ class SeqDataset(BaseDataset): return {"input_ids": x, "target_ids": y} - class SftDataset(BaseDataset): def __init__( self, @@ -182,15 +182,15 @@ class SftDataset(BaseDataset): self.user_token_id = user_token_id self.multi_turn = multi_turn - def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: - return self.fetcher.key_fetch(begin_idx, end_idx, key) + def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor: + return self.fetcher.key_fetch(begin_idx, end_idx, "sequence") def __getitem__(self, index): begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1) end_idx = begin_idx + self.chunk_size - x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long) - y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(dtype=torch.long) + x = self._fetch_data(begin_idx, end_idx).to(dtype=torch.long) + y = self._fetch_data(begin_idx + 1, end_idx + 1).to(dtype=torch.long) # fix the eos_token_id bug(change target_ids to input_ids) loss_mask = build_loss_mask(x, self.bos_token_id, self.eos_token_id)