diff --git a/benchmark.py b/benchmark.py index 36b465e..f8358e3 100644 --- a/benchmark.py +++ b/benchmark.py @@ -27,14 +27,10 @@ class GenerationBenchmark: def _initialize_kv_cache(self, batch_size: int) -> list: """初始化KV缓存""" - k_cache = torch.zeros( - (batch_size, config.n_layer, config.m_len, config.n_kvhead, config.n_dim // config.n_head), - device=self.device, dtype=self.dtype - ) - v_cache = torch.zeros( - (batch_size, config.n_layer, config.m_len, config.n_kvhead, config.n_dim // config.n_head), - device=self.device, dtype=self.dtype - ) + config = self.config + shape = (batch_size, config.n_layer, config.m_len, config.n_kvhead, config.n_dim // config.n_head) + k_cache = torch.zeros(shape, device=self.device, dtype=self.dtype) + v_cache = torch.zeros(shape, device=self.device, dtype=self.dtype) return (k_cache, v_cache) def _prepare_inputs(self, batch_size: int, prompt_length: int, total_length: int): @@ -148,7 +144,7 @@ class GenerationBenchmark: total_time=total_time, tokens_per_second=total_tokens / total_time, metadata={ - "benchmark_type": "generation", + "benchmark_type": "decoding", "batch_size": batch_size, "prompt_length": prompt_length, "gen_length": gen_length, @@ -169,7 +165,7 @@ def print_benchmark_result(result: BenchmarkResult): if benchmark_type == "prefill": print(f"Batch Size: {result.metadata['batch_size']} | Prompt Length: {result.metadata['prompt_length']}") - elif benchmark_type == "generation": + elif benchmark_type == "decoding": print(f"Batch Size: {result.metadata['batch_size']} | Gen Length: {result.metadata['gen_length']}") print(f"Device: {result.metadata['device']} | Dtype: {result.metadata['dtype']}") @@ -194,18 +190,9 @@ if __name__ == "__main__": print("Running Transformer Generation Benchmark") print("=" * 80) - prefill_result = benchmark.run_prefill_benchmark( - batch_size=4, - prompt_length=512, - num_trials=5 - ) + prefill_result = benchmark.run_prefill_benchmark(batch_size=4, prompt_length=512, num_trials=5) print_benchmark_result(prefill_result) - gen_result = benchmark.run_decoding_benchmark( - batch_size=4, - prompt_length=512, - gen_length=128, - num_trials=5 - ) + gen_result = benchmark.run_decoding_benchmark(batch_size=4, prompt_length=512, gen_length=128, num_trials=5) print_benchmark_result(gen_result) \ No newline at end of file