feat(benchmark): 优化KV缓存初始化逻辑
This commit is contained in:
parent
46b2a0f86f
commit
f2448a5147
26
benchmark.py
26
benchmark.py
|
|
@ -25,21 +25,17 @@ class GenerationBenchmark:
|
||||||
self.model = Transformer(config).to(device=device, dtype=dtype)
|
self.model = Transformer(config).to(device=device, dtype=dtype)
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
def _initialize_kv_cache(self, batch_size: int, max_len: int) -> list:
|
def _initialize_kv_cache(self, batch_size: int) -> list:
|
||||||
"""初始化KV缓存"""
|
"""初始化KV缓存"""
|
||||||
kv_cache = []
|
k_cache = torch.zeros(
|
||||||
head_dim = self.config.n_dim // self.config.n_head
|
(batch_size, config.n_layer, config.m_len, config.n_kvhead, config.n_dim // config.n_head),
|
||||||
for _ in range(self.config.n_layer):
|
device=self.device, dtype=self.dtype
|
||||||
k_cache = torch.zeros(
|
)
|
||||||
(batch_size, max_len, self.config.n_kvhead, head_dim),
|
v_cache = torch.zeros(
|
||||||
device=self.device, dtype=self.dtype
|
(batch_size, config.n_layer, config.m_len, config.n_kvhead, config.n_dim // config.n_head),
|
||||||
)
|
device=self.device, dtype=self.dtype
|
||||||
v_cache = torch.zeros(
|
)
|
||||||
(batch_size, max_len, self.config.n_kvhead, head_dim),
|
return (k_cache, v_cache)
|
||||||
device=self.device, dtype=self.dtype
|
|
||||||
)
|
|
||||||
kv_cache.append((k_cache, v_cache))
|
|
||||||
return kv_cache
|
|
||||||
|
|
||||||
def _prepare_inputs(self, batch_size: int, prompt_length: int, total_length: int):
|
def _prepare_inputs(self, batch_size: int, prompt_length: int, total_length: int):
|
||||||
prompt_ids = torch.randint(
|
prompt_ids = torch.randint(
|
||||||
|
|
@ -121,7 +117,7 @@ class GenerationBenchmark:
|
||||||
for trial in range(num_trials):
|
for trial in range(num_trials):
|
||||||
|
|
||||||
prompt_ids, gen_ids = self._prepare_inputs(batch_size, prompt_length, prompt_length + gen_length)
|
prompt_ids, gen_ids = self._prepare_inputs(batch_size, prompt_length, prompt_length + gen_length)
|
||||||
kv_cache = self._initialize_kv_cache(batch_size, self.config.m_len)
|
kv_cache = self._initialize_kv_cache(batch_size)
|
||||||
_ = self.model(prompt_ids, persistent_key_values=kv_cache, start_pos=0)
|
_ = self.model(prompt_ids, persistent_key_values=kv_cache, start_pos=0)
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue