From 6a3135f4011e62a05c2c1bff815ea98b83c0fe52 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Wed, 29 Oct 2025 20:58:33 +0800 Subject: [PATCH] =?UTF-8?q?fix(data=5Futil):=20=E4=BF=AE=E5=A4=8D=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E9=9B=86=E7=B4=A2=E5=BC=95=E8=AE=A1=E7=AE=97=E9=80=BB?= =?UTF-8?q?=E8=BE=91=E5=B9=B6=E6=8F=90=E5=8F=96=E9=80=9A=E7=94=A8=E6=96=B9?= =?UTF-8?q?=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/data/data_util.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/khaosz/data/data_util.py b/khaosz/data/data_util.py index 8e022c3..1053d48 100644 --- a/khaosz/data/data_util.py +++ b/khaosz/data/data_util.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from torch import Tensor from torch.utils.data import Dataset, Sampler from typing import Callable, List, Dict, Literal, Optional, Union - + MutiSeg = Dict[str, List[Tensor]] Seg = Dict[str, Tensor] @@ -99,22 +99,27 @@ class BaseDataset(Dataset, ABC): for i in range(segment_size): formated_segment = {key: self.segments[key][i] for key in keys} pkl.dump(formated_segment, open(f"{save_path}_{i}.pkl", "wb")) - def load(self, load_path: Union[str, List[str]]): paths = [load_path] if isinstance(load_path, str) else load_path self.segments, self.total_samples = load_pkl_files(paths) self.fetcher = MutiSegmentFetcher(self.segments) + def get_index(self, index: int) -> int: + begin_idx = min(index * self.step_size, self.total_samples - self.chunk_size - 1) + end_idx = begin_idx + self.chunk_size + + return begin_idx, end_idx + @abstractmethod def __getitem__(self, index: int) -> Dict[str, Tensor]: raise NotImplementedError def __len__(self) -> int: assert self.total_samples is not None - if self.total_samples < self.chunk_size: + if self.total_samples <= self.chunk_size: return 0 - return (self.total_samples - self.chunk_size) // self.step_size + 1 + return self.total_samples // self.step_size + 1 class SeqDataset(BaseDataset): @@ -127,8 +132,7 @@ class SeqDataset(BaseDataset): def __getitem__(self, index): # fix the range index bug - begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1) - end_idx = begin_idx + self.chunk_size + begin_idx, end_idx = self.get_index(index) x = self._fetch_data(begin_idx, end_idx).to(dtype=torch.long) y = self._fetch_data(begin_idx + 1, end_idx + 1).to(dtype=torch.long) @@ -145,8 +149,7 @@ class SftDataset(BaseDataset): return self.fetcher.key_fetch(begin_idx, end_idx, key) def __getitem__(self, index): - begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1) - end_idx = begin_idx + self.chunk_size + begin_idx, end_idx = self.get_index(index) x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long) y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(dtype=torch.long) @@ -164,8 +167,7 @@ class DpoDataset(BaseDataset): return self.fetcher.key_fetch(begin_idx, end_idx, key) def __getitem__(self, index: int): - begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1) - end_idx = begin_idx + self.chunk_size + begin_idx, end_idx = self.get_index(index) chosen = self._fetch_data(begin_idx, end_idx, "chosen").to(dtype=torch.long) rejected = self._fetch_data(begin_idx, end_idx, "rejected").to(dtype=torch.long) @@ -184,8 +186,7 @@ class PpoDataset(BaseDataset): return self.fetcher.key_fetch(begin_idx, end_idx, key) def __getitem__(self, index: int) -> Dict[str, Tensor]: - begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1) - end_idx = begin_idx + self.chunk_size + begin_idx, end_idx = self.get_index(index) input_ids = self._fetch_data(begin_idx, end_idx, "input_ids"), actions = self._fetch_data(begin_idx, end_idx, "actions"),