feat: 增加 label smothing
This commit is contained in:
parent
82d22c5742
commit
493fe4e84b
|
|
@ -51,8 +51,9 @@ class BaseStrategy(ABC):
|
||||||
|
|
||||||
|
|
||||||
class SeqStrategy(BaseStrategy):
|
class SeqStrategy(BaseStrategy):
|
||||||
def __init__(self, model, device):
|
def __init__(self, model, device, label_smoothing):
|
||||||
super().__init__(model, device)
|
super().__init__(model, device)
|
||||||
|
self.label_smoothing = label_smoothing
|
||||||
|
|
||||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||||
batch = move_to_device(batch, self.device)
|
batch = move_to_device(batch, self.device)
|
||||||
|
|
@ -68,8 +69,9 @@ class SeqStrategy(BaseStrategy):
|
||||||
|
|
||||||
|
|
||||||
class SftStrategy(BaseStrategy):
|
class SftStrategy(BaseStrategy):
|
||||||
def __init__(self, model, device):
|
def __init__(self, model, device, label_smoothing):
|
||||||
super().__init__(model, device)
|
super().__init__(model, device)
|
||||||
|
self.label_smoothing = label_smoothing
|
||||||
|
|
||||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||||
batch = move_to_device(batch, self.device)
|
batch = move_to_device(batch, self.device)
|
||||||
|
|
@ -124,8 +126,16 @@ class StrategyFactory:
|
||||||
|
|
||||||
def load(model, train_type, device, **kwargs):
|
def load(model, train_type, device, **kwargs):
|
||||||
train_strategy: Dict[str, Callable[[], BaseStrategy]] = {
|
train_strategy: Dict[str, Callable[[], BaseStrategy]] = {
|
||||||
"seq": lambda: SeqStrategy(model, device),
|
"seq": lambda: SeqStrategy(
|
||||||
"sft": lambda: SftStrategy(model, device),
|
model,
|
||||||
|
device,
|
||||||
|
kwargs.get("label_smoothing", 0.0)
|
||||||
|
),
|
||||||
|
"sft": lambda: SftStrategy(
|
||||||
|
model,
|
||||||
|
device,
|
||||||
|
kwargs.get("label_smoothing", 0.0)
|
||||||
|
),
|
||||||
"dpo": lambda: DpoStrategy(
|
"dpo": lambda: DpoStrategy(
|
||||||
model,
|
model,
|
||||||
device,
|
device,
|
||||||
|
|
@ -134,4 +144,4 @@ class StrategyFactory:
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
strategy = train_strategy[train_type]()
|
strategy = train_strategy[train_type]()
|
||||||
return strategy
|
return strategy
|
||||||
Loading…
Reference in New Issue