feat(benchmark): 优化KV缓存初始化逻辑
This commit is contained in:
parent
46b2a0f86f
commit
f2448a5147
14
benchmark.py
14
benchmark.py
|
|
@ -25,21 +25,17 @@ class GenerationBenchmark:
|
|||
self.model = Transformer(config).to(device=device, dtype=dtype)
|
||||
self.model.eval()
|
||||
|
||||
def _initialize_kv_cache(self, batch_size: int, max_len: int) -> list:
|
||||
def _initialize_kv_cache(self, batch_size: int) -> list:
|
||||
"""初始化KV缓存"""
|
||||
kv_cache = []
|
||||
head_dim = self.config.n_dim // self.config.n_head
|
||||
for _ in range(self.config.n_layer):
|
||||
k_cache = torch.zeros(
|
||||
(batch_size, max_len, self.config.n_kvhead, head_dim),
|
||||
(batch_size, config.n_layer, config.m_len, config.n_kvhead, config.n_dim // config.n_head),
|
||||
device=self.device, dtype=self.dtype
|
||||
)
|
||||
v_cache = torch.zeros(
|
||||
(batch_size, max_len, self.config.n_kvhead, head_dim),
|
||||
(batch_size, config.n_layer, config.m_len, config.n_kvhead, config.n_dim // config.n_head),
|
||||
device=self.device, dtype=self.dtype
|
||||
)
|
||||
kv_cache.append((k_cache, v_cache))
|
||||
return kv_cache
|
||||
return (k_cache, v_cache)
|
||||
|
||||
def _prepare_inputs(self, batch_size: int, prompt_length: int, total_length: int):
|
||||
prompt_ids = torch.randint(
|
||||
|
|
@ -121,7 +117,7 @@ class GenerationBenchmark:
|
|||
for trial in range(num_trials):
|
||||
|
||||
prompt_ids, gen_ids = self._prepare_inputs(batch_size, prompt_length, prompt_length + gen_length)
|
||||
kv_cache = self._initialize_kv_cache(batch_size, self.config.m_len)
|
||||
kv_cache = self._initialize_kv_cache(batch_size)
|
||||
_ = self.model(prompt_ids, persistent_key_values=kv_cache, start_pos=0)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
|
|
|||
Loading…
Reference in New Issue