test(test_module): 更新测试用例以使用新的generate_iterator接口
This commit is contained in:
parent
0db046f8d9
commit
e051005334
|
|
@ -101,7 +101,16 @@ def test_generator_core(test_env):
|
||||||
test_env["transformer_config"]
|
test_env["transformer_config"]
|
||||||
)
|
)
|
||||||
generator = GeneratorCore(parameter)
|
generator = GeneratorCore(parameter)
|
||||||
logits, incr = generator.compute_logits(torch.randint(0, test_env["transformer_config"].vocab_size, (4, 10)))
|
input_ids = torch.randint(0, test_env["transformer_config"].vocab_size, (4, 10))
|
||||||
|
next_token_id, cache_increase = generator.generate_iterator(
|
||||||
|
input_ids=input_ids,
|
||||||
|
temperature=0.8,
|
||||||
|
top_k=50,
|
||||||
|
top_p=0.95,
|
||||||
|
attn_mask=None,
|
||||||
|
kv_caches=None,
|
||||||
|
start_pos=0
|
||||||
|
)
|
||||||
|
|
||||||
assert logits.shape == (4, test_env["transformer_config"].vocab_size)
|
assert next_token_id.shape == (4, 1)
|
||||||
assert incr == 10
|
assert cache_increase == 10
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue