From 0a03a15679a664d7ae7a316a5f5272f00746ec76 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Thu, 2 Oct 2025 14:34:02 +0800 Subject: [PATCH] =?UTF-8?q?test(trainer):=20=E8=B0=83=E6=95=B4=E6=B5=8B?= =?UTF-8?q?=E8=AF=95=E5=8F=82=E6=95=B0=E4=BB=A5=E6=8F=90=E9=AB=98=E8=AE=AD?= =?UTF-8?q?=E7=BB=83=E5=92=8C=E6=96=AD=E8=A8=80=E7=9A=84=E7=A8=B3=E5=AE=9A?= =?UTF-8?q?=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_trainer.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/tests/test_trainer.py b/tests/test_trainer.py index cdaa856..82dbfb0 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -22,13 +22,13 @@ def test_env(): test_dir = tempfile.mkdtemp() config_path = os.path.join(test_dir, "config.json") - n_dim_choices = [16, 32, 64] - n_head_choices = [4, 8, 16] + n_dim_choices = [8, 16, 32] + n_head_choices = [2, 4] n_dim = int(np.random.choice(n_dim_choices)) n_head = int(np.random.choice(n_head_choices)) - n_kvhead = n_head // 4 - d_ffn = n_dim * 4 + n_kvhead = n_head // 2 + d_ffn = n_dim * 2 config = { "vocab_size": 1000, @@ -80,7 +80,6 @@ def test_env(): target_ids = torch.randint(0, self.vocab_size, (self.max_length,)) loss_mask = build_loss_mask(input_ids, 0, 1) attn_mask = build_attention_mask(input_ids, 2, True) - return { "input_ids": input_ids, @@ -220,12 +219,12 @@ def test_multi_turn_training(test_env): dataset=test_env["multi_turn_dataset"], optimizer=optimizer, checkpoint_dir=test_env["test_dir"], - n_epoch=1, + n_epoch=2, batch_size=2, checkpoint_interval=3, accumulation_steps=1, max_grad_norm=1.0, - random_seed=np.random.randint(1000) + random_seed=int(np.random.randint(1000)) ) schedule_config = CosineScheduleConfig( @@ -289,7 +288,7 @@ def test_gradient_accumulation(test_env): ) trainer = Trainer(model_parameter, train_config, schedule_config) - checkpoint = trainer.train() + trainer.train() assert train_config.accumulation_steps == accumulation_steps @@ -432,7 +431,7 @@ def test_early_stopping_simulation(test_env): dataset=dataset, optimizer=optimizer, checkpoint_dir=test_env["test_dir"], - n_epoch=1, + n_epoch=2, batch_size=2, checkpoint_interval=1, accumulation_steps=1, @@ -459,7 +458,7 @@ def test_early_stopping_simulation(test_env): pass checkpoint = trainer.train(checkpoint) - assert len(checkpoint.loss_list) == 5 + 1 + assert len(checkpoint.loss_list) == 10 + 1 if __name__ == "__main__":