From b260f5581dfe560639fb8096c17fa005284f070c Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Wed, 5 Nov 2025 15:44:29 +0800 Subject: [PATCH] =?UTF-8?q?fix(benchmark):=20=E4=BC=98=E5=8C=96=20KV=20?= =?UTF-8?q?=E7=BC=93=E5=AD=98=E5=88=9D=E5=A7=8B=E5=8C=96=E5=B9=B6=E6=9B=B4?= =?UTF-8?q?=E6=AD=A3=E5=9F=BA=E5=87=86=E6=B5=8B=E8=AF=95=E7=B1=BB=E5=9E=8B?= =?UTF-8?q?=E6=A0=87=E8=AF=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- benchmark.py | 29 ++++++++--------------------- 1 file changed, 8 insertions(+), 21 deletions(-) diff --git a/benchmark.py b/benchmark.py index 36b465e..f8358e3 100644 --- a/benchmark.py +++ b/benchmark.py @@ -27,14 +27,10 @@ class GenerationBenchmark: def _initialize_kv_cache(self, batch_size: int) -> list: """初始化KV缓存""" - k_cache = torch.zeros( - (batch_size, config.n_layer, config.m_len, config.n_kvhead, config.n_dim // config.n_head), - device=self.device, dtype=self.dtype - ) - v_cache = torch.zeros( - (batch_size, config.n_layer, config.m_len, config.n_kvhead, config.n_dim // config.n_head), - device=self.device, dtype=self.dtype - ) + config = self.config + shape = (batch_size, config.n_layer, config.m_len, config.n_kvhead, config.n_dim // config.n_head) + k_cache = torch.zeros(shape, device=self.device, dtype=self.dtype) + v_cache = torch.zeros(shape, device=self.device, dtype=self.dtype) return (k_cache, v_cache) def _prepare_inputs(self, batch_size: int, prompt_length: int, total_length: int): @@ -148,7 +144,7 @@ class GenerationBenchmark: total_time=total_time, tokens_per_second=total_tokens / total_time, metadata={ - "benchmark_type": "generation", + "benchmark_type": "decoding", "batch_size": batch_size, "prompt_length": prompt_length, "gen_length": gen_length, @@ -169,7 +165,7 @@ def print_benchmark_result(result: BenchmarkResult): if benchmark_type == "prefill": print(f"Batch Size: {result.metadata['batch_size']} | Prompt Length: {result.metadata['prompt_length']}") - elif benchmark_type == "generation": + elif benchmark_type == "decoding": print(f"Batch Size: {result.metadata['batch_size']} | Gen Length: {result.metadata['gen_length']}") print(f"Device: {result.metadata['device']} | Dtype: {result.metadata['dtype']}") @@ -194,18 +190,9 @@ if __name__ == "__main__": print("Running Transformer Generation Benchmark") print("=" * 80) - prefill_result = benchmark.run_prefill_benchmark( - batch_size=4, - prompt_length=512, - num_trials=5 - ) + prefill_result = benchmark.run_prefill_benchmark(batch_size=4, prompt_length=512, num_trials=5) print_benchmark_result(prefill_result) - gen_result = benchmark.run_decoding_benchmark( - batch_size=4, - prompt_length=512, - gen_length=128, - num_trials=5 - ) + gen_result = benchmark.run_decoding_benchmark(batch_size=4, prompt_length=512, gen_length=128, num_trials=5) print_benchmark_result(gen_result) \ No newline at end of file