chore: 修改策略命名
This commit is contained in:
parent
50f76cd7c7
commit
361cdeb296
|
|
@ -56,7 +56,7 @@ class BaseStrategy(ABC):
|
||||||
return self.compute_loss(batch)
|
return self.compute_loss(batch)
|
||||||
|
|
||||||
|
|
||||||
class SeqStrategy(BaseStrategy):
|
class SEQStrategy(BaseStrategy):
|
||||||
def __init__(self, model, device, label_smoothing):
|
def __init__(self, model, device, label_smoothing):
|
||||||
super().__init__(model, device)
|
super().__init__(model, device)
|
||||||
self.label_smoothing = label_smoothing
|
self.label_smoothing = label_smoothing
|
||||||
|
|
@ -74,7 +74,7 @@ class SeqStrategy(BaseStrategy):
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
class SftStrategy(BaseStrategy):
|
class SFTStrategy(BaseStrategy):
|
||||||
def __init__(self, model, device, label_smoothing):
|
def __init__(self, model, device, label_smoothing):
|
||||||
super().__init__(model, device)
|
super().__init__(model, device)
|
||||||
self.label_smoothing = label_smoothing
|
self.label_smoothing = label_smoothing
|
||||||
|
|
@ -96,7 +96,7 @@ class SftStrategy(BaseStrategy):
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
class DpoStrategy(BaseStrategy):
|
class DPOStrategy(BaseStrategy):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model,
|
model,
|
||||||
|
|
@ -141,7 +141,7 @@ class DpoStrategy(BaseStrategy):
|
||||||
return dpo_loss
|
return dpo_loss
|
||||||
|
|
||||||
|
|
||||||
class GrpoStrategy(BaseStrategy):
|
class GRPOStrategy(BaseStrategy):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -211,23 +211,23 @@ class StrategyFactory:
|
||||||
|
|
||||||
def load(model, train_type, device, **kwargs):
|
def load(model, train_type, device, **kwargs):
|
||||||
train_strategy: Dict[str, Callable[[], BaseStrategy]] = {
|
train_strategy: Dict[str, Callable[[], BaseStrategy]] = {
|
||||||
"seq": lambda: SeqStrategy(
|
"seq": lambda: SEQStrategy(
|
||||||
model,
|
model,
|
||||||
device,
|
device,
|
||||||
kwargs.get("label_smoothing", 0.0)
|
kwargs.get("label_smoothing", 0.0)
|
||||||
),
|
),
|
||||||
"sft": lambda: SftStrategy(
|
"sft": lambda: SFTStrategy(
|
||||||
model,
|
model,
|
||||||
device,
|
device,
|
||||||
kwargs.get("label_smoothing", 0.0)
|
kwargs.get("label_smoothing", 0.0)
|
||||||
),
|
),
|
||||||
"dpo": lambda: DpoStrategy(
|
"dpo": lambda: DPOStrategy(
|
||||||
model,
|
model,
|
||||||
device,
|
device,
|
||||||
kwargs.get("dpo_beta"),
|
kwargs.get("dpo_beta"),
|
||||||
kwargs.get("reduction", "mean")
|
kwargs.get("reduction", "mean")
|
||||||
),
|
),
|
||||||
"grpo": lambda: GrpoStrategy(
|
"grpo": lambda: GRPOStrategy(
|
||||||
model,
|
model,
|
||||||
device,
|
device,
|
||||||
kwargs.get("grpo_clip_eps"),
|
kwargs.get("grpo_clip_eps"),
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue