feat(strategy): 重构mask构建逻辑并优化策略工厂参数传递

This commit is contained in:
ViperEkura 2025-09-27 13:12:57 +08:00
parent a4443765ee
commit 676fdd59d7
2 changed files with 24 additions and 112 deletions

View File

@ -1,55 +0,0 @@
import torch
from abc import abstractmethod
from torch import Tensor
class MaskBuilder:
def __init__(
self,
bos_token_id: int,
eos_token_id: int,
user_token_id: int,
system_token_id: int,
):
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.user_token_id = user_token_id
self.system_token_id = system_token_id
@abstractmethod
def build(input_ids: Tensor) -> Tensor:
raise NotImplementedError
class LossMaskBuilder(MaskBuilder):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def build(self, input_ids: Tensor) -> Tensor:
token_markers = torch.zeros_like(input_ids, dtype=torch.int8)
is_user_token = input_ids.eq(self.user_token_id)
is_system_token = input_ids.eq(self.system_token_id)
token_markers[is_user_token] = 1
token_markers[is_system_token] = -1
cumulative_markers = torch.cumsum(token_markers, dim=-1)
min_cumulative = cumulative_markers.min(dim=-1, keepdim=True).values
loss_mask = cumulative_markers - min_cumulative
return loss_mask
class AttentionMaskBuilder:
def __init__(self, **kwargs):
super().__init__(**kwargs)
def build(input_ids: Tensor):
bsz = input_ids.size(0)

View File

@ -32,67 +32,34 @@ def get_logprobs(model:nn.Module, input_ids: Tensor, mask: Tensor, pad_token_id)
return (token_logprobs * valid_mask).sum(dim=-1)
class MaskBuilder:
def __init__(
self,
bos_token_id: int,
eos_token_id: int,
user_token_id: int,
system_token_id: int,
):
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.user_token_id = user_token_id
self.system_token_id = system_token_id
def build_loss_mask(input_ids: Tensor, bos_token_id: int, eos_token_id: int) -> Tensor:
token_markers = torch.zeros_like(input_ids, dtype=torch.int8)
@abstractmethod
def build(input_ids: Tensor) -> Tensor:
raise NotImplementedError
class LossMaskBuilder(MaskBuilder):
def __init__(self, **kwargs):
super().__init__(**kwargs)
is_bos_token = input_ids.eq(bos_token_id)
is_eos_token = input_ids.eq(eos_token_id)
def build(self, input_ids: Tensor) -> Tensor:
token_markers = torch.zeros_like(input_ids, dtype=torch.int8)
is_user_token = input_ids.eq(self.user_token_id)
is_system_token = input_ids.eq(self.system_token_id)
token_markers[is_user_token] = 1
token_markers[is_system_token] = -1
cumulative_markers = torch.cumsum(token_markers, dim=-1)
min_cumulative = cumulative_markers.min(dim=-1, keepdim=True).values
loss_mask = cumulative_markers - min_cumulative
token_markers[is_bos_token] = 1
token_markers[is_eos_token] = -1
return loss_mask.to(dtype=torch.bool)
cumulative_markers = torch.cumsum(token_markers, dim=-1)
min_cumulative = cumulative_markers.min(dim=-1, keepdim=True).values
loss_mask = cumulative_markers - min_cumulative
class AttentionMaskBuilder(MaskBuilder):
def __init__(self, multi_turn=False, **kwargs):
super().__init__(**kwargs)
self.multi_turn = multi_turn
return loss_mask.to(dtype=torch.bool)
def build_attention_mask(input_ids: Tensor, user_token_id: int, multi_turn: bool = False) -> Tensor:
bsz, seq_len = input_ids.size()
is_user_token = input_ids.eq(user_token_id)
turn_id = is_user_token.cumsum(dim=-1)
def build(self, input_ids: Tensor):
bsz = input_ids.size(0)
iq = turn_id.view(bsz, seq_len, 1)
ik = turn_id.view(bsz, 1, seq_len)
def _build_batch(self, input_ids: Tensor):
is_user_token = input_ids.eq(self.user_token_id)
token_markers = torch.zeros_like(input_ids, dtype=torch.int8)
token_markers[is_user_token] = 1
cumulative_markers = torch.cumsum(token_markers, dim=-1)
seq_mask = (iq <= ik) if multi_turn else (iq == ik)
causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=input_ids.device)).bool()
attention_mask = seq_mask & causal_mask
return attention_mask
class BaseStrategy(ABC):
@ -209,11 +176,11 @@ class PpoStrategy(BaseStrategy):
class StrategyFactory:
def load(model, train_type, pad_token_id, dpo_beta):
def load(model, train_type, **kwargs):
train_strategy: Dict[str, Callable[[], BaseStrategy]] = {
"seq": lambda: SeqStrategy(model),
"sft": lambda: SftStrategy(model),
"dpo": lambda: DpoStrategy(model, pad_token_id, dpo_beta)
"dpo": lambda: DpoStrategy(model, kwargs.pop("pad_token_id") , kwargs.pop("dpo_beta"))
}
strategy = train_strategy[train_type]()
return strategy