diff --git a/tests/test_trainer.py b/tests/test_trainer.py index e44d698..227c9c3 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -145,14 +145,16 @@ def test_checkpoint(test_env): config = test_env["transformer_config"] model = test_env["model"] tokenizer = test_env["tokenizer"] + optimizer = torch.optim.AdamW(model.parameters()) + for _ in range(3): + optimizer.step() - param = ModelParameter(model, tokenizer, config) checkpoint = Checkpoint( - model=param.model, - tokenizer=param.tokenizer, - config=param.config, - loss_list=[1.0, 2.0, 3.0], - current_iter=3 + model=model, + tokenizer=tokenizer, + config=config, + loss_list=[1.0, 2.0, 3.0], + optim_state=optimizer.state_dict() ) ckpt_dir = os.path.join(temp_dir, "ckpt") checkpoint.save(ckpt_dir) @@ -160,8 +162,8 @@ def test_checkpoint(test_env): loaded_ckpt = Checkpoint() loaded_ckpt.load(ckpt_dir) - assert loaded_ckpt.current_iter == 3 assert loaded_ckpt.loss_list == [1.0, 2.0, 3.0] + assert loaded_ckpt.optim_state == optimizer.state_dict() for p1, p2 in zip(model.parameters(), loaded_ckpt.model.parameters()): assert torch.allclose(p1, p2)