import torch import bisect from abc import ABC, abstractmethod from torch import Tensor from torch.utils.data import Dataset from khaosz.data.mmap import MmapFileHander from typing import Callable, List, Dict, Literal, Optional, Union Seg = List[Tensor] MultiSeg = Dict[str, Seg] class BaseSegmentFetcher: def __init__(self, segments: Seg): self.segments = segments self.cum_lengths = [] total = 0 for seg in segments: total += len(seg) self.cum_lengths.append(total) self.total_length = total if segments else 0 def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor: if not (0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length): raise ValueError("begin_idx or end_idx out of bounds") if begin_idx >= end_idx: return torch.tensor([], dtype=torch.long) # fix the range index bug seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx) seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx) result_segments = [] for i in range(seg_start_idx, seg_end_idx + 1): prev_cum = self.cum_lengths[i - 1] if i > 0 else 0 start = max(begin_idx - prev_cum, 0) end = min(end_idx - prev_cum, len(self.segments[i])) result_segments.append(self.segments[i][start:end]) return torch.cat(result_segments, dim=0) class MultiSegmentFetcher: def __init__(self, muti_segments: MultiSeg): self.muti_keys = list(muti_segments.keys()) self.muti_fetchers = { key: BaseSegmentFetcher(segments) for key, segments in muti_segments.items() } def key_fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]) -> Union[Tensor, Seg]: fetch_dict = {} keys = [keys] if isinstance(keys, str) else keys for key in keys: fetcher = self.muti_fetchers[key] fetch_tensor = fetcher.fetch_data(begin_idx, end_idx) fetch_dict[key] = fetch_tensor return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]] def fetch_data(self, begin_idx: int, end_idx: int) -> Union[Tensor, Seg]: return self.key_fetch(begin_idx, end_idx, self.muti_keys) class BaseDataset(Dataset, ABC): def __init__(self, window_size: int, stride: int): super().__init__() self.segments: MultiSeg = {} self.window_size = window_size self.stride = stride self.total_samples = None def load(self, load_path: str): self.segments, self.total_samples = MmapFileHander.load(load_path) self.fetcher = MultiSegmentFetcher(self.segments) def get_index(self, index: int) -> int: begin_idx = min(index * self.stride, self.total_samples - self.window_size - 1) end_idx = begin_idx + self.window_size return begin_idx, end_idx @abstractmethod def __getitem__(self, index: int) -> Dict[str, Tensor]: raise NotImplementedError def __len__(self) -> int: assert self.total_samples is not None if self.total_samples <= self.window_size: return 0 return self.total_samples // self.stride + 1 class SeqDataset(BaseDataset): def __init__(self, window_size: int, stride: int): super().__init__(window_size, stride) self.fetcher = MultiSegmentFetcher(self.segments) def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor: return self.fetcher.key_fetch(begin_idx, end_idx, "sequence") def __getitem__(self, index): # fix the range index bug begin_idx, end_idx = self.get_index(index) x = self._fetch_data(begin_idx, end_idx).to(dtype=torch.long) y = self._fetch_data(begin_idx + 1, end_idx + 1).to(dtype=torch.long) return {"input_ids": x, "target_ids": y} class SftDataset(BaseDataset): def __init__(self, window_size: int, stride: int): super().__init__(window_size, stride) self.fetcher = MultiSegmentFetcher(self.segments) def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: return self.fetcher.key_fetch(begin_idx, end_idx, key) def __getitem__(self, index): begin_idx, end_idx = self.get_index(index) x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long) y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(dtype=torch.long) loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "loss_mask").to(dtype=torch.bool) return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask} class DpoDataset(BaseDataset): def __init__(self, window_size: int, stride: int): super().__init__(window_size, stride) self.fetcher = MultiSegmentFetcher(self.segments) def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: return self.fetcher.key_fetch(begin_idx, end_idx, key) def __getitem__(self, index: int): begin_idx, end_idx = self.get_index(index) chosen = self._fetch_data(begin_idx, end_idx, "chosen").to(dtype=torch.long) rejected = self._fetch_data(begin_idx, end_idx, "rejected").to(dtype=torch.long) chosen_mask = self._fetch_data(begin_idx, end_idx, "chosen_mask").to(dtype=torch.bool) rejected_mask = self._fetch_data(begin_idx, end_idx, "rejected_mask").to(dtype=torch.bool) return {"chosen": chosen, "rejected": rejected, "chosen_mask": chosen_mask, "rejected_mask": rejected_mask} class PpoDataset(BaseDataset): def __init__(self, window_size: int, stride: int): super().__init__(window_size, stride) self.fetcher = MultiSegmentFetcher(self.segments) def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: return self.fetcher.key_fetch(begin_idx, end_idx, key) def __getitem__(self, index: int) -> Dict[str, Tensor]: begin_idx, end_idx = self.get_index(index) input_ids = self._fetch_data(begin_idx, end_idx, "input_ids"), actions = self._fetch_data(begin_idx, end_idx, "actions"), logprobs = self._fetch_data(begin_idx, end_idx, "logprobs"), rewards = self._fetch_data(begin_idx, end_idx, "rewards") return {"input_ids": input_ids, "actions": actions, "logprobs": logprobs, "rewards": rewards} class DatasetLoader: @staticmethod def load( train_type: Literal["seq", "sft", "dpo"], load_path: str, window_size: int, stride: Optional[int] = None, ) -> BaseDataset: if stride is None: stride = window_size dataset_router: Dict[str, Callable[[int], BaseDataset]] = { "seq": lambda window_size: SeqDataset(window_size, stride), "sft": lambda window_size: SftDataset(window_size, stride), "dpo": lambda window_size: DpoDataset(window_size, stride), } dataset = dataset_router[train_type](window_size) dataset.load(load_path) return dataset