AstrAI/train.py

143 lines
5.2 KiB
Python

import os
import argparse
import torch
from torch.optim import AdamW
from khaosz.core import ParameterLoader
from khaosz.trainer import Trainer, DatasetLoader, TrainConfig, CosineScheduleConfig
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
def get_files(root_path: str) -> list[str]:
paths = []
for root, _, files in os.walk(root_path):
paths.extend([os.path.join(root, file) for file in files])
return paths
def train(
train_type: str,
param_path: str,
data_root_path: str,
n_epoch: int,
batch_size: int,
n_iter_step: int,
warning_step: int,
max_lr: int,
n_iter_ckpt: int,
ckpt_dir: str,
dpo_beta: float,
adamw_betas: tuple,
adamw_weight_decay: float,
max_grad_norm: float,
embdeding_lr_rate: int,
random_seed: int,
multi_turn: bool,
):
assert train_type in ["seq", "sft", "dpo"]
assert os.path.exists(param_path)
parameter = ParameterLoader.load(param_path)
model = parameter.model
device = torch.device("cuda")
model = model.to(device=device, dtype=torch.bfloat16)
cache_files = get_files(data_root_path)
dataset_kwargs = {
"multi_turn": multi_turn,
"bos_token_id": parameter.tokenizer.bos_id,
"eos_token_id": parameter.tokenizer.eos_id,
"user_token_id":parameter.tokenizer.encode("<|user|>")[0]
}
dataset = DatasetLoader.load(
train_type=train_type,
load_path=cache_files,
max_len=parameter.config.m_len,
device=device,
dataset_kwargs=dataset_kwargs
)
param_groups = [
{"params": [p for n, p in model.named_parameters() if "embedding" in n], "lr": max_lr * embdeding_lr_rate},
{"params": [p for n, p in model.named_parameters() if "embedding" not in n], "lr": max_lr}
]
optim = AdamW(
param_groups,
betas=adamw_betas,
weight_decay=adamw_weight_decay
)
train_config = TrainConfig(
train_type=train_type,
dataset=dataset,
optimizer=optim,
ckpt_dir=ckpt_dir,
n_epoch=n_epoch,
batch_size=batch_size,
n_iter_ckpt=n_iter_ckpt,
n_iter_step=n_iter_step,
max_grad_norm=max_grad_norm,
random_seed=random_seed,
dpo_beta=dpo_beta
)
schedule_config = CosineScheduleConfig(
warning_step=warning_step,
total_iters=len(dataset) * n_epoch // batch_size,
)
trainer = Trainer(parameter)
trainer.train(
train_config=train_config,
schedule_config=schedule_config,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train the Transformer model.")
# train args
parser.add_argument("--train_type",choices=["seq", "sft", "dpo"], help="Train type.")
parser.add_argument("--data_root_path", type=str, required=True, help="Path to the root directory of the dataset.")
parser.add_argument("--param_path", type=str, required=True, help="Path to the model parameters or resume checkpoint.")
parser.add_argument("--n_epoch", type=int, default=1, help="Number of epochs to train.")
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for training.")
parser.add_argument("--n_iter_step", type=int, default=1, help="Number of iterations between each optimizer step.")
parser.add_argument("--warning_step", type=int, default=1000, help="Number of iters between warnings.")
parser.add_argument("--max_lr", type=float, default=3e-4, help="Max learning rate for training.")
parser.add_argument("--n_iter_ckpt", type=int, default=5000, help="Number of iters between checkpoints.")
parser.add_argument("--ckpt_dir", type=str, default="checkpoint", help="Directory to save checkpoints.")
parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.")
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping.")
parser.add_argument("--adamw_betas", type=tuple, default=(0.9, 0.95), help="Beta values for AdamW optimizer.")
parser.add_argument("--adamw_weight_decay", type=float, default=0.01, help="Weight decay for AdamW optimizer.")
parser.add_argument("--embdeding_lr_rate", type=float, default=1.0, help="The rate between the embedding layers lr rate and the max lr rate.")
parser.add_argument("--random_seed", type=int, default=3407, help="Random seed for reproducibility.")
# other configs
parser.add_argument("--multi_turn", type=bool, default=False, help="Whether to use multi-turn convsersation training.")
args = parser.parse_args()
train(
param_path=args.param_path,
data_root_path=args.data_root_path,
n_epoch=args.n_epoch,
batch_size=args.batch_size,
n_iter_step=args.n_iter_step,
warning_step=args.warning_step,
max_lr=args.max_lr,
dpo_beta=args.dpo_beta,
adamw_betas=args.adamw_betas,
adamw_weight_decay=args.adamw_weight_decay,
max_grad_norm=args.max_grad_norm,
embdeding_lr_rate=args.embdeding_lr_rate,
n_iter_ckpt=args.n_iter_ckpt,
ckpt_dir=args.ckpt_dir,
train_type=args.train_type,
random_seed=args.random_seed,
multi_turn=args.multi_turn
)