diff --git a/tests/test_tie_weight.py b/tests/test_tie_weight.py index fbebd3e..cc629db 100644 --- a/tests/test_tie_weight.py +++ b/tests/test_tie_weight.py @@ -14,7 +14,6 @@ def transformer_test_env(): test_dir = tempfile.mkdtemp(prefix="transformer_test_") config_path = os.path.join(test_dir, "config.json") - # 测试配置参数 config = { "vocab_size": 1000, "n_dim": 128, @@ -44,10 +43,11 @@ def transformer_test_env(): pass -def test_tie_weight_shared_logic(transformer_test_env): - +def test_tie_weight_init(transformer_test_env): config_path = transformer_test_env["config_path"] config_data = transformer_test_env["config"].copy() + + # case 1: tie weight config_data["tie_weight"] = True with open(config_path, 'w') as f: @@ -65,6 +65,7 @@ def test_tie_weight_shared_logic(transformer_test_env): assert torch.equal(model.lm_head.weight, model.embed_tokens.weight) assert not torch.equal(model.lm_head.weight, original_weight) + # case 2: not tie weight config_data["tie_weight"] = False with open(config_path, 'w') as f: @@ -82,6 +83,8 @@ def test_model_save_load_with_tie_weight(transformer_test_env): model_path = os.path.join(test_dir, "model.safetensors") config_data = transformer_test_env["config"].copy() + + # case 1: tie weight config_data["tie_weight"] = True config_path = os.path.join(test_dir, "config.json") @@ -101,6 +104,7 @@ def test_model_save_load_with_tie_weight(transformer_test_env): assert model.lm_head.weight.data_ptr() == model.embed_tokens.weight.data_ptr() assert "lm_head.weight" not in model.state_dict() + # case 2: not tie weight config_data["tie_weight"] = False with open(config_path, 'w') as f: json.dump(config_data, f) @@ -114,7 +118,7 @@ def test_model_save_load_with_tie_weight(transformer_test_env): model = Transformer(loaded_config) model.load_state_dict(st.load_file(model_path)) - assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight) + assert torch.equal(model.lm_head.weight, model.embed_tokens.weight) assert model.lm_head.weight.data_ptr() != model.embed_tokens.weight.data_ptr() assert "lm_head.weight" in model.state_dict() \ No newline at end of file