fix(data_util): 修复数据集索引计算逻辑并提取通用方法

This commit is contained in:
ViperEkura 2025-10-29 20:58:33 +08:00
parent 12850d403c
commit 6a3135f401
1 changed files with 13 additions and 12 deletions

View File

@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
from torch import Tensor
from torch.utils.data import Dataset, Sampler
from typing import Callable, List, Dict, Literal, Optional, Union
MutiSeg = Dict[str, List[Tensor]]
Seg = Dict[str, Tensor]
@ -99,22 +99,27 @@ class BaseDataset(Dataset, ABC):
for i in range(segment_size):
formated_segment = {key: self.segments[key][i] for key in keys}
pkl.dump(formated_segment, open(f"{save_path}_{i}.pkl", "wb"))
def load(self, load_path: Union[str, List[str]]):
paths = [load_path] if isinstance(load_path, str) else load_path
self.segments, self.total_samples = load_pkl_files(paths)
self.fetcher = MutiSegmentFetcher(self.segments)
def get_index(self, index: int) -> int:
begin_idx = min(index * self.step_size, self.total_samples - self.chunk_size - 1)
end_idx = begin_idx + self.chunk_size
return begin_idx, end_idx
@abstractmethod
def __getitem__(self, index: int) -> Dict[str, Tensor]:
raise NotImplementedError
def __len__(self) -> int:
assert self.total_samples is not None
if self.total_samples < self.chunk_size:
if self.total_samples <= self.chunk_size:
return 0
return (self.total_samples - self.chunk_size) // self.step_size + 1
return self.total_samples // self.step_size + 1
class SeqDataset(BaseDataset):
@ -127,8 +132,7 @@ class SeqDataset(BaseDataset):
def __getitem__(self, index):
# fix the range index bug
begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1)
end_idx = begin_idx + self.chunk_size
begin_idx, end_idx = self.get_index(index)
x = self._fetch_data(begin_idx, end_idx).to(dtype=torch.long)
y = self._fetch_data(begin_idx + 1, end_idx + 1).to(dtype=torch.long)
@ -145,8 +149,7 @@ class SftDataset(BaseDataset):
return self.fetcher.key_fetch(begin_idx, end_idx, key)
def __getitem__(self, index):
begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1)
end_idx = begin_idx + self.chunk_size
begin_idx, end_idx = self.get_index(index)
x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long)
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(dtype=torch.long)
@ -164,8 +167,7 @@ class DpoDataset(BaseDataset):
return self.fetcher.key_fetch(begin_idx, end_idx, key)
def __getitem__(self, index: int):
begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1)
end_idx = begin_idx + self.chunk_size
begin_idx, end_idx = self.get_index(index)
chosen = self._fetch_data(begin_idx, end_idx, "chosen").to(dtype=torch.long)
rejected = self._fetch_data(begin_idx, end_idx, "rejected").to(dtype=torch.long)
@ -184,8 +186,7 @@ class PpoDataset(BaseDataset):
return self.fetcher.key_fetch(begin_idx, end_idx, key)
def __getitem__(self, index: int) -> Dict[str, Tensor]:
begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1)
end_idx = begin_idx + self.chunk_size
begin_idx, end_idx = self.get_index(index)
input_ids = self._fetch_data(begin_idx, end_idx, "input_ids"),
actions = self._fetch_data(begin_idx, end_idx, "actions"),