test(transformer): 更新 tie_weight 相关测试逻辑
This commit is contained in:
parent
7a21f5d72e
commit
f31bf5a959
|
|
@ -14,7 +14,6 @@ def transformer_test_env():
|
|||
test_dir = tempfile.mkdtemp(prefix="transformer_test_")
|
||||
config_path = os.path.join(test_dir, "config.json")
|
||||
|
||||
# 测试配置参数
|
||||
config = {
|
||||
"vocab_size": 1000,
|
||||
"n_dim": 128,
|
||||
|
|
@ -44,10 +43,11 @@ def transformer_test_env():
|
|||
pass
|
||||
|
||||
|
||||
def test_tie_weight_shared_logic(transformer_test_env):
|
||||
|
||||
def test_tie_weight_init(transformer_test_env):
|
||||
config_path = transformer_test_env["config_path"]
|
||||
config_data = transformer_test_env["config"].copy()
|
||||
|
||||
# case 1: tie weight
|
||||
config_data["tie_weight"] = True
|
||||
|
||||
with open(config_path, 'w') as f:
|
||||
|
|
@ -65,6 +65,7 @@ def test_tie_weight_shared_logic(transformer_test_env):
|
|||
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
||||
assert not torch.equal(model.lm_head.weight, original_weight)
|
||||
|
||||
# case 2: not tie weight
|
||||
config_data["tie_weight"] = False
|
||||
|
||||
with open(config_path, 'w') as f:
|
||||
|
|
@ -82,6 +83,8 @@ def test_model_save_load_with_tie_weight(transformer_test_env):
|
|||
model_path = os.path.join(test_dir, "model.safetensors")
|
||||
|
||||
config_data = transformer_test_env["config"].copy()
|
||||
|
||||
# case 1: tie weight
|
||||
config_data["tie_weight"] = True
|
||||
config_path = os.path.join(test_dir, "config.json")
|
||||
|
||||
|
|
@ -101,6 +104,7 @@ def test_model_save_load_with_tie_weight(transformer_test_env):
|
|||
assert model.lm_head.weight.data_ptr() == model.embed_tokens.weight.data_ptr()
|
||||
assert "lm_head.weight" not in model.state_dict()
|
||||
|
||||
# case 2: not tie weight
|
||||
config_data["tie_weight"] = False
|
||||
with open(config_path, 'w') as f:
|
||||
json.dump(config_data, f)
|
||||
|
|
@ -114,7 +118,7 @@ def test_model_save_load_with_tie_weight(transformer_test_env):
|
|||
model = Transformer(loaded_config)
|
||||
model.load_state_dict(st.load_file(model_path))
|
||||
|
||||
assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
||||
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
||||
assert model.lm_head.weight.data_ptr() != model.embed_tokens.weight.data_ptr()
|
||||
assert "lm_head.weight" in model.state_dict()
|
||||
|
||||
Loading…
Reference in New Issue