AstrAI/khaosz/data/dataset.py

191 lines
7.1 KiB
Python

import torch
import bisect
from abc import ABC, abstractmethod
from torch import Tensor
from torch.utils.data import Dataset
from khaosz.data.mmap import MmapFileHandler
from typing import Callable, List, Dict, Literal, Optional, Union
Seg = List[Tensor]
MultiSeg = Dict[str, Seg]
class BaseSegmentFetcher:
def __init__(self, segments: Seg):
self.segments = segments
self.cum_lengths = []
total = 0
for seg in segments:
total += len(seg)
self.cum_lengths.append(total)
self.total_length = total if segments else 0
def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
if not (0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length):
raise ValueError("begin_idx or end_idx out of bounds")
if begin_idx >= end_idx:
return torch.tensor([], dtype=torch.long)
# fix the range index bug
seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx)
seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx)
result_segments = []
for i in range(seg_start_idx, seg_end_idx + 1):
prev_cum = self.cum_lengths[i - 1] if i > 0 else 0
start = max(begin_idx - prev_cum, 0)
end = min(end_idx - prev_cum, len(self.segments[i]))
result_segments.append(self.segments[i][start:end])
return torch.cat(result_segments, dim=0)
class MultiSegmentFetcher:
def __init__(self, muti_segments: MultiSeg):
self.muti_keys = list(muti_segments.keys())
self.muti_fetchers = {
key: BaseSegmentFetcher(segments)
for key, segments in muti_segments.items()
}
def key_fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]) -> Union[Tensor, Seg]:
fetch_dict = {}
keys = [keys] if isinstance(keys, str) else keys
for key in keys:
fetcher = self.muti_fetchers[key]
fetch_tensor = fetcher.fetch_data(begin_idx, end_idx)
fetch_dict[key] = fetch_tensor
return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]]
def fetch_data(self, begin_idx: int, end_idx: int) -> Union[Tensor, Seg]:
return self.key_fetch(begin_idx, end_idx, self.muti_keys)
class BaseDataset(Dataset, ABC):
def __init__(self, window_size: int, stride: int):
super().__init__()
self.segments: MultiSeg = {}
self.window_size = window_size
self.stride = stride
self.total_samples = None
def load(self, load_path: str):
self.segments, self.total_samples = MmapFileHandler.load(load_path)
self.fetcher = MultiSegmentFetcher(self.segments)
def get_index(self, index: int) -> int:
begin_idx = min(index * self.stride, self.total_samples - self.window_size - 1)
end_idx = begin_idx + self.window_size
return begin_idx, end_idx
@abstractmethod
def __getitem__(self, index: int) -> Dict[str, Tensor]:
raise NotImplementedError
def __len__(self) -> int:
assert self.total_samples is not None
if self.total_samples <= self.window_size:
return 0
return self.total_samples // self.stride + 1
class SeqDataset(BaseDataset):
def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride)
self.fetcher = MultiSegmentFetcher(self.segments)
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
return self.fetcher.key_fetch(begin_idx, end_idx, "sequence")
def __getitem__(self, index):
# fix the range index bug
begin_idx, end_idx = self.get_index(index)
x = self._fetch_data(begin_idx, end_idx).to(dtype=torch.long)
y = self._fetch_data(begin_idx + 1, end_idx + 1).to(dtype=torch.long)
return {"input_ids": x, "target_ids": y}
class SftDataset(BaseDataset):
def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride)
self.fetcher = MultiSegmentFetcher(self.segments)
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.fetcher.key_fetch(begin_idx, end_idx, key)
def __getitem__(self, index):
begin_idx, end_idx = self.get_index(index)
x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long)
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(dtype=torch.long)
loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "loss_mask").to(dtype=torch.bool)
return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask}
class DpoDataset(BaseDataset):
def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride)
self.fetcher = MultiSegmentFetcher(self.segments)
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.fetcher.key_fetch(begin_idx, end_idx, key)
def __getitem__(self, index: int):
begin_idx, end_idx = self.get_index(index)
chosen = self._fetch_data(begin_idx, end_idx, "chosen").to(dtype=torch.long)
rejected = self._fetch_data(begin_idx, end_idx, "rejected").to(dtype=torch.long)
chosen_mask = self._fetch_data(begin_idx, end_idx, "chosen_mask").to(dtype=torch.bool)
rejected_mask = self._fetch_data(begin_idx, end_idx, "rejected_mask").to(dtype=torch.bool)
return {"chosen": chosen, "rejected": rejected, "chosen_mask": chosen_mask, "rejected_mask": rejected_mask}
class PpoDataset(BaseDataset):
def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride)
self.fetcher = MultiSegmentFetcher(self.segments)
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.fetcher.key_fetch(begin_idx, end_idx, key)
def __getitem__(self, index: int) -> Dict[str, Tensor]:
begin_idx, end_idx = self.get_index(index)
input_ids = self._fetch_data(begin_idx, end_idx, "input_ids"),
actions = self._fetch_data(begin_idx, end_idx, "actions"),
logprobs = self._fetch_data(begin_idx, end_idx, "logprobs"),
rewards = self._fetch_data(begin_idx, end_idx, "rewards")
return {"input_ids": input_ids, "actions": actions, "logprobs": logprobs, "rewards": rewards}
class DatasetLoader:
@staticmethod
def load(
train_type: Literal["seq", "sft", "dpo"],
load_path: str,
window_size: int,
stride: Optional[int] = None,
) -> BaseDataset:
if stride is None:
stride = window_size
dataset_router: Dict[str, Callable[[int], BaseDataset]] = {
"seq": lambda window_size: SeqDataset(window_size, stride),
"sft": lambda window_size: SftDataset(window_size, stride),
"dpo": lambda window_size: DpoDataset(window_size, stride),
}
dataset = dataset_router[train_type](window_size)
dataset.load(load_path)
return dataset