fix(benchmark): 优化 KV 缓存初始化并更正基准测试类型标识
This commit is contained in:
parent
0a754e3341
commit
b260f5581d
29
benchmark.py
29
benchmark.py
|
|
@ -27,14 +27,10 @@ class GenerationBenchmark:
|
||||||
|
|
||||||
def _initialize_kv_cache(self, batch_size: int) -> list:
|
def _initialize_kv_cache(self, batch_size: int) -> list:
|
||||||
"""初始化KV缓存"""
|
"""初始化KV缓存"""
|
||||||
k_cache = torch.zeros(
|
config = self.config
|
||||||
(batch_size, config.n_layer, config.m_len, config.n_kvhead, config.n_dim // config.n_head),
|
shape = (batch_size, config.n_layer, config.m_len, config.n_kvhead, config.n_dim // config.n_head)
|
||||||
device=self.device, dtype=self.dtype
|
k_cache = torch.zeros(shape, device=self.device, dtype=self.dtype)
|
||||||
)
|
v_cache = torch.zeros(shape, device=self.device, dtype=self.dtype)
|
||||||
v_cache = torch.zeros(
|
|
||||||
(batch_size, config.n_layer, config.m_len, config.n_kvhead, config.n_dim // config.n_head),
|
|
||||||
device=self.device, dtype=self.dtype
|
|
||||||
)
|
|
||||||
return (k_cache, v_cache)
|
return (k_cache, v_cache)
|
||||||
|
|
||||||
def _prepare_inputs(self, batch_size: int, prompt_length: int, total_length: int):
|
def _prepare_inputs(self, batch_size: int, prompt_length: int, total_length: int):
|
||||||
|
|
@ -148,7 +144,7 @@ class GenerationBenchmark:
|
||||||
total_time=total_time,
|
total_time=total_time,
|
||||||
tokens_per_second=total_tokens / total_time,
|
tokens_per_second=total_tokens / total_time,
|
||||||
metadata={
|
metadata={
|
||||||
"benchmark_type": "generation",
|
"benchmark_type": "decoding",
|
||||||
"batch_size": batch_size,
|
"batch_size": batch_size,
|
||||||
"prompt_length": prompt_length,
|
"prompt_length": prompt_length,
|
||||||
"gen_length": gen_length,
|
"gen_length": gen_length,
|
||||||
|
|
@ -169,7 +165,7 @@ def print_benchmark_result(result: BenchmarkResult):
|
||||||
|
|
||||||
if benchmark_type == "prefill":
|
if benchmark_type == "prefill":
|
||||||
print(f"Batch Size: {result.metadata['batch_size']} | Prompt Length: {result.metadata['prompt_length']}")
|
print(f"Batch Size: {result.metadata['batch_size']} | Prompt Length: {result.metadata['prompt_length']}")
|
||||||
elif benchmark_type == "generation":
|
elif benchmark_type == "decoding":
|
||||||
print(f"Batch Size: {result.metadata['batch_size']} | Gen Length: {result.metadata['gen_length']}")
|
print(f"Batch Size: {result.metadata['batch_size']} | Gen Length: {result.metadata['gen_length']}")
|
||||||
|
|
||||||
print(f"Device: {result.metadata['device']} | Dtype: {result.metadata['dtype']}")
|
print(f"Device: {result.metadata['device']} | Dtype: {result.metadata['dtype']}")
|
||||||
|
|
@ -194,18 +190,9 @@ if __name__ == "__main__":
|
||||||
print("Running Transformer Generation Benchmark")
|
print("Running Transformer Generation Benchmark")
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
|
|
||||||
prefill_result = benchmark.run_prefill_benchmark(
|
prefill_result = benchmark.run_prefill_benchmark(batch_size=4, prompt_length=512, num_trials=5)
|
||||||
batch_size=4,
|
|
||||||
prompt_length=512,
|
|
||||||
num_trials=5
|
|
||||||
)
|
|
||||||
print_benchmark_result(prefill_result)
|
print_benchmark_result(prefill_result)
|
||||||
|
|
||||||
gen_result = benchmark.run_decoding_benchmark(
|
gen_result = benchmark.run_decoding_benchmark(batch_size=4, prompt_length=512, gen_length=128, num_trials=5)
|
||||||
batch_size=4,
|
|
||||||
prompt_length=512,
|
|
||||||
gen_length=128,
|
|
||||||
num_trials=5
|
|
||||||
)
|
|
||||||
print_benchmark_result(gen_result)
|
print_benchmark_result(gen_result)
|
||||||
|
|
||||||
Loading…
Reference in New Issue