import torch import bisect import pickle as pkl from abc import ABC, abstractmethod from torch import Tensor from torch.utils.data import Dataset, Sampler from typing import Callable, List, Dict, Literal, Union MutiSeg = Dict[str, List[Tensor]] Seg = Dict[str, Tensor] def load_pkl_files(paths: List[str]): segments: MutiSeg = {} total_samples = 0 for path in paths: with open(path, "rb") as f: pkl_file: Seg = pkl.load(f) for key, value in pkl_file.items(): if key not in segments: segments[key] = [] segments[key].append(value) first_key = list(pkl_file.keys())[0] total_samples += pkl_file[first_key].numel() return segments, total_samples def build_attention_mask(input_ids: Tensor, user_token_id: int, multi_turn: bool) -> Tensor: seq_len = input_ids.size(0) turn_id = input_ids.eq(user_token_id).cumsum(dim=-1) iq = turn_id.view(seq_len, 1) ik = turn_id.view(1, seq_len) # fix the causual attention mask(iq >= ik condition) seq_mask = (iq >= ik) if multi_turn else (iq == ik) attention_mask = torch.tril(seq_mask) # fix the shape (bsz, 1, seq_len, seq_len) unsqueeze for broadcast return attention_mask.unsqueeze(0) def build_loss_mask(input_ids: Tensor, bos_token_id: int, eos_token_id: int) -> Tensor: token_markers = torch.zeros_like(input_ids, dtype=torch.int8) is_bos_token = input_ids.eq(bos_token_id) is_eos_token = input_ids.eq(eos_token_id) # fix the eos_token_id bug(change target_ids to input_ids) token_markers[is_bos_token] = 1 token_markers[is_eos_token] = -1 cumulative_markers = torch.cumsum(token_markers, dim=-1) min_cumulative = cumulative_markers.min(dim=-1, keepdim=True).values loss_mask = cumulative_markers - min_cumulative return loss_mask.to(dtype=torch.bool) class BaseSegmentFetcher: def __init__(self, segments: List[Tensor]): self.segments = segments self.cum_lengths = [] total = 0 for seg in segments: total += len(seg) self.cum_lengths.append(total) self.total_length = total if segments else 0 def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor: if not (0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length): raise ValueError("begin_idx or end_idx out of bounds") if begin_idx >= end_idx: return torch.tensor([], dtype=torch.long) # fix the range index bug seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx) seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx) result_segments = [] for i in range(seg_start_idx, seg_end_idx + 1): prev_cum = self.cum_lengths[i - 1] if i > 0 else 0 start = max(begin_idx - prev_cum, 0) end = min(end_idx - prev_cum, len(self.segments[i])) result_segments.append(self.segments[i][start:end]) return torch.cat(result_segments, dim=0) class MutiSegmentFetcher: def __init__(self, muti_segments: MutiSeg): self.muti_keys = list(muti_segments.keys()) self.muti_fetchers = { key: BaseSegmentFetcher(segments) for key, segments in muti_segments.items() } def key_fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]) -> Union[Tensor, Seg]: fetch_dict = {} keys = [keys] if isinstance(keys, str) else keys for key in keys: fetcher = self.muti_fetchers[key] fetch_tensor = fetcher.fetch_data(begin_idx, end_idx) fetch_dict[key] = fetch_tensor return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]] def fetch_data(self, begin_idx: int, end_idx: int) -> Union[Tensor, Seg]: return self.key_fetch(begin_idx, end_idx, self.muti_keys) class BaseDataset(Dataset, ABC): def __init__(self, chunk_size: int): super().__init__() self.segments: MutiSeg = {} self.chunk_size = chunk_size self.total_samples = 0 def save(self, save_path: str): keys = list(self.segments.keys()) if not keys: return first_item = self.segments[keys[0]] segment_size = len(first_item) for i in range(segment_size): formated_segment = {key: self.segments[key][i] for key in keys} pkl.dump(formated_segment, open(f"{save_path}_{i}.pkl", "wb")) def load(self, load_path: Union[str, List[str]]): paths = [load_path] if isinstance(load_path, str) else load_path self.segments, self.total_samples = load_pkl_files(paths) self.fetcher = MutiSegmentFetcher(self.segments) @abstractmethod def __getitem__(self, index: int) -> Dict[str, Tensor]: raise NotImplementedError def __len__(self) -> int: assert self.total_samples // self.chunk_size > 0 return self.total_samples // self.chunk_size class SeqDataset(BaseDataset): def __init__( self, chunk_size, ): super().__init__(chunk_size) self.fetcher = MutiSegmentFetcher(self.segments) def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor: return self.fetcher.key_fetch(begin_idx, end_idx, "sequence") def __getitem__(self, index): # fix the range index bug begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1) end_idx = begin_idx + self.chunk_size x = self._fetch_data(begin_idx, end_idx).to(dtype=torch.long) y = self._fetch_data(begin_idx + 1, end_idx + 1).to(dtype=torch.long) return {"input_ids": x, "target_ids": y} class SftDataset(BaseDataset): def __init__( self, chunk_size, bos_token_id, eos_token_id, user_token_id, multi_turn=False, ): super().__init__(chunk_size) self.fetcher = MutiSegmentFetcher(self.segments) self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id self.user_token_id = user_token_id self.multi_turn = multi_turn def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor: return self.fetcher.key_fetch(begin_idx, end_idx, "sequence") def __getitem__(self, index): begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1) end_idx = begin_idx + self.chunk_size x = self._fetch_data(begin_idx, end_idx).to(dtype=torch.long) y = self._fetch_data(begin_idx + 1, end_idx + 1).to(dtype=torch.long) # fix the eos_token_id bug(change target_ids to input_ids) loss_mask = build_loss_mask(x, self.bos_token_id, self.eos_token_id) attn_mask = build_attention_mask(x, self.user_token_id, self.multi_turn) return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask, "attn_mask": attn_mask} class DpoDataset(BaseDataset): def __init__(self, chunk_size: int): super().__init__(chunk_size) self.fetcher = MutiSegmentFetcher(self.segments) def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: return self.fetcher.key_fetch(begin_idx, end_idx, key) def __getitem__(self, index: int): begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1) end_idx = begin_idx + self.chunk_size chosen = self._fetch_data(begin_idx, end_idx, "chosen").to(dtype=torch.long) rejected = self._fetch_data(begin_idx, end_idx, "rejected").to(dtype=torch.long) chosen_mask = self._fetch_data(begin_idx, end_idx, "chosen_mask").to(dtype=torch.bool) rejected_mask = self._fetch_data(begin_idx, end_idx, "rejected_mask").to(dtype=torch.bool) return {"chosen": chosen, "rejected": rejected, "chosen_mask": chosen_mask, "rejected_mask": rejected_mask} class PpoDataset(BaseDataset): def __init__(self, chunk_size: int): super().__init__(chunk_size) self.fetcher = MutiSegmentFetcher(self.segments) def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: return self.fetcher.key_fetch(begin_idx, end_idx, key) def __getitem__(self, index: int) -> Dict[str, Tensor]: begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1) end_idx = begin_idx + self.chunk_size input_ids = self._fetch_data(begin_idx, end_idx, "input_ids"), actions = self._fetch_data(begin_idx, end_idx, "actions"), logprobs = self._fetch_data(begin_idx, end_idx, "logprobs"), rewards = self._fetch_data(begin_idx, end_idx, "rewards") return {"input_ids": input_ids, "actions": actions, "logprobs": logprobs, "rewards": rewards} class DatasetLoader: @staticmethod def load( train_type: Literal["seq", "sft", "dpo"], load_path: Union[str, List[str]], max_len: int, **kwargs ) -> BaseDataset: dataset_router: Dict[str, Callable[[int], BaseDataset]] = { "seq": lambda max_len: SeqDataset(max_len), "sft": lambda max_len: SftDataset( max_len, bos_token_id=kwargs.get("bos_token_id"), eos_token_id=kwargs.get("eos_token_id"), user_token_id=kwargs.get("user_token_id"), multi_turn=kwargs.get("multi_turn") ), "dpo": lambda max_len: DpoDataset(max_len), } dataset = dataset_router[train_type](max_len) dataset.load(load_path) return dataset class ResumeableRandomSampler(Sampler[int]): def __init__(self, data_source, start_epoch=0, start_iter=0, seed=42): self.num_samples = len(data_source) self.epoch = start_epoch self.iter = start_iter generator = torch.Generator() generator.manual_seed(seed) # consume previous epochs for _ in range(start_epoch): torch.randperm(self.num_samples, generator=generator) self.generator = generator self._indices = None def _get_indices(self): current_epoch_indices = torch.randperm(self.num_samples, generator=self.generator).tolist() self._indices = current_epoch_indices[self.iter % self.num_samples:] def __iter__(self): if self._indices is None: self._get_indices() for i in self._indices: self.iter += 1 yield i self.epoch += 1 self._indices = None def __len__(self): if self._indices is None: self._get_indices() return len(self._indices)