fix(data_util): 修复数据集索引计算逻辑并提取通用方法
This commit is contained in:
parent
12850d403c
commit
6a3135f401
|
|
@ -100,21 +100,26 @@ class BaseDataset(Dataset, ABC):
|
|||
formated_segment = {key: self.segments[key][i] for key in keys}
|
||||
pkl.dump(formated_segment, open(f"{save_path}_{i}.pkl", "wb"))
|
||||
|
||||
|
||||
def load(self, load_path: Union[str, List[str]]):
|
||||
paths = [load_path] if isinstance(load_path, str) else load_path
|
||||
self.segments, self.total_samples = load_pkl_files(paths)
|
||||
self.fetcher = MutiSegmentFetcher(self.segments)
|
||||
|
||||
def get_index(self, index: int) -> int:
|
||||
begin_idx = min(index * self.step_size, self.total_samples - self.chunk_size - 1)
|
||||
end_idx = begin_idx + self.chunk_size
|
||||
|
||||
return begin_idx, end_idx
|
||||
|
||||
@abstractmethod
|
||||
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
def __len__(self) -> int:
|
||||
assert self.total_samples is not None
|
||||
if self.total_samples < self.chunk_size:
|
||||
if self.total_samples <= self.chunk_size:
|
||||
return 0
|
||||
return (self.total_samples - self.chunk_size) // self.step_size + 1
|
||||
return self.total_samples // self.step_size + 1
|
||||
|
||||
|
||||
class SeqDataset(BaseDataset):
|
||||
|
|
@ -127,8 +132,7 @@ class SeqDataset(BaseDataset):
|
|||
|
||||
def __getitem__(self, index):
|
||||
# fix the range index bug
|
||||
begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1)
|
||||
end_idx = begin_idx + self.chunk_size
|
||||
begin_idx, end_idx = self.get_index(index)
|
||||
|
||||
x = self._fetch_data(begin_idx, end_idx).to(dtype=torch.long)
|
||||
y = self._fetch_data(begin_idx + 1, end_idx + 1).to(dtype=torch.long)
|
||||
|
|
@ -145,8 +149,7 @@ class SftDataset(BaseDataset):
|
|||
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
||||
|
||||
def __getitem__(self, index):
|
||||
begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1)
|
||||
end_idx = begin_idx + self.chunk_size
|
||||
begin_idx, end_idx = self.get_index(index)
|
||||
|
||||
x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long)
|
||||
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(dtype=torch.long)
|
||||
|
|
@ -164,8 +167,7 @@ class DpoDataset(BaseDataset):
|
|||
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
||||
|
||||
def __getitem__(self, index: int):
|
||||
begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1)
|
||||
end_idx = begin_idx + self.chunk_size
|
||||
begin_idx, end_idx = self.get_index(index)
|
||||
|
||||
chosen = self._fetch_data(begin_idx, end_idx, "chosen").to(dtype=torch.long)
|
||||
rejected = self._fetch_data(begin_idx, end_idx, "rejected").to(dtype=torch.long)
|
||||
|
|
@ -184,8 +186,7 @@ class PpoDataset(BaseDataset):
|
|||
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
||||
|
||||
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
||||
begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1)
|
||||
end_idx = begin_idx + self.chunk_size
|
||||
begin_idx, end_idx = self.get_index(index)
|
||||
|
||||
input_ids = self._fetch_data(begin_idx, end_idx, "input_ids"),
|
||||
actions = self._fetch_data(begin_idx, end_idx, "actions"),
|
||||
|
|
|
|||
Loading…
Reference in New Issue