fix(tests): 更新测试代码以验证优化器状态的保存与加载

This commit is contained in:
ViperEkura 2025-09-28 14:00:38 +08:00
parent f25a249291
commit c8a38743a4
1 changed files with 9 additions and 7 deletions

View File

@ -145,14 +145,16 @@ def test_checkpoint(test_env):
config = test_env["transformer_config"] config = test_env["transformer_config"]
model = test_env["model"] model = test_env["model"]
tokenizer = test_env["tokenizer"] tokenizer = test_env["tokenizer"]
optimizer = torch.optim.AdamW(model.parameters())
for _ in range(3):
optimizer.step()
param = ModelParameter(model, tokenizer, config)
checkpoint = Checkpoint( checkpoint = Checkpoint(
model=param.model, model=model,
tokenizer=param.tokenizer, tokenizer=tokenizer,
config=param.config, config=config,
loss_list=[1.0, 2.0, 3.0], loss_list=[1.0, 2.0, 3.0],
current_iter=3 optim_state=optimizer.state_dict()
) )
ckpt_dir = os.path.join(temp_dir, "ckpt") ckpt_dir = os.path.join(temp_dir, "ckpt")
checkpoint.save(ckpt_dir) checkpoint.save(ckpt_dir)
@ -160,8 +162,8 @@ def test_checkpoint(test_env):
loaded_ckpt = Checkpoint() loaded_ckpt = Checkpoint()
loaded_ckpt.load(ckpt_dir) loaded_ckpt.load(ckpt_dir)
assert loaded_ckpt.current_iter == 3
assert loaded_ckpt.loss_list == [1.0, 2.0, 3.0] assert loaded_ckpt.loss_list == [1.0, 2.0, 3.0]
assert loaded_ckpt.optim_state == optimizer.state_dict()
for p1, p2 in zip(model.parameters(), loaded_ckpt.model.parameters()): for p1, p2 in zip(model.parameters(), loaded_ckpt.model.parameters()):
assert torch.allclose(p1, p2) assert torch.allclose(p1, p2)