feat(train.py): 支持从检查点恢复训练并优化数据加载配置

This commit is contained in:
ViperEkura 2025-10-07 22:02:50 +08:00
parent efbe3de9d3
commit 68a15005cb
1 changed files with 12 additions and 21 deletions

View File

@ -3,7 +3,7 @@ import argparse
import torch import torch
from torch.optim import AdamW from torch.optim import AdamW
from khaosz.core import ParameterLoader from khaosz.core import ParameterLoader, Checkpoint
from khaosz.trainer import Trainer, DatasetLoader, TrainConfig, CosineScheduleConfig from khaosz.trainer import Trainer, DatasetLoader, TrainConfig, CosineScheduleConfig
from khaosz.trainer import StrategyFactory from khaosz.trainer import StrategyFactory
@ -35,11 +35,17 @@ def train(
embdeding_lr_rate: int, embdeding_lr_rate: int,
random_seed: int, random_seed: int,
multi_turn: bool, multi_turn: bool,
resume_from_checkpoint: bool
): ):
assert train_type in ["seq", "sft", "dpo"] assert train_type in ["seq", "sft", "dpo"]
assert os.path.exists(param_path) assert os.path.exists(param_path)
parameter = ParameterLoader.load(param_path) parameter = ParameterLoader.load(param_path)
checkpoint = None
if isinstance(parameter, Checkpoint) and resume_from_checkpoint:
checkpoint = parameter
model = parameter.model model = parameter.model
device = torch.device("cuda") device = torch.device("cuda")
@ -92,6 +98,8 @@ def train(
accumulation_steps=accumulation_steps, accumulation_steps=accumulation_steps,
max_grad_norm=max_grad_norm, max_grad_norm=max_grad_norm,
random_seed=random_seed, random_seed=random_seed,
num_workers=4,
pin_memory=True
) )
schedule_config = CosineScheduleConfig( schedule_config = CosineScheduleConfig(
@ -104,7 +112,7 @@ def train(
train_config=train_config, train_config=train_config,
schedule_config=schedule_config, schedule_config=schedule_config,
) )
trainer.train() trainer.train(checkpoint)
if __name__ == "__main__": if __name__ == "__main__":
@ -127,27 +135,10 @@ if __name__ == "__main__":
parser.add_argument("--random_seed", type=int, default=3407, help="Random seed for reproducibility.") parser.add_argument("--random_seed", type=int, default=3407, help="Random seed for reproducibility.")
# other configs # other configs
parser.add_argument("--resume_from_checkpoint", type=bool, default=False, help="train from checkpoint or not.")
parser.add_argument("--multi_turn", type=bool, default=False, help="Whether to use multi-turn convsersation training.") parser.add_argument("--multi_turn", type=bool, default=False, help="Whether to use multi-turn convsersation training.")
parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.") parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.")
args = parser.parse_args() args = parser.parse_args()
train( train(**vars(args))
param_path=args.param_path,
data_root_path=args.data_root_path,
n_epoch=args.n_epoch,
batch_size=args.batch_size,
accumulation_steps=args.accumulation_steps,
warmup_steps=args.warmup_steps,
max_lr=args.max_lr,
dpo_beta=args.dpo_beta,
adamw_betas=args.adamw_betas,
adamw_weight_decay=args.adamw_weight_decay,
max_grad_norm=args.max_grad_norm,
embdeding_lr_rate=args.embdeding_lr_rate,
checkpoint_interval=args.checkpoint_interval,
checkpoint_dir=args.checkpoint_dir,
train_type=args.train_type,
random_seed=args.random_seed,
multi_turn=args.multi_turn
)