fix: 修复参数传递问题

This commit is contained in:
ViperEkura 2026-03-31 01:23:29 +08:00
parent aef7615abd
commit 780b9e1855
1 changed files with 10 additions and 7 deletions

View File

@ -83,10 +83,11 @@ class BaseStrategy(ABC):
"""Abstract base class for training strategies.""" """Abstract base class for training strategies."""
def __init__( def __init__(
self, model: Union[nn.Module, Callable[..., Dict[str, Tensor]]], device: str self, model: Union[Callable[..., Dict[str, Tensor]]], device: str, **kwargs
): ):
self.model = model self.model = model
self.device = device self.device = device
self.extra_kwargs = kwargs
@abstractmethod @abstractmethod
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
@ -191,8 +192,8 @@ class SEQStrategy(BaseStrategy):
Computes cross-entropy loss for next token prediction. Computes cross-entropy loss for next token prediction.
""" """
def __init__(self, model, device, label_smoothing: float = 0.0): def __init__(self, model, device, label_smoothing: float = 0.0, **kwargs):
super().__init__(model, device) super().__init__(model, device, **kwargs)
self.label_smoothing = label_smoothing self.label_smoothing = label_smoothing
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
@ -216,8 +217,8 @@ class SFTStrategy(BaseStrategy):
Applies cross-entropy loss only to tokens where loss_mask is True. Applies cross-entropy loss only to tokens where loss_mask is True.
""" """
def __init__(self, model, device, label_smoothing: float = 0.0): def __init__(self, model, device, label_smoothing: float = 0.0, **kwargs):
super().__init__(model, device) super().__init__(model, device, **kwargs)
self.label_smoothing = label_smoothing self.label_smoothing = label_smoothing
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
@ -256,8 +257,9 @@ class DPOStrategy(BaseStrategy):
device: str, device: str,
beta: float = 0.1, beta: float = 0.1,
reduction: str = "mean", reduction: str = "mean",
**kwargs,
): ):
super().__init__(model, device) super().__init__(model, device, **kwargs)
self.ref_model = create_ref_model(model) self.ref_model = create_ref_model(model)
self.beta = beta self.beta = beta
self.reduction = reduction self.reduction = reduction
@ -306,8 +308,9 @@ class GRPOStrategy(BaseStrategy):
kl_coef: float = 0.01, kl_coef: float = 0.01,
group_size: int = 4, group_size: int = 4,
reduction: str = "mean", reduction: str = "mean",
**kwargs,
): ):
super().__init__(model, device) super().__init__(model, device, **kwargs)
self.ref_model = create_ref_model(model) self.ref_model = create_ref_model(model)
self.clip_eps = clip_eps self.clip_eps = clip_eps
self.kl_coef = kl_coef self.kl_coef = kl_coef