AstrAI/khaosz/trainer/dataset.py

267 lines
9.9 KiB
Python

import torch
import bisect
import pickle as pkl
from abc import ABC, abstractmethod
from torch import Tensor
from torch.utils.data import Dataset
from typing import Callable, List, Dict, Literal, Union
MutiSeg = Dict[str, List[Tensor]]
Seg = Dict[str, Tensor]
def load_pkl_files(paths: List[str]):
segments: MutiSeg = {}
total_samples = 0
for path in paths:
with open(path, "rb") as f:
pkl_file: Seg = pkl.load(f)
for key, value in pkl_file.items():
if key not in segments:
segments[key] = []
segments[key].append(value)
first_key = list(pkl_file.keys())[0]
total_samples += pkl_file[first_key].numel()
return segments, total_samples
def build_loss_mask(input_ids: Tensor, bos_token_id: int, eos_token_id: int) -> Tensor:
token_markers = torch.zeros_like(input_ids, dtype=torch.int8)
is_bos_token = input_ids.eq(bos_token_id)
is_eos_token = input_ids.eq(eos_token_id)
token_markers[is_bos_token] = 1
token_markers[is_eos_token] = -1
cumulative_markers = torch.cumsum(token_markers, dim=-1)
min_cumulative = cumulative_markers.min(dim=-1, keepdim=True).values
loss_mask = cumulative_markers - min_cumulative
return loss_mask.to(dtype=torch.bool)
def build_attention_mask(input_ids: Tensor, user_token_id: int, multi_turn: bool) -> Tensor:
seq_len = input_ids.size(0)
is_user_token = input_ids.eq(user_token_id)
turn_id = is_user_token.cumsum(dim=-1)
iq = turn_id.view(seq_len, 1)
ik = turn_id.view(1, seq_len)
# fix the causual attention mask
seq_mask = (iq >= ik) if multi_turn else (iq == ik)
causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=input_ids.device)).bool()
attention_mask = seq_mask & causal_mask
return attention_mask
class BaseSegmentFetcher:
def __init__(self, segments: List[Tensor]):
self.segments = segments
self.cum_lengths = []
total = 0
for seg in segments:
total += len(seg)
self.cum_lengths.append(total)
self.total_length = total if segments else 0
def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
if not (0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length):
raise ValueError("begin_idx or end_idx out of bounds")
if begin_idx >= end_idx:
return torch.tensor([], dtype=torch.long)
seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx - 1)
seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx - 1)
result_segments = []
for i in range(seg_start_idx, seg_end_idx + 1):
prev_cum = self.cum_lengths[i - 1] if i > 0 else 0
start = max(begin_idx - prev_cum, 0)
end = min(end_idx - prev_cum, len(self.segments[i]))
result_segments.append(self.segments[i][start:end])
return torch.cat(result_segments, dim=0)
class MutiSegmentFetcher:
def __init__(self, muti_segments: MutiSeg):
self.muti_keys = list(muti_segments.keys())
self.muti_fetchers = {
key: BaseSegmentFetcher(segments)
for key, segments in muti_segments.items()
}
def key_fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]) -> Union[Tensor, Seg]:
fetch_dict = {}
keys = [keys] if isinstance(keys, str) else keys
for key in keys:
fetcher = self.muti_fetchers[key]
fetch_tensor = fetcher.fetch_data(begin_idx, end_idx)
fetch_dict[key] = fetch_tensor
return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]]
def fetch_data(self, begin_idx: int, end_idx: int) -> Union[Tensor, Seg]:
return self.key_fetch(begin_idx, end_idx, self.muti_keys)
class BaseDataset(Dataset, ABC):
def __init__(self, chunk_size: int, device: str):
super().__init__()
self.segments: MutiSeg = {}
self.chunk_size = chunk_size
self.total_samples = 0
self.device = device
def save(self, save_path: str):
first_item = self.segments[keys[0]]
segment_size = len(first_item)
keys = list(self.segments.keys())
for i in range(segment_size):
formated_segment = {key: self.segments[key][i] for key in keys}
pkl.dump(formated_segment, open(f"{save_path}_{i}.pkl", "wb"))
def load(self, load_path: Union[str, List[str]]):
paths = [load_path] if isinstance(load_path, str) else load_path
self.segments, self.total_samples = load_pkl_files(paths)
self.fetcher = MutiSegmentFetcher(self.segments)
@abstractmethod
def __getitem__(self, index: int) -> Dict[str, Tensor]:
raise NotImplementedError
def __len__(self) -> int:
assert self.total_samples // self.chunk_size > 0
return self.total_samples // self.chunk_size
class SeqDataset(BaseDataset):
def __init__(
self,
chunk_size,
device='cuda'
):
super().__init__(chunk_size, device)
self.fetcher = MutiSegmentFetcher(self.segments)
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
return self.fetcher.key_fetch(begin_idx, end_idx, "sequence")
def __getitem__(self, index):
begin_idx = index * self.chunk_size
end_idx = min(begin_idx + self.chunk_size, self.total_samples - 1)
x = self._fetch_data(begin_idx, end_idx).to(device=self.device, dtype=torch.long)
y = self._fetch_data(begin_idx + 1, end_idx + 1).to(device=self.device, dtype=torch.long)
return {"input_ids": x, "target_ids": y}
class SftDataset(BaseDataset):
def __init__(
self,
chunk_size,
bos_token_id,
eos_token_id,
user_token_id,
multi_turn=False,
device='cuda'
):
super().__init__(chunk_size, device)
self.fetcher = MutiSegmentFetcher(self.segments)
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.user_token_id = user_token_id
self.multi_turn = multi_turn
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.fetcher.key_fetch(begin_idx, end_idx, key)
def __getitem__(self, index):
begin_idx = index * self.chunk_size
end_idx = min(begin_idx + self.chunk_size, self.total_samples - 1)
x = self._fetch_data(begin_idx, end_idx, "sequence").to(device=self.device, dtype=torch.long)
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(device=self.device, dtype=torch.long)
loss_mask = build_loss_mask(y, self.bos_token_id, self.eos_token_id)
attn_mask = build_attention_mask(x, self.user_token_id, self.multi_turn)
return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask, "attn_mask": attn_mask}
class DpoDataset(BaseDataset):
def __init__(self, chunk_size: int, device="cuda"):
super().__init__(chunk_size, device)
self.fetcher = MutiSegmentFetcher(self.segments)
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.fetcher.key_fetch(begin_idx, end_idx, key)
def __getitem__(self, index: int):
start_idx = index * self.chunk_size
end_idx = min(start_idx + self.chunk_size, self.total_samples - 1)
chosen = self._fetch_data(start_idx, end_idx, "chosen").to(device=self.device, dtype=torch.long)
rejected = self._fetch_data(start_idx, end_idx, "rejected").to(device=self.device, dtype=torch.long)
chosen_mask = self._fetch_data(start_idx, end_idx, "chosen_mask").to(device=self.device, dtype=torch.bool)
rejected_mask = self._fetch_data(start_idx, end_idx, "rejected_mask").to(device=self.device, dtype=torch.bool)
return {"chosen": chosen, "rejected": rejected, "chosen_mask": chosen_mask, "rejected_mask": rejected_mask}
class PpoDataset(BaseDataset):
def __init__(self, chunk_size: int, device="cuda"):
super().__init__(chunk_size, device)
self.fetcher = MutiSegmentFetcher(self.segments)
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.fetcher.key_fetch(begin_idx, end_idx, key)
def __getitem__(self, index: int) -> Dict[str, Tensor]:
begin_idx = index * self.chunk_size
end_idx = min(begin_idx + self.chunk_size, self.total_samples - 1)
input_ids = self._fetch_data(begin_idx, end_idx, "input_ids").to(self.device),
actions = self._fetch_data(begin_idx, end_idx, "actions").to(self.device),
logprobs = self._fetch_data(begin_idx, end_idx, "logprobs").to(self.device),
rewards = self._fetch_data(begin_idx, end_idx, "rewards").to(self.device)
return {"input_ids": input_ids, "actions": actions, "logprobs": logprobs, "rewards": rewards}
class DatasetLoader:
@staticmethod
def load(
train_type: Literal["seq", "sft", "dpo"],
load_path: Union[str, List[str]],
max_len: int,
device: str,
**kwargs
) -> BaseDataset:
dataset_router: Dict[str, Callable[[int, torch.device], BaseDataset]] = {
"seq": lambda m_len, device: SeqDataset(m_len, device=device),
"sft": lambda m_len, device: SftDataset(
m_len,
device=device,
bos_token_id=kwargs.get("bos_token_id"),
eos_token_id=kwargs.get("eos_token_id"),
user_token_id=kwargs.get("user_token_id"),
multi_turn=kwargs.get("multi_turn")
),
"dpo": lambda m_len, device: DpoDataset(m_len, device=device),
}
dataset = dataset_router[train_type](max_len, device)
dataset.load(load_path)
return dataset