feat: 增加 label smothing

This commit is contained in:
ViperEkura 2026-03-06 11:41:14 +08:00
parent 82d22c5742
commit 493fe4e84b
1 changed files with 15 additions and 5 deletions

View File

@ -51,8 +51,9 @@ class BaseStrategy(ABC):
class SeqStrategy(BaseStrategy): class SeqStrategy(BaseStrategy):
def __init__(self, model, device): def __init__(self, model, device, label_smoothing):
super().__init__(model, device) super().__init__(model, device)
self.label_smoothing = label_smoothing
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
batch = move_to_device(batch, self.device) batch = move_to_device(batch, self.device)
@ -68,8 +69,9 @@ class SeqStrategy(BaseStrategy):
class SftStrategy(BaseStrategy): class SftStrategy(BaseStrategy):
def __init__(self, model, device): def __init__(self, model, device, label_smoothing):
super().__init__(model, device) super().__init__(model, device)
self.label_smoothing = label_smoothing
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
batch = move_to_device(batch, self.device) batch = move_to_device(batch, self.device)
@ -124,8 +126,16 @@ class StrategyFactory:
def load(model, train_type, device, **kwargs): def load(model, train_type, device, **kwargs):
train_strategy: Dict[str, Callable[[], BaseStrategy]] = { train_strategy: Dict[str, Callable[[], BaseStrategy]] = {
"seq": lambda: SeqStrategy(model, device), "seq": lambda: SeqStrategy(
"sft": lambda: SftStrategy(model, device), model,
device,
kwargs.get("label_smoothing", 0.0)
),
"sft": lambda: SftStrategy(
model,
device,
kwargs.get("label_smoothing", 0.0)
),
"dpo": lambda: DpoStrategy( "dpo": lambda: DpoStrategy(
model, model,
device, device,
@ -134,4 +144,4 @@ class StrategyFactory:
) )
} }
strategy = train_strategy[train_type]() strategy = train_strategy[train_type]()
return strategy return strategy