415 lines
13 KiB
Python
415 lines
13 KiB
Python
import copy
|
|
import math
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from torch import Tensor
|
|
from torch.optim import Optimizer
|
|
from torch.utils.data import Dataset
|
|
from typing import Any, Literal, Tuple, Callable, Dict
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import asdict, dataclass, field
|
|
|
|
|
|
def get_logprobs(model:nn.Module, input_ids: Tensor, mask: Tensor, pad_token_id: int):
|
|
input_mask = input_ids.ne(pad_token_id)
|
|
logits = model(input_ids, input_mask)["logits"]
|
|
log_probs = torch.log_softmax(logits, dim=-1)
|
|
|
|
shifted_log_probs = log_probs[:, :-1, :]
|
|
shifted_input_ids = input_ids[:, 1:]
|
|
shifted_response_mask = mask[:, 1:]
|
|
|
|
token_logprobs = torch.gather(
|
|
shifted_log_probs,
|
|
dim=-1,
|
|
index=shifted_input_ids.unsqueeze(-1)
|
|
).squeeze(-1)
|
|
|
|
prompt_mask = input_mask[:, 1:]
|
|
valid_mask = (prompt_mask & shifted_response_mask).float()
|
|
|
|
return (token_logprobs * valid_mask).sum(dim=-1)
|
|
|
|
|
|
class BaseStrategy(ABC):
|
|
def __init__(self, model: nn.Module):
|
|
self.model = model
|
|
|
|
@abstractmethod
|
|
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
|
raise NotImplementedError
|
|
|
|
def __call__(self, batch: Tuple[Tensor, ...]) -> Tensor:
|
|
return self.compute_loss(batch)
|
|
|
|
|
|
class SeqStrategy(BaseStrategy):
|
|
def __init__(self, model):
|
|
super().__init__(model)
|
|
|
|
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
|
input_ids, target_ids = batch["input_ids"], batch["target_ids"]
|
|
B, L = input_ids.size()
|
|
logits: Tensor = self.model(input_ids=input_ids)["logits"]
|
|
|
|
loss = F.cross_entropy(
|
|
input=logits.view(B * L, -1),
|
|
target=target_ids.flatten()
|
|
)
|
|
return loss
|
|
|
|
|
|
class SftStrategy(BaseStrategy):
|
|
def __init__(self, model: nn.Module):
|
|
super().__init__(model)
|
|
|
|
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
|
input_ids, target_ids = batch["input_ids"], batch["target_ids"]
|
|
loss_mask, attn_mask = batch["loss_mask"], batch["attn_mask"]
|
|
|
|
ignore_index = -100
|
|
B, L = input_ids.size()
|
|
|
|
logits: Tensor = self.model(
|
|
input_ids=input_ids,
|
|
input_mask=attn_mask
|
|
)["logits"]
|
|
|
|
target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index)
|
|
|
|
loss = F.cross_entropy(
|
|
input=logits.view(B * L, -1),
|
|
target=target_ids.flatten(),
|
|
ignore_index=ignore_index
|
|
)
|
|
|
|
return loss
|
|
|
|
|
|
class DpoStrategy(BaseStrategy):
|
|
def __init__(self, model, pad_token_id, beta):
|
|
super().__init__(model)
|
|
ref_model = copy.deepcopy(self.model)
|
|
ref_model.requires_grad_(False)
|
|
ref_model.eval()
|
|
|
|
self.ref_model = ref_model
|
|
self.pad_token_id = pad_token_id
|
|
self.beta = beta
|
|
|
|
def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor:
|
|
good_ids, bad_ids = batch["chosen"], batch["rejected"]
|
|
good_mask, bad_mask = batch["chosen_mask"], batch["rejected_mask"]
|
|
|
|
log_pi_good = get_logprobs(self.model, good_ids, good_mask, self.pad_token_id)
|
|
log_pi_bad = get_logprobs(self.model, bad_ids, bad_mask, self.pad_token_id)
|
|
|
|
with torch.no_grad():
|
|
log_ref_good = get_logprobs(self.ref_model, good_ids, good_mask, self.pad_token_id)
|
|
log_ref_bad = get_logprobs(self.ref_model, bad_ids, bad_mask, self.pad_token_id)
|
|
|
|
pi_log_ratio = log_pi_good - log_pi_bad
|
|
ref_log_ratio = log_ref_good - log_ref_bad
|
|
|
|
ratio_diff = pi_log_ratio - ref_log_ratio
|
|
|
|
dpo_loss = -F.logsigmoid(self.beta * ratio_diff).mean()
|
|
return dpo_loss
|
|
|
|
|
|
class PpoStrategy(BaseStrategy):
|
|
def __init__(self, model, pad_token_id, epsilon):
|
|
super().__init__(model)
|
|
ref_model = copy.deepcopy(self.model)
|
|
ref_model.requires_grad_(False)
|
|
ref_model.eval()
|
|
|
|
self.ref_model = ref_model
|
|
self.pad_token_id = pad_token_id
|
|
self.epsilon = epsilon
|
|
|
|
def ppo_clip_loss_masked(
|
|
self,
|
|
log_probs: Tensor,
|
|
old_log_probs: Tensor,
|
|
advantages: Tensor,
|
|
values: Tensor,
|
|
returns: Tensor,
|
|
mask: Tensor,
|
|
clip_eps: float=0.2,
|
|
):
|
|
ratio = torch.exp(log_probs - old_log_probs)
|
|
surr1 = ratio * advantages
|
|
surr2 = torch.clamp(ratio, 1 - clip_eps, 1 + clip_eps) * advantages
|
|
policy_loss = -torch.min(surr1, surr2).masked_select(mask).mean()
|
|
|
|
value_loss = F.mse_loss(values.masked_select(mask),
|
|
returns.masked_select(mask))
|
|
|
|
entropy = -(log_probs.exp() * log_probs).masked_select(mask).mean()
|
|
entropy_loss = -entropy
|
|
return policy_loss, value_loss, entropy_loss
|
|
|
|
|
|
|
|
class StrategyFactory:
|
|
|
|
def load(model, train_type, **kwargs):
|
|
train_strategy: Dict[str, Callable[[], BaseStrategy]] = {
|
|
"seq": lambda: SeqStrategy(model),
|
|
"sft": lambda: SftStrategy(
|
|
model,
|
|
kwargs.get("bos_token_id"),
|
|
kwargs.get("eos_token_id"),
|
|
kwargs.get("multi_turn")
|
|
),
|
|
"dpo": lambda: DpoStrategy(
|
|
model,
|
|
kwargs.get("pad_token_id"),
|
|
kwargs.get("dpo_beta")
|
|
)
|
|
}
|
|
strategy = train_strategy[train_type]()
|
|
return strategy
|
|
|
|
|
|
@dataclass
|
|
class TrainConfig:
|
|
|
|
strategy: BaseStrategy = field(
|
|
default=None,
|
|
metadata={"help": "Training strategy."}
|
|
)
|
|
dataset: Dataset = field(
|
|
default=None,
|
|
metadata={"help": "Dataset for training."}
|
|
)
|
|
optimizer: Optimizer = field(
|
|
default=None,
|
|
metadata={"help": "Optimizer for training."}
|
|
)
|
|
checkpoint_dir: str = field(
|
|
default="./checkpoint",
|
|
metadata={"help": "Checkpoint directory."}
|
|
)
|
|
n_epoch: int = field(
|
|
default=1,
|
|
metadata={"help": "Number of epochs for training."}
|
|
)
|
|
batch_size: int = field(
|
|
default=4,
|
|
metadata={"help": "Batch size for training."}
|
|
)
|
|
checkpoint_interval: int = field(
|
|
default=5000,
|
|
metadata={"help": "Number of iterations between checkpoints."}
|
|
)
|
|
accumulation_steps: int = field(
|
|
default=1,
|
|
metadata={"help": "Number of iterations between steps."}
|
|
)
|
|
max_grad_norm: float = field(
|
|
default=1.0,
|
|
metadata={"help": "Maximum gradient norm."}
|
|
)
|
|
random_seed: int = field(
|
|
default=3407,
|
|
metadata={"help": "Random seed."}
|
|
)
|
|
|
|
def get_kwargs(self)-> Dict[str, Any]:
|
|
config_dict = asdict(self)
|
|
return {k: v for k, v in config_dict.items() if v is not None}
|
|
|
|
|
|
@dataclass
|
|
class ScheduleConfig(ABC):
|
|
schedule_type: str = field(
|
|
default="cosine",
|
|
metadata={
|
|
"help": "Type of learning rate schedule.",
|
|
"choices": ["cosine", "sgdr"]
|
|
}
|
|
)
|
|
warmup_steps: int = field(
|
|
default=1000,
|
|
metadata={"help": "Number of warmup steps."}
|
|
)
|
|
min_rate: float = field(
|
|
default=0.05,
|
|
metadata={"help": "Minimum learning rate multiplier."}
|
|
)
|
|
|
|
@abstractmethod
|
|
def get_kwargs(self) -> Dict[str, Any]:
|
|
raise NotImplementedError
|
|
|
|
def validate(self) -> None:
|
|
"""Validate configuration parameters."""
|
|
if self.warmup_steps < 0:
|
|
raise ValueError(f"warmup_steps must be non-negative, got {self.warmup_steps}")
|
|
if not 0 <= self.min_rate <= 1:
|
|
raise ValueError(f"min_rate must be between 0 and 1, got {self.min_rate}")
|
|
|
|
|
|
@dataclass
|
|
class CosineScheduleConfig(ScheduleConfig):
|
|
total_steps: int = field(
|
|
default=None,
|
|
metadata={"help": "Total training steps for cosine schedule."}
|
|
)
|
|
schedule_type: Literal["cosine"] = "cosine"
|
|
|
|
def get_kwargs(self) -> Dict[str, Any]:
|
|
if self.total_steps is None:
|
|
raise ValueError("total_steps must be specified for cosine schedule")
|
|
|
|
return {
|
|
"schedule_type": self.schedule_type,
|
|
"warmup_steps": self.warmup_steps,
|
|
"lr_decay_steps": self.total_steps - self.warmup_steps,
|
|
"min_rate": self.min_rate
|
|
}
|
|
|
|
def validate(self) -> None:
|
|
super().validate()
|
|
if self.total_steps is not None and self.total_steps <= self.warmup_steps:
|
|
raise ValueError(f"total_steps ({self.total_steps}) must be greater than warmup_steps ({self.warmup_steps})")
|
|
|
|
|
|
@dataclass
|
|
class SgdrScheduleConfig(ScheduleConfig):
|
|
cycle_length: int = field(
|
|
default=1000,
|
|
metadata={"help": "Length of the first cycle in steps."}
|
|
)
|
|
t_mult: int = field(
|
|
default=2,
|
|
metadata={"help": "Multiplier for cycle length growth."}
|
|
)
|
|
schedule_type: Literal["sgdr"] = "sgdr"
|
|
|
|
def get_kwargs(self) -> Dict[str, Any]:
|
|
return {
|
|
"schedule_type": self.schedule_type,
|
|
"warmup_steps": self.warmup_steps,
|
|
"cycle_length": self.cycle_length,
|
|
"min_rate": self.min_rate,
|
|
"t_mult": self.t_mult
|
|
}
|
|
|
|
def validate(self) -> None:
|
|
super().validate()
|
|
if self.cycle_length <= 0:
|
|
raise ValueError(f"cycle_length must be positive, got {self.cycle_length}")
|
|
if self.t_mult < 1:
|
|
raise ValueError(f"t_mult must be >= 1, got {self.t_mult}")
|
|
|
|
|
|
class SchedulerFactory:
|
|
"""Factory for creating learning rate schedule functions."""
|
|
|
|
@staticmethod
|
|
def get_sgdr_schedule(
|
|
warmup_steps: int,
|
|
cycle_length: int,
|
|
min_rate: float = 0.05,
|
|
t_mult: int = 2
|
|
) -> Callable[[int], float]:
|
|
"""
|
|
Create SGDR (Stochastic Gradient Descent with Warm Restarts) schedule.
|
|
|
|
Args:
|
|
warmup_steps: Number of warmup steps
|
|
cycle_length: Length of the first cycle
|
|
min_rate: Minimum learning rate multiplier
|
|
t_mult: Cycle length multiplier
|
|
|
|
Returns:
|
|
Schedule function that takes current step and returns LR multiplier
|
|
"""
|
|
|
|
def sgdr_schedule(current_step: int) -> float:
|
|
# Warmup phase
|
|
if current_step < warmup_steps:
|
|
return max(min_rate, current_step / warmup_steps)
|
|
|
|
# SGDR phase
|
|
steps_since_warmup = current_step - warmup_steps
|
|
|
|
# Find current cycle and position within cycle
|
|
cycle_start = 0
|
|
current_cycle_length = cycle_length
|
|
cycle_index = 0
|
|
|
|
while steps_since_warmup >= cycle_start + current_cycle_length:
|
|
cycle_start += current_cycle_length
|
|
current_cycle_length *= t_mult
|
|
cycle_index += 1
|
|
|
|
position_in_cycle = steps_since_warmup - cycle_start
|
|
progress = position_in_cycle / current_cycle_length
|
|
|
|
# Cosine annealing within cycle
|
|
return max(min_rate, 0.5 * (1 + math.cos(math.pi * progress)))
|
|
|
|
return sgdr_schedule
|
|
|
|
@staticmethod
|
|
def get_cosine_schedule(
|
|
warmup_steps: int,
|
|
lr_decay_steps: int,
|
|
min_rate: float = 0.05
|
|
) -> Callable[[int], float]:
|
|
"""
|
|
Create cosine decay schedule with warmup.
|
|
|
|
Args:
|
|
warmup_steps: Number of warmup steps
|
|
lr_decay_steps: Number of steps for cosine decay after warmup
|
|
min_rate: Minimum learning rate multiplier
|
|
|
|
Returns:
|
|
Schedule function that takes current step and returns LR multiplier
|
|
"""
|
|
|
|
def cosine_schedule(current_step: int) -> float:
|
|
if current_step < warmup_steps:
|
|
# Linear warmup
|
|
return max(min_rate, current_step / warmup_steps)
|
|
else:
|
|
# Cosine decay
|
|
decay_progress = (current_step - warmup_steps) / lr_decay_steps
|
|
decay_progress = min(decay_progress, 1.0) # Clamp at 1.0
|
|
return max(min_rate, 0.5 * (1.0 + math.cos(math.pi * decay_progress)))
|
|
|
|
return cosine_schedule
|
|
|
|
@staticmethod
|
|
def create_schedule(config: ScheduleConfig) -> Callable[[int], float]:
|
|
"""
|
|
Create schedule from configuration.
|
|
|
|
Args:
|
|
config: Schedule configuration instance
|
|
|
|
Returns:
|
|
Schedule function
|
|
"""
|
|
config.validate()
|
|
kwargs = config.get_kwargs()
|
|
return SchedulerFactory.load_schedule_fn(**kwargs)
|
|
|
|
@staticmethod
|
|
def load_schedule_fn(**kwargs) -> Callable[[int], float]:
|
|
schedule_type = kwargs.pop("schedule_type")
|
|
|
|
if schedule_type == "cosine":
|
|
return SchedulerFactory.get_cosine_schedule(**kwargs)
|
|
elif schedule_type == "sgdr":
|
|
return SchedulerFactory.get_sgdr_schedule(**kwargs)
|
|
else:
|
|
raise ValueError(f"Unsupported schedule type: {schedule_type}")
|
|
|