fix(data_util): 修复数据集索引计算逻辑并提取通用方法

This commit is contained in:
ViperEkura 2025-10-29 20:58:33 +08:00
parent 12850d403c
commit 6a3135f401
1 changed files with 13 additions and 12 deletions

View File

@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
from torch import Tensor from torch import Tensor
from torch.utils.data import Dataset, Sampler from torch.utils.data import Dataset, Sampler
from typing import Callable, List, Dict, Literal, Optional, Union from typing import Callable, List, Dict, Literal, Optional, Union
MutiSeg = Dict[str, List[Tensor]] MutiSeg = Dict[str, List[Tensor]]
Seg = Dict[str, Tensor] Seg = Dict[str, Tensor]
@ -99,22 +99,27 @@ class BaseDataset(Dataset, ABC):
for i in range(segment_size): for i in range(segment_size):
formated_segment = {key: self.segments[key][i] for key in keys} formated_segment = {key: self.segments[key][i] for key in keys}
pkl.dump(formated_segment, open(f"{save_path}_{i}.pkl", "wb")) pkl.dump(formated_segment, open(f"{save_path}_{i}.pkl", "wb"))
def load(self, load_path: Union[str, List[str]]): def load(self, load_path: Union[str, List[str]]):
paths = [load_path] if isinstance(load_path, str) else load_path paths = [load_path] if isinstance(load_path, str) else load_path
self.segments, self.total_samples = load_pkl_files(paths) self.segments, self.total_samples = load_pkl_files(paths)
self.fetcher = MutiSegmentFetcher(self.segments) self.fetcher = MutiSegmentFetcher(self.segments)
def get_index(self, index: int) -> int:
begin_idx = min(index * self.step_size, self.total_samples - self.chunk_size - 1)
end_idx = begin_idx + self.chunk_size
return begin_idx, end_idx
@abstractmethod @abstractmethod
def __getitem__(self, index: int) -> Dict[str, Tensor]: def __getitem__(self, index: int) -> Dict[str, Tensor]:
raise NotImplementedError raise NotImplementedError
def __len__(self) -> int: def __len__(self) -> int:
assert self.total_samples is not None assert self.total_samples is not None
if self.total_samples < self.chunk_size: if self.total_samples <= self.chunk_size:
return 0 return 0
return (self.total_samples - self.chunk_size) // self.step_size + 1 return self.total_samples // self.step_size + 1
class SeqDataset(BaseDataset): class SeqDataset(BaseDataset):
@ -127,8 +132,7 @@ class SeqDataset(BaseDataset):
def __getitem__(self, index): def __getitem__(self, index):
# fix the range index bug # fix the range index bug
begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1) begin_idx, end_idx = self.get_index(index)
end_idx = begin_idx + self.chunk_size
x = self._fetch_data(begin_idx, end_idx).to(dtype=torch.long) x = self._fetch_data(begin_idx, end_idx).to(dtype=torch.long)
y = self._fetch_data(begin_idx + 1, end_idx + 1).to(dtype=torch.long) y = self._fetch_data(begin_idx + 1, end_idx + 1).to(dtype=torch.long)
@ -145,8 +149,7 @@ class SftDataset(BaseDataset):
return self.fetcher.key_fetch(begin_idx, end_idx, key) return self.fetcher.key_fetch(begin_idx, end_idx, key)
def __getitem__(self, index): def __getitem__(self, index):
begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1) begin_idx, end_idx = self.get_index(index)
end_idx = begin_idx + self.chunk_size
x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long) x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long)
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(dtype=torch.long) y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(dtype=torch.long)
@ -164,8 +167,7 @@ class DpoDataset(BaseDataset):
return self.fetcher.key_fetch(begin_idx, end_idx, key) return self.fetcher.key_fetch(begin_idx, end_idx, key)
def __getitem__(self, index: int): def __getitem__(self, index: int):
begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1) begin_idx, end_idx = self.get_index(index)
end_idx = begin_idx + self.chunk_size
chosen = self._fetch_data(begin_idx, end_idx, "chosen").to(dtype=torch.long) chosen = self._fetch_data(begin_idx, end_idx, "chosen").to(dtype=torch.long)
rejected = self._fetch_data(begin_idx, end_idx, "rejected").to(dtype=torch.long) rejected = self._fetch_data(begin_idx, end_idx, "rejected").to(dtype=torch.long)
@ -184,8 +186,7 @@ class PpoDataset(BaseDataset):
return self.fetcher.key_fetch(begin_idx, end_idx, key) return self.fetcher.key_fetch(begin_idx, end_idx, key)
def __getitem__(self, index: int) -> Dict[str, Tensor]: def __getitem__(self, index: int) -> Dict[str, Tensor]:
begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1) begin_idx, end_idx = self.get_index(index)
end_idx = begin_idx + self.chunk_size
input_ids = self._fetch_data(begin_idx, end_idx, "input_ids"), input_ids = self._fetch_data(begin_idx, end_idx, "input_ids"),
actions = self._fetch_data(begin_idx, end_idx, "actions"), actions = self._fetch_data(begin_idx, end_idx, "actions"),