388 lines
12 KiB
Python
388 lines
12 KiB
Python
import copy
|
|
import math
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from torch import Tensor
|
|
from torch.optim import Optimizer
|
|
from torch.utils.data import Dataset
|
|
from typing import Any, Literal, Tuple, Callable, Dict
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import asdict, dataclass, field
|
|
|
|
|
|
def get_logprobs(model:nn.Module, input_ids: Tensor, mask: Tensor, pad_token_id):
|
|
input_mask = input_ids.ne(pad_token_id)
|
|
logits = model(input_ids, input_mask)["logits"]
|
|
log_probs = torch.log_softmax(logits, dim=-1)
|
|
|
|
shifted_log_probs = log_probs[:, :-1, :]
|
|
shifted_input_ids = input_ids[:, 1:]
|
|
shifted_response_mask = mask[:, 1:]
|
|
|
|
token_logprobs = torch.gather(
|
|
shifted_log_probs,
|
|
dim=-1,
|
|
index=shifted_input_ids.unsqueeze(-1)
|
|
).squeeze(-1)
|
|
|
|
prompt_mask = input_mask[:, 1:]
|
|
valid_mask = (prompt_mask & shifted_response_mask).float()
|
|
|
|
return (token_logprobs * valid_mask).sum(dim=-1)
|
|
|
|
|
|
class MaskBuilder:
|
|
def __init__(
|
|
self,
|
|
bos_token_id: int,
|
|
eos_token_id: int,
|
|
user_token_id: int,
|
|
system_token_id: int,
|
|
|
|
):
|
|
self.bos_token_id = bos_token_id
|
|
self.eos_token_id = eos_token_id
|
|
self.user_token_id = user_token_id
|
|
self.system_token_id = system_token_id
|
|
|
|
@abstractmethod
|
|
def build(input_ids: Tensor) -> Tensor:
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
class LossMaskBuilder(MaskBuilder):
|
|
def __init__(self, **kwargs):
|
|
super().__init__(**kwargs)
|
|
|
|
def build(self, input_ids: Tensor) -> Tensor:
|
|
token_markers = torch.zeros_like(input_ids, dtype=torch.int8)
|
|
|
|
is_user_token = input_ids.eq(self.user_token_id)
|
|
is_system_token = input_ids.eq(self.system_token_id)
|
|
|
|
token_markers[is_user_token] = 1
|
|
token_markers[is_system_token] = -1
|
|
|
|
cumulative_markers = torch.cumsum(token_markers, dim=-1)
|
|
min_cumulative = cumulative_markers.min(dim=-1, keepdim=True).values
|
|
loss_mask = cumulative_markers - min_cumulative
|
|
|
|
return loss_mask.to(dtype=torch.bool)
|
|
|
|
|
|
|
|
|
|
class AttentionMaskBuilder(MaskBuilder):
|
|
def __init__(self, multi_turn=False, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.multi_turn = multi_turn
|
|
|
|
def build(self, input_ids: Tensor):
|
|
bsz = input_ids.size(0)
|
|
|
|
|
|
def _build_batch(self, input_ids: Tensor):
|
|
is_user_token = input_ids.eq(self.user_token_id)
|
|
|
|
token_markers = torch.zeros_like(input_ids, dtype=torch.int8)
|
|
token_markers[is_user_token] = 1
|
|
cumulative_markers = torch.cumsum(token_markers, dim=-1)
|
|
|
|
|
|
|
|
|
|
|
|
class BaseStrategy(ABC):
|
|
def __init__(self, model: nn.Module):
|
|
self.model = model
|
|
|
|
@abstractmethod
|
|
def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor:
|
|
raise NotImplementedError
|
|
|
|
def __call__(self, batch: Tuple[Tensor, ...]) -> Tensor:
|
|
return self.compute_loss(batch)
|
|
|
|
|
|
class SeqStrategy(BaseStrategy):
|
|
def __init__(self, model):
|
|
super().__init__(model)
|
|
|
|
def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor:
|
|
x, y = batch
|
|
B, L = x.size()
|
|
logits: Tensor = self.model(x)["logits"]
|
|
|
|
loss = F.cross_entropy(
|
|
logits.view(B * L, -1), y.flatten()
|
|
)
|
|
return loss
|
|
|
|
|
|
class SftStrategy(BaseStrategy):
|
|
def __init__(self, model):
|
|
super().__init__(model)
|
|
|
|
def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor:
|
|
x, y, loss_mask = batch
|
|
B, L = x.size()
|
|
ignore_idx = -1
|
|
|
|
logits: Tensor = self.model(x)["logits"]
|
|
masked_y = y.masked_fill(loss_mask == 0, ignore_idx)
|
|
|
|
loss = F.cross_entropy(
|
|
logits.view(B * L, -1),
|
|
masked_y.flatten(),
|
|
ignore_index=ignore_idx
|
|
)
|
|
|
|
return loss
|
|
|
|
class DpoStrategy(BaseStrategy):
|
|
def __init__(self, model, pad_token_id, beta):
|
|
super().__init__(model)
|
|
ref_model = copy.deepcopy(self.model)
|
|
ref_model.requires_grad_(False)
|
|
ref_model.eval()
|
|
|
|
self.ref_model = ref_model
|
|
self.pad_token_id = pad_token_id
|
|
self.beta = beta
|
|
|
|
def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor:
|
|
good_ids, bad_ids, good_mask, bad_mask = batch
|
|
|
|
log_pi_good = get_logprobs(self.model, good_ids, good_mask, self.pad_token_id)
|
|
log_pi_bad = get_logprobs(self.model, bad_ids, bad_mask, self.pad_token_id)
|
|
|
|
with torch.no_grad():
|
|
log_ref_good = get_logprobs(self.ref_model, good_ids, good_mask, self.pad_token_id)
|
|
log_ref_bad = get_logprobs(self.ref_model, bad_ids, bad_mask, self.pad_token_id)
|
|
|
|
pi_log_ratio = log_pi_good - log_pi_bad
|
|
ref_log_ratio = log_ref_good - log_ref_bad
|
|
|
|
ratio_diff = pi_log_ratio - ref_log_ratio
|
|
|
|
dpo_loss = -F.logsigmoid(self.beta * ratio_diff).mean()
|
|
return dpo_loss
|
|
|
|
|
|
class PpoStrategy(BaseStrategy):
|
|
def __init__(self, model, pad_token_id, epsilon):
|
|
super().__init__(model)
|
|
ref_model = copy.deepcopy(self.model)
|
|
ref_model.requires_grad_(False)
|
|
ref_model.eval()
|
|
|
|
self.ref_model = ref_model
|
|
self.pad_token_id = pad_token_id
|
|
self.epsilon = epsilon
|
|
|
|
def ppo_clip_loss_masked(
|
|
self,
|
|
log_probs: Tensor,
|
|
old_log_probs: Tensor,
|
|
advantages: Tensor,
|
|
values: Tensor,
|
|
returns: Tensor,
|
|
mask: Tensor,
|
|
clip_eps: float=0.2,
|
|
):
|
|
ratio = torch.exp(log_probs - old_log_probs)
|
|
surr1 = ratio * advantages
|
|
surr2 = torch.clamp(ratio, 1 - clip_eps, 1 + clip_eps) * advantages
|
|
policy_loss = -torch.min(surr1, surr2).masked_select(mask).mean()
|
|
|
|
value_loss = F.mse_loss(values.masked_select(mask),
|
|
returns.masked_select(mask))
|
|
|
|
entropy = -(log_probs.exp() * log_probs).masked_select(mask).mean()
|
|
entropy_loss = -entropy
|
|
return policy_loss, value_loss, entropy_loss
|
|
|
|
|
|
|
|
class StrategyFactory:
|
|
|
|
def load(model, train_type, pad_token_id, dpo_beta):
|
|
train_strategy: Dict[str, Callable[[], BaseStrategy]] = {
|
|
"seq": lambda: SeqStrategy(model),
|
|
"sft": lambda: SftStrategy(model),
|
|
"dpo": lambda: DpoStrategy(model, pad_token_id, dpo_beta)
|
|
}
|
|
strategy = train_strategy[train_type]()
|
|
return strategy
|
|
|
|
|
|
@dataclass
|
|
class TrainConfig:
|
|
train_type: str = field(
|
|
default_factory=["seq", "sft", "dpo"],
|
|
metadata={"help": "Type of training."}
|
|
)
|
|
dataset: Dataset = field(
|
|
default=None,
|
|
metadata={"help": "Dataset for training."}
|
|
)
|
|
optimizer: Optimizer = field(
|
|
default=None,
|
|
metadata={"help": "Optimizer for training."}
|
|
)
|
|
ckpt_dir: str = field(
|
|
default="./checkpoint",
|
|
metadata={"help": "Checkpoint directory."}
|
|
)
|
|
n_epoch: int = field(
|
|
default=1,
|
|
metadata={"help": "Number of epochs for training."}
|
|
)
|
|
batch_size: int = field(
|
|
default=4,
|
|
metadata={"help": "Batch size for training."}
|
|
)
|
|
n_iter_ckpt: int = field(
|
|
default=5000,
|
|
metadata={"help": "Number of iterations between checkpoints."}
|
|
)
|
|
n_iter_step: int = field(
|
|
default=1,
|
|
metadata={"help": "Number of iterations between steps."}
|
|
)
|
|
max_grad_norm: float = field(
|
|
default=1.0,
|
|
metadata={"help": "Maximum gradient norm."}
|
|
)
|
|
random_seed: int = field(
|
|
default=3407,
|
|
metadata={"help": "Random seed."}
|
|
)
|
|
dpo_beta: float = field(
|
|
default=0.1,
|
|
metadata={"help": "DPO beta."}
|
|
)
|
|
|
|
def get_kwargs(self)-> Dict[str, Any]:
|
|
config_dict = asdict(self)
|
|
return {k: v for k, v in config_dict.items() if v is not None}
|
|
|
|
|
|
@dataclass
|
|
class ScheduleConfig:
|
|
schedule_type: str = field(
|
|
default_factory=["cosine", "sgdr"],
|
|
metadata={"help": "Type of learning rate schedule."}
|
|
)
|
|
warning_step: int = field(
|
|
default=1000,
|
|
metadata= {"help": "Warning up step."}
|
|
)
|
|
@abstractmethod
|
|
def get_kwargs(self)-> Dict[str, Any]:
|
|
raise NotImplementedError
|
|
|
|
|
|
@dataclass
|
|
class CosineScheduleConfig(ScheduleConfig):
|
|
total_iters: int = field(
|
|
default=None,
|
|
metadata={"help": "Total iterations for cosine schedule."}
|
|
)
|
|
min_rate: float = field(
|
|
default=0.05,
|
|
metadata={"help": "Minimum rate for cosine schedule."}
|
|
)
|
|
schedule_type: Literal["cosine"] = "cosine"
|
|
|
|
def get_kwargs(self) -> Dict[str, Any]:
|
|
return {
|
|
"schedule_type": self.schedule_type,
|
|
"warning_step": self.warning_step,
|
|
"lr_decay_iters": self.total_iters - self.warning_step,
|
|
"min_rate": self.min_rate
|
|
}
|
|
|
|
@dataclass
|
|
class SgdrScheduleConfig(ScheduleConfig):
|
|
cycle_length: int = field(
|
|
default=1000,
|
|
metadata={"help": "Cycle length for sgdr schedule."}
|
|
)
|
|
min_rate: float = field(
|
|
default=0.05,
|
|
metadata={"help": "Minimum rate for sgdr schedule."}
|
|
)
|
|
T_mult: int = field(
|
|
default=2,
|
|
metadata={"help": "T_mult for sgdr schedule."}
|
|
)
|
|
schedule_type: Literal["sgdr"] = "sgdr"
|
|
|
|
def get_kwargs(self) -> Dict[str, Any]:
|
|
return {
|
|
"schedule_type": self.schedule_type,
|
|
"warning_step": self.warning_step,
|
|
"cycle_length": self.cycle_length,
|
|
"min_rate": self.min_rate,
|
|
"T_mult": self.T_mult
|
|
}
|
|
|
|
|
|
class SchedulerFactory:
|
|
|
|
@staticmethod
|
|
def get_sgdr_schedule(
|
|
warning_step: int,
|
|
cycle_length: int,
|
|
min_rate: float = 0.1,
|
|
T_mult: int = 2
|
|
) -> Callable[[int], float]:
|
|
|
|
def sgdr_schedule(now_iter: int) -> float:
|
|
if now_iter < warning_step:
|
|
return max(min_rate, now_iter / warning_step)
|
|
|
|
adjusted_iter = now_iter - warning_step
|
|
total_cycles, current_cycle = 0, 0
|
|
while adjusted_iter >= cycle_length * (T_mult ** total_cycles):
|
|
current_cycle += 1
|
|
total_cycles += 1
|
|
|
|
cycle_start = sum(cycle_length * (T_mult ** i) for i in range(current_cycle))
|
|
cycle_pos = adjusted_iter - cycle_start
|
|
|
|
cycle_length_current = cycle_length * (T_mult ** current_cycle)
|
|
return max(min_rate, 0.5 * (1 + math.cos(math.pi * cycle_pos / cycle_length_current)))
|
|
|
|
return sgdr_schedule
|
|
|
|
@staticmethod
|
|
def get_cosine_warmup_schedule(
|
|
warning_step: int,
|
|
lr_decay_iters: int,
|
|
min_rate: float = 0.1
|
|
) -> Callable[[int], float]:
|
|
|
|
def cosine_warmup_schedule(now_iter: int) -> float:
|
|
if now_iter <= warning_step:
|
|
return max(min_rate, now_iter / warning_step)
|
|
else:
|
|
rate = (now_iter - warning_step) / (lr_decay_iters - warning_step)
|
|
return max(min_rate, 0.5 * (1.0 + math.cos(math.pi * rate)))
|
|
|
|
return cosine_warmup_schedule
|
|
|
|
@staticmethod
|
|
def load_schedule_fn(**kwargs):
|
|
strategy = kwargs.pop("schedule_type")
|
|
if strategy == "cosine":
|
|
return SchedulerFactory.get_cosine_warmup_schedule(**kwargs)
|
|
elif strategy == "sgdr":
|
|
return SchedulerFactory.get_sgdr_schedule(**kwargs)
|
|
else:
|
|
raise ValueError(f"Invalid schedule type: {strategy}")
|
|
|