test(test_module): 更新测试用例以使用新的generate_iterator接口

This commit is contained in:
ViperEkura 2025-10-20 13:52:31 +08:00
parent 0db046f8d9
commit e051005334
1 changed files with 12 additions and 3 deletions

View File

@ -101,7 +101,16 @@ def test_generator_core(test_env):
test_env["transformer_config"]
)
generator = GeneratorCore(parameter)
logits, incr = generator.compute_logits(torch.randint(0, test_env["transformer_config"].vocab_size, (4, 10)))
input_ids = torch.randint(0, test_env["transformer_config"].vocab_size, (4, 10))
next_token_id, cache_increase = generator.generate_iterator(
input_ids=input_ids,
temperature=0.8,
top_k=50,
top_p=0.95,
attn_mask=None,
kv_caches=None,
start_pos=0
)
assert logits.shape == (4, test_env["transformer_config"].vocab_size)
assert incr == 10
assert next_token_id.shape == (4, 1)
assert cache_increase == 10