"""Training strategy implementations with factory pattern.""" import copy import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.parallel import DistributedDataParallel as DDP from torch import Tensor from typing import Any, Callable, Dict, Union from abc import ABC, abstractmethod def unwrap_model(model: nn.Module) -> nn.Module: """Unwrap DDP wrapper if present to get the original model.""" if isinstance(model, DDP): return model.module return model def create_ref_model(model: nn.Module) -> nn.Module: """Create a reference model for DPO/GRPO training. Handles DDP-wrapped models safely by unwrapping first, then creating a deep copy with frozen gradients. """ original_model = unwrap_model(model) ref_model = copy.deepcopy(original_model) ref_model.requires_grad_(False) ref_model.eval() return ref_model def move_to_device(batch: Dict[str, Tensor], device: str) -> Any: """Move batch tensors to specified device with non-blocking transfer.""" return {key: value.to(device, non_blocking=True) for key, value in batch.items()} def get_logprobs( model: Union[nn.Module, Callable[..., Dict[str, Tensor]]], input_ids: Tensor, mask: Tensor, reduction: str, ): """Compute token-wise log probabilities from model outputs. Args: model: The language model input_ids: Input token IDs of shape [batch_size, seq_len] mask: Attention mask of shape [batch_size, seq_len] reduction: How to reduce over sequence dimension ("mean", "sum", "none") Returns: Log probabilities with reduction applied over sequence dimension """ allowed_reductions = ["mean", "sum", "none"] if reduction not in allowed_reductions: raise ValueError( f"reduction must be one of {allowed_reductions}, got '{reduction}'" ) shifted_input_ids = input_ids[:, 1:] shifted_mask = mask[:, 1:] logits = model(input_ids[:, :-1], mask[:, :-1])["logits"] log_probs = torch.log_softmax(logits.float(), dim=-1) token_logprobs = torch.gather( log_probs, dim=-1, index=shifted_input_ids.unsqueeze(-1) ).squeeze(-1) if reduction == "mean": return (token_logprobs * shifted_mask).sum(dim=-1) / shifted_mask.sum( dim=-1 ).clamp(min=1.0) elif reduction == "sum": return (token_logprobs * shifted_mask).sum(dim=-1) else: return token_logprobs * shifted_mask class BaseStrategy(ABC): """Abstract base class for training strategies.""" def __init__( self, model: Union[nn.Module, Callable[..., Dict[str, Tensor]]], device: str ): self.model = model self.device = device @abstractmethod def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: """Compute loss for the given batch. Args: batch: Dictionary containing batch tensors Returns: Computed loss tensor """ raise NotImplementedError def __call__(self, batch: Dict[str, Tensor]) -> Tensor: """Allow calling strategy directly as a callable.""" return self.compute_loss(batch) class StrategyFactory: """Factory class for creating training strategy instances. Supports decorator-based registration for extensible strategy types. All default strategies (seq, sft, dpo, grpo) are automatically registered. Example usage: @StrategyFactory.register("custom") class CustomStrategy(BaseStrategy): ... strategy = StrategyFactory.create(model, "custom", device) """ SUPPORTED_STRATEGIES = frozenset({"seq", "sft", "dpo", "grpo"}) STRATEGY_MAP: Dict[str, type] = {} @classmethod def register(cls, name: str): """Decorator to register a new strategy class. Args: name: Registration name for the strategy Returns: Decorator function that registers the strategy class """ def decorator(strategy_cls: type) -> type: if not issubclass(strategy_cls, BaseStrategy): raise TypeError( f"{strategy_cls.__name__} must inherit from BaseStrategy" ) cls.STRATEGY_MAP[name] = strategy_cls return strategy_cls return decorator @classmethod def create(cls, model, train_type: str, device: str, **kwargs) -> BaseStrategy: """Create a strategy instance based on training type. Args: model: Model instance for the strategy train_type: Type of training ("seq", "sft", "dpo", "grpo") device: Device to run the strategy on **kwargs: Additional arguments passed to strategy constructor Returns: Strategy instance Raises: ValueError: If train_type is not supported NotImplementedError: If train_type is in supported list but not implemented """ if train_type not in cls.SUPPORTED_STRATEGIES: raise ValueError( f"Unknown training strategy: '{train_type}'. " f"Supported strategies: {sorted(cls.SUPPORTED_STRATEGIES)}" ) if train_type not in cls.STRATEGY_MAP: raise NotImplementedError( f"Strategy '{train_type}' is supported but not yet implemented." ) strategy_cls = cls.STRATEGY_MAP[train_type] return strategy_cls(model, device, **kwargs) @classmethod def available_strategies(cls) -> list: """Return list of registered strategy names.""" return list(cls.STRATEGY_MAP.keys()) # ============== Strategy Classes ============== # All strategies are registered at class definition time using the decorator @StrategyFactory.register("seq") class SEQStrategy(BaseStrategy): """Standard next-token prediction training strategy. Computes cross-entropy loss for next token prediction. """ def __init__(self, model, device, label_smoothing: float = 0.0): super().__init__(model, device) self.label_smoothing = label_smoothing def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: batch = move_to_device(batch, self.device) input_ids, target_ids = batch["input_ids"], batch["target_ids"] logits = self.model(input_ids=input_ids)["logits"] loss = F.cross_entropy( input=logits.flatten(0, 1).float(), target=target_ids.flatten(), label_smoothing=self.label_smoothing, ) return loss @StrategyFactory.register("sft") class SFTStrategy(BaseStrategy): """Supervised Fine-tuning strategy with loss masking. Applies cross-entropy loss only to tokens where loss_mask is True. """ def __init__(self, model, device, label_smoothing: float = 0.0): super().__init__(model, device) self.label_smoothing = label_smoothing def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: batch = move_to_device(batch, self.device) input_ids, target_ids, loss_mask = ( batch["input_ids"], batch["target_ids"], batch["loss_mask"], ) ignore_index = -100 logits = self.model(input_ids=input_ids)["logits"] target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index) loss = F.cross_entropy( input=logits.flatten(0, 1).float(), target=target_ids.flatten(), ignore_index=ignore_index, label_smoothing=self.label_smoothing, ) return loss @StrategyFactory.register("dpo") class DPOStrategy(BaseStrategy): """Direct Preference Optimization strategy. Implements the DPO loss from the paper "Direct Preference Optimization". Uses a reference model to compute KL divergence penalty. """ def __init__( self, model: nn.Module, device: str, beta: float = 0.1, reduction: str = "mean", ): super().__init__(model, device) self.ref_model = create_ref_model(model) self.beta = beta self.reduction = reduction def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: batch = move_to_device(batch, self.device) chosen_ids, rejected_ids = batch["chosen"], batch["rejected"] chosen_mask, rejected_mask = batch["chosen_mask"], batch["rejected_mask"] contact_ids = torch.cat([chosen_ids, rejected_ids], dim=0) contact_mask = torch.cat([chosen_mask, rejected_mask], dim=0) log_pi = get_logprobs(self.model, contact_ids, contact_mask, self.reduction) with torch.no_grad(): log_ref = get_logprobs( self.ref_model, contact_ids, contact_mask, self.reduction ) log_pi_chosen = log_pi[: chosen_ids.shape[0]] log_pi_rejected = log_pi[chosen_ids.shape[0] :] log_ref_chosen = log_ref[: chosen_ids.shape[0]] log_ref_rejected = log_ref[chosen_ids.shape[0] :] pi_log_ratio = log_pi_chosen - log_pi_rejected ref_log_ratio = log_ref_chosen - log_ref_rejected ratio_diff = pi_log_ratio - ref_log_ratio dpo_loss = -F.logsigmoid(self.beta * ratio_diff).mean() return dpo_loss @StrategyFactory.register("grpo") class GRPOStrategy(BaseStrategy): """Group Relative Policy Optimization strategy. Implements GRPO with clipping and KL penalty. """ def __init__( self, model: nn.Module, device: str, clip_eps: float = 0.2, kl_coef: float = 0.01, group_size: int = 4, reduction: str = "mean", ): super().__init__(model, device) self.ref_model = create_ref_model(model) self.clip_eps = clip_eps self.kl_coef = kl_coef self.group_size = group_size self.reduction = reduction def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: batch = move_to_device(batch, self.device) prompts = batch["prompts"] responses = batch["responses"] masks = batch["masks"] rewards = batch["rewards"] batch_size, group_size, response_len = responses.shape responses_flat = responses.view(-1, response_len) masks_flat = masks.view(-1, response_len) prompt_expanded = prompts.unsqueeze(1).repeat(1, group_size, 1).flatten(0, 1) # Shape: (batch_size * group_size, seq_len + response_len) full_sequences = torch.cat([prompt_expanded, responses_flat], dim=-1) full_masks = torch.cat([torch.ones_like(prompt_expanded), masks_flat], dim=-1) log_probs_policy = get_logprobs( self.model, full_sequences, full_masks, self.reduction ) log_probs_policy = log_probs_policy.view(batch_size, group_size) with torch.no_grad(): log_probs_ref = get_logprobs( self.ref_model, full_sequences, full_masks, self.reduction ) log_probs_ref = log_probs_ref.view(batch_size, group_size) # Compute advantages from rewards with normalization eps = torch.finfo(log_probs_policy.dtype).eps mean = rewards.mean(dim=-1, keepdim=True) std = rewards.std(dim=-1, keepdim=True) advantages = (rewards - mean) / (std + eps) # PPO-style clipped surrogate objective ratio = torch.exp(0) # Off-policy: policy_model = old_model surr1 = ratio * advantages surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * advantages policy_loss = -torch.min(surr1, surr2).mean() kl_penalty = self.kl_coef * (log_probs_policy - log_probs_ref).square().mean() total_loss = policy_loss + kl_penalty return total_loss