From efbe3de9d3127e3ca274991efad2abda6da4d436 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Tue, 7 Oct 2025 20:04:45 +0800 Subject: [PATCH] =?UTF-8?q?fix(khaosz/trainer/data=5Futil):=20=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D=E6=95=B0=E6=8D=AE=E9=9B=86=E7=B4=A2=E5=BC=95=E8=8C=83?= =?UTF-8?q?=E5=9B=B4=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/trainer/data_util.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/khaosz/trainer/data_util.py b/khaosz/trainer/data_util.py index 38fda05..5c229de 100644 --- a/khaosz/trainer/data_util.py +++ b/khaosz/trainer/data_util.py @@ -155,8 +155,9 @@ class SeqDataset(BaseDataset): return self.fetcher.key_fetch(begin_idx, end_idx, "sequence") def __getitem__(self, index): - begin_idx = index * self.chunk_size - end_idx = min(begin_idx + self.chunk_size, self.total_samples - 1) + # fix the range index bug + begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1) + end_idx = begin_idx + self.chunk_size x = self._fetch_data(begin_idx, end_idx).to(dtype=torch.long) y = self._fetch_data(begin_idx + 1, end_idx + 1).to(dtype=torch.long) @@ -185,8 +186,8 @@ class SftDataset(BaseDataset): return self.fetcher.key_fetch(begin_idx, end_idx, key) def __getitem__(self, index): - begin_idx = index * self.chunk_size - end_idx = min(begin_idx + self.chunk_size, self.total_samples - 1) + begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1) + end_idx = begin_idx + self.chunk_size x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long) y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(dtype=torch.long) @@ -207,13 +208,13 @@ class DpoDataset(BaseDataset): return self.fetcher.key_fetch(begin_idx, end_idx, key) def __getitem__(self, index: int): - start_idx = index * self.chunk_size - end_idx = min(start_idx + self.chunk_size, self.total_samples - 1) + begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1) + end_idx = begin_idx + self.chunk_size - chosen = self._fetch_data(start_idx, end_idx, "chosen").to(dtype=torch.long) - rejected = self._fetch_data(start_idx, end_idx, "rejected").to(dtype=torch.long) - chosen_mask = self._fetch_data(start_idx, end_idx, "chosen_mask").to(dtype=torch.bool) - rejected_mask = self._fetch_data(start_idx, end_idx, "rejected_mask").to(dtype=torch.bool) + chosen = self._fetch_data(begin_idx, end_idx, "chosen").to(dtype=torch.long) + rejected = self._fetch_data(begin_idx, end_idx, "rejected").to(dtype=torch.long) + chosen_mask = self._fetch_data(begin_idx, end_idx, "chosen_mask").to(dtype=torch.bool) + rejected_mask = self._fetch_data(begin_idx, end_idx, "rejected_mask").to(dtype=torch.bool) return {"chosen": chosen, "rejected": rejected, "chosen_mask": chosen_mask, "rejected_mask": rejected_mask} @@ -227,9 +228,8 @@ class PpoDataset(BaseDataset): return self.fetcher.key_fetch(begin_idx, end_idx, key) def __getitem__(self, index: int) -> Dict[str, Tensor]: - - begin_idx = index * self.chunk_size - end_idx = min(begin_idx + self.chunk_size, self.total_samples - 1) + begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1) + end_idx = begin_idx + self.chunk_size input_ids = self._fetch_data(begin_idx, end_idx, "input_ids"), actions = self._fetch_data(begin_idx, end_idx, "actions"),