From f2448a51470945799fbe7b915e70d136d29e4819 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Wed, 29 Oct 2025 12:41:32 +0800 Subject: [PATCH] =?UTF-8?q?feat(benchmark):=20=E4=BC=98=E5=8C=96KV?= =?UTF-8?q?=E7=BC=93=E5=AD=98=E5=88=9D=E5=A7=8B=E5=8C=96=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- benchmark.py | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/benchmark.py b/benchmark.py index f2db16e..36b465e 100644 --- a/benchmark.py +++ b/benchmark.py @@ -25,21 +25,17 @@ class GenerationBenchmark: self.model = Transformer(config).to(device=device, dtype=dtype) self.model.eval() - def _initialize_kv_cache(self, batch_size: int, max_len: int) -> list: + def _initialize_kv_cache(self, batch_size: int) -> list: """初始化KV缓存""" - kv_cache = [] - head_dim = self.config.n_dim // self.config.n_head - for _ in range(self.config.n_layer): - k_cache = torch.zeros( - (batch_size, max_len, self.config.n_kvhead, head_dim), - device=self.device, dtype=self.dtype - ) - v_cache = torch.zeros( - (batch_size, max_len, self.config.n_kvhead, head_dim), - device=self.device, dtype=self.dtype - ) - kv_cache.append((k_cache, v_cache)) - return kv_cache + k_cache = torch.zeros( + (batch_size, config.n_layer, config.m_len, config.n_kvhead, config.n_dim // config.n_head), + device=self.device, dtype=self.dtype + ) + v_cache = torch.zeros( + (batch_size, config.n_layer, config.m_len, config.n_kvhead, config.n_dim // config.n_head), + device=self.device, dtype=self.dtype + ) + return (k_cache, v_cache) def _prepare_inputs(self, batch_size: int, prompt_length: int, total_length: int): prompt_ids = torch.randint( @@ -121,7 +117,7 @@ class GenerationBenchmark: for trial in range(num_trials): prompt_ids, gen_ids = self._prepare_inputs(batch_size, prompt_length, prompt_length + gen_length) - kv_cache = self._initialize_kv_cache(batch_size, self.config.m_len) + kv_cache = self._initialize_kv_cache(batch_size) _ = self.model(prompt_ids, persistent_key_values=kv_cache, start_pos=0) torch.cuda.synchronize()