From 3c7ed84516bb494dcaa1ff25524ad741f947639a Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Wed, 19 Nov 2025 17:47:22 +0800 Subject: [PATCH] =?UTF-8?q?test(test=5Ftie=5Fweight):=20=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=E6=B5=8B=E8=AF=95=E4=BB=A5=E9=AA=8C=E8=AF=81=E6=9D=83=E9=87=8D?= =?UTF-8?q?=E7=BB=91=E5=AE=9A=E5=90=8E=E7=9A=84=E6=95=B0=E6=8D=AE=E4=BF=AE?= =?UTF-8?q?=E6=94=B9=E8=A1=8C=E4=B8=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_tie_weight.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_tie_weight.py b/tests/test_tie_weight.py index cc629db..c2d8911 100644 --- a/tests/test_tie_weight.py +++ b/tests/test_tie_weight.py @@ -77,6 +77,11 @@ def test_tie_weight_init(transformer_test_env): assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight) assert model.lm_head.weight.data_ptr() != model.embed_tokens.weight.data_ptr() + original_weight = model.embed_tokens.weight.clone() + model.embed_tokens.weight.data[0, 0] = 100.0 + + assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight) + assert not torch.equal(model.lm_head.weight, original_weight) def test_model_save_load_with_tie_weight(transformer_test_env): test_dir = transformer_test_env["test_dir"] @@ -104,16 +109,11 @@ def test_model_save_load_with_tie_weight(transformer_test_env): assert model.lm_head.weight.data_ptr() == model.embed_tokens.weight.data_ptr() assert "lm_head.weight" not in model.state_dict() - # case 2: not tie weight + # case 2: not tie weight (form tie-weight state dict load) config_data["tie_weight"] = False with open(config_path, 'w') as f: json.dump(config_data, f) - config = ModelConfig().load(config_path) - original_model = Transformer(config) - - st.save_file(original_model.state_dict(), model_path) - loaded_config = ModelConfig().load(config_path) model = Transformer(loaded_config) model.load_state_dict(st.load_file(model_path))