fix(trainer): 修复多轮对话中的因果注意力掩码计算逻辑等
This commit is contained in:
parent
0b96b11a6e
commit
1169cfad82
|
|
@ -99,6 +99,8 @@ def process_attention_mask(
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if seq_mask.dim() > 2:
|
if seq_mask.dim() > 2:
|
||||||
|
# shape (bsz, seq_len) or (bsz,n_heads, seq_len, seq_len + start_pos)
|
||||||
|
# if ndim > 2, it's 4D tensor
|
||||||
return seq_mask
|
return seq_mask
|
||||||
|
|
||||||
batch_size = seq_mask.size(0)
|
batch_size = seq_mask.size(0)
|
||||||
|
|
|
||||||
|
|
@ -48,7 +48,8 @@ def build_attention_mask(input_ids: Tensor, user_token_id: int, multi_turn: bool
|
||||||
iq = turn_id.view(seq_len, 1)
|
iq = turn_id.view(seq_len, 1)
|
||||||
ik = turn_id.view(1, seq_len)
|
ik = turn_id.view(1, seq_len)
|
||||||
|
|
||||||
seq_mask = (iq <= ik) if multi_turn else (iq == ik)
|
# fix the causual attention mask
|
||||||
|
seq_mask = (iq >= ik) if multi_turn else (iq == ik)
|
||||||
causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=input_ids.device)).bool()
|
causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=input_ids.device)).bool()
|
||||||
attention_mask = seq_mask & causal_mask
|
attention_mask = seq_mask & causal_mask
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -159,8 +159,17 @@ class StrategyFactory:
|
||||||
def load(model, train_type, **kwargs):
|
def load(model, train_type, **kwargs):
|
||||||
train_strategy: Dict[str, Callable[[], BaseStrategy]] = {
|
train_strategy: Dict[str, Callable[[], BaseStrategy]] = {
|
||||||
"seq": lambda: SeqStrategy(model),
|
"seq": lambda: SeqStrategy(model),
|
||||||
"sft": lambda: SftStrategy(model, kwargs.pop("bos_token_id"), kwargs.pop("eos_token_id"), kwargs.pop("multi_turn")),
|
"sft": lambda: SftStrategy(
|
||||||
"dpo": lambda: DpoStrategy(model, kwargs.pop("pad_token_id") , kwargs.pop("dpo_beta"))
|
model,
|
||||||
|
kwargs.get("bos_token_id"),
|
||||||
|
kwargs.get("eos_token_id"),
|
||||||
|
kwargs.get("multi_turn")
|
||||||
|
),
|
||||||
|
"dpo": lambda: DpoStrategy(
|
||||||
|
model,
|
||||||
|
kwargs.get("pad_token_id"),
|
||||||
|
kwargs.get("dpo_beta")
|
||||||
|
)
|
||||||
}
|
}
|
||||||
strategy = train_strategy[train_type]()
|
strategy = train_strategy[train_type]()
|
||||||
return strategy
|
return strategy
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue