From c8a38743a4a2a62a2a0958ca5541ec156a1d1c48 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sun, 28 Sep 2025 14:00:38 +0800 Subject: [PATCH] =?UTF-8?q?fix(tests):=20=E6=9B=B4=E6=96=B0=E6=B5=8B?= =?UTF-8?q?=E8=AF=95=E4=BB=A3=E7=A0=81=E4=BB=A5=E9=AA=8C=E8=AF=81=E4=BC=98?= =?UTF-8?q?=E5=8C=96=E5=99=A8=E7=8A=B6=E6=80=81=E7=9A=84=E4=BF=9D=E5=AD=98?= =?UTF-8?q?=E4=B8=8E=E5=8A=A0=E8=BD=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_trainer.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/tests/test_trainer.py b/tests/test_trainer.py index e44d698..227c9c3 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -145,14 +145,16 @@ def test_checkpoint(test_env): config = test_env["transformer_config"] model = test_env["model"] tokenizer = test_env["tokenizer"] + optimizer = torch.optim.AdamW(model.parameters()) + for _ in range(3): + optimizer.step() - param = ModelParameter(model, tokenizer, config) checkpoint = Checkpoint( - model=param.model, - tokenizer=param.tokenizer, - config=param.config, - loss_list=[1.0, 2.0, 3.0], - current_iter=3 + model=model, + tokenizer=tokenizer, + config=config, + loss_list=[1.0, 2.0, 3.0], + optim_state=optimizer.state_dict() ) ckpt_dir = os.path.join(temp_dir, "ckpt") checkpoint.save(ckpt_dir) @@ -160,8 +162,8 @@ def test_checkpoint(test_env): loaded_ckpt = Checkpoint() loaded_ckpt.load(ckpt_dir) - assert loaded_ckpt.current_iter == 3 assert loaded_ckpt.loss_list == [1.0, 2.0, 3.0] + assert loaded_ckpt.optim_state == optimizer.state_dict() for p1, p2 in zip(model.parameters(), loaded_ckpt.model.parameters()): assert torch.allclose(p1, p2)