120 lines
3.6 KiB
Python
120 lines
3.6 KiB
Python
import os
|
|
import json
|
|
import torch
|
|
import pytest
|
|
import tempfile
|
|
import safetensors.torch as st
|
|
from khaosz.model.transformer import Transformer
|
|
from khaosz.config.model_config import ModelConfig
|
|
|
|
|
|
@pytest.fixture
|
|
def transformer_test_env():
|
|
"""创建Transformer测试专用环境"""
|
|
test_dir = tempfile.mkdtemp(prefix="transformer_test_")
|
|
config_path = os.path.join(test_dir, "config.json")
|
|
|
|
# 测试配置参数
|
|
config = {
|
|
"vocab_size": 1000,
|
|
"n_dim": 128,
|
|
"n_head": 4,
|
|
"n_kvhead": 2,
|
|
"d_ffn": 256,
|
|
"m_len": 64,
|
|
"n_layer": 2,
|
|
"norm_eps": 1e-5
|
|
}
|
|
|
|
with open(config_path, 'w') as f:
|
|
json.dump(config, f)
|
|
|
|
yield {
|
|
"test_dir": test_dir,
|
|
"config_path": config_path,
|
|
"config": config
|
|
}
|
|
|
|
if os.path.exists(test_dir):
|
|
try:
|
|
for file in os.listdir(test_dir):
|
|
os.remove(os.path.join(test_dir, file))
|
|
os.rmdir(test_dir)
|
|
except:
|
|
pass
|
|
|
|
|
|
def test_tie_weight_shared_logic(transformer_test_env):
|
|
|
|
config_path = transformer_test_env["config_path"]
|
|
config_data = transformer_test_env["config"].copy()
|
|
config_data["tie_weight"] = True
|
|
|
|
with open(config_path, 'w') as f:
|
|
json.dump(config_data, f)
|
|
|
|
config = ModelConfig().load(config_path)
|
|
model = Transformer(config)
|
|
|
|
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
|
assert model.lm_head.weight.data_ptr() == model.embed_tokens.weight.data_ptr()
|
|
|
|
original_weight = model.embed_tokens.weight.clone()
|
|
model.embed_tokens.weight.data[0, 0] = 100.0
|
|
|
|
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
|
assert not torch.equal(model.lm_head.weight, original_weight)
|
|
|
|
config_data["tie_weight"] = False
|
|
|
|
with open(config_path, 'w') as f:
|
|
json.dump(config_data, f)
|
|
|
|
config = ModelConfig().load(config_path)
|
|
model = Transformer(config)
|
|
|
|
assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
|
assert model.lm_head.weight.data_ptr() != model.embed_tokens.weight.data_ptr()
|
|
|
|
|
|
def test_model_save_load_with_tie_weight(transformer_test_env):
|
|
test_dir = transformer_test_env["test_dir"]
|
|
model_path = os.path.join(test_dir, "model.safetensors")
|
|
|
|
config_data = transformer_test_env["config"].copy()
|
|
config_data["tie_weight"] = True
|
|
config_path = os.path.join(test_dir, "config.json")
|
|
|
|
with open(config_path, 'w') as f:
|
|
json.dump(config_data, f)
|
|
|
|
config = ModelConfig().load(config_path)
|
|
original_model = Transformer(config)
|
|
|
|
st.save_file(original_model.state_dict(), model_path)
|
|
|
|
loaded_config = ModelConfig().load(config_path)
|
|
model = Transformer(loaded_config)
|
|
model.load_state_dict(st.load_file(model_path))
|
|
|
|
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
|
assert model.lm_head.weight.data_ptr() == model.embed_tokens.weight.data_ptr()
|
|
assert "lm_head.weight" not in model.state_dict()
|
|
|
|
config_data["tie_weight"] = False
|
|
with open(config_path, 'w') as f:
|
|
json.dump(config_data, f)
|
|
|
|
config = ModelConfig().load(config_path)
|
|
original_model = Transformer(config)
|
|
|
|
st.save_file(original_model.state_dict(), model_path)
|
|
|
|
loaded_config = ModelConfig().load(config_path)
|
|
model = Transformer(loaded_config)
|
|
model.load_state_dict(st.load_file(model_path))
|
|
|
|
assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
|
assert model.lm_head.weight.data_ptr() != model.embed_tokens.weight.data_ptr()
|
|
assert "lm_head.weight" in model.state_dict()
|
|
|