feat(train.py): 支持从检查点恢复训练并优化数据加载配置
This commit is contained in:
parent
efbe3de9d3
commit
68a15005cb
33
train.py
33
train.py
|
|
@ -3,7 +3,7 @@ import argparse
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from torch.optim import AdamW
|
from torch.optim import AdamW
|
||||||
from khaosz.core import ParameterLoader
|
from khaosz.core import ParameterLoader, Checkpoint
|
||||||
from khaosz.trainer import Trainer, DatasetLoader, TrainConfig, CosineScheduleConfig
|
from khaosz.trainer import Trainer, DatasetLoader, TrainConfig, CosineScheduleConfig
|
||||||
from khaosz.trainer import StrategyFactory
|
from khaosz.trainer import StrategyFactory
|
||||||
|
|
||||||
|
|
@ -35,11 +35,17 @@ def train(
|
||||||
embdeding_lr_rate: int,
|
embdeding_lr_rate: int,
|
||||||
random_seed: int,
|
random_seed: int,
|
||||||
multi_turn: bool,
|
multi_turn: bool,
|
||||||
|
resume_from_checkpoint: bool
|
||||||
):
|
):
|
||||||
assert train_type in ["seq", "sft", "dpo"]
|
assert train_type in ["seq", "sft", "dpo"]
|
||||||
assert os.path.exists(param_path)
|
assert os.path.exists(param_path)
|
||||||
|
|
||||||
parameter = ParameterLoader.load(param_path)
|
parameter = ParameterLoader.load(param_path)
|
||||||
|
checkpoint = None
|
||||||
|
|
||||||
|
if isinstance(parameter, Checkpoint) and resume_from_checkpoint:
|
||||||
|
checkpoint = parameter
|
||||||
|
|
||||||
model = parameter.model
|
model = parameter.model
|
||||||
|
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
|
|
@ -92,6 +98,8 @@ def train(
|
||||||
accumulation_steps=accumulation_steps,
|
accumulation_steps=accumulation_steps,
|
||||||
max_grad_norm=max_grad_norm,
|
max_grad_norm=max_grad_norm,
|
||||||
random_seed=random_seed,
|
random_seed=random_seed,
|
||||||
|
num_workers=4,
|
||||||
|
pin_memory=True
|
||||||
)
|
)
|
||||||
|
|
||||||
schedule_config = CosineScheduleConfig(
|
schedule_config = CosineScheduleConfig(
|
||||||
|
|
@ -104,7 +112,7 @@ def train(
|
||||||
train_config=train_config,
|
train_config=train_config,
|
||||||
schedule_config=schedule_config,
|
schedule_config=schedule_config,
|
||||||
)
|
)
|
||||||
trainer.train()
|
trainer.train(checkpoint)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
@ -127,27 +135,10 @@ if __name__ == "__main__":
|
||||||
parser.add_argument("--random_seed", type=int, default=3407, help="Random seed for reproducibility.")
|
parser.add_argument("--random_seed", type=int, default=3407, help="Random seed for reproducibility.")
|
||||||
|
|
||||||
# other configs
|
# other configs
|
||||||
|
parser.add_argument("--resume_from_checkpoint", type=bool, default=False, help="train from checkpoint or not.")
|
||||||
parser.add_argument("--multi_turn", type=bool, default=False, help="Whether to use multi-turn convsersation training.")
|
parser.add_argument("--multi_turn", type=bool, default=False, help="Whether to use multi-turn convsersation training.")
|
||||||
parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.")
|
parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
train(
|
train(**vars(args))
|
||||||
param_path=args.param_path,
|
|
||||||
data_root_path=args.data_root_path,
|
|
||||||
n_epoch=args.n_epoch,
|
|
||||||
batch_size=args.batch_size,
|
|
||||||
accumulation_steps=args.accumulation_steps,
|
|
||||||
warmup_steps=args.warmup_steps,
|
|
||||||
max_lr=args.max_lr,
|
|
||||||
dpo_beta=args.dpo_beta,
|
|
||||||
adamw_betas=args.adamw_betas,
|
|
||||||
adamw_weight_decay=args.adamw_weight_decay,
|
|
||||||
max_grad_norm=args.max_grad_norm,
|
|
||||||
embdeding_lr_rate=args.embdeding_lr_rate,
|
|
||||||
checkpoint_interval=args.checkpoint_interval,
|
|
||||||
checkpoint_dir=args.checkpoint_dir,
|
|
||||||
train_type=args.train_type,
|
|
||||||
random_seed=args.random_seed,
|
|
||||||
multi_turn=args.multi_turn
|
|
||||||
)
|
|
||||||
Loading…
Reference in New Issue