fix: 修复参数传递问题

This commit is contained in:
ViperEkura 2026-03-31 01:23:29 +08:00
parent aef7615abd
commit 780b9e1855
1 changed files with 10 additions and 7 deletions

View File

@ -83,10 +83,11 @@ class BaseStrategy(ABC):
"""Abstract base class for training strategies."""
def __init__(
self, model: Union[nn.Module, Callable[..., Dict[str, Tensor]]], device: str
self, model: Union[Callable[..., Dict[str, Tensor]]], device: str, **kwargs
):
self.model = model
self.device = device
self.extra_kwargs = kwargs
@abstractmethod
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
@ -191,8 +192,8 @@ class SEQStrategy(BaseStrategy):
Computes cross-entropy loss for next token prediction.
"""
def __init__(self, model, device, label_smoothing: float = 0.0):
super().__init__(model, device)
def __init__(self, model, device, label_smoothing: float = 0.0, **kwargs):
super().__init__(model, device, **kwargs)
self.label_smoothing = label_smoothing
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
@ -216,8 +217,8 @@ class SFTStrategy(BaseStrategy):
Applies cross-entropy loss only to tokens where loss_mask is True.
"""
def __init__(self, model, device, label_smoothing: float = 0.0):
super().__init__(model, device)
def __init__(self, model, device, label_smoothing: float = 0.0, **kwargs):
super().__init__(model, device, **kwargs)
self.label_smoothing = label_smoothing
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
@ -256,8 +257,9 @@ class DPOStrategy(BaseStrategy):
device: str,
beta: float = 0.1,
reduction: str = "mean",
**kwargs,
):
super().__init__(model, device)
super().__init__(model, device, **kwargs)
self.ref_model = create_ref_model(model)
self.beta = beta
self.reduction = reduction
@ -306,8 +308,9 @@ class GRPOStrategy(BaseStrategy):
kl_coef: float = 0.01,
group_size: int = 4,
reduction: str = "mean",
**kwargs,
):
super().__init__(model, device)
super().__init__(model, device, **kwargs)
self.ref_model = create_ref_model(model)
self.clip_eps = clip_eps
self.kl_coef = kl_coef