chore: 修改策略命名
This commit is contained in:
parent
50f76cd7c7
commit
361cdeb296
|
|
@ -56,7 +56,7 @@ class BaseStrategy(ABC):
|
|||
return self.compute_loss(batch)
|
||||
|
||||
|
||||
class SeqStrategy(BaseStrategy):
|
||||
class SEQStrategy(BaseStrategy):
|
||||
def __init__(self, model, device, label_smoothing):
|
||||
super().__init__(model, device)
|
||||
self.label_smoothing = label_smoothing
|
||||
|
|
@ -74,7 +74,7 @@ class SeqStrategy(BaseStrategy):
|
|||
return loss
|
||||
|
||||
|
||||
class SftStrategy(BaseStrategy):
|
||||
class SFTStrategy(BaseStrategy):
|
||||
def __init__(self, model, device, label_smoothing):
|
||||
super().__init__(model, device)
|
||||
self.label_smoothing = label_smoothing
|
||||
|
|
@ -96,7 +96,7 @@ class SftStrategy(BaseStrategy):
|
|||
return loss
|
||||
|
||||
|
||||
class DpoStrategy(BaseStrategy):
|
||||
class DPOStrategy(BaseStrategy):
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
|
|
@ -141,7 +141,7 @@ class DpoStrategy(BaseStrategy):
|
|||
return dpo_loss
|
||||
|
||||
|
||||
class GrpoStrategy(BaseStrategy):
|
||||
class GRPOStrategy(BaseStrategy):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -211,23 +211,23 @@ class StrategyFactory:
|
|||
|
||||
def load(model, train_type, device, **kwargs):
|
||||
train_strategy: Dict[str, Callable[[], BaseStrategy]] = {
|
||||
"seq": lambda: SeqStrategy(
|
||||
"seq": lambda: SEQStrategy(
|
||||
model,
|
||||
device,
|
||||
kwargs.get("label_smoothing", 0.0)
|
||||
),
|
||||
"sft": lambda: SftStrategy(
|
||||
"sft": lambda: SFTStrategy(
|
||||
model,
|
||||
device,
|
||||
kwargs.get("label_smoothing", 0.0)
|
||||
),
|
||||
"dpo": lambda: DpoStrategy(
|
||||
"dpo": lambda: DPOStrategy(
|
||||
model,
|
||||
device,
|
||||
kwargs.get("dpo_beta"),
|
||||
kwargs.get("reduction", "mean")
|
||||
),
|
||||
"grpo": lambda: GrpoStrategy(
|
||||
"grpo": lambda: GRPOStrategy(
|
||||
model,
|
||||
device,
|
||||
kwargs.get("grpo_clip_eps"),
|
||||
|
|
|
|||
Loading…
Reference in New Issue