331 lines
11 KiB
Python
331 lines
11 KiB
Python
"""Training strategy implementations with factory pattern."""
|
|
|
|
import copy
|
|
from abc import ABC, abstractmethod
|
|
from typing import Any, Callable, Dict, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch import Tensor
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
|
|
from astrai.factory import BaseFactory
|
|
|
|
|
|
def unwrap_model(model: nn.Module) -> nn.Module:
|
|
"""Unwrap DDP wrapper if present to get the original model."""
|
|
if isinstance(model, DDP):
|
|
return model.module
|
|
return model
|
|
|
|
|
|
def create_ref_model(model: nn.Module) -> nn.Module:
|
|
"""Create a reference model for DPO/GRPO training.
|
|
|
|
Handles DDP-wrapped models safely by unwrapping first,
|
|
then creating a deep copy with frozen gradients.
|
|
"""
|
|
original_model = unwrap_model(model)
|
|
ref_model = copy.deepcopy(original_model)
|
|
ref_model.requires_grad_(False)
|
|
ref_model.eval()
|
|
return ref_model
|
|
|
|
|
|
def move_to_device(batch: Dict[str, Tensor], device: str) -> Any:
|
|
"""Move batch tensors to specified device with non-blocking transfer."""
|
|
return {key: value.to(device, non_blocking=True) for key, value in batch.items()}
|
|
|
|
|
|
def get_logprobs(
|
|
model: Union[nn.Module, Callable[..., Dict[str, Tensor]]],
|
|
input_ids: Tensor,
|
|
mask: Tensor,
|
|
reduction: str,
|
|
):
|
|
"""Compute token-wise log probabilities from model outputs.
|
|
|
|
Args:
|
|
model: The language model
|
|
input_ids: Input token IDs of shape [batch_size, seq_len]
|
|
mask: Attention mask of shape [batch_size, seq_len]
|
|
reduction: How to reduce over sequence dimension ("mean", "sum", "none")
|
|
|
|
Returns:
|
|
Log probabilities with reduction applied over sequence dimension
|
|
"""
|
|
allowed_reductions = ["mean", "sum", "none"]
|
|
if reduction not in allowed_reductions:
|
|
raise ValueError(
|
|
f"reduction must be one of {allowed_reductions}, got '{reduction}'"
|
|
)
|
|
|
|
shifted_input_ids = input_ids[:, 1:]
|
|
shifted_mask = mask[:, 1:]
|
|
|
|
logits = model(input_ids[:, :-1], mask[:, :-1])["logits"]
|
|
log_probs = torch.log_softmax(logits.float(), dim=-1)
|
|
|
|
token_logprobs = torch.gather(
|
|
log_probs, dim=-1, index=shifted_input_ids.unsqueeze(-1)
|
|
).squeeze(-1)
|
|
|
|
if reduction == "mean":
|
|
return (token_logprobs * shifted_mask).sum(dim=-1) / shifted_mask.sum(
|
|
dim=-1
|
|
).clamp(min=1.0)
|
|
elif reduction == "sum":
|
|
return (token_logprobs * shifted_mask).sum(dim=-1)
|
|
else:
|
|
return token_logprobs * shifted_mask
|
|
|
|
|
|
class BaseStrategy(ABC):
|
|
"""Abstract base class for training strategies."""
|
|
|
|
def __init__(
|
|
self, model: Union[Callable[..., Dict[str, Tensor]]], device: str, **kwargs
|
|
):
|
|
self.model = model
|
|
self.device = device
|
|
self.extra_kwargs = kwargs
|
|
|
|
@abstractmethod
|
|
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
|
"""Compute loss for the given batch.
|
|
|
|
Args:
|
|
batch: Dictionary containing batch tensors
|
|
|
|
Returns:
|
|
Computed loss tensor
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def __call__(self, batch: Dict[str, Tensor]) -> Tensor:
|
|
"""Allow calling strategy directly as a callable."""
|
|
return self.compute_loss(batch)
|
|
|
|
|
|
class StrategyFactory(BaseFactory["BaseStrategy"]):
|
|
"""Factory class for creating training strategy instances.
|
|
|
|
Supports decorator-based registration for extensible strategy types.
|
|
All default strategies (seq, sft, dpo, grpo) are automatically registered.
|
|
|
|
Example usage:
|
|
@StrategyFactory.register("custom")
|
|
class CustomStrategy(BaseStrategy):
|
|
...
|
|
|
|
strategy = StrategyFactory.create("custom", model, device)
|
|
"""
|
|
|
|
@classmethod
|
|
def _validate_component(cls, strategy_cls: type) -> None:
|
|
"""Validate that the strategy class inherits from BaseStrategy."""
|
|
if not issubclass(strategy_cls, BaseStrategy):
|
|
raise TypeError(f"{strategy_cls.__name__} must inherit from BaseStrategy")
|
|
|
|
@classmethod
|
|
def create(cls, train_type: str, model, device: str, **kwargs) -> "BaseStrategy":
|
|
"""Create a strategy instance based on training type.
|
|
|
|
Args:
|
|
train_type: Type of training ("seq", "sft", "dpo", "grpo")
|
|
model: Model instance for the strategy
|
|
device: Device to run the strategy on
|
|
**kwargs: Additional arguments passed to strategy constructor
|
|
|
|
Returns:
|
|
Strategy instance
|
|
"""
|
|
return super().create(train_type, model, device, **kwargs)
|
|
|
|
@classmethod
|
|
def available_strategies(cls) -> list:
|
|
"""Return list of registered strategy names."""
|
|
return cls.list_registered()
|
|
|
|
|
|
# ============== Strategy Classes ==============
|
|
# All strategies are registered at class definition time using the decorator
|
|
|
|
|
|
@StrategyFactory.register("seq")
|
|
class SEQStrategy(BaseStrategy):
|
|
"""Standard next-token prediction training strategy.
|
|
|
|
Computes cross-entropy loss for next token prediction.
|
|
"""
|
|
|
|
def __init__(self, model, device, label_smoothing: float = 0.0, **kwargs):
|
|
super().__init__(model, device, **kwargs)
|
|
self.label_smoothing = label_smoothing
|
|
|
|
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
|
batch = move_to_device(batch, self.device)
|
|
input_ids, target_ids = batch["input_ids"], batch["target_ids"]
|
|
logits = self.model(input_ids=input_ids)["logits"]
|
|
|
|
loss = F.cross_entropy(
|
|
input=logits.flatten(0, 1).float(),
|
|
target=target_ids.flatten(),
|
|
label_smoothing=self.label_smoothing,
|
|
)
|
|
|
|
return loss
|
|
|
|
|
|
@StrategyFactory.register("sft")
|
|
class SFTStrategy(BaseStrategy):
|
|
"""Supervised Fine-tuning strategy with loss masking.
|
|
|
|
Applies cross-entropy loss only to tokens where loss_mask is True.
|
|
"""
|
|
|
|
def __init__(self, model, device, label_smoothing: float = 0.0, **kwargs):
|
|
super().__init__(model, device, **kwargs)
|
|
self.label_smoothing = label_smoothing
|
|
|
|
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
|
batch = move_to_device(batch, self.device)
|
|
input_ids, target_ids, loss_mask = (
|
|
batch["input_ids"],
|
|
batch["target_ids"],
|
|
batch["loss_mask"],
|
|
)
|
|
|
|
ignore_index = -100
|
|
logits = self.model(input_ids=input_ids)["logits"]
|
|
target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index)
|
|
|
|
loss = F.cross_entropy(
|
|
input=logits.flatten(0, 1).float(),
|
|
target=target_ids.flatten(),
|
|
ignore_index=ignore_index,
|
|
label_smoothing=self.label_smoothing,
|
|
)
|
|
|
|
return loss
|
|
|
|
|
|
@StrategyFactory.register("dpo")
|
|
class DPOStrategy(BaseStrategy):
|
|
"""Direct Preference Optimization strategy.
|
|
|
|
Implements the DPO loss from the paper "Direct Preference Optimization".
|
|
Uses a reference model to compute KL divergence penalty.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model: nn.Module,
|
|
device: str,
|
|
beta: float = 0.1,
|
|
reduction: str = "mean",
|
|
**kwargs,
|
|
):
|
|
super().__init__(model, device, **kwargs)
|
|
self.ref_model = create_ref_model(model)
|
|
self.beta = beta
|
|
self.reduction = reduction
|
|
|
|
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
|
batch = move_to_device(batch, self.device)
|
|
chosen_ids, rejected_ids = batch["chosen"], batch["rejected"]
|
|
chosen_mask, rejected_mask = batch["chosen_mask"], batch["rejected_mask"]
|
|
|
|
contact_ids = torch.cat([chosen_ids, rejected_ids], dim=0)
|
|
contact_mask = torch.cat([chosen_mask, rejected_mask], dim=0)
|
|
|
|
log_pi = get_logprobs(self.model, contact_ids, contact_mask, self.reduction)
|
|
|
|
with torch.no_grad():
|
|
log_ref = get_logprobs(
|
|
self.ref_model, contact_ids, contact_mask, self.reduction
|
|
)
|
|
|
|
log_pi_chosen = log_pi[: chosen_ids.shape[0]]
|
|
log_pi_rejected = log_pi[chosen_ids.shape[0] :]
|
|
log_ref_chosen = log_ref[: chosen_ids.shape[0]]
|
|
log_ref_rejected = log_ref[chosen_ids.shape[0] :]
|
|
|
|
pi_log_ratio = log_pi_chosen - log_pi_rejected
|
|
ref_log_ratio = log_ref_chosen - log_ref_rejected
|
|
|
|
ratio_diff = pi_log_ratio - ref_log_ratio
|
|
dpo_loss = -F.logsigmoid(self.beta * ratio_diff).mean()
|
|
|
|
return dpo_loss
|
|
|
|
|
|
@StrategyFactory.register("grpo")
|
|
class GRPOStrategy(BaseStrategy):
|
|
"""Group Relative Policy Optimization strategy.
|
|
|
|
Implements GRPO with clipping and KL penalty.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model: nn.Module,
|
|
device: str,
|
|
clip_eps: float = 0.2,
|
|
kl_coef: float = 0.01,
|
|
group_size: int = 4,
|
|
reduction: str = "mean",
|
|
**kwargs,
|
|
):
|
|
super().__init__(model, device, **kwargs)
|
|
self.ref_model = create_ref_model(model)
|
|
self.clip_eps = clip_eps
|
|
self.kl_coef = kl_coef
|
|
self.group_size = group_size
|
|
self.reduction = reduction
|
|
|
|
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
|
batch = move_to_device(batch, self.device)
|
|
prompts = batch["prompts"]
|
|
responses = batch["responses"]
|
|
masks = batch["masks"]
|
|
rewards = batch["rewards"]
|
|
|
|
batch_size, group_size, response_len = responses.shape
|
|
responses_flat = responses.view(-1, response_len)
|
|
masks_flat = masks.view(-1, response_len)
|
|
prompt_expanded = prompts.unsqueeze(1).repeat(1, group_size, 1).flatten(0, 1)
|
|
|
|
# Shape: (batch_size * group_size, seq_len + response_len)
|
|
full_sequences = torch.cat([prompt_expanded, responses_flat], dim=-1)
|
|
full_masks = torch.cat([torch.ones_like(prompt_expanded), masks_flat], dim=-1)
|
|
|
|
log_probs_policy = get_logprobs(
|
|
self.model, full_sequences, full_masks, self.reduction
|
|
)
|
|
log_probs_policy = log_probs_policy.view(batch_size, group_size)
|
|
|
|
with torch.no_grad():
|
|
log_probs_ref = get_logprobs(
|
|
self.ref_model, full_sequences, full_masks, self.reduction
|
|
)
|
|
log_probs_ref = log_probs_ref.view(batch_size, group_size)
|
|
|
|
# Compute advantages from rewards with normalization
|
|
eps = torch.finfo(log_probs_policy.dtype).eps
|
|
mean = rewards.mean(dim=-1, keepdim=True)
|
|
std = rewards.std(dim=-1, keepdim=True)
|
|
advantages = (rewards - mean) / (std + eps)
|
|
|
|
# PPO-style clipped surrogate objective
|
|
ratio = torch.exp(0) # Off-policy: policy_model = old_model
|
|
surr1 = ratio * advantages
|
|
surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * advantages
|
|
|
|
policy_loss = -torch.min(surr1, surr2).mean()
|
|
kl_penalty = self.kl_coef * (log_probs_policy - log_probs_ref).square().mean()
|
|
total_loss = policy_loss + kl_penalty
|
|
|
|
return total_loss
|