feat: 增加 label smothing
This commit is contained in:
parent
82d22c5742
commit
493fe4e84b
|
|
@ -51,8 +51,9 @@ class BaseStrategy(ABC):
|
|||
|
||||
|
||||
class SeqStrategy(BaseStrategy):
|
||||
def __init__(self, model, device):
|
||||
def __init__(self, model, device, label_smoothing):
|
||||
super().__init__(model, device)
|
||||
self.label_smoothing = label_smoothing
|
||||
|
||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||
batch = move_to_device(batch, self.device)
|
||||
|
|
@ -68,8 +69,9 @@ class SeqStrategy(BaseStrategy):
|
|||
|
||||
|
||||
class SftStrategy(BaseStrategy):
|
||||
def __init__(self, model, device):
|
||||
def __init__(self, model, device, label_smoothing):
|
||||
super().__init__(model, device)
|
||||
self.label_smoothing = label_smoothing
|
||||
|
||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||
batch = move_to_device(batch, self.device)
|
||||
|
|
@ -124,8 +126,16 @@ class StrategyFactory:
|
|||
|
||||
def load(model, train_type, device, **kwargs):
|
||||
train_strategy: Dict[str, Callable[[], BaseStrategy]] = {
|
||||
"seq": lambda: SeqStrategy(model, device),
|
||||
"sft": lambda: SftStrategy(model, device),
|
||||
"seq": lambda: SeqStrategy(
|
||||
model,
|
||||
device,
|
||||
kwargs.get("label_smoothing", 0.0)
|
||||
),
|
||||
"sft": lambda: SftStrategy(
|
||||
model,
|
||||
device,
|
||||
kwargs.get("label_smoothing", 0.0)
|
||||
),
|
||||
"dpo": lambda: DpoStrategy(
|
||||
model,
|
||||
device,
|
||||
|
|
|
|||
Loading…
Reference in New Issue