From 254ec934be8df801365ded3a8ab6966c301e95be Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Fri, 7 Nov 2025 15:14:54 +0800 Subject: [PATCH] =?UTF-8?q?feat(transformer):=20=20=E7=AE=80=E5=8C=96?= =?UTF-8?q?=E6=9D=83=E9=87=8D=E7=BB=91=E5=AE=9A=E9=80=BB=E8=BE=91=E5=B9=B6?= =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E6=B5=8B=E8=AF=95=E5=8D=95=E5=85=83?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/model/transformer.py | 6 +- tests/test_tie_weight.py | 120 ++++++++++++++++++++++++++++++++++++ 2 files changed, 124 insertions(+), 2 deletions(-) create mode 100644 tests/test_tie_weight.py diff --git a/khaosz/model/transformer.py b/khaosz/model/transformer.py index a557b86..f67edc9 100644 --- a/khaosz/model/transformer.py +++ b/khaosz/model/transformer.py @@ -71,10 +71,12 @@ class Transformer(nn.Module): DecoderBlock(config.n_dim, config.n_head, config.d_ffn, config.n_kvhead, config.norm_eps, layer_id) for layer_id in range(config.n_layer) ]) - lm_head_init_weight = self.embed_tokens.weight if config.tie_weight == True else None self.norm = RMSNorm(config.n_dim, config.norm_eps) - self.lm_head = Linear(config.n_dim, config.vocab_size, weight_param=lm_head_init_weight) + self.lm_head = Linear(config.n_dim, config.vocab_size) + + if self.config.tie_weight == True: + self.lm_head.weight = self.embed_tokens.weight self._init_parameters() diff --git a/tests/test_tie_weight.py b/tests/test_tie_weight.py new file mode 100644 index 0000000..fbebd3e --- /dev/null +++ b/tests/test_tie_weight.py @@ -0,0 +1,120 @@ +import os +import json +import torch +import pytest +import tempfile +import safetensors.torch as st +from khaosz.model.transformer import Transformer +from khaosz.config.model_config import ModelConfig + + +@pytest.fixture +def transformer_test_env(): + """创建Transformer测试专用环境""" + test_dir = tempfile.mkdtemp(prefix="transformer_test_") + config_path = os.path.join(test_dir, "config.json") + + # 测试配置参数 + config = { + "vocab_size": 1000, + "n_dim": 128, + "n_head": 4, + "n_kvhead": 2, + "d_ffn": 256, + "m_len": 64, + "n_layer": 2, + "norm_eps": 1e-5 + } + + with open(config_path, 'w') as f: + json.dump(config, f) + + yield { + "test_dir": test_dir, + "config_path": config_path, + "config": config + } + + if os.path.exists(test_dir): + try: + for file in os.listdir(test_dir): + os.remove(os.path.join(test_dir, file)) + os.rmdir(test_dir) + except: + pass + + +def test_tie_weight_shared_logic(transformer_test_env): + + config_path = transformer_test_env["config_path"] + config_data = transformer_test_env["config"].copy() + config_data["tie_weight"] = True + + with open(config_path, 'w') as f: + json.dump(config_data, f) + + config = ModelConfig().load(config_path) + model = Transformer(config) + + assert torch.equal(model.lm_head.weight, model.embed_tokens.weight) + assert model.lm_head.weight.data_ptr() == model.embed_tokens.weight.data_ptr() + + original_weight = model.embed_tokens.weight.clone() + model.embed_tokens.weight.data[0, 0] = 100.0 + + assert torch.equal(model.lm_head.weight, model.embed_tokens.weight) + assert not torch.equal(model.lm_head.weight, original_weight) + + config_data["tie_weight"] = False + + with open(config_path, 'w') as f: + json.dump(config_data, f) + + config = ModelConfig().load(config_path) + model = Transformer(config) + + assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight) + assert model.lm_head.weight.data_ptr() != model.embed_tokens.weight.data_ptr() + + +def test_model_save_load_with_tie_weight(transformer_test_env): + test_dir = transformer_test_env["test_dir"] + model_path = os.path.join(test_dir, "model.safetensors") + + config_data = transformer_test_env["config"].copy() + config_data["tie_weight"] = True + config_path = os.path.join(test_dir, "config.json") + + with open(config_path, 'w') as f: + json.dump(config_data, f) + + config = ModelConfig().load(config_path) + original_model = Transformer(config) + + st.save_file(original_model.state_dict(), model_path) + + loaded_config = ModelConfig().load(config_path) + model = Transformer(loaded_config) + model.load_state_dict(st.load_file(model_path)) + + assert torch.equal(model.lm_head.weight, model.embed_tokens.weight) + assert model.lm_head.weight.data_ptr() == model.embed_tokens.weight.data_ptr() + assert "lm_head.weight" not in model.state_dict() + + config_data["tie_weight"] = False + with open(config_path, 'w') as f: + json.dump(config_data, f) + + config = ModelConfig().load(config_path) + original_model = Transformer(config) + + st.save_file(original_model.state_dict(), model_path) + + loaded_config = ModelConfig().load(config_path) + model = Transformer(loaded_config) + model.load_state_dict(st.load_file(model_path)) + + assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight) + assert model.lm_head.weight.data_ptr() != model.embed_tokens.weight.data_ptr() + assert "lm_head.weight" in model.state_dict() + \ No newline at end of file