diff --git a/train.py b/train.py index cf777b1..6249b91 100644 --- a/train.py +++ b/train.py @@ -3,7 +3,7 @@ import argparse import torch from torch.optim import AdamW -from khaosz.core import ParameterLoader +from khaosz.core import ParameterLoader, Checkpoint from khaosz.trainer import Trainer, DatasetLoader, TrainConfig, CosineScheduleConfig from khaosz.trainer import StrategyFactory @@ -35,11 +35,17 @@ def train( embdeding_lr_rate: int, random_seed: int, multi_turn: bool, + resume_from_checkpoint: bool ): assert train_type in ["seq", "sft", "dpo"] assert os.path.exists(param_path) parameter = ParameterLoader.load(param_path) + checkpoint = None + + if isinstance(parameter, Checkpoint) and resume_from_checkpoint: + checkpoint = parameter + model = parameter.model device = torch.device("cuda") @@ -92,6 +98,8 @@ def train( accumulation_steps=accumulation_steps, max_grad_norm=max_grad_norm, random_seed=random_seed, + num_workers=4, + pin_memory=True ) schedule_config = CosineScheduleConfig( @@ -104,7 +112,7 @@ def train( train_config=train_config, schedule_config=schedule_config, ) - trainer.train() + trainer.train(checkpoint) if __name__ == "__main__": @@ -127,27 +135,10 @@ if __name__ == "__main__": parser.add_argument("--random_seed", type=int, default=3407, help="Random seed for reproducibility.") # other configs + parser.add_argument("--resume_from_checkpoint", type=bool, default=False, help="train from checkpoint or not.") parser.add_argument("--multi_turn", type=bool, default=False, help="Whether to use multi-turn convsersation training.") parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.") args = parser.parse_args() - train( - param_path=args.param_path, - data_root_path=args.data_root_path, - n_epoch=args.n_epoch, - batch_size=args.batch_size, - accumulation_steps=args.accumulation_steps, - warmup_steps=args.warmup_steps, - max_lr=args.max_lr, - dpo_beta=args.dpo_beta, - adamw_betas=args.adamw_betas, - adamw_weight_decay=args.adamw_weight_decay, - max_grad_norm=args.max_grad_norm, - embdeding_lr_rate=args.embdeding_lr_rate, - checkpoint_interval=args.checkpoint_interval, - checkpoint_dir=args.checkpoint_dir, - train_type=args.train_type, - random_seed=args.random_seed, - multi_turn=args.multi_turn - ) \ No newline at end of file + train(**vars(args)) \ No newline at end of file