From 1169cfad8226a44c6afe85bde096bf329e2ff85d Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sun, 28 Sep 2025 15:15:19 +0800 Subject: [PATCH] =?UTF-8?q?fix(trainer):=20=E4=BF=AE=E5=A4=8D=E5=A4=9A?= =?UTF-8?q?=E8=BD=AE=E5=AF=B9=E8=AF=9D=E4=B8=AD=E7=9A=84=E5=9B=A0=E6=9E=9C?= =?UTF-8?q?=E6=B3=A8=E6=84=8F=E5=8A=9B=E6=8E=A9=E7=A0=81=E8=AE=A1=E7=AE=97?= =?UTF-8?q?=E9=80=BB=E8=BE=91=E7=AD=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/core/transformer.py | 2 ++ khaosz/trainer/dataset.py | 3 ++- khaosz/trainer/strategy.py | 13 +++++++++++-- 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/khaosz/core/transformer.py b/khaosz/core/transformer.py index ae2d9c3..3ec47f3 100644 --- a/khaosz/core/transformer.py +++ b/khaosz/core/transformer.py @@ -99,6 +99,8 @@ def process_attention_mask( return None if seq_mask.dim() > 2: + # shape (bsz, seq_len) or (bsz,n_heads, seq_len, seq_len + start_pos) + # if ndim > 2, it's 4D tensor return seq_mask batch_size = seq_mask.size(0) diff --git a/khaosz/trainer/dataset.py b/khaosz/trainer/dataset.py index 7dcd326..908447e 100644 --- a/khaosz/trainer/dataset.py +++ b/khaosz/trainer/dataset.py @@ -48,7 +48,8 @@ def build_attention_mask(input_ids: Tensor, user_token_id: int, multi_turn: bool iq = turn_id.view(seq_len, 1) ik = turn_id.view(1, seq_len) - seq_mask = (iq <= ik) if multi_turn else (iq == ik) + # fix the causual attention mask + seq_mask = (iq >= ik) if multi_turn else (iq == ik) causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=input_ids.device)).bool() attention_mask = seq_mask & causal_mask diff --git a/khaosz/trainer/strategy.py b/khaosz/trainer/strategy.py index 7f1f7e8..e94b122 100644 --- a/khaosz/trainer/strategy.py +++ b/khaosz/trainer/strategy.py @@ -159,8 +159,17 @@ class StrategyFactory: def load(model, train_type, **kwargs): train_strategy: Dict[str, Callable[[], BaseStrategy]] = { "seq": lambda: SeqStrategy(model), - "sft": lambda: SftStrategy(model, kwargs.pop("bos_token_id"), kwargs.pop("eos_token_id"), kwargs.pop("multi_turn")), - "dpo": lambda: DpoStrategy(model, kwargs.pop("pad_token_id") , kwargs.pop("dpo_beta")) + "sft": lambda: SftStrategy( + model, + kwargs.get("bos_token_id"), + kwargs.get("eos_token_id"), + kwargs.get("multi_turn") + ), + "dpo": lambda: DpoStrategy( + model, + kwargs.get("pad_token_id"), + kwargs.get("dpo_beta") + ) } strategy = train_strategy[train_type]() return strategy