test(test_tie_weight): 添加测试以验证权重绑定后的数据修改行为
This commit is contained in:
parent
1c3a693d79
commit
3c7ed84516
|
|
@ -77,6 +77,11 @@ def test_tie_weight_init(transformer_test_env):
|
||||||
assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
||||||
assert model.lm_head.weight.data_ptr() != model.embed_tokens.weight.data_ptr()
|
assert model.lm_head.weight.data_ptr() != model.embed_tokens.weight.data_ptr()
|
||||||
|
|
||||||
|
original_weight = model.embed_tokens.weight.clone()
|
||||||
|
model.embed_tokens.weight.data[0, 0] = 100.0
|
||||||
|
|
||||||
|
assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
||||||
|
assert not torch.equal(model.lm_head.weight, original_weight)
|
||||||
|
|
||||||
def test_model_save_load_with_tie_weight(transformer_test_env):
|
def test_model_save_load_with_tie_weight(transformer_test_env):
|
||||||
test_dir = transformer_test_env["test_dir"]
|
test_dir = transformer_test_env["test_dir"]
|
||||||
|
|
@ -104,16 +109,11 @@ def test_model_save_load_with_tie_weight(transformer_test_env):
|
||||||
assert model.lm_head.weight.data_ptr() == model.embed_tokens.weight.data_ptr()
|
assert model.lm_head.weight.data_ptr() == model.embed_tokens.weight.data_ptr()
|
||||||
assert "lm_head.weight" not in model.state_dict()
|
assert "lm_head.weight" not in model.state_dict()
|
||||||
|
|
||||||
# case 2: not tie weight
|
# case 2: not tie weight (form tie-weight state dict load)
|
||||||
config_data["tie_weight"] = False
|
config_data["tie_weight"] = False
|
||||||
with open(config_path, 'w') as f:
|
with open(config_path, 'w') as f:
|
||||||
json.dump(config_data, f)
|
json.dump(config_data, f)
|
||||||
|
|
||||||
config = ModelConfig().load(config_path)
|
|
||||||
original_model = Transformer(config)
|
|
||||||
|
|
||||||
st.save_file(original_model.state_dict(), model_path)
|
|
||||||
|
|
||||||
loaded_config = ModelConfig().load(config_path)
|
loaded_config = ModelConfig().load(config_path)
|
||||||
model = Transformer(loaded_config)
|
model = Transformer(loaded_config)
|
||||||
model.load_state_dict(st.load_file(model_path))
|
model.load_state_dict(st.load_file(model_path))
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue