301 lines
11 KiB
Python
301 lines
11 KiB
Python
import torch
|
|
import bisect
|
|
import pickle as pkl
|
|
from abc import ABC, abstractmethod
|
|
from torch import Tensor
|
|
from torch.utils.data import Dataset, Sampler
|
|
from typing import Callable, List, Dict, Literal, Union
|
|
|
|
MutiSeg = Dict[str, List[Tensor]]
|
|
Seg = Dict[str, Tensor]
|
|
|
|
def load_pkl_files(paths: List[str]):
|
|
segments: MutiSeg = {}
|
|
total_samples = 0
|
|
|
|
for path in paths:
|
|
with open(path, "rb") as f:
|
|
pkl_file: Seg = pkl.load(f)
|
|
for key, value in pkl_file.items():
|
|
if key not in segments:
|
|
segments[key] = []
|
|
segments[key].append(value)
|
|
first_key = list(pkl_file.keys())[0]
|
|
total_samples += pkl_file[first_key].numel()
|
|
|
|
return segments, total_samples
|
|
|
|
def build_attention_mask(input_ids: Tensor, user_token_id: int, multi_turn: bool) -> Tensor:
|
|
seq_len = input_ids.size(0)
|
|
turn_id = input_ids.eq(user_token_id).cumsum(dim=-1)
|
|
|
|
iq = turn_id.view(seq_len, 1)
|
|
ik = turn_id.view(1, seq_len)
|
|
|
|
# fix the causual attention mask(iq >= ik condition)
|
|
seq_mask = (iq >= ik) if multi_turn else (iq == ik)
|
|
attention_mask = torch.tril(seq_mask)
|
|
|
|
# fix the shape (bsz, 1, seq_len, seq_len) unsqueeze for broadcast
|
|
return attention_mask.unsqueeze(0)
|
|
|
|
def build_loss_mask(input_ids: Tensor, bos_token_id: int, eos_token_id: int) -> Tensor:
|
|
token_markers = torch.zeros_like(input_ids, dtype=torch.int8)
|
|
|
|
is_bos_token = input_ids.eq(bos_token_id)
|
|
is_eos_token = input_ids.eq(eos_token_id)
|
|
|
|
# fix the eos_token_id bug(change target_ids to input_ids)
|
|
token_markers[is_bos_token] = 1
|
|
token_markers[is_eos_token] = -1
|
|
|
|
cumulative_markers = torch.cumsum(token_markers, dim=-1)
|
|
min_cumulative = cumulative_markers.min(dim=-1, keepdim=True).values
|
|
loss_mask = cumulative_markers - min_cumulative
|
|
|
|
return loss_mask.to(dtype=torch.bool)
|
|
|
|
|
|
class BaseSegmentFetcher:
|
|
def __init__(self, segments: List[Tensor]):
|
|
self.segments = segments
|
|
self.cum_lengths = []
|
|
total = 0
|
|
for seg in segments:
|
|
total += len(seg)
|
|
self.cum_lengths.append(total)
|
|
self.total_length = total if segments else 0
|
|
|
|
def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
|
if not (0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length):
|
|
raise ValueError("begin_idx or end_idx out of bounds")
|
|
if begin_idx >= end_idx:
|
|
return torch.tensor([], dtype=torch.long)
|
|
|
|
# fix the range index bug
|
|
seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx)
|
|
seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx)
|
|
|
|
result_segments = []
|
|
|
|
for i in range(seg_start_idx, seg_end_idx + 1):
|
|
prev_cum = self.cum_lengths[i - 1] if i > 0 else 0
|
|
start = max(begin_idx - prev_cum, 0)
|
|
end = min(end_idx - prev_cum, len(self.segments[i]))
|
|
result_segments.append(self.segments[i][start:end])
|
|
|
|
return torch.cat(result_segments, dim=0)
|
|
|
|
|
|
class MutiSegmentFetcher:
|
|
def __init__(self, muti_segments: MutiSeg):
|
|
self.muti_keys = list(muti_segments.keys())
|
|
self.muti_fetchers = {
|
|
key: BaseSegmentFetcher(segments)
|
|
for key, segments in muti_segments.items()
|
|
}
|
|
|
|
def key_fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]) -> Union[Tensor, Seg]:
|
|
fetch_dict = {}
|
|
keys = [keys] if isinstance(keys, str) else keys
|
|
|
|
for key in keys:
|
|
fetcher = self.muti_fetchers[key]
|
|
fetch_tensor = fetcher.fetch_data(begin_idx, end_idx)
|
|
fetch_dict[key] = fetch_tensor
|
|
|
|
return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]]
|
|
|
|
def fetch_data(self, begin_idx: int, end_idx: int) -> Union[Tensor, Seg]:
|
|
return self.key_fetch(begin_idx, end_idx, self.muti_keys)
|
|
|
|
|
|
class BaseDataset(Dataset, ABC):
|
|
def __init__(self, chunk_size: int):
|
|
super().__init__()
|
|
self.segments: MutiSeg = {}
|
|
self.chunk_size = chunk_size
|
|
self.total_samples = 0
|
|
|
|
def save(self, save_path: str):
|
|
keys = list(self.segments.keys())
|
|
if not keys:
|
|
return
|
|
|
|
first_item = self.segments[keys[0]]
|
|
segment_size = len(first_item)
|
|
|
|
for i in range(segment_size):
|
|
formated_segment = {key: self.segments[key][i] for key in keys}
|
|
pkl.dump(formated_segment, open(f"{save_path}_{i}.pkl", "wb"))
|
|
|
|
|
|
def load(self, load_path: Union[str, List[str]]):
|
|
paths = [load_path] if isinstance(load_path, str) else load_path
|
|
self.segments, self.total_samples = load_pkl_files(paths)
|
|
self.fetcher = MutiSegmentFetcher(self.segments)
|
|
|
|
@abstractmethod
|
|
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
|
raise NotImplementedError
|
|
|
|
def __len__(self) -> int:
|
|
assert self.total_samples // self.chunk_size > 0
|
|
return self.total_samples // self.chunk_size
|
|
|
|
|
|
class SeqDataset(BaseDataset):
|
|
def __init__(
|
|
self,
|
|
chunk_size,
|
|
):
|
|
super().__init__(chunk_size)
|
|
self.fetcher = MutiSegmentFetcher(self.segments)
|
|
|
|
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
|
return self.fetcher.key_fetch(begin_idx, end_idx, "sequence")
|
|
|
|
def __getitem__(self, index):
|
|
# fix the range index bug
|
|
begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1)
|
|
end_idx = begin_idx + self.chunk_size
|
|
|
|
x = self._fetch_data(begin_idx, end_idx).to(dtype=torch.long)
|
|
y = self._fetch_data(begin_idx + 1, end_idx + 1).to(dtype=torch.long)
|
|
|
|
return {"input_ids": x, "target_ids": y}
|
|
|
|
|
|
class SftDataset(BaseDataset):
|
|
def __init__(
|
|
self,
|
|
chunk_size,
|
|
bos_token_id,
|
|
eos_token_id,
|
|
user_token_id,
|
|
multi_turn=False,
|
|
):
|
|
super().__init__(chunk_size)
|
|
self.fetcher = MutiSegmentFetcher(self.segments)
|
|
self.bos_token_id = bos_token_id
|
|
self.eos_token_id = eos_token_id
|
|
self.user_token_id = user_token_id
|
|
self.multi_turn = multi_turn
|
|
|
|
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
|
return self.fetcher.key_fetch(begin_idx, end_idx, "sequence")
|
|
|
|
def __getitem__(self, index):
|
|
begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1)
|
|
end_idx = begin_idx + self.chunk_size
|
|
|
|
x = self._fetch_data(begin_idx, end_idx).to(dtype=torch.long)
|
|
y = self._fetch_data(begin_idx + 1, end_idx + 1).to(dtype=torch.long)
|
|
|
|
# fix the eos_token_id bug(change target_ids to input_ids)
|
|
loss_mask = build_loss_mask(x, self.bos_token_id, self.eos_token_id)
|
|
attn_mask = build_attention_mask(x, self.user_token_id, self.multi_turn)
|
|
|
|
return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask, "attn_mask": attn_mask}
|
|
|
|
|
|
class DpoDataset(BaseDataset):
|
|
def __init__(self, chunk_size: int):
|
|
super().__init__(chunk_size)
|
|
self.fetcher = MutiSegmentFetcher(self.segments)
|
|
|
|
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
|
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
|
|
|
def __getitem__(self, index: int):
|
|
begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1)
|
|
end_idx = begin_idx + self.chunk_size
|
|
|
|
chosen = self._fetch_data(begin_idx, end_idx, "chosen").to(dtype=torch.long)
|
|
rejected = self._fetch_data(begin_idx, end_idx, "rejected").to(dtype=torch.long)
|
|
chosen_mask = self._fetch_data(begin_idx, end_idx, "chosen_mask").to(dtype=torch.bool)
|
|
rejected_mask = self._fetch_data(begin_idx, end_idx, "rejected_mask").to(dtype=torch.bool)
|
|
|
|
return {"chosen": chosen, "rejected": rejected, "chosen_mask": chosen_mask, "rejected_mask": rejected_mask}
|
|
|
|
|
|
class PpoDataset(BaseDataset):
|
|
def __init__(self, chunk_size: int):
|
|
super().__init__(chunk_size)
|
|
self.fetcher = MutiSegmentFetcher(self.segments)
|
|
|
|
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
|
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
|
|
|
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
|
begin_idx = min(index * self.chunk_size, self.total_samples - self.chunk_size - 1)
|
|
end_idx = begin_idx + self.chunk_size
|
|
|
|
input_ids = self._fetch_data(begin_idx, end_idx, "input_ids"),
|
|
actions = self._fetch_data(begin_idx, end_idx, "actions"),
|
|
logprobs = self._fetch_data(begin_idx, end_idx, "logprobs"),
|
|
rewards = self._fetch_data(begin_idx, end_idx, "rewards")
|
|
|
|
return {"input_ids": input_ids, "actions": actions, "logprobs": logprobs, "rewards": rewards}
|
|
|
|
|
|
class DatasetLoader:
|
|
@staticmethod
|
|
def load(
|
|
train_type: Literal["seq", "sft", "dpo"],
|
|
load_path: Union[str, List[str]],
|
|
max_len: int,
|
|
**kwargs
|
|
) -> BaseDataset:
|
|
|
|
dataset_router: Dict[str, Callable[[int], BaseDataset]] = {
|
|
"seq": lambda max_len: SeqDataset(max_len),
|
|
"sft": lambda max_len: SftDataset(
|
|
max_len,
|
|
bos_token_id=kwargs.get("bos_token_id"),
|
|
eos_token_id=kwargs.get("eos_token_id"),
|
|
user_token_id=kwargs.get("user_token_id"),
|
|
multi_turn=kwargs.get("multi_turn")
|
|
),
|
|
"dpo": lambda max_len: DpoDataset(max_len),
|
|
}
|
|
dataset = dataset_router[train_type](max_len)
|
|
dataset.load(load_path)
|
|
|
|
return dataset
|
|
|
|
|
|
class ResumeableRandomSampler(Sampler[int]):
|
|
def __init__(self, data_source, start_epoch=0, start_iter=0, seed=42):
|
|
self.num_samples = len(data_source)
|
|
self.epoch = start_epoch
|
|
self.iter = start_iter
|
|
|
|
generator = torch.Generator()
|
|
generator.manual_seed(seed)
|
|
|
|
self.generator = generator
|
|
self._indices = None
|
|
|
|
def _get_indices(self):
|
|
for _ in range(self.epoch):
|
|
_ = torch.randperm(self.num_samples, generator=self.generator)
|
|
|
|
current_epoch_indices = torch.randperm(self.num_samples, generator=self.generator).tolist()
|
|
self._indices = current_epoch_indices[self.iter % self.num_samples:]
|
|
|
|
def __iter__(self):
|
|
if self._indices is None:
|
|
self._get_indices()
|
|
|
|
for i in self._indices:
|
|
self.iter += 1
|
|
yield i
|
|
|
|
self.epoch += 1
|
|
self._indices = None
|
|
|
|
def __len__(self):
|
|
if self._indices is None:
|
|
self._get_indices()
|
|
return len(self._indices) |