From 493fe4e84bf52c9b1170cf891681e13b2c4acdfd Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Fri, 6 Mar 2026 11:41:14 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=A2=9E=E5=8A=A0=20label=20smothing?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/trainer/strategy.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/khaosz/trainer/strategy.py b/khaosz/trainer/strategy.py index e0adbb7..c4dc7f7 100644 --- a/khaosz/trainer/strategy.py +++ b/khaosz/trainer/strategy.py @@ -51,8 +51,9 @@ class BaseStrategy(ABC): class SeqStrategy(BaseStrategy): - def __init__(self, model, device): + def __init__(self, model, device, label_smoothing): super().__init__(model, device) + self.label_smoothing = label_smoothing def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: batch = move_to_device(batch, self.device) @@ -68,8 +69,9 @@ class SeqStrategy(BaseStrategy): class SftStrategy(BaseStrategy): - def __init__(self, model, device): + def __init__(self, model, device, label_smoothing): super().__init__(model, device) + self.label_smoothing = label_smoothing def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: batch = move_to_device(batch, self.device) @@ -124,8 +126,16 @@ class StrategyFactory: def load(model, train_type, device, **kwargs): train_strategy: Dict[str, Callable[[], BaseStrategy]] = { - "seq": lambda: SeqStrategy(model, device), - "sft": lambda: SftStrategy(model, device), + "seq": lambda: SeqStrategy( + model, + device, + kwargs.get("label_smoothing", 0.0) + ), + "sft": lambda: SftStrategy( + model, + device, + kwargs.get("label_smoothing", 0.0) + ), "dpo": lambda: DpoStrategy( model, device, @@ -134,4 +144,4 @@ class StrategyFactory: ) } strategy = train_strategy[train_type]() - return strategy \ No newline at end of file + return strategy \ No newline at end of file