fix: 修复参数传递问题
This commit is contained in:
parent
aef7615abd
commit
780b9e1855
|
|
@ -83,10 +83,11 @@ class BaseStrategy(ABC):
|
|||
"""Abstract base class for training strategies."""
|
||||
|
||||
def __init__(
|
||||
self, model: Union[nn.Module, Callable[..., Dict[str, Tensor]]], device: str
|
||||
self, model: Union[Callable[..., Dict[str, Tensor]]], device: str, **kwargs
|
||||
):
|
||||
self.model = model
|
||||
self.device = device
|
||||
self.extra_kwargs = kwargs
|
||||
|
||||
@abstractmethod
|
||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||
|
|
@ -191,8 +192,8 @@ class SEQStrategy(BaseStrategy):
|
|||
Computes cross-entropy loss for next token prediction.
|
||||
"""
|
||||
|
||||
def __init__(self, model, device, label_smoothing: float = 0.0):
|
||||
super().__init__(model, device)
|
||||
def __init__(self, model, device, label_smoothing: float = 0.0, **kwargs):
|
||||
super().__init__(model, device, **kwargs)
|
||||
self.label_smoothing = label_smoothing
|
||||
|
||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||
|
|
@ -216,8 +217,8 @@ class SFTStrategy(BaseStrategy):
|
|||
Applies cross-entropy loss only to tokens where loss_mask is True.
|
||||
"""
|
||||
|
||||
def __init__(self, model, device, label_smoothing: float = 0.0):
|
||||
super().__init__(model, device)
|
||||
def __init__(self, model, device, label_smoothing: float = 0.0, **kwargs):
|
||||
super().__init__(model, device, **kwargs)
|
||||
self.label_smoothing = label_smoothing
|
||||
|
||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||
|
|
@ -256,8 +257,9 @@ class DPOStrategy(BaseStrategy):
|
|||
device: str,
|
||||
beta: float = 0.1,
|
||||
reduction: str = "mean",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(model, device)
|
||||
super().__init__(model, device, **kwargs)
|
||||
self.ref_model = create_ref_model(model)
|
||||
self.beta = beta
|
||||
self.reduction = reduction
|
||||
|
|
@ -306,8 +308,9 @@ class GRPOStrategy(BaseStrategy):
|
|||
kl_coef: float = 0.01,
|
||||
group_size: int = 4,
|
||||
reduction: str = "mean",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(model, device)
|
||||
super().__init__(model, device, **kwargs)
|
||||
self.ref_model = create_ref_model(model)
|
||||
self.clip_eps = clip_eps
|
||||
self.kl_coef = kl_coef
|
||||
|
|
|
|||
Loading…
Reference in New Issue