fix(trainer): 修复多轮对话中的因果注意力掩码计算逻辑等

This commit is contained in:
ViperEkura 2025-09-28 15:15:19 +08:00
parent 0b96b11a6e
commit 1169cfad82
3 changed files with 15 additions and 3 deletions

View File

@ -99,6 +99,8 @@ def process_attention_mask(
return None return None
if seq_mask.dim() > 2: if seq_mask.dim() > 2:
# shape (bsz, seq_len) or (bsz,n_heads, seq_len, seq_len + start_pos)
# if ndim > 2, it's 4D tensor
return seq_mask return seq_mask
batch_size = seq_mask.size(0) batch_size = seq_mask.size(0)

View File

@ -48,7 +48,8 @@ def build_attention_mask(input_ids: Tensor, user_token_id: int, multi_turn: bool
iq = turn_id.view(seq_len, 1) iq = turn_id.view(seq_len, 1)
ik = turn_id.view(1, seq_len) ik = turn_id.view(1, seq_len)
seq_mask = (iq <= ik) if multi_turn else (iq == ik) # fix the causual attention mask
seq_mask = (iq >= ik) if multi_turn else (iq == ik)
causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=input_ids.device)).bool() causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=input_ids.device)).bool()
attention_mask = seq_mask & causal_mask attention_mask = seq_mask & causal_mask

View File

@ -159,8 +159,17 @@ class StrategyFactory:
def load(model, train_type, **kwargs): def load(model, train_type, **kwargs):
train_strategy: Dict[str, Callable[[], BaseStrategy]] = { train_strategy: Dict[str, Callable[[], BaseStrategy]] = {
"seq": lambda: SeqStrategy(model), "seq": lambda: SeqStrategy(model),
"sft": lambda: SftStrategy(model, kwargs.pop("bos_token_id"), kwargs.pop("eos_token_id"), kwargs.pop("multi_turn")), "sft": lambda: SftStrategy(
"dpo": lambda: DpoStrategy(model, kwargs.pop("pad_token_id") , kwargs.pop("dpo_beta")) model,
kwargs.get("bos_token_id"),
kwargs.get("eos_token_id"),
kwargs.get("multi_turn")
),
"dpo": lambda: DpoStrategy(
model,
kwargs.get("pad_token_id"),
kwargs.get("dpo_beta")
)
} }
strategy = train_strategy[train_type]() strategy = train_strategy[train_type]()
return strategy return strategy