diff --git a/khaosz/trainer/strategy.py b/khaosz/trainer/strategy.py index 8df3fd8..49c1744 100644 --- a/khaosz/trainer/strategy.py +++ b/khaosz/trainer/strategy.py @@ -56,7 +56,7 @@ class BaseStrategy(ABC): return self.compute_loss(batch) -class SeqStrategy(BaseStrategy): +class SEQStrategy(BaseStrategy): def __init__(self, model, device, label_smoothing): super().__init__(model, device) self.label_smoothing = label_smoothing @@ -74,7 +74,7 @@ class SeqStrategy(BaseStrategy): return loss -class SftStrategy(BaseStrategy): +class SFTStrategy(BaseStrategy): def __init__(self, model, device, label_smoothing): super().__init__(model, device) self.label_smoothing = label_smoothing @@ -96,7 +96,7 @@ class SftStrategy(BaseStrategy): return loss -class DpoStrategy(BaseStrategy): +class DPOStrategy(BaseStrategy): def __init__( self, model, @@ -141,7 +141,7 @@ class DpoStrategy(BaseStrategy): return dpo_loss -class GrpoStrategy(BaseStrategy): +class GRPOStrategy(BaseStrategy): def __init__( self, @@ -211,23 +211,23 @@ class StrategyFactory: def load(model, train_type, device, **kwargs): train_strategy: Dict[str, Callable[[], BaseStrategy]] = { - "seq": lambda: SeqStrategy( + "seq": lambda: SEQStrategy( model, device, kwargs.get("label_smoothing", 0.0) ), - "sft": lambda: SftStrategy( + "sft": lambda: SFTStrategy( model, device, kwargs.get("label_smoothing", 0.0) ), - "dpo": lambda: DpoStrategy( + "dpo": lambda: DPOStrategy( model, device, kwargs.get("dpo_beta"), kwargs.get("reduction", "mean") ), - "grpo": lambda: GrpoStrategy( + "grpo": lambda: GRPOStrategy( model, device, kwargs.get("grpo_clip_eps"),