From 780b9e1855f79b8aa11719822230e96f3ef97b24 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Tue, 31 Mar 2026 01:23:29 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E5=8F=82=E6=95=B0?= =?UTF-8?q?=E4=BC=A0=E9=80=92=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/trainer/strategy.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/khaosz/trainer/strategy.py b/khaosz/trainer/strategy.py index 1f959b3..4d973cf 100644 --- a/khaosz/trainer/strategy.py +++ b/khaosz/trainer/strategy.py @@ -83,10 +83,11 @@ class BaseStrategy(ABC): """Abstract base class for training strategies.""" def __init__( - self, model: Union[nn.Module, Callable[..., Dict[str, Tensor]]], device: str + self, model: Union[Callable[..., Dict[str, Tensor]]], device: str, **kwargs ): self.model = model self.device = device + self.extra_kwargs = kwargs @abstractmethod def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: @@ -191,8 +192,8 @@ class SEQStrategy(BaseStrategy): Computes cross-entropy loss for next token prediction. """ - def __init__(self, model, device, label_smoothing: float = 0.0): - super().__init__(model, device) + def __init__(self, model, device, label_smoothing: float = 0.0, **kwargs): + super().__init__(model, device, **kwargs) self.label_smoothing = label_smoothing def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: @@ -216,8 +217,8 @@ class SFTStrategy(BaseStrategy): Applies cross-entropy loss only to tokens where loss_mask is True. """ - def __init__(self, model, device, label_smoothing: float = 0.0): - super().__init__(model, device) + def __init__(self, model, device, label_smoothing: float = 0.0, **kwargs): + super().__init__(model, device, **kwargs) self.label_smoothing = label_smoothing def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: @@ -256,8 +257,9 @@ class DPOStrategy(BaseStrategy): device: str, beta: float = 0.1, reduction: str = "mean", + **kwargs, ): - super().__init__(model, device) + super().__init__(model, device, **kwargs) self.ref_model = create_ref_model(model) self.beta = beta self.reduction = reduction @@ -306,8 +308,9 @@ class GRPOStrategy(BaseStrategy): kl_coef: float = 0.01, group_size: int = 4, reduction: str = "mean", + **kwargs, ): - super().__init__(model, device) + super().__init__(model, device, **kwargs) self.ref_model = create_ref_model(model) self.clip_eps = clip_eps self.kl_coef = kl_coef