AstrAI/khaosz/trainer/strategy.py

347 lines
12 KiB
Python

import copy
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from typing import Any, Literal, Tuple, Callable, Dict, Union
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
def get_logprobs(model:nn.Module, input_ids: Tensor, mask: Tensor, pad_token_id: int):
input_mask = input_ids.ne(pad_token_id)
logits = model(input_ids, input_mask)["logits"]
log_probs = torch.log_softmax(logits, dim=-1)
shifted_log_probs = log_probs[:, :-1, :]
shifted_input_ids = input_ids[:, 1:]
shifted_response_mask = mask[:, 1:]
token_logprobs = torch.gather(
shifted_log_probs,
dim=-1,
index=shifted_input_ids.unsqueeze(-1)
).squeeze(-1)
prompt_mask = input_mask[:, 1:]
valid_mask = (prompt_mask & shifted_response_mask).float()
return (token_logprobs * valid_mask).sum(dim=-1)
def move_to_device(batch:Dict[str, Tensor], device: str) -> Any:
return {key: value.to(device, non_blocking=True) for key, value in batch.items()}
class BaseStrategy(ABC):
def __init__(self, model: Union[nn.Module, Callable[..., Dict[str, Tensor]]], device: str):
self.model = model
self.device = device
@abstractmethod
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
raise NotImplementedError
def __call__(self, batch: Tuple[Tensor, ...]) -> Tensor:
return self.compute_loss(batch)
class SeqStrategy(BaseStrategy):
def __init__(self, model, device):
super().__init__(model, device)
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
batch = move_to_device(batch, self.device)
input_ids, target_ids = batch["input_ids"], batch["target_ids"]
logits = self.model(input_ids=input_ids)["logits"]
loss = F.cross_entropy(
input=logits.flatten(0, 1),
target=target_ids.flatten()
)
return loss
class SftStrategy(BaseStrategy):
def __init__(self, model, device):
super().__init__(model, device)
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
batch = move_to_device(batch, self.device)
input_ids, target_ids = batch["input_ids"], batch["target_ids"]
loss_mask, attn_mask = batch["loss_mask"], batch["attn_mask"]
ignore_index = -100
logits = self.model(input_ids=input_ids, input_mask=attn_mask)["logits"]
target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index)
loss = F.cross_entropy(
input=logits.flatten(0, 1),
target=target_ids.flatten(),
ignore_index=ignore_index
)
return loss
class DpoStrategy(BaseStrategy):
def __init__(self, model, device, pad_token_id, beta):
super().__init__(model, device)
ref_model = copy.deepcopy(self.model)
ref_model.requires_grad_(False)
ref_model.eval()
self.ref_model = ref_model
self.pad_token_id = pad_token_id
self.beta = beta
def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor:
batch = move_to_device(batch, self.device)
good_ids, bad_ids = batch["chosen"], batch["rejected"]
good_mask, bad_mask = batch["chosen_mask"], batch["rejected_mask"]
log_pi_good = get_logprobs(self.model, good_ids, good_mask, self.pad_token_id)
log_pi_bad = get_logprobs(self.model, bad_ids, bad_mask, self.pad_token_id)
with torch.no_grad():
log_ref_good = get_logprobs(self.ref_model, good_ids, good_mask, self.pad_token_id)
log_ref_bad = get_logprobs(self.ref_model, bad_ids, bad_mask, self.pad_token_id)
pi_log_ratio = log_pi_good - log_pi_bad
ref_log_ratio = log_ref_good - log_ref_bad
ratio_diff = pi_log_ratio - ref_log_ratio
dpo_loss = -F.logsigmoid(self.beta * ratio_diff).mean()
return dpo_loss
class PpoStrategy(BaseStrategy):
def __init__(self, model, pad_token_id, epsilon):
super().__init__(model)
ref_model = copy.deepcopy(self.model)
ref_model.requires_grad_(False)
ref_model.eval()
self.ref_model = ref_model
self.pad_token_id = pad_token_id
self.epsilon = epsilon
def ppo_clip_loss_masked(
self,
log_probs: Tensor,
old_log_probs: Tensor,
advantages: Tensor,
values: Tensor,
returns: Tensor,
mask: Tensor,
clip_eps: float=0.2,
):
ratio = torch.exp(log_probs - old_log_probs)
surr1 = ratio * advantages
surr2 = torch.clamp(ratio, 1 - clip_eps, 1 + clip_eps) * advantages
policy_loss = -torch.min(surr1, surr2).masked_select(mask).mean()
value_loss = F.mse_loss(values.masked_select(mask),
returns.masked_select(mask))
entropy = -(log_probs.exp() * log_probs).masked_select(mask).mean()
entropy_loss = -entropy
return policy_loss, value_loss, entropy_loss
class StrategyFactory:
def load(model, train_type, device, **kwargs):
train_strategy: Dict[str, Callable[[], BaseStrategy]] = {
"seq": lambda: SeqStrategy(model, device),
"sft": lambda: SftStrategy(model, device),
"dpo": lambda: DpoStrategy(
model,
device,
kwargs.get("pad_token_id"),
kwargs.get("dpo_beta")
)
}
strategy = train_strategy[train_type]()
return strategy
@dataclass
class ScheduleConfig(ABC):
schedule_type: str = field(
default="cosine",
metadata={
"help": "Type of learning rate schedule.",
"choices": ["cosine", "sgdr"]
}
)
warmup_steps: int = field(
default=1000,
metadata={"help": "Number of warmup steps."}
)
min_rate: float = field(
default=0.05,
metadata={"help": "Minimum learning rate multiplier."}
)
@abstractmethod
def get_kwargs(self) -> Dict[str, Any]:
raise NotImplementedError
def validate(self) -> None:
"""Validate configuration parameters."""
if self.warmup_steps < 0:
raise ValueError(f"warmup_steps must be non-negative, got {self.warmup_steps}")
if not 0 <= self.min_rate <= 1:
raise ValueError(f"min_rate must be between 0 and 1, got {self.min_rate}")
@dataclass
class CosineScheduleConfig(ScheduleConfig):
total_steps: int = field(
default=None,
metadata={"help": "Total training steps for cosine schedule."}
)
schedule_type: Literal["cosine"] = "cosine"
def get_kwargs(self) -> Dict[str, Any]:
if self.total_steps is None:
raise ValueError("total_steps must be specified for cosine schedule")
return {
"schedule_type": self.schedule_type,
"warmup_steps": self.warmup_steps,
"lr_decay_steps": self.total_steps - self.warmup_steps,
"min_rate": self.min_rate
}
def validate(self) -> None:
super().validate()
if self.total_steps is not None and self.total_steps <= self.warmup_steps:
raise ValueError(f"total_steps ({self.total_steps}) must be greater than warmup_steps ({self.warmup_steps})")
@dataclass
class SgdrScheduleConfig(ScheduleConfig):
cycle_length: int = field(
default=1000,
metadata={"help": "Length of the first cycle in steps."}
)
t_mult: int = field(
default=2,
metadata={"help": "Multiplier for cycle length growth."}
)
schedule_type: Literal["sgdr"] = "sgdr"
def get_kwargs(self) -> Dict[str, Any]:
return {
"schedule_type": self.schedule_type,
"warmup_steps": self.warmup_steps,
"cycle_length": self.cycle_length,
"min_rate": self.min_rate,
"t_mult": self.t_mult
}
def validate(self) -> None:
super().validate()
if self.cycle_length <= 0:
raise ValueError(f"cycle_length must be positive, got {self.cycle_length}")
if self.t_mult < 1:
raise ValueError(f"t_mult must be >= 1, got {self.t_mult}")
class SchedulerFactory:
"""Factory for creating learning rate schedule functions."""
@staticmethod
def get_sgdr_schedule(
warmup_steps: int,
cycle_length: int,
min_rate: float = 0.05,
t_mult: int = 2
) -> Callable[[int], float]:
"""
Create SGDR (Stochastic Gradient Descent with Warm Restarts) schedule.
Args:
warmup_steps: Number of warmup steps
cycle_length: Length of the first cycle
min_rate: Minimum learning rate multiplier
t_mult: Cycle length multiplier
Returns:
Schedule function that takes current step and returns LR multiplier
"""
def sgdr_schedule(current_step: int) -> float:
# Warmup phase
if current_step < warmup_steps:
return max(min_rate, current_step / warmup_steps)
# SGDR phase
steps_since_warmup = current_step - warmup_steps
# Find current cycle and position within cycle
cycle_start = 0
current_cycle_length = cycle_length
cycle_index = 0
while steps_since_warmup >= cycle_start + current_cycle_length:
cycle_start += current_cycle_length
current_cycle_length *= t_mult
cycle_index += 1
position_in_cycle = steps_since_warmup - cycle_start
progress = position_in_cycle / current_cycle_length
# Cosine annealing within cycle
return max(min_rate, 0.5 * (1 + math.cos(math.pi * progress)))
return sgdr_schedule
@staticmethod
def get_cosine_schedule(
warmup_steps: int,
lr_decay_steps: int,
min_rate: float = 0.05
) -> Callable[[int], float]:
"""
Create cosine decay schedule with warmup.
Args:
warmup_steps: Number of warmup steps
lr_decay_steps: Number of steps for cosine decay after warmup
min_rate: Minimum learning rate multiplier
Returns:
Schedule function that takes current step and returns LR multiplier
"""
def cosine_schedule(current_step: int) -> float:
if current_step < warmup_steps:
# Linear warmup
return max(min_rate, current_step / warmup_steps)
else:
# Cosine decay
decay_progress = (current_step - warmup_steps) / lr_decay_steps
decay_progress = min(decay_progress, 1.0) # Clamp at 1.0
return max(min_rate, 0.5 * (1.0 + math.cos(math.pi * decay_progress)))
return cosine_schedule
@staticmethod
def load_schedule_fn(scedule_config: ScheduleConfig) -> Callable[[int], float]:
kwargs = scedule_config.get_kwargs()
schedule_type = kwargs.pop("schedule_type")
if schedule_type == "cosine":
return SchedulerFactory.get_cosine_schedule(**kwargs)
elif schedule_type == "sgdr":
return SchedulerFactory.get_sgdr_schedule(**kwargs)
else:
raise ValueError(f"Unsupported schedule type: {schedule_type}")