test(transformer): 更新 tie_weight 相关测试逻辑
This commit is contained in:
parent
7a21f5d72e
commit
f31bf5a959
|
|
@ -14,7 +14,6 @@ def transformer_test_env():
|
||||||
test_dir = tempfile.mkdtemp(prefix="transformer_test_")
|
test_dir = tempfile.mkdtemp(prefix="transformer_test_")
|
||||||
config_path = os.path.join(test_dir, "config.json")
|
config_path = os.path.join(test_dir, "config.json")
|
||||||
|
|
||||||
# 测试配置参数
|
|
||||||
config = {
|
config = {
|
||||||
"vocab_size": 1000,
|
"vocab_size": 1000,
|
||||||
"n_dim": 128,
|
"n_dim": 128,
|
||||||
|
|
@ -44,10 +43,11 @@ def transformer_test_env():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def test_tie_weight_shared_logic(transformer_test_env):
|
def test_tie_weight_init(transformer_test_env):
|
||||||
|
|
||||||
config_path = transformer_test_env["config_path"]
|
config_path = transformer_test_env["config_path"]
|
||||||
config_data = transformer_test_env["config"].copy()
|
config_data = transformer_test_env["config"].copy()
|
||||||
|
|
||||||
|
# case 1: tie weight
|
||||||
config_data["tie_weight"] = True
|
config_data["tie_weight"] = True
|
||||||
|
|
||||||
with open(config_path, 'w') as f:
|
with open(config_path, 'w') as f:
|
||||||
|
|
@ -65,6 +65,7 @@ def test_tie_weight_shared_logic(transformer_test_env):
|
||||||
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
||||||
assert not torch.equal(model.lm_head.weight, original_weight)
|
assert not torch.equal(model.lm_head.weight, original_weight)
|
||||||
|
|
||||||
|
# case 2: not tie weight
|
||||||
config_data["tie_weight"] = False
|
config_data["tie_weight"] = False
|
||||||
|
|
||||||
with open(config_path, 'w') as f:
|
with open(config_path, 'w') as f:
|
||||||
|
|
@ -82,6 +83,8 @@ def test_model_save_load_with_tie_weight(transformer_test_env):
|
||||||
model_path = os.path.join(test_dir, "model.safetensors")
|
model_path = os.path.join(test_dir, "model.safetensors")
|
||||||
|
|
||||||
config_data = transformer_test_env["config"].copy()
|
config_data = transformer_test_env["config"].copy()
|
||||||
|
|
||||||
|
# case 1: tie weight
|
||||||
config_data["tie_weight"] = True
|
config_data["tie_weight"] = True
|
||||||
config_path = os.path.join(test_dir, "config.json")
|
config_path = os.path.join(test_dir, "config.json")
|
||||||
|
|
||||||
|
|
@ -101,6 +104,7 @@ def test_model_save_load_with_tie_weight(transformer_test_env):
|
||||||
assert model.lm_head.weight.data_ptr() == model.embed_tokens.weight.data_ptr()
|
assert model.lm_head.weight.data_ptr() == model.embed_tokens.weight.data_ptr()
|
||||||
assert "lm_head.weight" not in model.state_dict()
|
assert "lm_head.weight" not in model.state_dict()
|
||||||
|
|
||||||
|
# case 2: not tie weight
|
||||||
config_data["tie_weight"] = False
|
config_data["tie_weight"] = False
|
||||||
with open(config_path, 'w') as f:
|
with open(config_path, 'w') as f:
|
||||||
json.dump(config_data, f)
|
json.dump(config_data, f)
|
||||||
|
|
@ -114,7 +118,7 @@ def test_model_save_load_with_tie_weight(transformer_test_env):
|
||||||
model = Transformer(loaded_config)
|
model = Transformer(loaded_config)
|
||||||
model.load_state_dict(st.load_file(model_path))
|
model.load_state_dict(st.load_file(model_path))
|
||||||
|
|
||||||
assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
||||||
assert model.lm_head.weight.data_ptr() != model.embed_tokens.weight.data_ptr()
|
assert model.lm_head.weight.data_ptr() != model.embed_tokens.weight.data_ptr()
|
||||||
assert "lm_head.weight" in model.state_dict()
|
assert "lm_head.weight" in model.state_dict()
|
||||||
|
|
||||||
Loading…
Reference in New Issue