diff --git a/khaosz/trainer/strategy.py b/khaosz/trainer/strategy.py index e0adbb7..c4dc7f7 100644 --- a/khaosz/trainer/strategy.py +++ b/khaosz/trainer/strategy.py @@ -51,8 +51,9 @@ class BaseStrategy(ABC): class SeqStrategy(BaseStrategy): - def __init__(self, model, device): + def __init__(self, model, device, label_smoothing): super().__init__(model, device) + self.label_smoothing = label_smoothing def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: batch = move_to_device(batch, self.device) @@ -68,8 +69,9 @@ class SeqStrategy(BaseStrategy): class SftStrategy(BaseStrategy): - def __init__(self, model, device): + def __init__(self, model, device, label_smoothing): super().__init__(model, device) + self.label_smoothing = label_smoothing def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: batch = move_to_device(batch, self.device) @@ -124,8 +126,16 @@ class StrategyFactory: def load(model, train_type, device, **kwargs): train_strategy: Dict[str, Callable[[], BaseStrategy]] = { - "seq": lambda: SeqStrategy(model, device), - "sft": lambda: SftStrategy(model, device), + "seq": lambda: SeqStrategy( + model, + device, + kwargs.get("label_smoothing", 0.0) + ), + "sft": lambda: SftStrategy( + model, + device, + kwargs.get("label_smoothing", 0.0) + ), "dpo": lambda: DpoStrategy( model, device, @@ -134,4 +144,4 @@ class StrategyFactory: ) } strategy = train_strategy[train_type]() - return strategy \ No newline at end of file + return strategy \ No newline at end of file