import torch import bisect from abc import ABC, abstractmethod from torch import Tensor from torch.utils.data import Dataset from khaosz.data.serialization import load_h5 from typing import Callable, List, Dict, Literal, Optional, Union class BaseSegmentFetcher: def __init__(self, segments: List[Tensor]): self.segments = segments self.cum_lengths = [] total = 0 for seg in segments: total += torch.numel(seg) self.cum_lengths.append(total) self.total_length = total def __len__(self) -> int: return self.total_length def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor: if not (0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length): raise ValueError("begin_idx or end_idx out of bounds") if begin_idx >= end_idx: return torch.tensor([], dtype=torch.long) # fix the range index bug seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx) seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx) result_segments = [] for i in range(seg_start_idx, seg_end_idx + 1): prev_cum = self.cum_lengths[i - 1] if i > 0 else 0 start = max(begin_idx - prev_cum, 0) end = min(end_idx - prev_cum, len(self.segments[i])) data = self.segments[i][start:end] result_segments.append(data) return torch.cat(result_segments, dim=0) class MultiSegmentFetcher: def __init__(self, muti_segments: Dict): self.muti_keys = list(muti_segments.keys()) self.muti_fetchers = { key: BaseSegmentFetcher(segments) for key, segments in muti_segments.items() } def __len__(self) -> int: len_list = [len(seg) for seg in self.muti_fetchers.values()] return min(len_list) def key_fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]) -> Dict: fetch_dict = {} keys = [keys] if isinstance(keys, str) else keys for key in keys: fetcher = self.muti_fetchers[key] fetch_tensor = fetcher.fetch_data(begin_idx, end_idx) fetch_dict[key] = fetch_tensor return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]] def fetch_data(self, begin_idx: int, end_idx: int) -> Dict: return self.key_fetch(begin_idx, end_idx, self.muti_keys) class BaseDataset(Dataset, ABC): def __init__(self, window_size: int, stride: int): super().__init__() self.segments = {} self.window_size = window_size self.stride = stride self.total_samples = None def load(self, load_path: str): self.segments = load_h5(load_path) self.fetcher = MultiSegmentFetcher(self.segments) self.total_samples = len(self.fetcher) def get_index(self, index: int) -> int: assert self.total_samples > self.window_size begin_idx = min(index * self.stride, self.total_samples - 1 - self.window_size) end_idx = min(begin_idx + self.window_size, self.total_samples - 1) return begin_idx, end_idx @abstractmethod def __getitem__(self, index: int) -> Dict[str, Tensor]: raise NotImplementedError def __len__(self) -> int: assert self.total_samples is not None if self.total_samples <= self.window_size: return 0 return (self.total_samples - 1 - self.window_size) // self.stride + 1 class SEQDataset(BaseDataset): def __init__(self, window_size: int, stride: int): super().__init__(window_size, stride) self.fetcher = MultiSegmentFetcher(self.segments) def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor: return self.fetcher.key_fetch(begin_idx, end_idx, "sequence") def __getitem__(self, index): # fix the range index bug begin_idx, end_idx = self.get_index(index) x = self._fetch_data(begin_idx, end_idx).to(dtype=torch.long) y = self._fetch_data(begin_idx + 1, end_idx + 1).to(dtype=torch.long) return {"input_ids": x, "target_ids": y} class SFTDataset(BaseDataset): def __init__(self, window_size: int, stride: int): super().__init__(window_size, stride) self.fetcher = MultiSegmentFetcher(self.segments) def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: return self.fetcher.key_fetch(begin_idx, end_idx, key) def __getitem__(self, index): begin_idx, end_idx = self.get_index(index) x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long) y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(dtype=torch.long) loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "loss_mask").to(dtype=torch.bool) return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask} class DPODataset(BaseDataset): def __init__(self, window_size: int, stride: int): super().__init__(window_size, stride) self.fetcher = MultiSegmentFetcher(self.segments) def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: return self.fetcher.key_fetch(begin_idx, end_idx, key) def __getitem__(self, index: int): begin_idx, end_idx = self.get_index(index) chosen = self._fetch_data(begin_idx, end_idx, "chosen").to(dtype=torch.long) rejected = self._fetch_data(begin_idx, end_idx, "rejected").to(dtype=torch.long) chosen_mask = self._fetch_data(begin_idx, end_idx, "chosen_mask").to(dtype=torch.bool) rejected_mask = self._fetch_data(begin_idx, end_idx, "rejected_mask").to(dtype=torch.bool) return {"chosen": chosen, "rejected": rejected, "chosen_mask": chosen_mask, "rejected_mask": rejected_mask} class GRPODataset(BaseDataset): def __init__(self, window_size: int, stride: int): super().__init__(window_size, stride) self.fetcher = MultiSegmentFetcher(self.segments) def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: return self.fetcher.key_fetch(begin_idx, end_idx, key) def __getitem__(self, index: int) -> Dict[str, Tensor]: begin_idx, end_idx = self.get_index(index) prompts = self._fetch_data(begin_idx, end_idx, "prompts"), responses = self._fetch_data(begin_idx, end_idx, "responses"), masks = self._fetch_data(begin_idx, end_idx, "masks"), rewards = self._fetch_data(begin_idx, end_idx, "rewards") return {"prompts": prompts, "responses": responses, "masks": masks, "rewards": rewards} class DatasetLoader: @staticmethod def load( train_type: Literal["seq", "sft", "dpo"], load_path: str, window_size: int, stride: Optional[int] = None, ) -> BaseDataset: if stride is None: stride = window_size dataset_router: Dict[str, Callable[[int], BaseDataset]] = { "seq": lambda window_size: SEQDataset(window_size, stride), "sft": lambda window_size: SFTDataset(window_size, stride), "dpo": lambda window_size: DPODataset(window_size, stride), "grpo": lambda window_size: GRPODataset(window_size, stride), } dataset = dataset_router[train_type](window_size) dataset.load(load_path) return dataset