diff --git a/khaosz/trainer/data_util.py b/khaosz/trainer/data_util.py index 38fda05..5c229de 100644 --- a/khaosz/trainer/data_util.py +++ b/khaosz/trainer/data_util.py @@ -155,8 +155,9 @@ class SeqDataset(BaseDataset): return self.fetcher.key_fetch(begin_idx, end_idx, "sequence") def __getitem__(self, index): - begin_idx = index * self.chunk_size - end_idx = min(begin_idx + self.chunk_size, self.total_samples - 1) + # fix the range index bug + begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1) + end_idx = begin_idx + self.chunk_size x = self._fetch_data(begin_idx, end_idx).to(dtype=torch.long) y = self._fetch_data(begin_idx + 1, end_idx + 1).to(dtype=torch.long) @@ -185,8 +186,8 @@ class SftDataset(BaseDataset): return self.fetcher.key_fetch(begin_idx, end_idx, key) def __getitem__(self, index): - begin_idx = index * self.chunk_size - end_idx = min(begin_idx + self.chunk_size, self.total_samples - 1) + begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1) + end_idx = begin_idx + self.chunk_size x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long) y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(dtype=torch.long) @@ -207,13 +208,13 @@ class DpoDataset(BaseDataset): return self.fetcher.key_fetch(begin_idx, end_idx, key) def __getitem__(self, index: int): - start_idx = index * self.chunk_size - end_idx = min(start_idx + self.chunk_size, self.total_samples - 1) + begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1) + end_idx = begin_idx + self.chunk_size - chosen = self._fetch_data(start_idx, end_idx, "chosen").to(dtype=torch.long) - rejected = self._fetch_data(start_idx, end_idx, "rejected").to(dtype=torch.long) - chosen_mask = self._fetch_data(start_idx, end_idx, "chosen_mask").to(dtype=torch.bool) - rejected_mask = self._fetch_data(start_idx, end_idx, "rejected_mask").to(dtype=torch.bool) + chosen = self._fetch_data(begin_idx, end_idx, "chosen").to(dtype=torch.long) + rejected = self._fetch_data(begin_idx, end_idx, "rejected").to(dtype=torch.long) + chosen_mask = self._fetch_data(begin_idx, end_idx, "chosen_mask").to(dtype=torch.bool) + rejected_mask = self._fetch_data(begin_idx, end_idx, "rejected_mask").to(dtype=torch.bool) return {"chosen": chosen, "rejected": rejected, "chosen_mask": chosen_mask, "rejected_mask": rejected_mask} @@ -227,9 +228,8 @@ class PpoDataset(BaseDataset): return self.fetcher.key_fetch(begin_idx, end_idx, key) def __getitem__(self, index: int) -> Dict[str, Tensor]: - - begin_idx = index * self.chunk_size - end_idx = min(begin_idx + self.chunk_size, self.total_samples - 1) + begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1) + end_idx = begin_idx + self.chunk_size input_ids = self._fetch_data(begin_idx, end_idx, "input_ids"), actions = self._fetch_data(begin_idx, end_idx, "actions"),