diff --git a/khaosz/core/transformer.py b/khaosz/core/transformer.py index ae2d9c3..3ec47f3 100644 --- a/khaosz/core/transformer.py +++ b/khaosz/core/transformer.py @@ -99,6 +99,8 @@ def process_attention_mask( return None if seq_mask.dim() > 2: + # shape (bsz, seq_len) or (bsz,n_heads, seq_len, seq_len + start_pos) + # if ndim > 2, it's 4D tensor return seq_mask batch_size = seq_mask.size(0) diff --git a/khaosz/trainer/dataset.py b/khaosz/trainer/dataset.py index 7dcd326..908447e 100644 --- a/khaosz/trainer/dataset.py +++ b/khaosz/trainer/dataset.py @@ -48,7 +48,8 @@ def build_attention_mask(input_ids: Tensor, user_token_id: int, multi_turn: bool iq = turn_id.view(seq_len, 1) ik = turn_id.view(1, seq_len) - seq_mask = (iq <= ik) if multi_turn else (iq == ik) + # fix the causual attention mask + seq_mask = (iq >= ik) if multi_turn else (iq == ik) causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=input_ids.device)).bool() attention_mask = seq_mask & causal_mask diff --git a/khaosz/trainer/strategy.py b/khaosz/trainer/strategy.py index 7f1f7e8..e94b122 100644 --- a/khaosz/trainer/strategy.py +++ b/khaosz/trainer/strategy.py @@ -159,8 +159,17 @@ class StrategyFactory: def load(model, train_type, **kwargs): train_strategy: Dict[str, Callable[[], BaseStrategy]] = { "seq": lambda: SeqStrategy(model), - "sft": lambda: SftStrategy(model, kwargs.pop("bos_token_id"), kwargs.pop("eos_token_id"), kwargs.pop("multi_turn")), - "dpo": lambda: DpoStrategy(model, kwargs.pop("pad_token_id") , kwargs.pop("dpo_beta")) + "sft": lambda: SftStrategy( + model, + kwargs.get("bos_token_id"), + kwargs.get("eos_token_id"), + kwargs.get("multi_turn") + ), + "dpo": lambda: DpoStrategy( + model, + kwargs.get("pad_token_id"), + kwargs.get("dpo_beta") + ) } strategy = train_strategy[train_type]() return strategy