fix(model): 调整 KV Cache 的维度顺序以匹配新的索引逻辑

This commit is contained in:
ViperEkura 2025-11-19 18:26:15 +08:00
parent e12ed0a72b
commit d9ff662e3a
3 changed files with 7 additions and 7 deletions

View File

@ -206,11 +206,11 @@ class KVCacheManager:
def _initialize(self):
k_cache = torch.zeros(
(self.batch_size, self.num_layers, self.max_len, self.num_heads, self.head_dim),
(self.batch_size, self.max_len, self.num_layers, self.num_heads, self.head_dim),
device=self.device, dtype=self.dtype
)
v_cache = torch.zeros(
(self.batch_size, self.num_layers, self.max_len, self.num_heads, self.head_dim),
(self.batch_size, self.max_len, self.num_layers, self.num_heads, self.head_dim),
device=self.device, dtype=self.dtype
)
self._kv_cache = (k_cache, v_cache)

View File

@ -182,12 +182,12 @@ class GQA(nn.Module):
k_cache, v_cache = kv_cache
# copy to cache
k_cache[:bsz, self.layer_id, start_pos:start_pos + seq_len] = k
v_cache[:bsz, self.layer_id, start_pos:start_pos + seq_len] = v
k_cache[:bsz, start_pos:start_pos + seq_len, self.layer_id] = k
v_cache[:bsz, start_pos:start_pos + seq_len, self.layer_id] = v
# get cache
k = k_cache[:bsz, self.layer_id, :start_pos + seq_len]
v = v_cache[:bsz, self.layer_id, :start_pos + seq_len]
k = k_cache[:bsz, :start_pos + seq_len, self.layer_id]
v = v_cache[:bsz, :start_pos + seq_len, self.layer_id]
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)

View File

@ -28,7 +28,7 @@ class GenerationBenchmark:
def _initialize_kv_cache(self, batch_size: int) -> list:
"""初始化KV缓存"""
config = self.config
shape = (batch_size, config.n_layer, config.m_len, config.n_kvhead, config.n_dim // config.n_head)
shape = (batch_size, config.m_len, config.n_layer, config.n_kvhead, config.n_dim // config.n_head)
k_cache = torch.zeros(shape, device=self.device, dtype=self.dtype)
v_cache = torch.zeros(shape, device=self.device, dtype=self.dtype)
return (k_cache, v_cache)