diff --git a/khaosz/trainer/dataset.py b/khaosz/trainer/dataset.py index 439567d..7dcd326 100644 --- a/khaosz/trainer/dataset.py +++ b/khaosz/trainer/dataset.py @@ -25,6 +25,35 @@ def load_pkl_files(paths: List[str]): return segments, total_samples +def build_loss_mask(input_ids: Tensor, bos_token_id: int, eos_token_id: int) -> Tensor: + token_markers = torch.zeros_like(input_ids, dtype=torch.int8) + + is_bos_token = input_ids.eq(bos_token_id) + is_eos_token = input_ids.eq(eos_token_id) + + token_markers[is_bos_token] = 1 + token_markers[is_eos_token] = -1 + + cumulative_markers = torch.cumsum(token_markers, dim=-1) + min_cumulative = cumulative_markers.min(dim=-1, keepdim=True).values + loss_mask = cumulative_markers - min_cumulative + + return loss_mask.to(dtype=torch.bool) + +def build_attention_mask(input_ids: Tensor, user_token_id: int, multi_turn: bool) -> Tensor: + seq_len = input_ids.size(0) + is_user_token = input_ids.eq(user_token_id) + turn_id = is_user_token.cumsum(dim=-1) + + iq = turn_id.view(seq_len, 1) + ik = turn_id.view(1, seq_len) + + seq_mask = (iq <= ik) if multi_turn else (iq == ik) + causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=input_ids.device)).bool() + attention_mask = seq_mask & causal_mask + + return attention_mask + class BaseSegmentFetcher: def __init__(self, segments: List[Tensor]): @@ -103,7 +132,7 @@ class BaseDataset(Dataset, ABC): self.fetcher = MutiSegmentFetcher(self.segments) @abstractmethod - def __getitem__(self, index: int): + def __getitem__(self, index: int) -> Dict[str, Tensor]: raise NotImplementedError def __len__(self) -> int: @@ -112,7 +141,11 @@ class BaseDataset(Dataset, ABC): class SeqDataset(BaseDataset): - def __init__(self, chunk_size , device='cuda'): + def __init__( + self, + chunk_size, + device='cuda' + ): super().__init__(chunk_size, device) self.fetcher = MutiSegmentFetcher(self.segments) @@ -126,13 +159,26 @@ class SeqDataset(BaseDataset): x = self._fetch_data(begin_idx, end_idx).to(device=self.device, dtype=torch.long) y = self._fetch_data(begin_idx + 1, end_idx + 1).to(device=self.device, dtype=torch.long) - return x, y + return {"input_ids": x, "target_ids": y} + class SftDataset(BaseDataset): - def __init__(self, chunk_size, device='cuda'): + def __init__( + self, + chunk_size, + bos_token_id, + eos_token_id, + user_token_id, + multi_turn=False, + device='cuda' + ): super().__init__(chunk_size, device) self.fetcher = MutiSegmentFetcher(self.segments) + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.user_token_id = user_token_id + self.multi_turn = multi_turn def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: return self.fetcher.key_fetch(begin_idx, end_idx, key) @@ -143,9 +189,11 @@ class SftDataset(BaseDataset): x = self._fetch_data(begin_idx, end_idx, "sequence").to(device=self.device, dtype=torch.long) y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(device=self.device, dtype=torch.long) - loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "mask").to(device=self.device, dtype=torch.bool) - return x, y, loss_mask + loss_mask = build_loss_mask(y, self.bos_token_id, self.eos_token_id) + attn_mask = build_attention_mask(x, self.user_token_id, self.multi_turn) + + return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask, "attn_mask": attn_mask} class DpoDataset(BaseDataset): @@ -165,7 +213,7 @@ class DpoDataset(BaseDataset): chosen_mask = self._fetch_data(start_idx, end_idx, "chosen_mask").to(device=self.device, dtype=torch.bool) rejected_mask = self._fetch_data(start_idx, end_idx, "rejected_mask").to(device=self.device, dtype=torch.bool) - return chosen, rejected, chosen_mask, rejected_mask + return {"chosen": chosen, "rejected": rejected, "chosen_mask": chosen_mask, "rejected_mask": rejected_mask} class PpoDataset(BaseDataset): @@ -187,7 +235,7 @@ class PpoDataset(BaseDataset): logprobs = self._fetch_data(begin_idx, end_idx, "logprobs").to(self.device), rewards = self._fetch_data(begin_idx, end_idx, "rewards").to(self.device) - return input_ids, actions, logprobs, rewards + return {"input_ids": input_ids, "actions": actions, "logprobs": logprobs, "rewards": rewards} class DatasetLoader: @@ -196,12 +244,20 @@ class DatasetLoader: train_type: Literal["seq", "sft", "dpo"], load_path: Union[str, List[str]], max_len: int, - device: str + device: str, + **kwargs ) -> BaseDataset: dataset_router: Dict[str, Callable[[int, torch.device], BaseDataset]] = { "seq": lambda m_len, device: SeqDataset(m_len, device=device), - "sft": lambda m_len, device: SftDataset(m_len, device=device), + "sft": lambda m_len, device: SftDataset( + m_len, + device=device, + bos_token_id=kwargs.get("bos_token_id"), + eos_token_id=kwargs.get("eos_token_id"), + user_token_id=kwargs.get("user_token_id"), + multi_turn=kwargs.get("multi_turn", False) + ), "dpo": lambda m_len, device: DpoDataset(m_len, device=device), } dataset = dataset_router[train_type](max_len, device) diff --git a/khaosz/trainer/strategy.py b/khaosz/trainer/strategy.py index 7112684..7f1f7e8 100644 --- a/khaosz/trainer/strategy.py +++ b/khaosz/trainer/strategy.py @@ -32,42 +32,13 @@ def get_logprobs(model:nn.Module, input_ids: Tensor, mask: Tensor, pad_token_id: return (token_logprobs * valid_mask).sum(dim=-1) -def build_loss_mask(input_ids: Tensor, bos_token_id: int, eos_token_id: int) -> Tensor: - token_markers = torch.zeros_like(input_ids, dtype=torch.int8) - - is_bos_token = input_ids.eq(bos_token_id) - is_eos_token = input_ids.eq(eos_token_id) - - token_markers[is_bos_token] = 1 - token_markers[is_eos_token] = -1 - - cumulative_markers = torch.cumsum(token_markers, dim=-1) - min_cumulative = cumulative_markers.min(dim=-1, keepdim=True).values - loss_mask = cumulative_markers - min_cumulative - - return loss_mask.to(dtype=torch.bool) - -def build_attention_mask(input_ids: Tensor, user_token_id: int, multi_turn: bool) -> Tensor: - bsz, seq_len = input_ids.size() - is_user_token = input_ids.eq(user_token_id) - turn_id = is_user_token.cumsum(dim=-1) - - iq = turn_id.view(bsz, seq_len, 1) - ik = turn_id.view(bsz, 1, seq_len) - - seq_mask = (iq <= ik) if multi_turn else (iq == ik) - causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=input_ids.device)).bool() - attention_mask = seq_mask & causal_mask - - return attention_mask - class BaseStrategy(ABC): def __init__(self, model: nn.Module): self.model = model @abstractmethod - def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor: + def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: raise NotImplementedError def __call__(self, batch: Tuple[Tensor, ...]) -> Tensor: @@ -78,47 +49,45 @@ class SeqStrategy(BaseStrategy): def __init__(self, model): super().__init__(model) - def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor: - x, y = batch - B, L = x.size() - logits: Tensor = self.model(x)["logits"] + def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: + input_ids, target_ids = batch["input_ids"], batch["target_ids"] + B, L = input_ids.size() + logits: Tensor = self.model(input_ids=input_ids)["logits"] loss = F.cross_entropy( - logits.view(B * L, -1), y.flatten() + input=logits.view(B * L, -1), + target=target_ids.flatten() ) return loss class SftStrategy(BaseStrategy): - def __init__( - self, - model: nn.Module, - bos_id: int, - eos_id: int, - user_token_id: int, - multi_turn: bool - ): + def __init__(self, model: nn.Module): super().__init__(model) - - self.loss_mask_builder = lambda x: build_loss_mask(x, bos_id, eos_id) - self.attn_mask_builder = lambda x: build_attention_mask(x, user_token_id, multi_turn) - def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor: - x, y, loss_mask = batch - B, L = x.size() - ignore_idx = -1 + def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: + input_ids, target_ids = batch["input_ids"], batch["target_ids"] + loss_mask, attn_mask = batch["loss_mask"], batch["attn_mask"] - logits: Tensor = self.model(x)["logits"] - masked_y = y.masked_fill(loss_mask == 0, ignore_idx) + ignore_index = -100 + B, L = input_ids.size() + + logits: Tensor = self.model( + input_ids=input_ids, + input_mask=attn_mask + )["logits"] + + target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index) loss = F.cross_entropy( - logits.view(B * L, -1), - masked_y.flatten(), - ignore_index=ignore_idx + input=logits.view(B * L, -1), + target=target_ids.flatten(), + ignore_index=ignore_index ) - + return loss + class DpoStrategy(BaseStrategy): def __init__(self, model, pad_token_id, beta): super().__init__(model) @@ -131,7 +100,8 @@ class DpoStrategy(BaseStrategy): self.beta = beta def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor: - good_ids, bad_ids, good_mask, bad_mask = batch + good_ids, bad_ids = batch["chosen"], batch["rejected"] + good_mask, bad_mask = batch["chosen_mask"], batch["rejected_mask"] log_pi_good = get_logprobs(self.model, good_ids, good_mask, self.pad_token_id) log_pi_bad = get_logprobs(self.model, bad_ids, bad_mask, self.pad_token_id) diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 1c23544..e44d698 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -46,10 +46,10 @@ def test_env(): return self.length def __getitem__(self, idx): - return ( - torch.randint(0, 1000, (64,)), - torch.randint(0, 1000, (64,)) - ) + return { + "input_ids": torch.randint(0, 1000, (64,)), + "target_ids": torch.randint(0, 1000, (64,)) + } dataset = DummyDataset()