feat(benchmark): 优化KV缓存初始化逻辑

This commit is contained in:
ViperEkura 2025-10-29 12:41:32 +08:00
parent 46b2a0f86f
commit f2448a5147
1 changed files with 11 additions and 15 deletions

View File

@ -25,21 +25,17 @@ class GenerationBenchmark:
self.model = Transformer(config).to(device=device, dtype=dtype) self.model = Transformer(config).to(device=device, dtype=dtype)
self.model.eval() self.model.eval()
def _initialize_kv_cache(self, batch_size: int, max_len: int) -> list: def _initialize_kv_cache(self, batch_size: int) -> list:
"""初始化KV缓存""" """初始化KV缓存"""
kv_cache = [] k_cache = torch.zeros(
head_dim = self.config.n_dim // self.config.n_head (batch_size, config.n_layer, config.m_len, config.n_kvhead, config.n_dim // config.n_head),
for _ in range(self.config.n_layer): device=self.device, dtype=self.dtype
k_cache = torch.zeros( )
(batch_size, max_len, self.config.n_kvhead, head_dim), v_cache = torch.zeros(
device=self.device, dtype=self.dtype (batch_size, config.n_layer, config.m_len, config.n_kvhead, config.n_dim // config.n_head),
) device=self.device, dtype=self.dtype
v_cache = torch.zeros( )
(batch_size, max_len, self.config.n_kvhead, head_dim), return (k_cache, v_cache)
device=self.device, dtype=self.dtype
)
kv_cache.append((k_cache, v_cache))
return kv_cache
def _prepare_inputs(self, batch_size: int, prompt_length: int, total_length: int): def _prepare_inputs(self, batch_size: int, prompt_length: int, total_length: int):
prompt_ids = torch.randint( prompt_ids = torch.randint(
@ -121,7 +117,7 @@ class GenerationBenchmark:
for trial in range(num_trials): for trial in range(num_trials):
prompt_ids, gen_ids = self._prepare_inputs(batch_size, prompt_length, prompt_length + gen_length) prompt_ids, gen_ids = self._prepare_inputs(batch_size, prompt_length, prompt_length + gen_length)
kv_cache = self._initialize_kv_cache(batch_size, self.config.m_len) kv_cache = self._initialize_kv_cache(batch_size)
_ = self.model(prompt_ids, persistent_key_values=kv_cache, start_pos=0) _ = self.model(prompt_ids, persistent_key_values=kv_cache, start_pos=0)
torch.cuda.synchronize() torch.cuda.synchronize()