feat(strategy): 重构mask构建逻辑并优化策略工厂参数传递

This commit is contained in:
ViperEkura 2025-09-27 13:12:57 +08:00
parent a4443765ee
commit 676fdd59d7
2 changed files with 24 additions and 112 deletions

View File

@ -1,55 +0,0 @@
import torch
from abc import abstractmethod
from torch import Tensor
class MaskBuilder:
def __init__(
self,
bos_token_id: int,
eos_token_id: int,
user_token_id: int,
system_token_id: int,
):
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.user_token_id = user_token_id
self.system_token_id = system_token_id
@abstractmethod
def build(input_ids: Tensor) -> Tensor:
raise NotImplementedError
class LossMaskBuilder(MaskBuilder):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def build(self, input_ids: Tensor) -> Tensor:
token_markers = torch.zeros_like(input_ids, dtype=torch.int8)
is_user_token = input_ids.eq(self.user_token_id)
is_system_token = input_ids.eq(self.system_token_id)
token_markers[is_user_token] = 1
token_markers[is_system_token] = -1
cumulative_markers = torch.cumsum(token_markers, dim=-1)
min_cumulative = cumulative_markers.min(dim=-1, keepdim=True).values
loss_mask = cumulative_markers - min_cumulative
return loss_mask
class AttentionMaskBuilder:
def __init__(self, **kwargs):
super().__init__(**kwargs)
def build(input_ids: Tensor):
bsz = input_ids.size(0)

View File

@ -32,67 +32,34 @@ def get_logprobs(model:nn.Module, input_ids: Tensor, mask: Tensor, pad_token_id)
return (token_logprobs * valid_mask).sum(dim=-1) return (token_logprobs * valid_mask).sum(dim=-1)
def build_loss_mask(input_ids: Tensor, bos_token_id: int, eos_token_id: int) -> Tensor:
token_markers = torch.zeros_like(input_ids, dtype=torch.int8)
class MaskBuilder: is_bos_token = input_ids.eq(bos_token_id)
def __init__( is_eos_token = input_ids.eq(eos_token_id)
self,
bos_token_id: int,
eos_token_id: int,
user_token_id: int,
system_token_id: int,
): token_markers[is_bos_token] = 1
self.bos_token_id = bos_token_id token_markers[is_eos_token] = -1
self.eos_token_id = eos_token_id
self.user_token_id = user_token_id
self.system_token_id = system_token_id
@abstractmethod cumulative_markers = torch.cumsum(token_markers, dim=-1)
def build(input_ids: Tensor) -> Tensor: min_cumulative = cumulative_markers.min(dim=-1, keepdim=True).values
raise NotImplementedError loss_mask = cumulative_markers - min_cumulative
return loss_mask.to(dtype=torch.bool)
def build_attention_mask(input_ids: Tensor, user_token_id: int, multi_turn: bool = False) -> Tensor:
bsz, seq_len = input_ids.size()
is_user_token = input_ids.eq(user_token_id)
turn_id = is_user_token.cumsum(dim=-1)
class LossMaskBuilder(MaskBuilder): iq = turn_id.view(bsz, seq_len, 1)
def __init__(self, **kwargs): ik = turn_id.view(bsz, 1, seq_len)
super().__init__(**kwargs)
def build(self, input_ids: Tensor) -> Tensor:
token_markers = torch.zeros_like(input_ids, dtype=torch.int8)
is_user_token = input_ids.eq(self.user_token_id)
is_system_token = input_ids.eq(self.system_token_id)
token_markers[is_user_token] = 1
token_markers[is_system_token] = -1
cumulative_markers = torch.cumsum(token_markers, dim=-1)
min_cumulative = cumulative_markers.min(dim=-1, keepdim=True).values
loss_mask = cumulative_markers - min_cumulative
return loss_mask.to(dtype=torch.bool)
class AttentionMaskBuilder(MaskBuilder):
def __init__(self, multi_turn=False, **kwargs):
super().__init__(**kwargs)
self.multi_turn = multi_turn
def build(self, input_ids: Tensor):
bsz = input_ids.size(0)
def _build_batch(self, input_ids: Tensor):
is_user_token = input_ids.eq(self.user_token_id)
token_markers = torch.zeros_like(input_ids, dtype=torch.int8)
token_markers[is_user_token] = 1
cumulative_markers = torch.cumsum(token_markers, dim=-1)
seq_mask = (iq <= ik) if multi_turn else (iq == ik)
causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=input_ids.device)).bool()
attention_mask = seq_mask & causal_mask
return attention_mask
class BaseStrategy(ABC): class BaseStrategy(ABC):
@ -209,11 +176,11 @@ class PpoStrategy(BaseStrategy):
class StrategyFactory: class StrategyFactory:
def load(model, train_type, pad_token_id, dpo_beta): def load(model, train_type, **kwargs):
train_strategy: Dict[str, Callable[[], BaseStrategy]] = { train_strategy: Dict[str, Callable[[], BaseStrategy]] = {
"seq": lambda: SeqStrategy(model), "seq": lambda: SeqStrategy(model),
"sft": lambda: SftStrategy(model), "sft": lambda: SftStrategy(model),
"dpo": lambda: DpoStrategy(model, pad_token_id, dpo_beta) "dpo": lambda: DpoStrategy(model, kwargs.pop("pad_token_id") , kwargs.pop("dpo_beta"))
} }
strategy = train_strategy[train_type]() strategy = train_strategy[train_type]()
return strategy return strategy