From 68a15005cb0f352eb487cbe86b2964683960808e Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Tue, 7 Oct 2025 22:02:50 +0800 Subject: [PATCH] =?UTF-8?q?feat(train.py):=20=E6=94=AF=E6=8C=81=E4=BB=8E?= =?UTF-8?q?=E6=A3=80=E6=9F=A5=E7=82=B9=E6=81=A2=E5=A4=8D=E8=AE=AD=E7=BB=83?= =?UTF-8?q?=E5=B9=B6=E4=BC=98=E5=8C=96=E6=95=B0=E6=8D=AE=E5=8A=A0=E8=BD=BD?= =?UTF-8?q?=E9=85=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- train.py | 33 ++++++++++++--------------------- 1 file changed, 12 insertions(+), 21 deletions(-) diff --git a/train.py b/train.py index cf777b1..6249b91 100644 --- a/train.py +++ b/train.py @@ -3,7 +3,7 @@ import argparse import torch from torch.optim import AdamW -from khaosz.core import ParameterLoader +from khaosz.core import ParameterLoader, Checkpoint from khaosz.trainer import Trainer, DatasetLoader, TrainConfig, CosineScheduleConfig from khaosz.trainer import StrategyFactory @@ -35,11 +35,17 @@ def train( embdeding_lr_rate: int, random_seed: int, multi_turn: bool, + resume_from_checkpoint: bool ): assert train_type in ["seq", "sft", "dpo"] assert os.path.exists(param_path) parameter = ParameterLoader.load(param_path) + checkpoint = None + + if isinstance(parameter, Checkpoint) and resume_from_checkpoint: + checkpoint = parameter + model = parameter.model device = torch.device("cuda") @@ -92,6 +98,8 @@ def train( accumulation_steps=accumulation_steps, max_grad_norm=max_grad_norm, random_seed=random_seed, + num_workers=4, + pin_memory=True ) schedule_config = CosineScheduleConfig( @@ -104,7 +112,7 @@ def train( train_config=train_config, schedule_config=schedule_config, ) - trainer.train() + trainer.train(checkpoint) if __name__ == "__main__": @@ -127,27 +135,10 @@ if __name__ == "__main__": parser.add_argument("--random_seed", type=int, default=3407, help="Random seed for reproducibility.") # other configs + parser.add_argument("--resume_from_checkpoint", type=bool, default=False, help="train from checkpoint or not.") parser.add_argument("--multi_turn", type=bool, default=False, help="Whether to use multi-turn convsersation training.") parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.") args = parser.parse_args() - train( - param_path=args.param_path, - data_root_path=args.data_root_path, - n_epoch=args.n_epoch, - batch_size=args.batch_size, - accumulation_steps=args.accumulation_steps, - warmup_steps=args.warmup_steps, - max_lr=args.max_lr, - dpo_beta=args.dpo_beta, - adamw_betas=args.adamw_betas, - adamw_weight_decay=args.adamw_weight_decay, - max_grad_norm=args.max_grad_norm, - embdeding_lr_rate=args.embdeding_lr_rate, - checkpoint_interval=args.checkpoint_interval, - checkpoint_dir=args.checkpoint_dir, - train_type=args.train_type, - random_seed=args.random_seed, - multi_turn=args.multi_turn - ) \ No newline at end of file + train(**vars(args)) \ No newline at end of file