import copy import math import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor from typing import Any, Literal, Tuple, Callable, Dict, Union from abc import ABC, abstractmethod from dataclasses import dataclass, field def get_logprobs(model:nn.Module, input_ids: Tensor, mask: Tensor, pad_token_id: int): input_mask = input_ids.ne(pad_token_id) logits = model(input_ids, input_mask)["logits"] log_probs = torch.log_softmax(logits, dim=-1) shifted_log_probs = log_probs[:, :-1, :] shifted_input_ids = input_ids[:, 1:] shifted_response_mask = mask[:, 1:] token_logprobs = torch.gather( shifted_log_probs, dim=-1, index=shifted_input_ids.unsqueeze(-1) ).squeeze(-1) prompt_mask = input_mask[:, 1:] valid_mask = (prompt_mask & shifted_response_mask).float() return (token_logprobs * valid_mask).sum(dim=-1) def move_to_device(batch:Dict[str, Tensor], device: str) -> Any: return {key: value.to(device, non_blocking=True) for key, value in batch.items()} class BaseStrategy(ABC): def __init__(self, model: Union[nn.Module, Callable[..., Dict[str, Tensor]]], device: str): self.model = model self.device = device @abstractmethod def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: raise NotImplementedError def __call__(self, batch: Tuple[Tensor, ...]) -> Tensor: return self.compute_loss(batch) class SeqStrategy(BaseStrategy): def __init__(self, model, device): super().__init__(model, device) def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: batch = move_to_device(batch, self.device) input_ids, target_ids = batch["input_ids"], batch["target_ids"] logits = self.model(input_ids=input_ids)["logits"] loss = F.cross_entropy( input=logits.flatten(0, 1), target=target_ids.flatten() ) return loss class SftStrategy(BaseStrategy): def __init__(self, model, device): super().__init__(model, device) def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: batch = move_to_device(batch, self.device) input_ids, target_ids = batch["input_ids"], batch["target_ids"] loss_mask, attn_mask = batch["loss_mask"], batch["attn_mask"] ignore_index = -100 logits = self.model(input_ids=input_ids, input_mask=attn_mask)["logits"] target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index) loss = F.cross_entropy( input=logits.flatten(0, 1), target=target_ids.flatten(), ignore_index=ignore_index ) return loss class DpoStrategy(BaseStrategy): def __init__(self, model, device, pad_token_id, beta): super().__init__(model, device) ref_model = copy.deepcopy(self.model) ref_model.requires_grad_(False) ref_model.eval() self.ref_model = ref_model self.pad_token_id = pad_token_id self.beta = beta def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor: batch = move_to_device(batch, self.device) good_ids, bad_ids = batch["chosen"], batch["rejected"] good_mask, bad_mask = batch["chosen_mask"], batch["rejected_mask"] log_pi_good = get_logprobs(self.model, good_ids, good_mask, self.pad_token_id) log_pi_bad = get_logprobs(self.model, bad_ids, bad_mask, self.pad_token_id) with torch.no_grad(): log_ref_good = get_logprobs(self.ref_model, good_ids, good_mask, self.pad_token_id) log_ref_bad = get_logprobs(self.ref_model, bad_ids, bad_mask, self.pad_token_id) pi_log_ratio = log_pi_good - log_pi_bad ref_log_ratio = log_ref_good - log_ref_bad ratio_diff = pi_log_ratio - ref_log_ratio dpo_loss = -F.logsigmoid(self.beta * ratio_diff).mean() return dpo_loss class PpoStrategy(BaseStrategy): def __init__(self, model, pad_token_id, epsilon): super().__init__(model) ref_model = copy.deepcopy(self.model) ref_model.requires_grad_(False) ref_model.eval() self.ref_model = ref_model self.pad_token_id = pad_token_id self.epsilon = epsilon def ppo_clip_loss_masked( self, log_probs: Tensor, old_log_probs: Tensor, advantages: Tensor, values: Tensor, returns: Tensor, mask: Tensor, clip_eps: float=0.2, ): ratio = torch.exp(log_probs - old_log_probs) surr1 = ratio * advantages surr2 = torch.clamp(ratio, 1 - clip_eps, 1 + clip_eps) * advantages policy_loss = -torch.min(surr1, surr2).masked_select(mask).mean() value_loss = F.mse_loss(values.masked_select(mask), returns.masked_select(mask)) entropy = -(log_probs.exp() * log_probs).masked_select(mask).mean() entropy_loss = -entropy return policy_loss, value_loss, entropy_loss class StrategyFactory: def load(model, train_type, device, **kwargs): train_strategy: Dict[str, Callable[[], BaseStrategy]] = { "seq": lambda: SeqStrategy(model, device), "sft": lambda: SftStrategy(model, device), "dpo": lambda: DpoStrategy( model, device, kwargs.get("pad_token_id"), kwargs.get("dpo_beta") ) } strategy = train_strategy[train_type]() return strategy @dataclass class ScheduleConfig(ABC): schedule_type: str = field( default="cosine", metadata={ "help": "Type of learning rate schedule.", "choices": ["cosine", "sgdr"] } ) warmup_steps: int = field( default=1000, metadata={"help": "Number of warmup steps."} ) min_rate: float = field( default=0.05, metadata={"help": "Minimum learning rate multiplier."} ) @abstractmethod def get_kwargs(self) -> Dict[str, Any]: raise NotImplementedError def validate(self) -> None: """Validate configuration parameters.""" if self.warmup_steps < 0: raise ValueError(f"warmup_steps must be non-negative, got {self.warmup_steps}") if not 0 <= self.min_rate <= 1: raise ValueError(f"min_rate must be between 0 and 1, got {self.min_rate}") @dataclass class CosineScheduleConfig(ScheduleConfig): total_steps: int = field( default=None, metadata={"help": "Total training steps for cosine schedule."} ) schedule_type: Literal["cosine"] = "cosine" def get_kwargs(self) -> Dict[str, Any]: if self.total_steps is None: raise ValueError("total_steps must be specified for cosine schedule") return { "schedule_type": self.schedule_type, "warmup_steps": self.warmup_steps, "lr_decay_steps": self.total_steps - self.warmup_steps, "min_rate": self.min_rate } def validate(self) -> None: super().validate() if self.total_steps is not None and self.total_steps <= self.warmup_steps: raise ValueError(f"total_steps ({self.total_steps}) must be greater than warmup_steps ({self.warmup_steps})") @dataclass class SgdrScheduleConfig(ScheduleConfig): cycle_length: int = field( default=1000, metadata={"help": "Length of the first cycle in steps."} ) t_mult: int = field( default=2, metadata={"help": "Multiplier for cycle length growth."} ) schedule_type: Literal["sgdr"] = "sgdr" def get_kwargs(self) -> Dict[str, Any]: return { "schedule_type": self.schedule_type, "warmup_steps": self.warmup_steps, "cycle_length": self.cycle_length, "min_rate": self.min_rate, "t_mult": self.t_mult } def validate(self) -> None: super().validate() if self.cycle_length <= 0: raise ValueError(f"cycle_length must be positive, got {self.cycle_length}") if self.t_mult < 1: raise ValueError(f"t_mult must be >= 1, got {self.t_mult}") class SchedulerFactory: """Factory for creating learning rate schedule functions.""" @staticmethod def get_sgdr_schedule( warmup_steps: int, cycle_length: int, min_rate: float = 0.05, t_mult: int = 2 ) -> Callable[[int], float]: """ Create SGDR (Stochastic Gradient Descent with Warm Restarts) schedule. Args: warmup_steps: Number of warmup steps cycle_length: Length of the first cycle min_rate: Minimum learning rate multiplier t_mult: Cycle length multiplier Returns: Schedule function that takes current step and returns LR multiplier """ def sgdr_schedule(current_step: int) -> float: # Warmup phase if current_step < warmup_steps: return max(min_rate, current_step / warmup_steps) # SGDR phase steps_since_warmup = current_step - warmup_steps # Find current cycle and position within cycle cycle_start = 0 current_cycle_length = cycle_length cycle_index = 0 while steps_since_warmup >= cycle_start + current_cycle_length: cycle_start += current_cycle_length current_cycle_length *= t_mult cycle_index += 1 position_in_cycle = steps_since_warmup - cycle_start progress = position_in_cycle / current_cycle_length # Cosine annealing within cycle return max(min_rate, 0.5 * (1 + math.cos(math.pi * progress))) return sgdr_schedule @staticmethod def get_cosine_schedule( warmup_steps: int, lr_decay_steps: int, min_rate: float = 0.05 ) -> Callable[[int], float]: """ Create cosine decay schedule with warmup. Args: warmup_steps: Number of warmup steps lr_decay_steps: Number of steps for cosine decay after warmup min_rate: Minimum learning rate multiplier Returns: Schedule function that takes current step and returns LR multiplier """ def cosine_schedule(current_step: int) -> float: if current_step < warmup_steps: # Linear warmup return max(min_rate, current_step / warmup_steps) else: # Cosine decay decay_progress = (current_step - warmup_steps) / lr_decay_steps decay_progress = min(decay_progress, 1.0) # Clamp at 1.0 return max(min_rate, 0.5 * (1.0 + math.cos(math.pi * decay_progress))) return cosine_schedule @staticmethod def load_schedule_fn(scedule_config: ScheduleConfig) -> Callable[[int], float]: kwargs = scedule_config.get_kwargs() schedule_type = kwargs.pop("schedule_type") if schedule_type == "cosine": return SchedulerFactory.get_cosine_schedule(**kwargs) elif schedule_type == "sgdr": return SchedulerFactory.get_sgdr_schedule(**kwargs) else: raise ValueError(f"Unsupported schedule type: {schedule_type}")