fix(trainer): 修复多轮对话中的因果注意力掩码计算逻辑等

This commit is contained in:
ViperEkura 2025-09-28 15:15:19 +08:00
parent 0b96b11a6e
commit 1169cfad82
3 changed files with 15 additions and 3 deletions

View File

@ -99,6 +99,8 @@ def process_attention_mask(
return None
if seq_mask.dim() > 2:
# shape (bsz, seq_len) or (bsz,n_heads, seq_len, seq_len + start_pos)
# if ndim > 2, it's 4D tensor
return seq_mask
batch_size = seq_mask.size(0)

View File

@ -48,7 +48,8 @@ def build_attention_mask(input_ids: Tensor, user_token_id: int, multi_turn: bool
iq = turn_id.view(seq_len, 1)
ik = turn_id.view(1, seq_len)
seq_mask = (iq <= ik) if multi_turn else (iq == ik)
# fix the causual attention mask
seq_mask = (iq >= ik) if multi_turn else (iq == ik)
causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=input_ids.device)).bool()
attention_mask = seq_mask & causal_mask

View File

@ -159,8 +159,17 @@ class StrategyFactory:
def load(model, train_type, **kwargs):
train_strategy: Dict[str, Callable[[], BaseStrategy]] = {
"seq": lambda: SeqStrategy(model),
"sft": lambda: SftStrategy(model, kwargs.pop("bos_token_id"), kwargs.pop("eos_token_id"), kwargs.pop("multi_turn")),
"dpo": lambda: DpoStrategy(model, kwargs.pop("pad_token_id") , kwargs.pop("dpo_beta"))
"sft": lambda: SftStrategy(
model,
kwargs.get("bos_token_id"),
kwargs.get("eos_token_id"),
kwargs.get("multi_turn")
),
"dpo": lambda: DpoStrategy(
model,
kwargs.get("pad_token_id"),
kwargs.get("dpo_beta")
)
}
strategy = train_strategy[train_type]()
return strategy