test(test_module): 更新测试用例以使用新的generate_iterator接口
This commit is contained in:
parent
0db046f8d9
commit
e051005334
|
|
@ -101,7 +101,16 @@ def test_generator_core(test_env):
|
|||
test_env["transformer_config"]
|
||||
)
|
||||
generator = GeneratorCore(parameter)
|
||||
logits, incr = generator.compute_logits(torch.randint(0, test_env["transformer_config"].vocab_size, (4, 10)))
|
||||
input_ids = torch.randint(0, test_env["transformer_config"].vocab_size, (4, 10))
|
||||
next_token_id, cache_increase = generator.generate_iterator(
|
||||
input_ids=input_ids,
|
||||
temperature=0.8,
|
||||
top_k=50,
|
||||
top_p=0.95,
|
||||
attn_mask=None,
|
||||
kv_caches=None,
|
||||
start_pos=0
|
||||
)
|
||||
|
||||
assert logits.shape == (4, test_env["transformer_config"].vocab_size)
|
||||
assert incr == 10
|
||||
assert next_token_id.shape == (4, 1)
|
||||
assert cache_increase == 10
|
||||
|
|
|
|||
Loading…
Reference in New Issue