fix(tests): 更新测试代码以验证优化器状态的保存与加载
This commit is contained in:
parent
f25a249291
commit
c8a38743a4
|
|
@ -145,14 +145,16 @@ def test_checkpoint(test_env):
|
|||
config = test_env["transformer_config"]
|
||||
model = test_env["model"]
|
||||
tokenizer = test_env["tokenizer"]
|
||||
optimizer = torch.optim.AdamW(model.parameters())
|
||||
for _ in range(3):
|
||||
optimizer.step()
|
||||
|
||||
param = ModelParameter(model, tokenizer, config)
|
||||
checkpoint = Checkpoint(
|
||||
model=param.model,
|
||||
tokenizer=param.tokenizer,
|
||||
config=param.config,
|
||||
loss_list=[1.0, 2.0, 3.0],
|
||||
current_iter=3
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
config=config,
|
||||
loss_list=[1.0, 2.0, 3.0],
|
||||
optim_state=optimizer.state_dict()
|
||||
)
|
||||
ckpt_dir = os.path.join(temp_dir, "ckpt")
|
||||
checkpoint.save(ckpt_dir)
|
||||
|
|
@ -160,8 +162,8 @@ def test_checkpoint(test_env):
|
|||
loaded_ckpt = Checkpoint()
|
||||
loaded_ckpt.load(ckpt_dir)
|
||||
|
||||
assert loaded_ckpt.current_iter == 3
|
||||
assert loaded_ckpt.loss_list == [1.0, 2.0, 3.0]
|
||||
assert loaded_ckpt.optim_state == optimizer.state_dict()
|
||||
|
||||
for p1, p2 in zip(model.parameters(), loaded_ckpt.model.parameters()):
|
||||
assert torch.allclose(p1, p2)
|
||||
|
|
|
|||
Loading…
Reference in New Issue