262 lines
8.8 KiB
Python
262 lines
8.8 KiB
Python
import copy
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
|
|
from torch import Tensor
|
|
from typing import Any, Callable, Dict, Union, Optional
|
|
from abc import ABC, abstractmethod
|
|
|
|
|
|
def unwrap_model(model: nn.Module) -> nn.Module:
|
|
"""Unwrap DDP wrapper if present to get the original model."""
|
|
if isinstance(model, DDP):
|
|
return model.module
|
|
return model
|
|
|
|
|
|
def create_ref_model(model: nn.Module) -> nn.Module:
|
|
"""
|
|
Create a reference model for DPO/GRPO training.
|
|
Handles DDP-wrapped models safely.
|
|
"""
|
|
original_model = unwrap_model(model)
|
|
ref_model = copy.deepcopy(original_model)
|
|
ref_model.requires_grad_(False)
|
|
ref_model.eval()
|
|
return ref_model
|
|
|
|
|
|
def move_to_device(batch:Dict[str, Tensor], device: str) -> Any:
|
|
return {key: value.to(device, non_blocking=True) for key, value in batch.items()}
|
|
|
|
def get_logprobs(
|
|
model: Union[nn.Module, Callable[..., Dict[str, Tensor]]],
|
|
input_ids: Tensor,
|
|
mask: Tensor,
|
|
reduction: str,
|
|
):
|
|
"""
|
|
Compute token-wise log probabilities from model outputs.
|
|
|
|
Args:
|
|
model: The language model
|
|
input_ids: Input token IDs of shape [batch_size, seq_len]
|
|
mask: Attention mask of shape [batch_size, seq_len]
|
|
reduction: How to reduce over sequence dimension ("mean", "sum", "none")
|
|
|
|
Returns:
|
|
Log probabilities with reduction applied over sequence dimension
|
|
"""
|
|
# reduction on seq_len dim
|
|
allowed_reductions = ["mean", "sum", "none"]
|
|
if reduction not in allowed_reductions:
|
|
raise ValueError(f"reduction must be one of {allowed_reductions}, got '{reduction}'")
|
|
|
|
shifted_input_ids = input_ids[:, 1:]
|
|
shifted_mask = mask[:, 1:]
|
|
|
|
logits = model(input_ids[:, :-1], mask[:, :-1])["logits"]
|
|
log_probs = torch.log_softmax(logits.float(), dim=-1)
|
|
|
|
# [batch_size, seq_len - 1]
|
|
token_logprobs = torch.gather(
|
|
log_probs,
|
|
dim=-1,
|
|
index=shifted_input_ids.unsqueeze(-1)
|
|
).squeeze(-1)
|
|
|
|
if reduction == "mean":
|
|
return (token_logprobs * shifted_mask).sum(dim=-1) / shifted_mask.sum(dim=-1).clamp(min=1.0)
|
|
elif reduction == "sum":
|
|
return (token_logprobs * shifted_mask).sum(dim=-1)
|
|
else:
|
|
return token_logprobs * shifted_mask
|
|
|
|
|
|
class BaseStrategy(ABC):
|
|
def __init__(self, model: Union[nn.Module, Callable[..., Dict[str, Tensor]]], device: str):
|
|
self.model = model
|
|
self.device = device
|
|
|
|
@abstractmethod
|
|
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
|
raise NotImplementedError
|
|
|
|
def __call__(self, batch: Dict[str, Tensor]) -> Tensor:
|
|
return self.compute_loss(batch)
|
|
|
|
|
|
class SEQStrategy(BaseStrategy):
|
|
def __init__(self, model, device, label_smoothing):
|
|
super().__init__(model, device)
|
|
self.label_smoothing = label_smoothing
|
|
|
|
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
|
batch = move_to_device(batch, self.device)
|
|
input_ids, target_ids = batch["input_ids"], batch["target_ids"]
|
|
logits = self.model(input_ids=input_ids)["logits"]
|
|
|
|
loss = F.cross_entropy(
|
|
input=logits.flatten(0, 1).float(),
|
|
target=target_ids.flatten()
|
|
)
|
|
|
|
return loss
|
|
|
|
|
|
class SFTStrategy(BaseStrategy):
|
|
def __init__(self, model, device, label_smoothing):
|
|
super().__init__(model, device)
|
|
self.label_smoothing = label_smoothing
|
|
|
|
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
|
batch = move_to_device(batch, self.device)
|
|
input_ids, target_ids, loss_mask = batch["input_ids"], batch["target_ids"], batch["loss_mask"]
|
|
|
|
ignore_index = -100
|
|
logits = self.model(input_ids=input_ids)["logits"]
|
|
target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index)
|
|
|
|
loss = F.cross_entropy(
|
|
input=logits.flatten(0, 1).float(),
|
|
target=target_ids.flatten(),
|
|
ignore_index=ignore_index
|
|
)
|
|
|
|
return loss
|
|
|
|
|
|
class DPOStrategy(BaseStrategy):
|
|
def __init__(
|
|
self,
|
|
model: nn.Module,
|
|
device: str,
|
|
beta: float,
|
|
reduction: str,
|
|
):
|
|
super().__init__(model, device)
|
|
self.ref_model = create_ref_model(model)
|
|
self.beta = beta
|
|
self.reduction = reduction
|
|
|
|
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
|
batch = move_to_device(batch, self.device)
|
|
chosen_ids, rejected_ids = batch["chosen"], batch["rejected"]
|
|
chosen_mask, rejected_mask = batch["chosen_mask"], batch["rejected_mask"]
|
|
|
|
contact_ids = torch.cat([chosen_ids, rejected_ids], dim=0)
|
|
contact_mask = torch.cat([chosen_mask, rejected_mask], dim=0)
|
|
|
|
log_pi = get_logprobs(self.model, contact_ids, contact_mask, self.reduction)
|
|
|
|
with torch.no_grad():
|
|
log_ref = get_logprobs(self.ref_model, contact_ids, contact_mask, self.reduction)
|
|
|
|
log_pi_chosen = log_pi[:chosen_ids.shape[0]]
|
|
log_pi_rejected = log_pi[chosen_ids.shape[0]:]
|
|
log_ref_chosen = log_ref[:chosen_ids.shape[0]]
|
|
log_ref_rejected = log_ref[chosen_ids.shape[0]:]
|
|
|
|
pi_log_ratio = log_pi_chosen - log_pi_rejected
|
|
ref_log_ratio = log_ref_chosen - log_ref_rejected
|
|
|
|
ratio_diff = pi_log_ratio - ref_log_ratio
|
|
dpo_loss = -F.logsigmoid(self.beta * ratio_diff).mean()
|
|
|
|
return dpo_loss
|
|
|
|
|
|
class GRPOStrategy(BaseStrategy):
|
|
|
|
def __init__(
|
|
self,
|
|
model: nn.Module,
|
|
device: str,
|
|
clip_eps: float,
|
|
kl_coef: float,
|
|
group_size: int,
|
|
reduction: str,
|
|
):
|
|
super().__init__(model, device)
|
|
self.ref_model = create_ref_model(model)
|
|
self.clip_eps = clip_eps
|
|
self.kl_coef = kl_coef
|
|
self.group_size = group_size
|
|
self.reduction = reduction
|
|
|
|
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
|
batch = move_to_device(batch, self.device)
|
|
prompts = batch["prompts"]
|
|
responses = batch["responses"]
|
|
masks = batch["masks"]
|
|
rewards = batch["rewards"]
|
|
|
|
batch_size, group_size, response_len = responses.shape
|
|
responses_flat = responses.view(-1, response_len)
|
|
masks_flat = masks.view(-1, response_len)
|
|
prompt_expanded = prompts.unsqueeze(1).repeat(1, group_size, 1).flatten(0, 1)
|
|
|
|
# Shape: (batch_size * group_size, seq_len + response_len)
|
|
full_sequences = torch.cat([prompt_expanded, responses_flat], dim=-1)
|
|
full_masks = torch.cat([torch.ones_like(prompt_expanded), masks_flat], dim=-1)
|
|
|
|
log_probs_policy = get_logprobs(self.model, full_sequences, full_masks, self.reduction)
|
|
log_probs_policy = log_probs_policy.view(batch_size, group_size)
|
|
|
|
with torch.no_grad():
|
|
log_probs_ref = get_logprobs(self.ref_model, full_sequences, full_masks, self.reduction)
|
|
log_probs_ref = log_probs_ref.view(batch_size, group_size)
|
|
|
|
# Compute advantages from rewards
|
|
eps = torch.finfo(log_probs_policy.dtype).eps
|
|
mean = rewards.mean(dim=-1, keepdim=True)
|
|
std = rewards.std(dim=-1, keepdim=True)
|
|
advantages = (rewards - mean) / (std + eps)
|
|
|
|
# log_ratio = log_probs_policy - log_probs_old
|
|
# ratio = torch.exp(log_ratio)
|
|
# off policy: policy_model = old_model, then ratio = 1
|
|
ratio = torch.exp(0)
|
|
surr1 = ratio * advantages
|
|
surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * advantages
|
|
|
|
policy_loss = -torch.min(surr1, surr2).mean()
|
|
kl_penalty = self.kl_coef * (log_probs_policy - log_probs_ref).square().mean()
|
|
total_loss = policy_loss + kl_penalty
|
|
|
|
return total_loss
|
|
|
|
|
|
class StrategyFactory:
|
|
|
|
def load(model, train_type, device, **kwargs):
|
|
train_strategy: Dict[str, Callable[[], BaseStrategy]] = {
|
|
"seq": lambda: SEQStrategy(
|
|
model,
|
|
device,
|
|
kwargs.get("label_smoothing", 0.0)
|
|
),
|
|
"sft": lambda: SFTStrategy(
|
|
model,
|
|
device,
|
|
kwargs.get("label_smoothing", 0.0)
|
|
),
|
|
"dpo": lambda: DPOStrategy(
|
|
model,
|
|
device,
|
|
kwargs.get("dpo_beta"),
|
|
kwargs.get("reduction", "mean")
|
|
),
|
|
"grpo": lambda: GRPOStrategy(
|
|
model,
|
|
device,
|
|
kwargs.get("grpo_clip_eps"),
|
|
kwargs.get("grpo_kl_coef"),
|
|
kwargs.get("grpo_group_size"),
|
|
kwargs.get("reduction", "mean")
|
|
)
|
|
}
|
|
strategy = train_strategy[train_type]()
|
|
return strategy |