diff --git a/khaosz/model/transformer.py b/khaosz/model/transformer.py index a557b86..f67edc9 100644 --- a/khaosz/model/transformer.py +++ b/khaosz/model/transformer.py @@ -71,10 +71,12 @@ class Transformer(nn.Module): DecoderBlock(config.n_dim, config.n_head, config.d_ffn, config.n_kvhead, config.norm_eps, layer_id) for layer_id in range(config.n_layer) ]) - lm_head_init_weight = self.embed_tokens.weight if config.tie_weight == True else None self.norm = RMSNorm(config.n_dim, config.norm_eps) - self.lm_head = Linear(config.n_dim, config.vocab_size, weight_param=lm_head_init_weight) + self.lm_head = Linear(config.n_dim, config.vocab_size) + + if self.config.tie_weight == True: + self.lm_head.weight = self.embed_tokens.weight self._init_parameters() diff --git a/tests/test_tie_weight.py b/tests/test_tie_weight.py new file mode 100644 index 0000000..fbebd3e --- /dev/null +++ b/tests/test_tie_weight.py @@ -0,0 +1,120 @@ +import os +import json +import torch +import pytest +import tempfile +import safetensors.torch as st +from khaosz.model.transformer import Transformer +from khaosz.config.model_config import ModelConfig + + +@pytest.fixture +def transformer_test_env(): + """创建Transformer测试专用环境""" + test_dir = tempfile.mkdtemp(prefix="transformer_test_") + config_path = os.path.join(test_dir, "config.json") + + # 测试配置参数 + config = { + "vocab_size": 1000, + "n_dim": 128, + "n_head": 4, + "n_kvhead": 2, + "d_ffn": 256, + "m_len": 64, + "n_layer": 2, + "norm_eps": 1e-5 + } + + with open(config_path, 'w') as f: + json.dump(config, f) + + yield { + "test_dir": test_dir, + "config_path": config_path, + "config": config + } + + if os.path.exists(test_dir): + try: + for file in os.listdir(test_dir): + os.remove(os.path.join(test_dir, file)) + os.rmdir(test_dir) + except: + pass + + +def test_tie_weight_shared_logic(transformer_test_env): + + config_path = transformer_test_env["config_path"] + config_data = transformer_test_env["config"].copy() + config_data["tie_weight"] = True + + with open(config_path, 'w') as f: + json.dump(config_data, f) + + config = ModelConfig().load(config_path) + model = Transformer(config) + + assert torch.equal(model.lm_head.weight, model.embed_tokens.weight) + assert model.lm_head.weight.data_ptr() == model.embed_tokens.weight.data_ptr() + + original_weight = model.embed_tokens.weight.clone() + model.embed_tokens.weight.data[0, 0] = 100.0 + + assert torch.equal(model.lm_head.weight, model.embed_tokens.weight) + assert not torch.equal(model.lm_head.weight, original_weight) + + config_data["tie_weight"] = False + + with open(config_path, 'w') as f: + json.dump(config_data, f) + + config = ModelConfig().load(config_path) + model = Transformer(config) + + assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight) + assert model.lm_head.weight.data_ptr() != model.embed_tokens.weight.data_ptr() + + +def test_model_save_load_with_tie_weight(transformer_test_env): + test_dir = transformer_test_env["test_dir"] + model_path = os.path.join(test_dir, "model.safetensors") + + config_data = transformer_test_env["config"].copy() + config_data["tie_weight"] = True + config_path = os.path.join(test_dir, "config.json") + + with open(config_path, 'w') as f: + json.dump(config_data, f) + + config = ModelConfig().load(config_path) + original_model = Transformer(config) + + st.save_file(original_model.state_dict(), model_path) + + loaded_config = ModelConfig().load(config_path) + model = Transformer(loaded_config) + model.load_state_dict(st.load_file(model_path)) + + assert torch.equal(model.lm_head.weight, model.embed_tokens.weight) + assert model.lm_head.weight.data_ptr() == model.embed_tokens.weight.data_ptr() + assert "lm_head.weight" not in model.state_dict() + + config_data["tie_weight"] = False + with open(config_path, 'w') as f: + json.dump(config_data, f) + + config = ModelConfig().load(config_path) + original_model = Transformer(config) + + st.save_file(original_model.state_dict(), model_path) + + loaded_config = ModelConfig().load(config_path) + model = Transformer(loaded_config) + model.load_state_dict(st.load_file(model_path)) + + assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight) + assert model.lm_head.weight.data_ptr() != model.embed_tokens.weight.data_ptr() + assert "lm_head.weight" in model.state_dict() + \ No newline at end of file