diff --git a/tests/test_trainer.py b/tests/test_trainer.py index cdaa856..82dbfb0 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -22,13 +22,13 @@ def test_env(): test_dir = tempfile.mkdtemp() config_path = os.path.join(test_dir, "config.json") - n_dim_choices = [16, 32, 64] - n_head_choices = [4, 8, 16] + n_dim_choices = [8, 16, 32] + n_head_choices = [2, 4] n_dim = int(np.random.choice(n_dim_choices)) n_head = int(np.random.choice(n_head_choices)) - n_kvhead = n_head // 4 - d_ffn = n_dim * 4 + n_kvhead = n_head // 2 + d_ffn = n_dim * 2 config = { "vocab_size": 1000, @@ -80,7 +80,6 @@ def test_env(): target_ids = torch.randint(0, self.vocab_size, (self.max_length,)) loss_mask = build_loss_mask(input_ids, 0, 1) attn_mask = build_attention_mask(input_ids, 2, True) - return { "input_ids": input_ids, @@ -220,12 +219,12 @@ def test_multi_turn_training(test_env): dataset=test_env["multi_turn_dataset"], optimizer=optimizer, checkpoint_dir=test_env["test_dir"], - n_epoch=1, + n_epoch=2, batch_size=2, checkpoint_interval=3, accumulation_steps=1, max_grad_norm=1.0, - random_seed=np.random.randint(1000) + random_seed=int(np.random.randint(1000)) ) schedule_config = CosineScheduleConfig( @@ -289,7 +288,7 @@ def test_gradient_accumulation(test_env): ) trainer = Trainer(model_parameter, train_config, schedule_config) - checkpoint = trainer.train() + trainer.train() assert train_config.accumulation_steps == accumulation_steps @@ -432,7 +431,7 @@ def test_early_stopping_simulation(test_env): dataset=dataset, optimizer=optimizer, checkpoint_dir=test_env["test_dir"], - n_epoch=1, + n_epoch=2, batch_size=2, checkpoint_interval=1, accumulation_steps=1, @@ -459,7 +458,7 @@ def test_early_stopping_simulation(test_env): pass checkpoint = trainer.train(checkpoint) - assert len(checkpoint.loss_list) == 5 + 1 + assert len(checkpoint.loss_list) == 10 + 1 if __name__ == "__main__":