test(transformer): 更新 tie_weight 相关测试逻辑

This commit is contained in:
ViperEkura 2025-11-09 17:23:33 +08:00
parent 7a21f5d72e
commit f31bf5a959
1 changed files with 8 additions and 4 deletions

View File

@ -14,7 +14,6 @@ def transformer_test_env():
test_dir = tempfile.mkdtemp(prefix="transformer_test_")
config_path = os.path.join(test_dir, "config.json")
# 测试配置参数
config = {
"vocab_size": 1000,
"n_dim": 128,
@ -44,10 +43,11 @@ def transformer_test_env():
pass
def test_tie_weight_shared_logic(transformer_test_env):
def test_tie_weight_init(transformer_test_env):
config_path = transformer_test_env["config_path"]
config_data = transformer_test_env["config"].copy()
# case 1: tie weight
config_data["tie_weight"] = True
with open(config_path, 'w') as f:
@ -65,6 +65,7 @@ def test_tie_weight_shared_logic(transformer_test_env):
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
assert not torch.equal(model.lm_head.weight, original_weight)
# case 2: not tie weight
config_data["tie_weight"] = False
with open(config_path, 'w') as f:
@ -82,6 +83,8 @@ def test_model_save_load_with_tie_weight(transformer_test_env):
model_path = os.path.join(test_dir, "model.safetensors")
config_data = transformer_test_env["config"].copy()
# case 1: tie weight
config_data["tie_weight"] = True
config_path = os.path.join(test_dir, "config.json")
@ -101,6 +104,7 @@ def test_model_save_load_with_tie_weight(transformer_test_env):
assert model.lm_head.weight.data_ptr() == model.embed_tokens.weight.data_ptr()
assert "lm_head.weight" not in model.state_dict()
# case 2: not tie weight
config_data["tie_weight"] = False
with open(config_path, 'w') as f:
json.dump(config_data, f)
@ -114,7 +118,7 @@ def test_model_save_load_with_tie_weight(transformer_test_env):
model = Transformer(loaded_config)
model.load_state_dict(st.load_file(model_path))
assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight)
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
assert model.lm_head.weight.data_ptr() != model.embed_tokens.weight.data_ptr()
assert "lm_head.weight" in model.state_dict()