feat(benchmark): 优化KV缓存初始化逻辑

This commit is contained in:
ViperEkura 2025-10-29 12:41:32 +08:00
parent 46b2a0f86f
commit f2448a5147
1 changed files with 11 additions and 15 deletions

View File

@ -25,21 +25,17 @@ class GenerationBenchmark:
self.model = Transformer(config).to(device=device, dtype=dtype)
self.model.eval()
def _initialize_kv_cache(self, batch_size: int, max_len: int) -> list:
def _initialize_kv_cache(self, batch_size: int) -> list:
"""初始化KV缓存"""
kv_cache = []
head_dim = self.config.n_dim // self.config.n_head
for _ in range(self.config.n_layer):
k_cache = torch.zeros(
(batch_size, max_len, self.config.n_kvhead, head_dim),
device=self.device, dtype=self.dtype
)
v_cache = torch.zeros(
(batch_size, max_len, self.config.n_kvhead, head_dim),
device=self.device, dtype=self.dtype
)
kv_cache.append((k_cache, v_cache))
return kv_cache
k_cache = torch.zeros(
(batch_size, config.n_layer, config.m_len, config.n_kvhead, config.n_dim // config.n_head),
device=self.device, dtype=self.dtype
)
v_cache = torch.zeros(
(batch_size, config.n_layer, config.m_len, config.n_kvhead, config.n_dim // config.n_head),
device=self.device, dtype=self.dtype
)
return (k_cache, v_cache)
def _prepare_inputs(self, batch_size: int, prompt_length: int, total_length: int):
prompt_ids = torch.randint(
@ -121,7 +117,7 @@ class GenerationBenchmark:
for trial in range(num_trials):
prompt_ids, gen_ids = self._prepare_inputs(batch_size, prompt_length, prompt_length + gen_length)
kv_cache = self._initialize_kv_cache(batch_size, self.config.m_len)
kv_cache = self._initialize_kv_cache(batch_size)
_ = self.model(prompt_ids, persistent_key_values=kv_cache, start_pos=0)
torch.cuda.synchronize()