fix(model): 调整 KV Cache 的维度顺序以匹配新的索引逻辑
This commit is contained in:
parent
e12ed0a72b
commit
d9ff662e3a
|
|
@ -206,11 +206,11 @@ class KVCacheManager:
|
|||
|
||||
def _initialize(self):
|
||||
k_cache = torch.zeros(
|
||||
(self.batch_size, self.num_layers, self.max_len, self.num_heads, self.head_dim),
|
||||
(self.batch_size, self.max_len, self.num_layers, self.num_heads, self.head_dim),
|
||||
device=self.device, dtype=self.dtype
|
||||
)
|
||||
v_cache = torch.zeros(
|
||||
(self.batch_size, self.num_layers, self.max_len, self.num_heads, self.head_dim),
|
||||
(self.batch_size, self.max_len, self.num_layers, self.num_heads, self.head_dim),
|
||||
device=self.device, dtype=self.dtype
|
||||
)
|
||||
self._kv_cache = (k_cache, v_cache)
|
||||
|
|
|
|||
|
|
@ -182,12 +182,12 @@ class GQA(nn.Module):
|
|||
k_cache, v_cache = kv_cache
|
||||
|
||||
# copy to cache
|
||||
k_cache[:bsz, self.layer_id, start_pos:start_pos + seq_len] = k
|
||||
v_cache[:bsz, self.layer_id, start_pos:start_pos + seq_len] = v
|
||||
k_cache[:bsz, start_pos:start_pos + seq_len, self.layer_id] = k
|
||||
v_cache[:bsz, start_pos:start_pos + seq_len, self.layer_id] = v
|
||||
|
||||
# get cache
|
||||
k = k_cache[:bsz, self.layer_id, :start_pos + seq_len]
|
||||
v = v_cache[:bsz, self.layer_id, :start_pos + seq_len]
|
||||
k = k_cache[:bsz, :start_pos + seq_len, self.layer_id]
|
||||
v = v_cache[:bsz, :start_pos + seq_len, self.layer_id]
|
||||
|
||||
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
|
||||
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ class GenerationBenchmark:
|
|||
def _initialize_kv_cache(self, batch_size: int) -> list:
|
||||
"""初始化KV缓存"""
|
||||
config = self.config
|
||||
shape = (batch_size, config.n_layer, config.m_len, config.n_kvhead, config.n_dim // config.n_head)
|
||||
shape = (batch_size, config.m_len, config.n_layer, config.n_kvhead, config.n_dim // config.n_head)
|
||||
k_cache = torch.zeros(shape, device=self.device, dtype=self.dtype)
|
||||
v_cache = torch.zeros(shape, device=self.device, dtype=self.dtype)
|
||||
return (k_cache, v_cache)
|
||||
|
|
|
|||
Loading…
Reference in New Issue