chore: 修改策略命名

This commit is contained in:
ViperEkura 2026-03-19 23:08:41 +08:00
parent 50f76cd7c7
commit 361cdeb296
1 changed files with 8 additions and 8 deletions

View File

@ -56,7 +56,7 @@ class BaseStrategy(ABC):
return self.compute_loss(batch) return self.compute_loss(batch)
class SeqStrategy(BaseStrategy): class SEQStrategy(BaseStrategy):
def __init__(self, model, device, label_smoothing): def __init__(self, model, device, label_smoothing):
super().__init__(model, device) super().__init__(model, device)
self.label_smoothing = label_smoothing self.label_smoothing = label_smoothing
@ -74,7 +74,7 @@ class SeqStrategy(BaseStrategy):
return loss return loss
class SftStrategy(BaseStrategy): class SFTStrategy(BaseStrategy):
def __init__(self, model, device, label_smoothing): def __init__(self, model, device, label_smoothing):
super().__init__(model, device) super().__init__(model, device)
self.label_smoothing = label_smoothing self.label_smoothing = label_smoothing
@ -96,7 +96,7 @@ class SftStrategy(BaseStrategy):
return loss return loss
class DpoStrategy(BaseStrategy): class DPOStrategy(BaseStrategy):
def __init__( def __init__(
self, self,
model, model,
@ -141,7 +141,7 @@ class DpoStrategy(BaseStrategy):
return dpo_loss return dpo_loss
class GrpoStrategy(BaseStrategy): class GRPOStrategy(BaseStrategy):
def __init__( def __init__(
self, self,
@ -211,23 +211,23 @@ class StrategyFactory:
def load(model, train_type, device, **kwargs): def load(model, train_type, device, **kwargs):
train_strategy: Dict[str, Callable[[], BaseStrategy]] = { train_strategy: Dict[str, Callable[[], BaseStrategy]] = {
"seq": lambda: SeqStrategy( "seq": lambda: SEQStrategy(
model, model,
device, device,
kwargs.get("label_smoothing", 0.0) kwargs.get("label_smoothing", 0.0)
), ),
"sft": lambda: SftStrategy( "sft": lambda: SFTStrategy(
model, model,
device, device,
kwargs.get("label_smoothing", 0.0) kwargs.get("label_smoothing", 0.0)
), ),
"dpo": lambda: DpoStrategy( "dpo": lambda: DPOStrategy(
model, model,
device, device,
kwargs.get("dpo_beta"), kwargs.get("dpo_beta"),
kwargs.get("reduction", "mean") kwargs.get("reduction", "mean")
), ),
"grpo": lambda: GrpoStrategy( "grpo": lambda: GRPOStrategy(
model, model,
device, device,
kwargs.get("grpo_clip_eps"), kwargs.get("grpo_clip_eps"),