From 38b2725cd1f558bf85de7426c0211d6d1075ff7f Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Wed, 29 Oct 2025 12:01:28 +0800 Subject: [PATCH] =?UTF-8?q?feat(KVCacheManager):=20=E4=BC=98=E5=8C=96KV?= =?UTF-8?q?=E7=BC=93=E5=AD=98=E7=BB=93=E6=9E=84=E4=B8=BA=E5=85=83=E7=BB=84?= =?UTF-8?q?=E5=BD=A2=E5=BC=8F=E4=BB=A5=E6=8F=90=E5=8D=87=E6=80=A7=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/inference/core.py | 37 ++++++++++++++----------------------- khaosz/model/module.py | 16 +++++++++------- khaosz/model/transformer.py | 19 ++++++------------- 3 files changed, 29 insertions(+), 43 deletions(-) diff --git a/khaosz/inference/core.py b/khaosz/inference/core.py index c6b4b8d..f019f86 100644 --- a/khaosz/inference/core.py +++ b/khaosz/inference/core.py @@ -173,34 +173,25 @@ class KVCacheManager: self.device = device self.dtype = dtype - self._kv_cache: List[Tuple[Tensor, Tensor]] = None + self._kv_cache: Tuple[Tensor, Tensor] = None self._seq_mask: Tensor = None self._initialize() def _initialize(self): - self._kv_cache = [] - for _ in range(self.num_layers): - k_cache = torch.zeros( - (self.batch_size, self.max_len, self.num_heads, self.head_dim), - device=self.device, dtype=self.dtype - ) - v_cache = torch.zeros( - (self.batch_size, self.max_len, self.num_heads, self.head_dim), - device=self.device, dtype=self.dtype - ) - self._kv_cache.append((k_cache, v_cache)) - - self._seq_mask = torch.ones( - (self.batch_size, self.max_len), - device=self.device, dtype=torch.bool + k_cache = torch.zeros( + (self.batch_size, self.num_layers, self.max_len, self.num_heads, self.head_dim), + device=self.device, dtype=self.dtype ) + v_cache = torch.zeros( + (self.batch_size, self.num_layers, self.max_len, self.num_heads, self.head_dim), + device=self.device, dtype=self.dtype + ) + self._kv_cache = (k_cache, v_cache) + self._seq_mask = torch.ones((self.batch_size, self.max_len), device=self.device, dtype=torch.bool) - def update(self, active_mask: Tensor): - for i in range(self.num_layers): - k_cache, v_cache = self._kv_cache[i] - new_k_cache, new_v_cache = k_cache[active_mask], v_cache[active_mask] - self._kv_cache[i] = (new_k_cache, new_v_cache) - + def update(self, active_mask: Tensor): + k_cache, v_cache = self._kv_cache + self._kv_cache = (k_cache[active_mask], v_cache[active_mask]) self._seq_mask = self._seq_mask[active_mask] def reset(self, full_reset=False): @@ -215,7 +206,7 @@ class KVCacheManager: bool_mask = (input_ids != pad_id) self._seq_mask[: batch_size, : seq_len] = bool_mask - def get_kvcache(self) -> List[Tuple[Tensor, Tensor]]: + def get_kvcache(self) -> Tuple[Tensor, Tensor]: return self._kv_cache def get_seq_mask(self) -> Tensor: diff --git a/khaosz/model/module.py b/khaosz/model/module.py index bba68ce..643dd0d 100644 --- a/khaosz/model/module.py +++ b/khaosz/model/module.py @@ -117,12 +117,14 @@ class GQA(nn.Module): n_dim: int, n_head: int, n_kvhead: int, + layer_id: int ): super().__init__() assert n_dim % n_head == 0 assert n_head % n_kvhead == 0 self.head_dim = n_dim // n_head + self.layer_id = layer_id self.n_dim = n_dim self.n_heads = n_head self.n_kvheads = n_kvhead @@ -150,14 +152,14 @@ class GQA(nn.Module): if kv_cache is not None: k_cache, v_cache = kv_cache - + # copy to cache - k_cache[:bsz, start_pos:start_pos + seq_len] = k - v_cache[:bsz, start_pos:start_pos + seq_len] = v + k_cache[:bsz, self.layer_id, start_pos:start_pos + seq_len] = k + v_cache[:bsz, self.layer_id, start_pos:start_pos + seq_len] = v # get cache - k = k_cache[:bsz, :start_pos + seq_len] - v = v_cache[:bsz, :start_pos + seq_len] + k = k_cache[:bsz, self.layer_id, :start_pos + seq_len] + v = v_cache[:bsz, self.layer_id, :start_pos + seq_len] k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep) @@ -175,9 +177,9 @@ class GQA(nn.Module): class DecoderBlock(nn.Module): - def __init__(self, n_dim, n_head, d_ffn, n_kvhead, norm_eps): + def __init__(self, n_dim, n_head, d_ffn, n_kvhead, norm_eps, layer_id): super().__init__() - self.attention = GQA(n_dim, n_head, n_kvhead) + self.attention = GQA(n_dim, n_head, n_kvhead, layer_id) self.norm_attn = RMSNorm(n_dim, norm_eps) self.ffn = MLP(n_dim, d_ffn) self.norm_ffn = RMSNorm(n_dim, norm_eps) diff --git a/khaosz/model/transformer.py b/khaosz/model/transformer.py index 6290fe6..2c68d6b 100644 --- a/khaosz/model/transformer.py +++ b/khaosz/model/transformer.py @@ -4,7 +4,7 @@ import torch.nn.functional as F from torch import Tensor from torch.nn import init -from typing import List, Optional, Tuple +from typing import Optional, Tuple from khaosz.config.model_config import TransformerConfig from khaosz.model.module import DecoderBlock, RMSNorm, get_rotary_emb @@ -68,14 +68,8 @@ class Transformer(nn.Module): super().__init__() self.embedding = nn.Parameter(torch.empty(config.vocab_size, config.n_dim)) self.layers = nn.ModuleList([ - DecoderBlock( - config.n_dim, - config.n_head, - config.d_ffn, - config.n_kvhead, - config.norm_eps - ) - for _ in range(config.n_layer) + DecoderBlock(config.n_dim, config.n_head, config.d_ffn, config.n_kvhead, config.norm_eps, layer_id) + for layer_id in range(config.n_layer) ]) self.norm = RMSNorm(config.n_dim, config.norm_eps) self.freq_cis = get_rotary_emb(config.n_dim // config.n_head, config.m_len) @@ -85,7 +79,7 @@ class Transformer(nn.Module): self, input_ids: Tensor, input_mask: Optional[Tensor]=None, - persistent_key_values: Optional[List[Tuple[Tensor, Tensor]]]=None, + persistent_key_values: Optional[Tuple[Tensor, Tensor]]=None, start_pos: int = 0 ) -> Tensor: assert input_ids.ndim == 2 @@ -105,9 +99,8 @@ class Transformer(nn.Module): dtype=x.dtype ) - for i, layer in enumerate(self.layers): - kv_cache = persistent_key_values[i] if persistent_key_values else None - x = layer(x, freqs_cis, attn_mask, kv_cache, start_pos) + for layer in self.layers: + x = layer(x, freqs_cis, attn_mask, persistent_key_values, start_pos) hidden_states = self.norm(x) logits = F.linear(hidden_states, self.embedding)