fix(trainer): 修复多轮对话中的因果注意力掩码计算逻辑等
This commit is contained in:
parent
0b96b11a6e
commit
1169cfad82
|
|
@ -99,6 +99,8 @@ def process_attention_mask(
|
|||
return None
|
||||
|
||||
if seq_mask.dim() > 2:
|
||||
# shape (bsz, seq_len) or (bsz,n_heads, seq_len, seq_len + start_pos)
|
||||
# if ndim > 2, it's 4D tensor
|
||||
return seq_mask
|
||||
|
||||
batch_size = seq_mask.size(0)
|
||||
|
|
|
|||
|
|
@ -48,7 +48,8 @@ def build_attention_mask(input_ids: Tensor, user_token_id: int, multi_turn: bool
|
|||
iq = turn_id.view(seq_len, 1)
|
||||
ik = turn_id.view(1, seq_len)
|
||||
|
||||
seq_mask = (iq <= ik) if multi_turn else (iq == ik)
|
||||
# fix the causual attention mask
|
||||
seq_mask = (iq >= ik) if multi_turn else (iq == ik)
|
||||
causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=input_ids.device)).bool()
|
||||
attention_mask = seq_mask & causal_mask
|
||||
|
||||
|
|
|
|||
|
|
@ -159,8 +159,17 @@ class StrategyFactory:
|
|||
def load(model, train_type, **kwargs):
|
||||
train_strategy: Dict[str, Callable[[], BaseStrategy]] = {
|
||||
"seq": lambda: SeqStrategy(model),
|
||||
"sft": lambda: SftStrategy(model, kwargs.pop("bos_token_id"), kwargs.pop("eos_token_id"), kwargs.pop("multi_turn")),
|
||||
"dpo": lambda: DpoStrategy(model, kwargs.pop("pad_token_id") , kwargs.pop("dpo_beta"))
|
||||
"sft": lambda: SftStrategy(
|
||||
model,
|
||||
kwargs.get("bos_token_id"),
|
||||
kwargs.get("eos_token_id"),
|
||||
kwargs.get("multi_turn")
|
||||
),
|
||||
"dpo": lambda: DpoStrategy(
|
||||
model,
|
||||
kwargs.get("pad_token_id"),
|
||||
kwargs.get("dpo_beta")
|
||||
)
|
||||
}
|
||||
strategy = train_strategy[train_type]()
|
||||
return strategy
|
||||
|
|
|
|||
Loading…
Reference in New Issue