diff --git a/tests/test_module.py b/tests/test_module.py index 9b341ab..36010c0 100644 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -101,7 +101,16 @@ def test_generator_core(test_env): test_env["transformer_config"] ) generator = GeneratorCore(parameter) - logits, incr = generator.compute_logits(torch.randint(0, test_env["transformer_config"].vocab_size, (4, 10))) + input_ids = torch.randint(0, test_env["transformer_config"].vocab_size, (4, 10)) + next_token_id, cache_increase = generator.generate_iterator( + input_ids=input_ids, + temperature=0.8, + top_k=50, + top_p=0.95, + attn_mask=None, + kv_caches=None, + start_pos=0 + ) - assert logits.shape == (4, test_env["transformer_config"].vocab_size) - assert incr == 10 \ No newline at end of file + assert next_token_id.shape == (4, 1) + assert cache_increase == 10