feat(strategy): 重构mask构建逻辑并优化策略工厂参数传递
This commit is contained in:
parent
a4443765ee
commit
676fdd59d7
|
|
@ -1,55 +0,0 @@
|
|||
|
||||
import torch
|
||||
from abc import abstractmethod
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
|
||||
class MaskBuilder:
|
||||
def __init__(
|
||||
self,
|
||||
bos_token_id: int,
|
||||
eos_token_id: int,
|
||||
user_token_id: int,
|
||||
system_token_id: int,
|
||||
|
||||
):
|
||||
self.bos_token_id = bos_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
self.user_token_id = user_token_id
|
||||
self.system_token_id = system_token_id
|
||||
|
||||
@abstractmethod
|
||||
def build(input_ids: Tensor) -> Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
||||
class LossMaskBuilder(MaskBuilder):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def build(self, input_ids: Tensor) -> Tensor:
|
||||
token_markers = torch.zeros_like(input_ids, dtype=torch.int8)
|
||||
|
||||
is_user_token = input_ids.eq(self.user_token_id)
|
||||
is_system_token = input_ids.eq(self.system_token_id)
|
||||
|
||||
token_markers[is_user_token] = 1
|
||||
token_markers[is_system_token] = -1
|
||||
|
||||
cumulative_markers = torch.cumsum(token_markers, dim=-1)
|
||||
min_cumulative = cumulative_markers.min(dim=-1, keepdim=True).values
|
||||
loss_mask = cumulative_markers - min_cumulative
|
||||
|
||||
return loss_mask
|
||||
|
||||
|
||||
|
||||
|
||||
class AttentionMaskBuilder:
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def build(input_ids: Tensor):
|
||||
bsz = input_ids.size(0)
|
||||
|
|
@ -32,67 +32,34 @@ def get_logprobs(model:nn.Module, input_ids: Tensor, mask: Tensor, pad_token_id)
|
|||
|
||||
return (token_logprobs * valid_mask).sum(dim=-1)
|
||||
|
||||
def build_loss_mask(input_ids: Tensor, bos_token_id: int, eos_token_id: int) -> Tensor:
|
||||
token_markers = torch.zeros_like(input_ids, dtype=torch.int8)
|
||||
|
||||
class MaskBuilder:
|
||||
def __init__(
|
||||
self,
|
||||
bos_token_id: int,
|
||||
eos_token_id: int,
|
||||
user_token_id: int,
|
||||
system_token_id: int,
|
||||
is_bos_token = input_ids.eq(bos_token_id)
|
||||
is_eos_token = input_ids.eq(eos_token_id)
|
||||
|
||||
):
|
||||
self.bos_token_id = bos_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
self.user_token_id = user_token_id
|
||||
self.system_token_id = system_token_id
|
||||
token_markers[is_bos_token] = 1
|
||||
token_markers[is_eos_token] = -1
|
||||
|
||||
@abstractmethod
|
||||
def build(input_ids: Tensor) -> Tensor:
|
||||
raise NotImplementedError
|
||||
cumulative_markers = torch.cumsum(token_markers, dim=-1)
|
||||
min_cumulative = cumulative_markers.min(dim=-1, keepdim=True).values
|
||||
loss_mask = cumulative_markers - min_cumulative
|
||||
|
||||
return loss_mask.to(dtype=torch.bool)
|
||||
|
||||
def build_attention_mask(input_ids: Tensor, user_token_id: int, multi_turn: bool = False) -> Tensor:
|
||||
bsz, seq_len = input_ids.size()
|
||||
is_user_token = input_ids.eq(user_token_id)
|
||||
turn_id = is_user_token.cumsum(dim=-1)
|
||||
|
||||
class LossMaskBuilder(MaskBuilder):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def build(self, input_ids: Tensor) -> Tensor:
|
||||
token_markers = torch.zeros_like(input_ids, dtype=torch.int8)
|
||||
|
||||
is_user_token = input_ids.eq(self.user_token_id)
|
||||
is_system_token = input_ids.eq(self.system_token_id)
|
||||
|
||||
token_markers[is_user_token] = 1
|
||||
token_markers[is_system_token] = -1
|
||||
|
||||
cumulative_markers = torch.cumsum(token_markers, dim=-1)
|
||||
min_cumulative = cumulative_markers.min(dim=-1, keepdim=True).values
|
||||
loss_mask = cumulative_markers - min_cumulative
|
||||
|
||||
return loss_mask.to(dtype=torch.bool)
|
||||
|
||||
|
||||
|
||||
|
||||
class AttentionMaskBuilder(MaskBuilder):
|
||||
def __init__(self, multi_turn=False, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.multi_turn = multi_turn
|
||||
|
||||
def build(self, input_ids: Tensor):
|
||||
bsz = input_ids.size(0)
|
||||
|
||||
|
||||
def _build_batch(self, input_ids: Tensor):
|
||||
is_user_token = input_ids.eq(self.user_token_id)
|
||||
|
||||
token_markers = torch.zeros_like(input_ids, dtype=torch.int8)
|
||||
token_markers[is_user_token] = 1
|
||||
cumulative_markers = torch.cumsum(token_markers, dim=-1)
|
||||
|
||||
iq = turn_id.view(bsz, seq_len, 1)
|
||||
ik = turn_id.view(bsz, 1, seq_len)
|
||||
|
||||
seq_mask = (iq <= ik) if multi_turn else (iq == ik)
|
||||
causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=input_ids.device)).bool()
|
||||
attention_mask = seq_mask & causal_mask
|
||||
|
||||
return attention_mask
|
||||
|
||||
|
||||
class BaseStrategy(ABC):
|
||||
|
|
@ -209,11 +176,11 @@ class PpoStrategy(BaseStrategy):
|
|||
|
||||
class StrategyFactory:
|
||||
|
||||
def load(model, train_type, pad_token_id, dpo_beta):
|
||||
def load(model, train_type, **kwargs):
|
||||
train_strategy: Dict[str, Callable[[], BaseStrategy]] = {
|
||||
"seq": lambda: SeqStrategy(model),
|
||||
"sft": lambda: SftStrategy(model),
|
||||
"dpo": lambda: DpoStrategy(model, pad_token_id, dpo_beta)
|
||||
"dpo": lambda: DpoStrategy(model, kwargs.pop("pad_token_id") , kwargs.pop("dpo_beta"))
|
||||
}
|
||||
strategy = train_strategy[train_type]()
|
||||
return strategy
|
||||
|
|
|
|||
Loading…
Reference in New Issue