chore: 修改策略命名

This commit is contained in:
ViperEkura 2026-03-19 23:08:41 +08:00
parent 50f76cd7c7
commit 361cdeb296
1 changed files with 8 additions and 8 deletions

View File

@ -56,7 +56,7 @@ class BaseStrategy(ABC):
return self.compute_loss(batch)
class SeqStrategy(BaseStrategy):
class SEQStrategy(BaseStrategy):
def __init__(self, model, device, label_smoothing):
super().__init__(model, device)
self.label_smoothing = label_smoothing
@ -74,7 +74,7 @@ class SeqStrategy(BaseStrategy):
return loss
class SftStrategy(BaseStrategy):
class SFTStrategy(BaseStrategy):
def __init__(self, model, device, label_smoothing):
super().__init__(model, device)
self.label_smoothing = label_smoothing
@ -96,7 +96,7 @@ class SftStrategy(BaseStrategy):
return loss
class DpoStrategy(BaseStrategy):
class DPOStrategy(BaseStrategy):
def __init__(
self,
model,
@ -141,7 +141,7 @@ class DpoStrategy(BaseStrategy):
return dpo_loss
class GrpoStrategy(BaseStrategy):
class GRPOStrategy(BaseStrategy):
def __init__(
self,
@ -211,23 +211,23 @@ class StrategyFactory:
def load(model, train_type, device, **kwargs):
train_strategy: Dict[str, Callable[[], BaseStrategy]] = {
"seq": lambda: SeqStrategy(
"seq": lambda: SEQStrategy(
model,
device,
kwargs.get("label_smoothing", 0.0)
),
"sft": lambda: SftStrategy(
"sft": lambda: SFTStrategy(
model,
device,
kwargs.get("label_smoothing", 0.0)
),
"dpo": lambda: DpoStrategy(
"dpo": lambda: DPOStrategy(
model,
device,
kwargs.get("dpo_beta"),
kwargs.get("reduction", "mean")
),
"grpo": lambda: GrpoStrategy(
"grpo": lambda: GRPOStrategy(
model,
device,
kwargs.get("grpo_clip_eps"),