diff --git a/tests/test_tie_weight.py b/tests/test_tie_weight.py index cc629db..c2d8911 100644 --- a/tests/test_tie_weight.py +++ b/tests/test_tie_weight.py @@ -77,6 +77,11 @@ def test_tie_weight_init(transformer_test_env): assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight) assert model.lm_head.weight.data_ptr() != model.embed_tokens.weight.data_ptr() + original_weight = model.embed_tokens.weight.clone() + model.embed_tokens.weight.data[0, 0] = 100.0 + + assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight) + assert not torch.equal(model.lm_head.weight, original_weight) def test_model_save_load_with_tie_weight(transformer_test_env): test_dir = transformer_test_env["test_dir"] @@ -104,16 +109,11 @@ def test_model_save_load_with_tie_weight(transformer_test_env): assert model.lm_head.weight.data_ptr() == model.embed_tokens.weight.data_ptr() assert "lm_head.weight" not in model.state_dict() - # case 2: not tie weight + # case 2: not tie weight (form tie-weight state dict load) config_data["tie_weight"] = False with open(config_path, 'w') as f: json.dump(config_data, f) - config = ModelConfig().load(config_path) - original_model = Transformer(config) - - st.save_file(original_model.state_dict(), model_path) - loaded_config = ModelConfig().load(config_path) model = Transformer(loaded_config) model.load_state_dict(st.load_file(model_path))