AstrAI/scripts/tools/train.py

278 lines
7.3 KiB
Python

import argparse
import os
from functools import partial
import safetensors.torch as st
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
from astrai.config import ModelConfig, TrainConfig
from astrai.dataset import DatasetFactory
from astrai.model import Transformer
from astrai.parallel import get_rank
from astrai.trainer import SchedulerFactory, Trainer
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Train the Transformer model.")
parser.add_argument(
"--train_type",
type=str,
required=True,
choices=["seq", "sft", "dpo"],
help="Train type.",
)
parser.add_argument(
"--data_root_path",
type=str,
required=True,
help="Path to the root directory of the dataset.",
)
parser.add_argument(
"--param_path",
type=str,
required=True,
help="Path to the model parameters or resume checkpoint.",
)
parser.add_argument(
"--n_epoch", type=int, default=1, help="Number of epochs to train."
)
parser.add_argument(
"--batch_size", type=int, default=1, help="Batch size for training."
)
parser.add_argument(
"--accumulation_steps",
type=int,
default=1,
help="Number of iterations between each optimizer step.",
)
parser.add_argument(
"--warmup_steps",
type=int,
default=1000,
help="Number of iters between warnings.",
)
parser.add_argument(
"--max_lr", type=float, default=3e-4, help="Max learning rate for training."
)
parser.add_argument(
"--max_grad_norm",
type=float,
default=1.0,
help="Max gradient norm for clipping.",
)
parser.add_argument(
"--adamw_beta1",
type=float,
default=0.9,
help="Beta values for AdamW optimizer.",
)
parser.add_argument(
"--adamw_beta2",
type=float,
default=0.95,
help="Beta values for AdamW optimizer.",
)
parser.add_argument(
"--adamw_weight_decay",
type=float,
default=0.01,
help="Weight decay for AdamW optimizer.",
)
parser.add_argument(
"--random_seed", type=int, default=3407, help="Random seed for reproducibility."
)
parser.add_argument(
"--num_workers", type=int, default=4, help="Number of workers for data loading."
)
parser.add_argument(
"--no_pin_memory",
action="store_false",
dest="pin_memory",
help="Disable pin memory",
)
parser.add_argument(
"--window_size",
type=int,
default=None,
help="the max length of the input sequence.",
)
parser.add_argument(
"--stride", type=int, default=None, help="the step size of the input sequence."
)
parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.")
parser.add_argument(
"--label_smoothing",
type=float,
default=0.1,
help="cross_entropy function label smoothing parameter",
)
parser.add_argument(
"--ckpt_interval",
type=int,
default=5000,
help="Number of iters between checkpoints.",
)
parser.add_argument(
"--ckpt_dir",
type=str,
default="checkpoint",
help="Directory to save checkpoints.",
)
parser.add_argument(
"--start_epoch", type=int, default=0, help="Start epoch for training."
)
parser.add_argument(
"--start_batch", type=int, default=0, help="Start batch for training."
)
parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.")
parser.add_argument(
"--device_type", type=str, default="cuda", help="Device type to use."
)
args = parser.parse_args()
return args
def ddp_wrap(model: nn.Module):
local_rank = get_rank()
model = model.to(device=f"cuda:{local_rank}", dtype=torch.bfloat16)
ddp_model = DDP(
model,
device_ids=[local_rank],
output_device=local_rank,
find_unused_parameters=False,
)
return ddp_model
def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer:
return optim.AdamW(model.parameters(), **kwargs)
def create_scheduler(
optimizer: optim.Optimizer, **kwargs
) -> optim.lr_scheduler.LRScheduler:
return SchedulerFactory.create(optimizer, **kwargs)
def prepare_checkpoint(model: nn.Module) -> dict:
return model.module.state_dict()
def train(
train_type: str,
param_path: str,
data_root_path: str,
max_lr: float,
n_epoch: int,
batch_size: int,
start_epoch: int,
start_batch: int,
accumulation_steps: int,
warmup_steps: int,
ckpt_interval: int,
ckpt_dir: str,
dpo_beta: float,
adamw_beta1: float,
adamw_beta2: float,
adamw_weight_decay: float,
max_grad_norm: float,
label_smoothing: float,
random_seed: int,
num_workers: int,
pin_memory: bool,
window_size: int,
stride: int,
nprocs: int,
device_type: str,
):
assert train_type in ["seq", "sft", "dpo"]
assert os.path.exists(param_path)
# Load config
config = ModelConfig()
config_path = os.path.join(param_path, "config.json")
if os.path.exists(config_path):
config.load(config_path)
if window_size is None:
window_size = config.max_len
# Create bare Transformer (for training, no tokenizer needed)
model = Transformer(config)
# Load weights if available
weights_path = os.path.join(param_path, "model.safetensors")
if os.path.exists(weights_path):
state_dict = st.load_file(weights_path)
model.load_state_dict(state_dict, strict=False)
strategy_kwargs = {"dpo_beta": dpo_beta, "label_smoothing": label_smoothing}
dataset = DatasetFactory.load(
train_type=train_type,
load_path=data_root_path,
window_size=window_size,
stride=stride,
)
optimizer_fn = partial(
create_optimizer,
**{
"lr": max_lr,
"betas": (adamw_beta1, adamw_beta2),
"weight_decay": adamw_weight_decay,
},
)
toltal_steps = len(dataset) * n_epoch // (batch_size * nprocs)
scheduler_fn = partial(
create_scheduler,
**{
"schedule_type": "cosine",
"warmup_steps": warmup_steps,
"lr_decay_steps": toltal_steps - warmup_steps,
},
)
train_config = TrainConfig(
model=model,
strategy=train_type,
dataset=dataset,
optimizer_fn=optimizer_fn,
scheduler_fn=scheduler_fn,
ckpt_dir=ckpt_dir,
n_epoch=n_epoch,
batch_size=batch_size,
start_epoch=start_epoch,
start_batch=start_batch,
ckpt_interval=ckpt_interval,
accumulation_steps=accumulation_steps,
max_grad_norm=max_grad_norm,
random_seed=random_seed,
num_workers=num_workers,
pin_memory=pin_memory,
nprocs=nprocs,
parallel_wrapper=ddp_wrap,
state_dict_fn=prepare_checkpoint,
device_type=device_type,
extra_kwargs=strategy_kwargs,
)
trainer = Trainer(train_config)
trainer.train()
if __name__ == "__main__":
args = parse_args()
train(**vars(args))