diff --git a/khaosz/inference/core.py b/khaosz/inference/core.py index f8c182b..2bd6e51 100644 --- a/khaosz/inference/core.py +++ b/khaosz/inference/core.py @@ -206,11 +206,11 @@ class KVCacheManager: def _initialize(self): k_cache = torch.zeros( - (self.batch_size, self.num_layers, self.max_len, self.num_heads, self.head_dim), + (self.batch_size, self.max_len, self.num_layers, self.num_heads, self.head_dim), device=self.device, dtype=self.dtype ) v_cache = torch.zeros( - (self.batch_size, self.num_layers, self.max_len, self.num_heads, self.head_dim), + (self.batch_size, self.max_len, self.num_layers, self.num_heads, self.head_dim), device=self.device, dtype=self.dtype ) self._kv_cache = (k_cache, v_cache) diff --git a/khaosz/model/module.py b/khaosz/model/module.py index 748c53e..31c26ab 100644 --- a/khaosz/model/module.py +++ b/khaosz/model/module.py @@ -182,12 +182,12 @@ class GQA(nn.Module): k_cache, v_cache = kv_cache # copy to cache - k_cache[:bsz, self.layer_id, start_pos:start_pos + seq_len] = k - v_cache[:bsz, self.layer_id, start_pos:start_pos + seq_len] = v + k_cache[:bsz, start_pos:start_pos + seq_len, self.layer_id] = k + v_cache[:bsz, start_pos:start_pos + seq_len, self.layer_id] = v # get cache - k = k_cache[:bsz, self.layer_id, :start_pos + seq_len] - v = v_cache[:bsz, self.layer_id, :start_pos + seq_len] + k = k_cache[:bsz, :start_pos + seq_len, self.layer_id] + v = v_cache[:bsz, :start_pos + seq_len, self.layer_id] k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep) diff --git a/tools/benchmark.py b/tools/benchmark.py index 683a357..d588620 100644 --- a/tools/benchmark.py +++ b/tools/benchmark.py @@ -28,7 +28,7 @@ class GenerationBenchmark: def _initialize_kv_cache(self, batch_size: int) -> list: """初始化KV缓存""" config = self.config - shape = (batch_size, config.n_layer, config.m_len, config.n_kvhead, config.n_dim // config.n_head) + shape = (batch_size, config.m_len, config.n_layer, config.n_kvhead, config.n_dim // config.n_head) k_cache = torch.zeros(shape, device=self.device, dtype=self.dtype) v_cache = torch.zeros(shape, device=self.device, dtype=self.dtype) return (k_cache, v_cache)