feat: 增加 label smothing

This commit is contained in:
ViperEkura 2026-03-06 11:41:14 +08:00
parent 82d22c5742
commit 493fe4e84b
1 changed files with 15 additions and 5 deletions

View File

@ -51,8 +51,9 @@ class BaseStrategy(ABC):
class SeqStrategy(BaseStrategy):
def __init__(self, model, device):
def __init__(self, model, device, label_smoothing):
super().__init__(model, device)
self.label_smoothing = label_smoothing
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
batch = move_to_device(batch, self.device)
@ -68,8 +69,9 @@ class SeqStrategy(BaseStrategy):
class SftStrategy(BaseStrategy):
def __init__(self, model, device):
def __init__(self, model, device, label_smoothing):
super().__init__(model, device)
self.label_smoothing = label_smoothing
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
batch = move_to_device(batch, self.device)
@ -124,8 +126,16 @@ class StrategyFactory:
def load(model, train_type, device, **kwargs):
train_strategy: Dict[str, Callable[[], BaseStrategy]] = {
"seq": lambda: SeqStrategy(model, device),
"sft": lambda: SftStrategy(model, device),
"seq": lambda: SeqStrategy(
model,
device,
kwargs.get("label_smoothing", 0.0)
),
"sft": lambda: SftStrategy(
model,
device,
kwargs.get("label_smoothing", 0.0)
),
"dpo": lambda: DpoStrategy(
model,
device,
@ -134,4 +144,4 @@ class StrategyFactory:
)
}
strategy = train_strategy[train_type]()
return strategy
return strategy