diff --git a/khaosz/trainer/mask.py b/khaosz/trainer/mask.py deleted file mode 100644 index 9de0472..0000000 --- a/khaosz/trainer/mask.py +++ /dev/null @@ -1,55 +0,0 @@ - -import torch -from abc import abstractmethod -from torch import Tensor - - - -class MaskBuilder: - def __init__( - self, - bos_token_id: int, - eos_token_id: int, - user_token_id: int, - system_token_id: int, - - ): - self.bos_token_id = bos_token_id - self.eos_token_id = eos_token_id - self.user_token_id = user_token_id - self.system_token_id = system_token_id - - @abstractmethod - def build(input_ids: Tensor) -> Tensor: - raise NotImplementedError - - - -class LossMaskBuilder(MaskBuilder): - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def build(self, input_ids: Tensor) -> Tensor: - token_markers = torch.zeros_like(input_ids, dtype=torch.int8) - - is_user_token = input_ids.eq(self.user_token_id) - is_system_token = input_ids.eq(self.system_token_id) - - token_markers[is_user_token] = 1 - token_markers[is_system_token] = -1 - - cumulative_markers = torch.cumsum(token_markers, dim=-1) - min_cumulative = cumulative_markers.min(dim=-1, keepdim=True).values - loss_mask = cumulative_markers - min_cumulative - - return loss_mask - - - - -class AttentionMaskBuilder: - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def build(input_ids: Tensor): - bsz = input_ids.size(0) \ No newline at end of file diff --git a/khaosz/trainer/strategy.py b/khaosz/trainer/strategy.py index 123b43f..275f933 100644 --- a/khaosz/trainer/strategy.py +++ b/khaosz/trainer/strategy.py @@ -32,67 +32,34 @@ def get_logprobs(model:nn.Module, input_ids: Tensor, mask: Tensor, pad_token_id) return (token_logprobs * valid_mask).sum(dim=-1) - -class MaskBuilder: - def __init__( - self, - bos_token_id: int, - eos_token_id: int, - user_token_id: int, - system_token_id: int, - - ): - self.bos_token_id = bos_token_id - self.eos_token_id = eos_token_id - self.user_token_id = user_token_id - self.system_token_id = system_token_id +def build_loss_mask(input_ids: Tensor, bos_token_id: int, eos_token_id: int) -> Tensor: + token_markers = torch.zeros_like(input_ids, dtype=torch.int8) - @abstractmethod - def build(input_ids: Tensor) -> Tensor: - raise NotImplementedError - - - -class LossMaskBuilder(MaskBuilder): - def __init__(self, **kwargs): - super().__init__(**kwargs) + is_bos_token = input_ids.eq(bos_token_id) + is_eos_token = input_ids.eq(eos_token_id) - def build(self, input_ids: Tensor) -> Tensor: - token_markers = torch.zeros_like(input_ids, dtype=torch.int8) - - is_user_token = input_ids.eq(self.user_token_id) - is_system_token = input_ids.eq(self.system_token_id) - - token_markers[is_user_token] = 1 - token_markers[is_system_token] = -1 - - cumulative_markers = torch.cumsum(token_markers, dim=-1) - min_cumulative = cumulative_markers.min(dim=-1, keepdim=True).values - loss_mask = cumulative_markers - min_cumulative + token_markers[is_bos_token] = 1 + token_markers[is_eos_token] = -1 - return loss_mask.to(dtype=torch.bool) - - - + cumulative_markers = torch.cumsum(token_markers, dim=-1) + min_cumulative = cumulative_markers.min(dim=-1, keepdim=True).values + loss_mask = cumulative_markers - min_cumulative -class AttentionMaskBuilder(MaskBuilder): - def __init__(self, multi_turn=False, **kwargs): - super().__init__(**kwargs) - self.multi_turn = multi_turn + return loss_mask.to(dtype=torch.bool) + +def build_attention_mask(input_ids: Tensor, user_token_id: int, multi_turn: bool = False) -> Tensor: + bsz, seq_len = input_ids.size() + is_user_token = input_ids.eq(user_token_id) + turn_id = is_user_token.cumsum(dim=-1) - def build(self, input_ids: Tensor): - bsz = input_ids.size(0) - + iq = turn_id.view(bsz, seq_len, 1) + ik = turn_id.view(bsz, 1, seq_len) - def _build_batch(self, input_ids: Tensor): - is_user_token = input_ids.eq(self.user_token_id) - - token_markers = torch.zeros_like(input_ids, dtype=torch.int8) - token_markers[is_user_token] = 1 - cumulative_markers = torch.cumsum(token_markers, dim=-1) - - - + seq_mask = (iq <= ik) if multi_turn else (iq == ik) + causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=input_ids.device)).bool() + attention_mask = seq_mask & causal_mask + + return attention_mask class BaseStrategy(ABC): @@ -209,11 +176,11 @@ class PpoStrategy(BaseStrategy): class StrategyFactory: - def load(model, train_type, pad_token_id, dpo_beta): + def load(model, train_type, **kwargs): train_strategy: Dict[str, Callable[[], BaseStrategy]] = { "seq": lambda: SeqStrategy(model), "sft": lambda: SftStrategy(model), - "dpo": lambda: DpoStrategy(model, pad_token_id, dpo_beta) + "dpo": lambda: DpoStrategy(model, kwargs.pop("pad_token_id") , kwargs.pop("dpo_beta")) } strategy = train_strategy[train_type]() return strategy