diff --git a/khaosz/trainer/strategy.py b/khaosz/trainer/strategy.py index 275f933..7112684 100644 --- a/khaosz/trainer/strategy.py +++ b/khaosz/trainer/strategy.py @@ -12,7 +12,7 @@ from abc import ABC, abstractmethod from dataclasses import asdict, dataclass, field -def get_logprobs(model:nn.Module, input_ids: Tensor, mask: Tensor, pad_token_id): +def get_logprobs(model:nn.Module, input_ids: Tensor, mask: Tensor, pad_token_id: int): input_mask = input_ids.ne(pad_token_id) logits = model(input_ids, input_mask)["logits"] log_probs = torch.log_softmax(logits, dim=-1) @@ -47,7 +47,7 @@ def build_loss_mask(input_ids: Tensor, bos_token_id: int, eos_token_id: int) -> return loss_mask.to(dtype=torch.bool) -def build_attention_mask(input_ids: Tensor, user_token_id: int, multi_turn: bool = False) -> Tensor: +def build_attention_mask(input_ids: Tensor, user_token_id: int, multi_turn: bool) -> Tensor: bsz, seq_len = input_ids.size() is_user_token = input_ids.eq(user_token_id) turn_id = is_user_token.cumsum(dim=-1) @@ -90,8 +90,18 @@ class SeqStrategy(BaseStrategy): class SftStrategy(BaseStrategy): - def __init__(self, model): + def __init__( + self, + model: nn.Module, + bos_id: int, + eos_id: int, + user_token_id: int, + multi_turn: bool + ): super().__init__(model) + + self.loss_mask_builder = lambda x: build_loss_mask(x, bos_id, eos_id) + self.attn_mask_builder = lambda x: build_attention_mask(x, user_token_id, multi_turn) def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor: x, y, loss_mask = batch @@ -179,7 +189,7 @@ class StrategyFactory: def load(model, train_type, **kwargs): train_strategy: Dict[str, Callable[[], BaseStrategy]] = { "seq": lambda: SeqStrategy(model), - "sft": lambda: SftStrategy(model), + "sft": lambda: SftStrategy(model, kwargs.pop("bos_token_id"), kwargs.pop("eos_token_id"), kwargs.pop("multi_turn")), "dpo": lambda: DpoStrategy(model, kwargs.pop("pad_token_id") , kwargs.pop("dpo_beta")) } strategy = train_strategy[train_type]() diff --git a/khaosz/trainer/trainer.py b/khaosz/trainer/trainer.py index a0ea0fe..85293eb 100644 --- a/khaosz/trainer/trainer.py +++ b/khaosz/trainer/trainer.py @@ -86,9 +86,11 @@ class Trainer: strategy = StrategyFactory.load( self.model, - train_config.train_type, - self.tokenizer.pad_id, - train_config.dpo_beta + train_type=train_config.train_type, + bos_token_id=self.tokenizer.bos_id, + eos_token_id=self.tokenizer.eos_id, + pad_token_id=self.tokenizer.pad_id, + dpo_beta=train_config.dpo_beta ) scheduler = LambdaLR(