feat(strategy): 重构mask构建逻辑并优化策略工厂参数传递
This commit is contained in:
parent
a4443765ee
commit
676fdd59d7
|
|
@ -1,55 +0,0 @@
|
||||||
|
|
||||||
import torch
|
|
||||||
from abc import abstractmethod
|
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class MaskBuilder:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
bos_token_id: int,
|
|
||||||
eos_token_id: int,
|
|
||||||
user_token_id: int,
|
|
||||||
system_token_id: int,
|
|
||||||
|
|
||||||
):
|
|
||||||
self.bos_token_id = bos_token_id
|
|
||||||
self.eos_token_id = eos_token_id
|
|
||||||
self.user_token_id = user_token_id
|
|
||||||
self.system_token_id = system_token_id
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def build(input_ids: Tensor) -> Tensor:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class LossMaskBuilder(MaskBuilder):
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
|
|
||||||
def build(self, input_ids: Tensor) -> Tensor:
|
|
||||||
token_markers = torch.zeros_like(input_ids, dtype=torch.int8)
|
|
||||||
|
|
||||||
is_user_token = input_ids.eq(self.user_token_id)
|
|
||||||
is_system_token = input_ids.eq(self.system_token_id)
|
|
||||||
|
|
||||||
token_markers[is_user_token] = 1
|
|
||||||
token_markers[is_system_token] = -1
|
|
||||||
|
|
||||||
cumulative_markers = torch.cumsum(token_markers, dim=-1)
|
|
||||||
min_cumulative = cumulative_markers.min(dim=-1, keepdim=True).values
|
|
||||||
loss_mask = cumulative_markers - min_cumulative
|
|
||||||
|
|
||||||
return loss_mask
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class AttentionMaskBuilder:
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
|
|
||||||
def build(input_ids: Tensor):
|
|
||||||
bsz = input_ids.size(0)
|
|
||||||
|
|
@ -32,67 +32,34 @@ def get_logprobs(model:nn.Module, input_ids: Tensor, mask: Tensor, pad_token_id)
|
||||||
|
|
||||||
return (token_logprobs * valid_mask).sum(dim=-1)
|
return (token_logprobs * valid_mask).sum(dim=-1)
|
||||||
|
|
||||||
|
def build_loss_mask(input_ids: Tensor, bos_token_id: int, eos_token_id: int) -> Tensor:
|
||||||
|
token_markers = torch.zeros_like(input_ids, dtype=torch.int8)
|
||||||
|
|
||||||
class MaskBuilder:
|
is_bos_token = input_ids.eq(bos_token_id)
|
||||||
def __init__(
|
is_eos_token = input_ids.eq(eos_token_id)
|
||||||
self,
|
|
||||||
bos_token_id: int,
|
|
||||||
eos_token_id: int,
|
|
||||||
user_token_id: int,
|
|
||||||
system_token_id: int,
|
|
||||||
|
|
||||||
):
|
token_markers[is_bos_token] = 1
|
||||||
self.bos_token_id = bos_token_id
|
token_markers[is_eos_token] = -1
|
||||||
self.eos_token_id = eos_token_id
|
|
||||||
self.user_token_id = user_token_id
|
|
||||||
self.system_token_id = system_token_id
|
|
||||||
|
|
||||||
@abstractmethod
|
cumulative_markers = torch.cumsum(token_markers, dim=-1)
|
||||||
def build(input_ids: Tensor) -> Tensor:
|
min_cumulative = cumulative_markers.min(dim=-1, keepdim=True).values
|
||||||
raise NotImplementedError
|
loss_mask = cumulative_markers - min_cumulative
|
||||||
|
|
||||||
|
return loss_mask.to(dtype=torch.bool)
|
||||||
|
|
||||||
|
def build_attention_mask(input_ids: Tensor, user_token_id: int, multi_turn: bool = False) -> Tensor:
|
||||||
|
bsz, seq_len = input_ids.size()
|
||||||
|
is_user_token = input_ids.eq(user_token_id)
|
||||||
|
turn_id = is_user_token.cumsum(dim=-1)
|
||||||
|
|
||||||
class LossMaskBuilder(MaskBuilder):
|
iq = turn_id.view(bsz, seq_len, 1)
|
||||||
def __init__(self, **kwargs):
|
ik = turn_id.view(bsz, 1, seq_len)
|
||||||
super().__init__(**kwargs)
|
|
||||||
|
|
||||||
def build(self, input_ids: Tensor) -> Tensor:
|
|
||||||
token_markers = torch.zeros_like(input_ids, dtype=torch.int8)
|
|
||||||
|
|
||||||
is_user_token = input_ids.eq(self.user_token_id)
|
|
||||||
is_system_token = input_ids.eq(self.system_token_id)
|
|
||||||
|
|
||||||
token_markers[is_user_token] = 1
|
|
||||||
token_markers[is_system_token] = -1
|
|
||||||
|
|
||||||
cumulative_markers = torch.cumsum(token_markers, dim=-1)
|
|
||||||
min_cumulative = cumulative_markers.min(dim=-1, keepdim=True).values
|
|
||||||
loss_mask = cumulative_markers - min_cumulative
|
|
||||||
|
|
||||||
return loss_mask.to(dtype=torch.bool)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class AttentionMaskBuilder(MaskBuilder):
|
|
||||||
def __init__(self, multi_turn=False, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.multi_turn = multi_turn
|
|
||||||
|
|
||||||
def build(self, input_ids: Tensor):
|
|
||||||
bsz = input_ids.size(0)
|
|
||||||
|
|
||||||
|
|
||||||
def _build_batch(self, input_ids: Tensor):
|
|
||||||
is_user_token = input_ids.eq(self.user_token_id)
|
|
||||||
|
|
||||||
token_markers = torch.zeros_like(input_ids, dtype=torch.int8)
|
|
||||||
token_markers[is_user_token] = 1
|
|
||||||
cumulative_markers = torch.cumsum(token_markers, dim=-1)
|
|
||||||
|
|
||||||
|
|
||||||
|
seq_mask = (iq <= ik) if multi_turn else (iq == ik)
|
||||||
|
causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=input_ids.device)).bool()
|
||||||
|
attention_mask = seq_mask & causal_mask
|
||||||
|
|
||||||
|
return attention_mask
|
||||||
|
|
||||||
|
|
||||||
class BaseStrategy(ABC):
|
class BaseStrategy(ABC):
|
||||||
|
|
@ -209,11 +176,11 @@ class PpoStrategy(BaseStrategy):
|
||||||
|
|
||||||
class StrategyFactory:
|
class StrategyFactory:
|
||||||
|
|
||||||
def load(model, train_type, pad_token_id, dpo_beta):
|
def load(model, train_type, **kwargs):
|
||||||
train_strategy: Dict[str, Callable[[], BaseStrategy]] = {
|
train_strategy: Dict[str, Callable[[], BaseStrategy]] = {
|
||||||
"seq": lambda: SeqStrategy(model),
|
"seq": lambda: SeqStrategy(model),
|
||||||
"sft": lambda: SftStrategy(model),
|
"sft": lambda: SftStrategy(model),
|
||||||
"dpo": lambda: DpoStrategy(model, pad_token_id, dpo_beta)
|
"dpo": lambda: DpoStrategy(model, kwargs.pop("pad_token_id") , kwargs.pop("dpo_beta"))
|
||||||
}
|
}
|
||||||
strategy = train_strategy[train_type]()
|
strategy = train_strategy[train_type]()
|
||||||
return strategy
|
return strategy
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue