From e0510053342221455700f2bbefb7327c60258063 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Mon, 20 Oct 2025 13:52:31 +0800 Subject: [PATCH] =?UTF-8?q?test(test=5Fmodule):=20=E6=9B=B4=E6=96=B0?= =?UTF-8?q?=E6=B5=8B=E8=AF=95=E7=94=A8=E4=BE=8B=E4=BB=A5=E4=BD=BF=E7=94=A8?= =?UTF-8?q?=E6=96=B0=E7=9A=84generate=5Fiterator=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_module.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/tests/test_module.py b/tests/test_module.py index 9b341ab..36010c0 100644 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -101,7 +101,16 @@ def test_generator_core(test_env): test_env["transformer_config"] ) generator = GeneratorCore(parameter) - logits, incr = generator.compute_logits(torch.randint(0, test_env["transformer_config"].vocab_size, (4, 10))) + input_ids = torch.randint(0, test_env["transformer_config"].vocab_size, (4, 10)) + next_token_id, cache_increase = generator.generate_iterator( + input_ids=input_ids, + temperature=0.8, + top_k=50, + top_p=0.95, + attn_mask=None, + kv_caches=None, + start_pos=0 + ) - assert logits.shape == (4, test_env["transformer_config"].vocab_size) - assert incr == 10 \ No newline at end of file + assert next_token_id.shape == (4, 1) + assert cache_increase == 10