From d9ff662e3a1fbbdaca5f36671d271614de38af3c Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Wed, 19 Nov 2025 18:26:15 +0800 Subject: [PATCH] =?UTF-8?q?fix(model):=20=E8=B0=83=E6=95=B4=20KV=20Cache?= =?UTF-8?q?=20=E7=9A=84=E7=BB=B4=E5=BA=A6=E9=A1=BA=E5=BA=8F=E4=BB=A5?= =?UTF-8?q?=E5=8C=B9=E9=85=8D=E6=96=B0=E7=9A=84=E7=B4=A2=E5=BC=95=E9=80=BB?= =?UTF-8?q?=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/inference/core.py | 4 ++-- khaosz/model/module.py | 8 ++++---- tools/benchmark.py | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/khaosz/inference/core.py b/khaosz/inference/core.py index f8c182b..2bd6e51 100644 --- a/khaosz/inference/core.py +++ b/khaosz/inference/core.py @@ -206,11 +206,11 @@ class KVCacheManager: def _initialize(self): k_cache = torch.zeros( - (self.batch_size, self.num_layers, self.max_len, self.num_heads, self.head_dim), + (self.batch_size, self.max_len, self.num_layers, self.num_heads, self.head_dim), device=self.device, dtype=self.dtype ) v_cache = torch.zeros( - (self.batch_size, self.num_layers, self.max_len, self.num_heads, self.head_dim), + (self.batch_size, self.max_len, self.num_layers, self.num_heads, self.head_dim), device=self.device, dtype=self.dtype ) self._kv_cache = (k_cache, v_cache) diff --git a/khaosz/model/module.py b/khaosz/model/module.py index 748c53e..31c26ab 100644 --- a/khaosz/model/module.py +++ b/khaosz/model/module.py @@ -182,12 +182,12 @@ class GQA(nn.Module): k_cache, v_cache = kv_cache # copy to cache - k_cache[:bsz, self.layer_id, start_pos:start_pos + seq_len] = k - v_cache[:bsz, self.layer_id, start_pos:start_pos + seq_len] = v + k_cache[:bsz, start_pos:start_pos + seq_len, self.layer_id] = k + v_cache[:bsz, start_pos:start_pos + seq_len, self.layer_id] = v # get cache - k = k_cache[:bsz, self.layer_id, :start_pos + seq_len] - v = v_cache[:bsz, self.layer_id, :start_pos + seq_len] + k = k_cache[:bsz, :start_pos + seq_len, self.layer_id] + v = v_cache[:bsz, :start_pos + seq_len, self.layer_id] k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep) diff --git a/tools/benchmark.py b/tools/benchmark.py index 683a357..d588620 100644 --- a/tools/benchmark.py +++ b/tools/benchmark.py @@ -28,7 +28,7 @@ class GenerationBenchmark: def _initialize_kv_cache(self, batch_size: int) -> list: """初始化KV缓存""" config = self.config - shape = (batch_size, config.n_layer, config.m_len, config.n_kvhead, config.n_dim // config.n_head) + shape = (batch_size, config.m_len, config.n_layer, config.n_kvhead, config.n_dim // config.n_head) k_cache = torch.zeros(shape, device=self.device, dtype=self.dtype) v_cache = torch.zeros(shape, device=self.device, dtype=self.dtype) return (k_cache, v_cache)