From f31bf5a959b38de521116b50359e2ce3013675b2 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sun, 9 Nov 2025 17:23:33 +0800 Subject: [PATCH] =?UTF-8?q?test(transformer):=20=E6=9B=B4=E6=96=B0=20tie?= =?UTF-8?q?=5Fweight=20=E7=9B=B8=E5=85=B3=E6=B5=8B=E8=AF=95=E9=80=BB?= =?UTF-8?q?=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_tie_weight.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/test_tie_weight.py b/tests/test_tie_weight.py index fbebd3e..cc629db 100644 --- a/tests/test_tie_weight.py +++ b/tests/test_tie_weight.py @@ -14,7 +14,6 @@ def transformer_test_env(): test_dir = tempfile.mkdtemp(prefix="transformer_test_") config_path = os.path.join(test_dir, "config.json") - # 测试配置参数 config = { "vocab_size": 1000, "n_dim": 128, @@ -44,10 +43,11 @@ def transformer_test_env(): pass -def test_tie_weight_shared_logic(transformer_test_env): - +def test_tie_weight_init(transformer_test_env): config_path = transformer_test_env["config_path"] config_data = transformer_test_env["config"].copy() + + # case 1: tie weight config_data["tie_weight"] = True with open(config_path, 'w') as f: @@ -65,6 +65,7 @@ def test_tie_weight_shared_logic(transformer_test_env): assert torch.equal(model.lm_head.weight, model.embed_tokens.weight) assert not torch.equal(model.lm_head.weight, original_weight) + # case 2: not tie weight config_data["tie_weight"] = False with open(config_path, 'w') as f: @@ -82,6 +83,8 @@ def test_model_save_load_with_tie_weight(transformer_test_env): model_path = os.path.join(test_dir, "model.safetensors") config_data = transformer_test_env["config"].copy() + + # case 1: tie weight config_data["tie_weight"] = True config_path = os.path.join(test_dir, "config.json") @@ -101,6 +104,7 @@ def test_model_save_load_with_tie_weight(transformer_test_env): assert model.lm_head.weight.data_ptr() == model.embed_tokens.weight.data_ptr() assert "lm_head.weight" not in model.state_dict() + # case 2: not tie weight config_data["tie_weight"] = False with open(config_path, 'w') as f: json.dump(config_data, f) @@ -114,7 +118,7 @@ def test_model_save_load_with_tie_weight(transformer_test_env): model = Transformer(loaded_config) model.load_state_dict(st.load_file(model_path)) - assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight) + assert torch.equal(model.lm_head.weight, model.embed_tokens.weight) assert model.lm_head.weight.data_ptr() != model.embed_tokens.weight.data_ptr() assert "lm_head.weight" in model.state_dict() \ No newline at end of file