diff --git a/tests/test_trainer.py b/tests/test_trainer.py index eb0412d..a94b1d2 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -78,14 +78,13 @@ def test_dataset_loader(test_env): def test_training_config(test_env): optimizer = torch.optim.AdamW(test_env["model"].parameters()) train_config = TrainConfig( - train_type="seq", dataset=test_env["dataset"], optimizer=optimizer, - ckpt_dir=test_env["test_dir"], + checkpoint_dir=test_env["test_dir"], n_epoch=1, batch_size=2, - n_iter_ckpt=5, - n_iter_step=1, + checkpoint_interval=5, + accumulation_steps=1, max_grad_norm=1.0, random_seed=42 ) @@ -94,51 +93,56 @@ def test_training_config(test_env): def test_cosine_schedule(test_env): assert test_env is not None schedule_config = CosineScheduleConfig( - warning_step=100, - total_iters=1000 + warmup_steps=100, + total_steps=1000 ) kwargs = schedule_config.get_kwargs() - assert kwargs["warning_step"] == 100 - assert kwargs["lr_decay_iters"] == 900 + assert kwargs["warmup_steps"] == 100 + assert kwargs["lr_decay_steps"] == 900 def test_sgdr_schedule(test_env): assert test_env is not None schedule_config = SgdrScheduleConfig( - warning_step=100, + warmup_steps=100, cycle_length=200, - T_mult=2 + t_mult=2 ) kwargs = schedule_config.get_kwargs() - assert kwargs["warning_step"] == 100 + assert kwargs["warmup_steps"] == 100 assert kwargs["cycle_length"] == 200 - assert kwargs["T_mult"] == 2 + assert kwargs["t_mult"] == 2 def test_trainer_train(test_env): optimizer = torch.optim.AdamW(test_env["model"].parameters()) train_config = TrainConfig( - train_type="seq", dataset=test_env["dataset"], optimizer=optimizer, - ckpt_dir=test_env["test_dir"], + checkpoint_dir=test_env["test_dir"], n_epoch=1, batch_size=2, - n_iter_ckpt=5, - n_iter_step=1, + checkpoint_interval=5, + accumulation_steps=1, max_grad_norm=1.0, random_seed=42 ) schedule_config = CosineScheduleConfig( - warning_step=100, - total_iters=1000 + warmup_steps=100, + total_steps=1000 + ) + + train_config.strategy = StrategyFactory.load( + test_env["model"], + "seq", + pad_token_id=test_env["tokenizer"].pad_id ) model_parameter = ModelParameter( test_env["model"], test_env["tokenizer"], test_env["transformer_config"] ) - trainer = Trainer(model_parameter) - trainer.train(train_config, schedule_config) + trainer = Trainer(model_parameter, train_config, schedule_config) + trainer.train() def test_checkpoint(test_env): temp_dir = test_env["test_dir"] @@ -195,36 +199,32 @@ def test_checkpoint_train(test_env): dataset = InterruptDataset(length=10, interrupt_idx=3) param = ModelParameter(model, tokenizer, config) - trainer = Trainer(param) - optimizer = torch.optim.AdamW(test_env["model"].parameters()) train_config = TrainConfig( - train_type="seq", dataset=dataset, optimizer=optimizer, - ckpt_dir=test_env["test_dir"], + checkpoint_dir=test_env["test_dir"], n_epoch=2, batch_size=2, - n_iter_ckpt=5, - n_iter_step=1, + checkpoint_interval=5, + accumulation_steps=1, max_grad_norm=1.0, random_seed=42 ) - schedule_config = CosineScheduleConfig( - warning_step=100, - total_iters=1000 + + train_config.strategy = StrategyFactory.load( + test_env["model"], + "seq", + pad_token_id=test_env["tokenizer"].pad_id ) + schedule_config = CosineScheduleConfig( + warmup_steps=100, + total_steps=1000 + ) + trainer = Trainer(param, train_config, schedule_config) + try: - trainer.train( - train_config=train_config, - schedule_config=schedule_config, - ) + trainer.train() except Exception: checkpoint = trainer.checkpoint - trainer.train( - train_config=train_config, - schedule_config=schedule_config, - train_checkpoint=checkpoint - ) - - \ No newline at end of file + trainer.train(train_checkpoint=checkpoint) \ No newline at end of file