fix: 修复参数传递问题
This commit is contained in:
parent
aef7615abd
commit
780b9e1855
|
|
@ -83,10 +83,11 @@ class BaseStrategy(ABC):
|
||||||
"""Abstract base class for training strategies."""
|
"""Abstract base class for training strategies."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, model: Union[nn.Module, Callable[..., Dict[str, Tensor]]], device: str
|
self, model: Union[Callable[..., Dict[str, Tensor]]], device: str, **kwargs
|
||||||
):
|
):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.device = device
|
self.device = device
|
||||||
|
self.extra_kwargs = kwargs
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||||
|
|
@ -191,8 +192,8 @@ class SEQStrategy(BaseStrategy):
|
||||||
Computes cross-entropy loss for next token prediction.
|
Computes cross-entropy loss for next token prediction.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model, device, label_smoothing: float = 0.0):
|
def __init__(self, model, device, label_smoothing: float = 0.0, **kwargs):
|
||||||
super().__init__(model, device)
|
super().__init__(model, device, **kwargs)
|
||||||
self.label_smoothing = label_smoothing
|
self.label_smoothing = label_smoothing
|
||||||
|
|
||||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||||
|
|
@ -216,8 +217,8 @@ class SFTStrategy(BaseStrategy):
|
||||||
Applies cross-entropy loss only to tokens where loss_mask is True.
|
Applies cross-entropy loss only to tokens where loss_mask is True.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model, device, label_smoothing: float = 0.0):
|
def __init__(self, model, device, label_smoothing: float = 0.0, **kwargs):
|
||||||
super().__init__(model, device)
|
super().__init__(model, device, **kwargs)
|
||||||
self.label_smoothing = label_smoothing
|
self.label_smoothing = label_smoothing
|
||||||
|
|
||||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||||
|
|
@ -256,8 +257,9 @@ class DPOStrategy(BaseStrategy):
|
||||||
device: str,
|
device: str,
|
||||||
beta: float = 0.1,
|
beta: float = 0.1,
|
||||||
reduction: str = "mean",
|
reduction: str = "mean",
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(model, device)
|
super().__init__(model, device, **kwargs)
|
||||||
self.ref_model = create_ref_model(model)
|
self.ref_model = create_ref_model(model)
|
||||||
self.beta = beta
|
self.beta = beta
|
||||||
self.reduction = reduction
|
self.reduction = reduction
|
||||||
|
|
@ -306,8 +308,9 @@ class GRPOStrategy(BaseStrategy):
|
||||||
kl_coef: float = 0.01,
|
kl_coef: float = 0.01,
|
||||||
group_size: int = 4,
|
group_size: int = 4,
|
||||||
reduction: str = "mean",
|
reduction: str = "mean",
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(model, device)
|
super().__init__(model, device, **kwargs)
|
||||||
self.ref_model = create_ref_model(model)
|
self.ref_model = create_ref_model(model)
|
||||||
self.clip_eps = clip_eps
|
self.clip_eps = clip_eps
|
||||||
self.kl_coef = kl_coef
|
self.kl_coef = kl_coef
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue