diff --git a/benchmark.py b/benchmark.py index f2db16e..36b465e 100644 --- a/benchmark.py +++ b/benchmark.py @@ -25,21 +25,17 @@ class GenerationBenchmark: self.model = Transformer(config).to(device=device, dtype=dtype) self.model.eval() - def _initialize_kv_cache(self, batch_size: int, max_len: int) -> list: + def _initialize_kv_cache(self, batch_size: int) -> list: """初始化KV缓存""" - kv_cache = [] - head_dim = self.config.n_dim // self.config.n_head - for _ in range(self.config.n_layer): - k_cache = torch.zeros( - (batch_size, max_len, self.config.n_kvhead, head_dim), - device=self.device, dtype=self.dtype - ) - v_cache = torch.zeros( - (batch_size, max_len, self.config.n_kvhead, head_dim), - device=self.device, dtype=self.dtype - ) - kv_cache.append((k_cache, v_cache)) - return kv_cache + k_cache = torch.zeros( + (batch_size, config.n_layer, config.m_len, config.n_kvhead, config.n_dim // config.n_head), + device=self.device, dtype=self.dtype + ) + v_cache = torch.zeros( + (batch_size, config.n_layer, config.m_len, config.n_kvhead, config.n_dim // config.n_head), + device=self.device, dtype=self.dtype + ) + return (k_cache, v_cache) def _prepare_inputs(self, batch_size: int, prompt_length: int, total_length: int): prompt_ids = torch.randint( @@ -121,7 +117,7 @@ class GenerationBenchmark: for trial in range(num_trials): prompt_ids, gen_ids = self._prepare_inputs(batch_size, prompt_length, prompt_length + gen_length) - kv_cache = self._initialize_kv_cache(batch_size, self.config.m_len) + kv_cache = self._initialize_kv_cache(batch_size) _ = self.model(prompt_ids, persistent_key_values=kv_cache, start_pos=0) torch.cuda.synchronize()