fix(khaosz/trainer/data_util): 修复数据集索引范围错误

This commit is contained in:
ViperEkura 2025-10-07 20:04:45 +08:00
parent 12793bc2d3
commit efbe3de9d3
1 changed files with 13 additions and 13 deletions

View File

@ -155,8 +155,9 @@ class SeqDataset(BaseDataset):
return self.fetcher.key_fetch(begin_idx, end_idx, "sequence")
def __getitem__(self, index):
begin_idx = index * self.chunk_size
end_idx = min(begin_idx + self.chunk_size, self.total_samples - 1)
# fix the range index bug
begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1)
end_idx = begin_idx + self.chunk_size
x = self._fetch_data(begin_idx, end_idx).to(dtype=torch.long)
y = self._fetch_data(begin_idx + 1, end_idx + 1).to(dtype=torch.long)
@ -185,8 +186,8 @@ class SftDataset(BaseDataset):
return self.fetcher.key_fetch(begin_idx, end_idx, key)
def __getitem__(self, index):
begin_idx = index * self.chunk_size
end_idx = min(begin_idx + self.chunk_size, self.total_samples - 1)
begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1)
end_idx = begin_idx + self.chunk_size
x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long)
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(dtype=torch.long)
@ -207,13 +208,13 @@ class DpoDataset(BaseDataset):
return self.fetcher.key_fetch(begin_idx, end_idx, key)
def __getitem__(self, index: int):
start_idx = index * self.chunk_size
end_idx = min(start_idx + self.chunk_size, self.total_samples - 1)
begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1)
end_idx = begin_idx + self.chunk_size
chosen = self._fetch_data(start_idx, end_idx, "chosen").to(dtype=torch.long)
rejected = self._fetch_data(start_idx, end_idx, "rejected").to(dtype=torch.long)
chosen_mask = self._fetch_data(start_idx, end_idx, "chosen_mask").to(dtype=torch.bool)
rejected_mask = self._fetch_data(start_idx, end_idx, "rejected_mask").to(dtype=torch.bool)
chosen = self._fetch_data(begin_idx, end_idx, "chosen").to(dtype=torch.long)
rejected = self._fetch_data(begin_idx, end_idx, "rejected").to(dtype=torch.long)
chosen_mask = self._fetch_data(begin_idx, end_idx, "chosen_mask").to(dtype=torch.bool)
rejected_mask = self._fetch_data(begin_idx, end_idx, "rejected_mask").to(dtype=torch.bool)
return {"chosen": chosen, "rejected": rejected, "chosen_mask": chosen_mask, "rejected_mask": rejected_mask}
@ -227,9 +228,8 @@ class PpoDataset(BaseDataset):
return self.fetcher.key_fetch(begin_idx, end_idx, key)
def __getitem__(self, index: int) -> Dict[str, Tensor]:
begin_idx = index * self.chunk_size
end_idx = min(begin_idx + self.chunk_size, self.total_samples - 1)
begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1)
end_idx = begin_idx + self.chunk_size
input_ids = self._fetch_data(begin_idx, end_idx, "input_ids"),
actions = self._fetch_data(begin_idx, end_idx, "actions"),