From 676fdd59d7aad36c4ba1c31e3bf1d969648dbe7d Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sat, 27 Sep 2025 13:12:57 +0800 Subject: [PATCH] =?UTF-8?q?feat(strategy):=20=E9=87=8D=E6=9E=84mask?= =?UTF-8?q?=E6=9E=84=E5=BB=BA=E9=80=BB=E8=BE=91=E5=B9=B6=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E7=AD=96=E7=95=A5=E5=B7=A5=E5=8E=82=E5=8F=82=E6=95=B0=E4=BC=A0?= =?UTF-8?q?=E9=80=92?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/trainer/mask.py | 55 -------------------------- khaosz/trainer/strategy.py | 81 +++++++++++--------------------------- 2 files changed, 24 insertions(+), 112 deletions(-) delete mode 100644 khaosz/trainer/mask.py diff --git a/khaosz/trainer/mask.py b/khaosz/trainer/mask.py deleted file mode 100644 index 9de0472..0000000 --- a/khaosz/trainer/mask.py +++ /dev/null @@ -1,55 +0,0 @@ - -import torch -from abc import abstractmethod -from torch import Tensor - - - -class MaskBuilder: - def __init__( - self, - bos_token_id: int, - eos_token_id: int, - user_token_id: int, - system_token_id: int, - - ): - self.bos_token_id = bos_token_id - self.eos_token_id = eos_token_id - self.user_token_id = user_token_id - self.system_token_id = system_token_id - - @abstractmethod - def build(input_ids: Tensor) -> Tensor: - raise NotImplementedError - - - -class LossMaskBuilder(MaskBuilder): - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def build(self, input_ids: Tensor) -> Tensor: - token_markers = torch.zeros_like(input_ids, dtype=torch.int8) - - is_user_token = input_ids.eq(self.user_token_id) - is_system_token = input_ids.eq(self.system_token_id) - - token_markers[is_user_token] = 1 - token_markers[is_system_token] = -1 - - cumulative_markers = torch.cumsum(token_markers, dim=-1) - min_cumulative = cumulative_markers.min(dim=-1, keepdim=True).values - loss_mask = cumulative_markers - min_cumulative - - return loss_mask - - - - -class AttentionMaskBuilder: - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def build(input_ids: Tensor): - bsz = input_ids.size(0) \ No newline at end of file diff --git a/khaosz/trainer/strategy.py b/khaosz/trainer/strategy.py index 123b43f..275f933 100644 --- a/khaosz/trainer/strategy.py +++ b/khaosz/trainer/strategy.py @@ -32,67 +32,34 @@ def get_logprobs(model:nn.Module, input_ids: Tensor, mask: Tensor, pad_token_id) return (token_logprobs * valid_mask).sum(dim=-1) - -class MaskBuilder: - def __init__( - self, - bos_token_id: int, - eos_token_id: int, - user_token_id: int, - system_token_id: int, - - ): - self.bos_token_id = bos_token_id - self.eos_token_id = eos_token_id - self.user_token_id = user_token_id - self.system_token_id = system_token_id +def build_loss_mask(input_ids: Tensor, bos_token_id: int, eos_token_id: int) -> Tensor: + token_markers = torch.zeros_like(input_ids, dtype=torch.int8) - @abstractmethod - def build(input_ids: Tensor) -> Tensor: - raise NotImplementedError - - - -class LossMaskBuilder(MaskBuilder): - def __init__(self, **kwargs): - super().__init__(**kwargs) + is_bos_token = input_ids.eq(bos_token_id) + is_eos_token = input_ids.eq(eos_token_id) - def build(self, input_ids: Tensor) -> Tensor: - token_markers = torch.zeros_like(input_ids, dtype=torch.int8) - - is_user_token = input_ids.eq(self.user_token_id) - is_system_token = input_ids.eq(self.system_token_id) - - token_markers[is_user_token] = 1 - token_markers[is_system_token] = -1 - - cumulative_markers = torch.cumsum(token_markers, dim=-1) - min_cumulative = cumulative_markers.min(dim=-1, keepdim=True).values - loss_mask = cumulative_markers - min_cumulative + token_markers[is_bos_token] = 1 + token_markers[is_eos_token] = -1 - return loss_mask.to(dtype=torch.bool) - - - + cumulative_markers = torch.cumsum(token_markers, dim=-1) + min_cumulative = cumulative_markers.min(dim=-1, keepdim=True).values + loss_mask = cumulative_markers - min_cumulative -class AttentionMaskBuilder(MaskBuilder): - def __init__(self, multi_turn=False, **kwargs): - super().__init__(**kwargs) - self.multi_turn = multi_turn + return loss_mask.to(dtype=torch.bool) + +def build_attention_mask(input_ids: Tensor, user_token_id: int, multi_turn: bool = False) -> Tensor: + bsz, seq_len = input_ids.size() + is_user_token = input_ids.eq(user_token_id) + turn_id = is_user_token.cumsum(dim=-1) - def build(self, input_ids: Tensor): - bsz = input_ids.size(0) - + iq = turn_id.view(bsz, seq_len, 1) + ik = turn_id.view(bsz, 1, seq_len) - def _build_batch(self, input_ids: Tensor): - is_user_token = input_ids.eq(self.user_token_id) - - token_markers = torch.zeros_like(input_ids, dtype=torch.int8) - token_markers[is_user_token] = 1 - cumulative_markers = torch.cumsum(token_markers, dim=-1) - - - + seq_mask = (iq <= ik) if multi_turn else (iq == ik) + causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=input_ids.device)).bool() + attention_mask = seq_mask & causal_mask + + return attention_mask class BaseStrategy(ABC): @@ -209,11 +176,11 @@ class PpoStrategy(BaseStrategy): class StrategyFactory: - def load(model, train_type, pad_token_id, dpo_beta): + def load(model, train_type, **kwargs): train_strategy: Dict[str, Callable[[], BaseStrategy]] = { "seq": lambda: SeqStrategy(model), "sft": lambda: SftStrategy(model), - "dpo": lambda: DpoStrategy(model, pad_token_id, dpo_beta) + "dpo": lambda: DpoStrategy(model, kwargs.pop("pad_token_id") , kwargs.pop("dpo_beta")) } strategy = train_strategy[train_type]() return strategy