AstrAI/khaosz/trainer/strategy.py

355 lines
12 KiB
Python

import copy
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.optim import Optimizer
from torch.utils.data import Dataset
from typing import Any, Literal, Optional, Tuple, Callable, Dict
from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass, field
def get_logprobs(model:nn.Module, input_ids: Tensor, mask: Tensor, pad_token_id: int):
input_mask = input_ids.ne(pad_token_id)
logits = model(input_ids, input_mask)["logits"]
log_probs = torch.log_softmax(logits, dim=-1)
shifted_log_probs = log_probs[:, :-1, :]
shifted_input_ids = input_ids[:, 1:]
shifted_response_mask = mask[:, 1:]
token_logprobs = torch.gather(
shifted_log_probs,
dim=-1,
index=shifted_input_ids.unsqueeze(-1)
).squeeze(-1)
prompt_mask = input_mask[:, 1:]
valid_mask = (prompt_mask & shifted_response_mask).float()
return (token_logprobs * valid_mask).sum(dim=-1)
def move_to_device(batch:Dict[str, Tensor], device: str) -> Any:
return {key: value.to(device, non_blocking=True) for key, value in batch.items()}
class BaseStrategy(ABC):
def __init__(self, model: nn.Module, device: str):
self.model = model
self.device = device
@abstractmethod
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
raise NotImplementedError
def __call__(self, batch: Tuple[Tensor, ...]) -> Tensor:
return self.compute_loss(batch)
class SeqStrategy(BaseStrategy):
def __init__(self, model, device):
super().__init__(model, device)
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
batch = move_to_device(batch, self.device)
input_ids, target_ids = batch["input_ids"], batch["target_ids"]
B, L = input_ids.size()
logits: Tensor = self.model(input_ids=input_ids)["logits"]
loss = F.cross_entropy(
input=logits.view(B * L, -1),
target=target_ids.flatten()
)
return loss
class SftStrategy(BaseStrategy):
def __init__(self, model, device):
super().__init__(model, device)
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
batch = move_to_device(batch, self.device)
input_ids, target_ids = batch["input_ids"], batch["target_ids"]
loss_mask, attn_mask = batch["loss_mask"], batch["attn_mask"]
ignore_index = -100
B, L = input_ids.size()
logits: Tensor = self.model(
input_ids=input_ids,
input_mask=attn_mask
)["logits"]
target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index)
loss = F.cross_entropy(
input=logits.view(B * L, -1),
target=target_ids.flatten(),
ignore_index=ignore_index
)
return loss
class DpoStrategy(BaseStrategy):
def __init__(self, model, device, pad_token_id, beta):
super().__init__(model, device)
ref_model = copy.deepcopy(self.model)
ref_model.requires_grad_(False)
ref_model.eval()
self.ref_model = ref_model
self.pad_token_id = pad_token_id
self.beta = beta
def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor:
batch = move_to_device(batch, self.device)
good_ids, bad_ids = batch["chosen"], batch["rejected"]
good_mask, bad_mask = batch["chosen_mask"], batch["rejected_mask"]
log_pi_good = get_logprobs(self.model, good_ids, good_mask, self.pad_token_id)
log_pi_bad = get_logprobs(self.model, bad_ids, bad_mask, self.pad_token_id)
with torch.no_grad():
log_ref_good = get_logprobs(self.ref_model, good_ids, good_mask, self.pad_token_id)
log_ref_bad = get_logprobs(self.ref_model, bad_ids, bad_mask, self.pad_token_id)
pi_log_ratio = log_pi_good - log_pi_bad
ref_log_ratio = log_ref_good - log_ref_bad
ratio_diff = pi_log_ratio - ref_log_ratio
dpo_loss = -F.logsigmoid(self.beta * ratio_diff).mean()
return dpo_loss
class PpoStrategy(BaseStrategy):
def __init__(self, model, pad_token_id, epsilon):
super().__init__(model)
ref_model = copy.deepcopy(self.model)
ref_model.requires_grad_(False)
ref_model.eval()
self.ref_model = ref_model
self.pad_token_id = pad_token_id
self.epsilon = epsilon
def ppo_clip_loss_masked(
self,
log_probs: Tensor,
old_log_probs: Tensor,
advantages: Tensor,
values: Tensor,
returns: Tensor,
mask: Tensor,
clip_eps: float=0.2,
):
ratio = torch.exp(log_probs - old_log_probs)
surr1 = ratio * advantages
surr2 = torch.clamp(ratio, 1 - clip_eps, 1 + clip_eps) * advantages
policy_loss = -torch.min(surr1, surr2).masked_select(mask).mean()
value_loss = F.mse_loss(values.masked_select(mask),
returns.masked_select(mask))
entropy = -(log_probs.exp() * log_probs).masked_select(mask).mean()
entropy_loss = -entropy
return policy_loss, value_loss, entropy_loss
class StrategyFactory:
def load(model, train_type, device, **kwargs):
train_strategy: Dict[str, Callable[[], BaseStrategy]] = {
"seq": lambda: SeqStrategy(model, device),
"sft": lambda: SftStrategy(model, device),
"dpo": lambda: DpoStrategy(
model,
device,
kwargs.get("pad_token_id"),
kwargs.get("dpo_beta")
)
}
strategy = train_strategy[train_type]()
return strategy
@dataclass
class ScheduleConfig(ABC):
schedule_type: str = field(
default="cosine",
metadata={
"help": "Type of learning rate schedule.",
"choices": ["cosine", "sgdr"]
}
)
warmup_steps: int = field(
default=1000,
metadata={"help": "Number of warmup steps."}
)
min_rate: float = field(
default=0.05,
metadata={"help": "Minimum learning rate multiplier."}
)
@abstractmethod
def get_kwargs(self) -> Dict[str, Any]:
raise NotImplementedError
def validate(self) -> None:
"""Validate configuration parameters."""
if self.warmup_steps < 0:
raise ValueError(f"warmup_steps must be non-negative, got {self.warmup_steps}")
if not 0 <= self.min_rate <= 1:
raise ValueError(f"min_rate must be between 0 and 1, got {self.min_rate}")
@dataclass
class CosineScheduleConfig(ScheduleConfig):
total_steps: int = field(
default=None,
metadata={"help": "Total training steps for cosine schedule."}
)
schedule_type: Literal["cosine"] = "cosine"
def get_kwargs(self) -> Dict[str, Any]:
if self.total_steps is None:
raise ValueError("total_steps must be specified for cosine schedule")
return {
"schedule_type": self.schedule_type,
"warmup_steps": self.warmup_steps,
"lr_decay_steps": self.total_steps - self.warmup_steps,
"min_rate": self.min_rate
}
def validate(self) -> None:
super().validate()
if self.total_steps is not None and self.total_steps <= self.warmup_steps:
raise ValueError(f"total_steps ({self.total_steps}) must be greater than warmup_steps ({self.warmup_steps})")
@dataclass
class SgdrScheduleConfig(ScheduleConfig):
cycle_length: int = field(
default=1000,
metadata={"help": "Length of the first cycle in steps."}
)
t_mult: int = field(
default=2,
metadata={"help": "Multiplier for cycle length growth."}
)
schedule_type: Literal["sgdr"] = "sgdr"
def get_kwargs(self) -> Dict[str, Any]:
return {
"schedule_type": self.schedule_type,
"warmup_steps": self.warmup_steps,
"cycle_length": self.cycle_length,
"min_rate": self.min_rate,
"t_mult": self.t_mult
}
def validate(self) -> None:
super().validate()
if self.cycle_length <= 0:
raise ValueError(f"cycle_length must be positive, got {self.cycle_length}")
if self.t_mult < 1:
raise ValueError(f"t_mult must be >= 1, got {self.t_mult}")
class SchedulerFactory:
"""Factory for creating learning rate schedule functions."""
@staticmethod
def get_sgdr_schedule(
warmup_steps: int,
cycle_length: int,
min_rate: float = 0.05,
t_mult: int = 2
) -> Callable[[int], float]:
"""
Create SGDR (Stochastic Gradient Descent with Warm Restarts) schedule.
Args:
warmup_steps: Number of warmup steps
cycle_length: Length of the first cycle
min_rate: Minimum learning rate multiplier
t_mult: Cycle length multiplier
Returns:
Schedule function that takes current step and returns LR multiplier
"""
def sgdr_schedule(current_step: int) -> float:
# Warmup phase
if current_step < warmup_steps:
return max(min_rate, current_step / warmup_steps)
# SGDR phase
steps_since_warmup = current_step - warmup_steps
# Find current cycle and position within cycle
cycle_start = 0
current_cycle_length = cycle_length
cycle_index = 0
while steps_since_warmup >= cycle_start + current_cycle_length:
cycle_start += current_cycle_length
current_cycle_length *= t_mult
cycle_index += 1
position_in_cycle = steps_since_warmup - cycle_start
progress = position_in_cycle / current_cycle_length
# Cosine annealing within cycle
return max(min_rate, 0.5 * (1 + math.cos(math.pi * progress)))
return sgdr_schedule
@staticmethod
def get_cosine_schedule(
warmup_steps: int,
lr_decay_steps: int,
min_rate: float = 0.05
) -> Callable[[int], float]:
"""
Create cosine decay schedule with warmup.
Args:
warmup_steps: Number of warmup steps
lr_decay_steps: Number of steps for cosine decay after warmup
min_rate: Minimum learning rate multiplier
Returns:
Schedule function that takes current step and returns LR multiplier
"""
def cosine_schedule(current_step: int) -> float:
if current_step < warmup_steps:
# Linear warmup
return max(min_rate, current_step / warmup_steps)
else:
# Cosine decay
decay_progress = (current_step - warmup_steps) / lr_decay_steps
decay_progress = min(decay_progress, 1.0) # Clamp at 1.0
return max(min_rate, 0.5 * (1.0 + math.cos(math.pi * decay_progress)))
return cosine_schedule
@staticmethod
def load_schedule_fn(scedule_config: ScheduleConfig) -> Callable[[int], float]:
kwargs = scedule_config.get_kwargs()
schedule_type = kwargs.pop("schedule_type")
if schedule_type == "cosine":
return SchedulerFactory.get_cosine_schedule(**kwargs)
elif schedule_type == "sgdr":
return SchedulerFactory.get_sgdr_schedule(**kwargs)
else:
raise ValueError(f"Unsupported schedule type: {schedule_type}")