From 053f4a4dade160480309bcf53f37aebe2c61cb6b Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sat, 27 Sep 2025 13:24:16 +0800 Subject: [PATCH] =?UTF-8?q?feat(=20StrategyFactory):=20=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=20SFT=20=E7=AD=96=E7=95=A5=E5=88=9D=E5=A7=8B=E5=8C=96=E5=8F=82?= =?UTF-8?q?=E6=95=B0=E5=B9=B6=E5=AE=8C=E5=96=84=E5=B7=A5=E5=8E=82=E6=96=B9?= =?UTF-8?q?=E6=B3=95=E8=B0=83=E7=94=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/trainer/strategy.py | 18 ++++++++++++++---- khaosz/trainer/trainer.py | 8 +++++--- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/khaosz/trainer/strategy.py b/khaosz/trainer/strategy.py index 275f933..7112684 100644 --- a/khaosz/trainer/strategy.py +++ b/khaosz/trainer/strategy.py @@ -12,7 +12,7 @@ from abc import ABC, abstractmethod from dataclasses import asdict, dataclass, field -def get_logprobs(model:nn.Module, input_ids: Tensor, mask: Tensor, pad_token_id): +def get_logprobs(model:nn.Module, input_ids: Tensor, mask: Tensor, pad_token_id: int): input_mask = input_ids.ne(pad_token_id) logits = model(input_ids, input_mask)["logits"] log_probs = torch.log_softmax(logits, dim=-1) @@ -47,7 +47,7 @@ def build_loss_mask(input_ids: Tensor, bos_token_id: int, eos_token_id: int) -> return loss_mask.to(dtype=torch.bool) -def build_attention_mask(input_ids: Tensor, user_token_id: int, multi_turn: bool = False) -> Tensor: +def build_attention_mask(input_ids: Tensor, user_token_id: int, multi_turn: bool) -> Tensor: bsz, seq_len = input_ids.size() is_user_token = input_ids.eq(user_token_id) turn_id = is_user_token.cumsum(dim=-1) @@ -90,8 +90,18 @@ class SeqStrategy(BaseStrategy): class SftStrategy(BaseStrategy): - def __init__(self, model): + def __init__( + self, + model: nn.Module, + bos_id: int, + eos_id: int, + user_token_id: int, + multi_turn: bool + ): super().__init__(model) + + self.loss_mask_builder = lambda x: build_loss_mask(x, bos_id, eos_id) + self.attn_mask_builder = lambda x: build_attention_mask(x, user_token_id, multi_turn) def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor: x, y, loss_mask = batch @@ -179,7 +189,7 @@ class StrategyFactory: def load(model, train_type, **kwargs): train_strategy: Dict[str, Callable[[], BaseStrategy]] = { "seq": lambda: SeqStrategy(model), - "sft": lambda: SftStrategy(model), + "sft": lambda: SftStrategy(model, kwargs.pop("bos_token_id"), kwargs.pop("eos_token_id"), kwargs.pop("multi_turn")), "dpo": lambda: DpoStrategy(model, kwargs.pop("pad_token_id") , kwargs.pop("dpo_beta")) } strategy = train_strategy[train_type]() diff --git a/khaosz/trainer/trainer.py b/khaosz/trainer/trainer.py index a0ea0fe..85293eb 100644 --- a/khaosz/trainer/trainer.py +++ b/khaosz/trainer/trainer.py @@ -86,9 +86,11 @@ class Trainer: strategy = StrategyFactory.load( self.model, - train_config.train_type, - self.tokenizer.pad_id, - train_config.dpo_beta + train_type=train_config.train_type, + bos_token_id=self.tokenizer.bos_id, + eos_token_id=self.tokenizer.eos_id, + pad_token_id=self.tokenizer.pad_id, + dpo_beta=train_config.dpo_beta ) scheduler = LambdaLR(