fix(khaosz/trainer/data_util): 修复数据集索引范围错误
This commit is contained in:
parent
12793bc2d3
commit
efbe3de9d3
|
|
@ -155,8 +155,9 @@ class SeqDataset(BaseDataset):
|
|||
return self.fetcher.key_fetch(begin_idx, end_idx, "sequence")
|
||||
|
||||
def __getitem__(self, index):
|
||||
begin_idx = index * self.chunk_size
|
||||
end_idx = min(begin_idx + self.chunk_size, self.total_samples - 1)
|
||||
# fix the range index bug
|
||||
begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1)
|
||||
end_idx = begin_idx + self.chunk_size
|
||||
|
||||
x = self._fetch_data(begin_idx, end_idx).to(dtype=torch.long)
|
||||
y = self._fetch_data(begin_idx + 1, end_idx + 1).to(dtype=torch.long)
|
||||
|
|
@ -185,8 +186,8 @@ class SftDataset(BaseDataset):
|
|||
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
||||
|
||||
def __getitem__(self, index):
|
||||
begin_idx = index * self.chunk_size
|
||||
end_idx = min(begin_idx + self.chunk_size, self.total_samples - 1)
|
||||
begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1)
|
||||
end_idx = begin_idx + self.chunk_size
|
||||
|
||||
x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long)
|
||||
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(dtype=torch.long)
|
||||
|
|
@ -207,13 +208,13 @@ class DpoDataset(BaseDataset):
|
|||
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
||||
|
||||
def __getitem__(self, index: int):
|
||||
start_idx = index * self.chunk_size
|
||||
end_idx = min(start_idx + self.chunk_size, self.total_samples - 1)
|
||||
begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1)
|
||||
end_idx = begin_idx + self.chunk_size
|
||||
|
||||
chosen = self._fetch_data(start_idx, end_idx, "chosen").to(dtype=torch.long)
|
||||
rejected = self._fetch_data(start_idx, end_idx, "rejected").to(dtype=torch.long)
|
||||
chosen_mask = self._fetch_data(start_idx, end_idx, "chosen_mask").to(dtype=torch.bool)
|
||||
rejected_mask = self._fetch_data(start_idx, end_idx, "rejected_mask").to(dtype=torch.bool)
|
||||
chosen = self._fetch_data(begin_idx, end_idx, "chosen").to(dtype=torch.long)
|
||||
rejected = self._fetch_data(begin_idx, end_idx, "rejected").to(dtype=torch.long)
|
||||
chosen_mask = self._fetch_data(begin_idx, end_idx, "chosen_mask").to(dtype=torch.bool)
|
||||
rejected_mask = self._fetch_data(begin_idx, end_idx, "rejected_mask").to(dtype=torch.bool)
|
||||
|
||||
return {"chosen": chosen, "rejected": rejected, "chosen_mask": chosen_mask, "rejected_mask": rejected_mask}
|
||||
|
||||
|
|
@ -227,9 +228,8 @@ class PpoDataset(BaseDataset):
|
|||
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
||||
|
||||
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
||||
|
||||
begin_idx = index * self.chunk_size
|
||||
end_idx = min(begin_idx + self.chunk_size, self.total_samples - 1)
|
||||
begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1)
|
||||
end_idx = begin_idx + self.chunk_size
|
||||
|
||||
input_ids = self._fetch_data(begin_idx, end_idx, "input_ids"),
|
||||
actions = self._fetch_data(begin_idx, end_idx, "actions"),
|
||||
|
|
|
|||
Loading…
Reference in New Issue