fix(data_util): 修复数据集索引计算逻辑并提取通用方法
This commit is contained in:
parent
12850d403c
commit
6a3135f401
|
|
@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.utils.data import Dataset, Sampler
|
from torch.utils.data import Dataset, Sampler
|
||||||
from typing import Callable, List, Dict, Literal, Optional, Union
|
from typing import Callable, List, Dict, Literal, Optional, Union
|
||||||
|
|
||||||
MutiSeg = Dict[str, List[Tensor]]
|
MutiSeg = Dict[str, List[Tensor]]
|
||||||
Seg = Dict[str, Tensor]
|
Seg = Dict[str, Tensor]
|
||||||
|
|
||||||
|
|
@ -99,22 +99,27 @@ class BaseDataset(Dataset, ABC):
|
||||||
for i in range(segment_size):
|
for i in range(segment_size):
|
||||||
formated_segment = {key: self.segments[key][i] for key in keys}
|
formated_segment = {key: self.segments[key][i] for key in keys}
|
||||||
pkl.dump(formated_segment, open(f"{save_path}_{i}.pkl", "wb"))
|
pkl.dump(formated_segment, open(f"{save_path}_{i}.pkl", "wb"))
|
||||||
|
|
||||||
|
|
||||||
def load(self, load_path: Union[str, List[str]]):
|
def load(self, load_path: Union[str, List[str]]):
|
||||||
paths = [load_path] if isinstance(load_path, str) else load_path
|
paths = [load_path] if isinstance(load_path, str) else load_path
|
||||||
self.segments, self.total_samples = load_pkl_files(paths)
|
self.segments, self.total_samples = load_pkl_files(paths)
|
||||||
self.fetcher = MutiSegmentFetcher(self.segments)
|
self.fetcher = MutiSegmentFetcher(self.segments)
|
||||||
|
|
||||||
|
def get_index(self, index: int) -> int:
|
||||||
|
begin_idx = min(index * self.step_size, self.total_samples - self.chunk_size - 1)
|
||||||
|
end_idx = begin_idx + self.chunk_size
|
||||||
|
|
||||||
|
return begin_idx, end_idx
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
assert self.total_samples is not None
|
assert self.total_samples is not None
|
||||||
if self.total_samples < self.chunk_size:
|
if self.total_samples <= self.chunk_size:
|
||||||
return 0
|
return 0
|
||||||
return (self.total_samples - self.chunk_size) // self.step_size + 1
|
return self.total_samples // self.step_size + 1
|
||||||
|
|
||||||
|
|
||||||
class SeqDataset(BaseDataset):
|
class SeqDataset(BaseDataset):
|
||||||
|
|
@ -127,8 +132,7 @@ class SeqDataset(BaseDataset):
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
# fix the range index bug
|
# fix the range index bug
|
||||||
begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1)
|
begin_idx, end_idx = self.get_index(index)
|
||||||
end_idx = begin_idx + self.chunk_size
|
|
||||||
|
|
||||||
x = self._fetch_data(begin_idx, end_idx).to(dtype=torch.long)
|
x = self._fetch_data(begin_idx, end_idx).to(dtype=torch.long)
|
||||||
y = self._fetch_data(begin_idx + 1, end_idx + 1).to(dtype=torch.long)
|
y = self._fetch_data(begin_idx + 1, end_idx + 1).to(dtype=torch.long)
|
||||||
|
|
@ -145,8 +149,7 @@ class SftDataset(BaseDataset):
|
||||||
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1)
|
begin_idx, end_idx = self.get_index(index)
|
||||||
end_idx = begin_idx + self.chunk_size
|
|
||||||
|
|
||||||
x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long)
|
x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long)
|
||||||
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(dtype=torch.long)
|
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(dtype=torch.long)
|
||||||
|
|
@ -164,8 +167,7 @@ class DpoDataset(BaseDataset):
|
||||||
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
||||||
|
|
||||||
def __getitem__(self, index: int):
|
def __getitem__(self, index: int):
|
||||||
begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1)
|
begin_idx, end_idx = self.get_index(index)
|
||||||
end_idx = begin_idx + self.chunk_size
|
|
||||||
|
|
||||||
chosen = self._fetch_data(begin_idx, end_idx, "chosen").to(dtype=torch.long)
|
chosen = self._fetch_data(begin_idx, end_idx, "chosen").to(dtype=torch.long)
|
||||||
rejected = self._fetch_data(begin_idx, end_idx, "rejected").to(dtype=torch.long)
|
rejected = self._fetch_data(begin_idx, end_idx, "rejected").to(dtype=torch.long)
|
||||||
|
|
@ -184,8 +186,7 @@ class PpoDataset(BaseDataset):
|
||||||
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
||||||
|
|
||||||
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
||||||
begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1)
|
begin_idx, end_idx = self.get_index(index)
|
||||||
end_idx = begin_idx + self.chunk_size
|
|
||||||
|
|
||||||
input_ids = self._fetch_data(begin_idx, end_idx, "input_ids"),
|
input_ids = self._fetch_data(begin_idx, end_idx, "input_ids"),
|
||||||
actions = self._fetch_data(begin_idx, end_idx, "actions"),
|
actions = self._fetch_data(begin_idx, end_idx, "actions"),
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue