feat(train.py): 支持从检查点恢复训练并优化数据加载配置
This commit is contained in:
parent
efbe3de9d3
commit
68a15005cb
33
train.py
33
train.py
|
|
@ -3,7 +3,7 @@ import argparse
|
|||
import torch
|
||||
|
||||
from torch.optim import AdamW
|
||||
from khaosz.core import ParameterLoader
|
||||
from khaosz.core import ParameterLoader, Checkpoint
|
||||
from khaosz.trainer import Trainer, DatasetLoader, TrainConfig, CosineScheduleConfig
|
||||
from khaosz.trainer import StrategyFactory
|
||||
|
||||
|
|
@ -35,11 +35,17 @@ def train(
|
|||
embdeding_lr_rate: int,
|
||||
random_seed: int,
|
||||
multi_turn: bool,
|
||||
resume_from_checkpoint: bool
|
||||
):
|
||||
assert train_type in ["seq", "sft", "dpo"]
|
||||
assert os.path.exists(param_path)
|
||||
|
||||
parameter = ParameterLoader.load(param_path)
|
||||
checkpoint = None
|
||||
|
||||
if isinstance(parameter, Checkpoint) and resume_from_checkpoint:
|
||||
checkpoint = parameter
|
||||
|
||||
model = parameter.model
|
||||
|
||||
device = torch.device("cuda")
|
||||
|
|
@ -92,6 +98,8 @@ def train(
|
|||
accumulation_steps=accumulation_steps,
|
||||
max_grad_norm=max_grad_norm,
|
||||
random_seed=random_seed,
|
||||
num_workers=4,
|
||||
pin_memory=True
|
||||
)
|
||||
|
||||
schedule_config = CosineScheduleConfig(
|
||||
|
|
@ -104,7 +112,7 @@ def train(
|
|||
train_config=train_config,
|
||||
schedule_config=schedule_config,
|
||||
)
|
||||
trainer.train()
|
||||
trainer.train(checkpoint)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
@ -127,27 +135,10 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--random_seed", type=int, default=3407, help="Random seed for reproducibility.")
|
||||
|
||||
# other configs
|
||||
parser.add_argument("--resume_from_checkpoint", type=bool, default=False, help="train from checkpoint or not.")
|
||||
parser.add_argument("--multi_turn", type=bool, default=False, help="Whether to use multi-turn convsersation training.")
|
||||
parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
train(
|
||||
param_path=args.param_path,
|
||||
data_root_path=args.data_root_path,
|
||||
n_epoch=args.n_epoch,
|
||||
batch_size=args.batch_size,
|
||||
accumulation_steps=args.accumulation_steps,
|
||||
warmup_steps=args.warmup_steps,
|
||||
max_lr=args.max_lr,
|
||||
dpo_beta=args.dpo_beta,
|
||||
adamw_betas=args.adamw_betas,
|
||||
adamw_weight_decay=args.adamw_weight_decay,
|
||||
max_grad_norm=args.max_grad_norm,
|
||||
embdeding_lr_rate=args.embdeding_lr_rate,
|
||||
checkpoint_interval=args.checkpoint_interval,
|
||||
checkpoint_dir=args.checkpoint_dir,
|
||||
train_type=args.train_type,
|
||||
random_seed=args.random_seed,
|
||||
multi_turn=args.multi_turn
|
||||
)
|
||||
train(**vars(args))
|
||||
Loading…
Reference in New Issue