feat(train.py): 支持从检查点恢复训练并优化数据加载配置

This commit is contained in:
ViperEkura 2025-10-07 22:02:50 +08:00
parent efbe3de9d3
commit 68a15005cb
1 changed files with 12 additions and 21 deletions

View File

@ -3,7 +3,7 @@ import argparse
import torch
from torch.optim import AdamW
from khaosz.core import ParameterLoader
from khaosz.core import ParameterLoader, Checkpoint
from khaosz.trainer import Trainer, DatasetLoader, TrainConfig, CosineScheduleConfig
from khaosz.trainer import StrategyFactory
@ -35,11 +35,17 @@ def train(
embdeding_lr_rate: int,
random_seed: int,
multi_turn: bool,
resume_from_checkpoint: bool
):
assert train_type in ["seq", "sft", "dpo"]
assert os.path.exists(param_path)
parameter = ParameterLoader.load(param_path)
checkpoint = None
if isinstance(parameter, Checkpoint) and resume_from_checkpoint:
checkpoint = parameter
model = parameter.model
device = torch.device("cuda")
@ -92,6 +98,8 @@ def train(
accumulation_steps=accumulation_steps,
max_grad_norm=max_grad_norm,
random_seed=random_seed,
num_workers=4,
pin_memory=True
)
schedule_config = CosineScheduleConfig(
@ -104,7 +112,7 @@ def train(
train_config=train_config,
schedule_config=schedule_config,
)
trainer.train()
trainer.train(checkpoint)
if __name__ == "__main__":
@ -127,27 +135,10 @@ if __name__ == "__main__":
parser.add_argument("--random_seed", type=int, default=3407, help="Random seed for reproducibility.")
# other configs
parser.add_argument("--resume_from_checkpoint", type=bool, default=False, help="train from checkpoint or not.")
parser.add_argument("--multi_turn", type=bool, default=False, help="Whether to use multi-turn convsersation training.")
parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.")
args = parser.parse_args()
train(
param_path=args.param_path,
data_root_path=args.data_root_path,
n_epoch=args.n_epoch,
batch_size=args.batch_size,
accumulation_steps=args.accumulation_steps,
warmup_steps=args.warmup_steps,
max_lr=args.max_lr,
dpo_beta=args.dpo_beta,
adamw_betas=args.adamw_betas,
adamw_weight_decay=args.adamw_weight_decay,
max_grad_norm=args.max_grad_norm,
embdeding_lr_rate=args.embdeding_lr_rate,
checkpoint_interval=args.checkpoint_interval,
checkpoint_dir=args.checkpoint_dir,
train_type=args.train_type,
random_seed=args.random_seed,
multi_turn=args.multi_turn
)
train(**vars(args))