fix(tests): 更新测试代码以验证优化器状态的保存与加载
This commit is contained in:
parent
f25a249291
commit
c8a38743a4
|
|
@ -145,14 +145,16 @@ def test_checkpoint(test_env):
|
||||||
config = test_env["transformer_config"]
|
config = test_env["transformer_config"]
|
||||||
model = test_env["model"]
|
model = test_env["model"]
|
||||||
tokenizer = test_env["tokenizer"]
|
tokenizer = test_env["tokenizer"]
|
||||||
|
optimizer = torch.optim.AdamW(model.parameters())
|
||||||
|
for _ in range(3):
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
param = ModelParameter(model, tokenizer, config)
|
|
||||||
checkpoint = Checkpoint(
|
checkpoint = Checkpoint(
|
||||||
model=param.model,
|
model=model,
|
||||||
tokenizer=param.tokenizer,
|
tokenizer=tokenizer,
|
||||||
config=param.config,
|
config=config,
|
||||||
loss_list=[1.0, 2.0, 3.0],
|
loss_list=[1.0, 2.0, 3.0],
|
||||||
current_iter=3
|
optim_state=optimizer.state_dict()
|
||||||
)
|
)
|
||||||
ckpt_dir = os.path.join(temp_dir, "ckpt")
|
ckpt_dir = os.path.join(temp_dir, "ckpt")
|
||||||
checkpoint.save(ckpt_dir)
|
checkpoint.save(ckpt_dir)
|
||||||
|
|
@ -160,8 +162,8 @@ def test_checkpoint(test_env):
|
||||||
loaded_ckpt = Checkpoint()
|
loaded_ckpt = Checkpoint()
|
||||||
loaded_ckpt.load(ckpt_dir)
|
loaded_ckpt.load(ckpt_dir)
|
||||||
|
|
||||||
assert loaded_ckpt.current_iter == 3
|
|
||||||
assert loaded_ckpt.loss_list == [1.0, 2.0, 3.0]
|
assert loaded_ckpt.loss_list == [1.0, 2.0, 3.0]
|
||||||
|
assert loaded_ckpt.optim_state == optimizer.state_dict()
|
||||||
|
|
||||||
for p1, p2 in zip(model.parameters(), loaded_ckpt.model.parameters()):
|
for p1, p2 in zip(model.parameters(), loaded_ckpt.model.parameters()):
|
||||||
assert torch.allclose(p1, p2)
|
assert torch.allclose(p1, p2)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue