feat(KVCacheManager): 优化KV缓存结构为元组形式以提升性能

This commit is contained in:
ViperEkura 2025-10-29 12:01:28 +08:00
parent bc5ef72001
commit 38b2725cd1
3 changed files with 29 additions and 43 deletions

View File

@ -173,34 +173,25 @@ class KVCacheManager:
self.device = device
self.dtype = dtype
self._kv_cache: List[Tuple[Tensor, Tensor]] = None
self._kv_cache: Tuple[Tensor, Tensor] = None
self._seq_mask: Tensor = None
self._initialize()
def _initialize(self):
self._kv_cache = []
for _ in range(self.num_layers):
k_cache = torch.zeros(
(self.batch_size, self.max_len, self.num_heads, self.head_dim),
device=self.device, dtype=self.dtype
)
v_cache = torch.zeros(
(self.batch_size, self.max_len, self.num_heads, self.head_dim),
device=self.device, dtype=self.dtype
)
self._kv_cache.append((k_cache, v_cache))
self._seq_mask = torch.ones(
(self.batch_size, self.max_len),
device=self.device, dtype=torch.bool
k_cache = torch.zeros(
(self.batch_size, self.num_layers, self.max_len, self.num_heads, self.head_dim),
device=self.device, dtype=self.dtype
)
v_cache = torch.zeros(
(self.batch_size, self.num_layers, self.max_len, self.num_heads, self.head_dim),
device=self.device, dtype=self.dtype
)
self._kv_cache = (k_cache, v_cache)
self._seq_mask = torch.ones((self.batch_size, self.max_len), device=self.device, dtype=torch.bool)
def update(self, active_mask: Tensor):
for i in range(self.num_layers):
k_cache, v_cache = self._kv_cache[i]
new_k_cache, new_v_cache = k_cache[active_mask], v_cache[active_mask]
self._kv_cache[i] = (new_k_cache, new_v_cache)
def update(self, active_mask: Tensor):
k_cache, v_cache = self._kv_cache
self._kv_cache = (k_cache[active_mask], v_cache[active_mask])
self._seq_mask = self._seq_mask[active_mask]
def reset(self, full_reset=False):
@ -215,7 +206,7 @@ class KVCacheManager:
bool_mask = (input_ids != pad_id)
self._seq_mask[: batch_size, : seq_len] = bool_mask
def get_kvcache(self) -> List[Tuple[Tensor, Tensor]]:
def get_kvcache(self) -> Tuple[Tensor, Tensor]:
return self._kv_cache
def get_seq_mask(self) -> Tensor:

View File

@ -117,12 +117,14 @@ class GQA(nn.Module):
n_dim: int,
n_head: int,
n_kvhead: int,
layer_id: int
):
super().__init__()
assert n_dim % n_head == 0
assert n_head % n_kvhead == 0
self.head_dim = n_dim // n_head
self.layer_id = layer_id
self.n_dim = n_dim
self.n_heads = n_head
self.n_kvheads = n_kvhead
@ -150,14 +152,14 @@ class GQA(nn.Module):
if kv_cache is not None:
k_cache, v_cache = kv_cache
# copy to cache
k_cache[:bsz, start_pos:start_pos + seq_len] = k
v_cache[:bsz, start_pos:start_pos + seq_len] = v
k_cache[:bsz, self.layer_id, start_pos:start_pos + seq_len] = k
v_cache[:bsz, self.layer_id, start_pos:start_pos + seq_len] = v
# get cache
k = k_cache[:bsz, :start_pos + seq_len]
v = v_cache[:bsz, :start_pos + seq_len]
k = k_cache[:bsz, self.layer_id, :start_pos + seq_len]
v = v_cache[:bsz, self.layer_id, :start_pos + seq_len]
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
@ -175,9 +177,9 @@ class GQA(nn.Module):
class DecoderBlock(nn.Module):
def __init__(self, n_dim, n_head, d_ffn, n_kvhead, norm_eps):
def __init__(self, n_dim, n_head, d_ffn, n_kvhead, norm_eps, layer_id):
super().__init__()
self.attention = GQA(n_dim, n_head, n_kvhead)
self.attention = GQA(n_dim, n_head, n_kvhead, layer_id)
self.norm_attn = RMSNorm(n_dim, norm_eps)
self.ffn = MLP(n_dim, d_ffn)
self.norm_ffn = RMSNorm(n_dim, norm_eps)

View File

@ -4,7 +4,7 @@ import torch.nn.functional as F
from torch import Tensor
from torch.nn import init
from typing import List, Optional, Tuple
from typing import Optional, Tuple
from khaosz.config.model_config import TransformerConfig
from khaosz.model.module import DecoderBlock, RMSNorm, get_rotary_emb
@ -68,14 +68,8 @@ class Transformer(nn.Module):
super().__init__()
self.embedding = nn.Parameter(torch.empty(config.vocab_size, config.n_dim))
self.layers = nn.ModuleList([
DecoderBlock(
config.n_dim,
config.n_head,
config.d_ffn,
config.n_kvhead,
config.norm_eps
)
for _ in range(config.n_layer)
DecoderBlock(config.n_dim, config.n_head, config.d_ffn, config.n_kvhead, config.norm_eps, layer_id)
for layer_id in range(config.n_layer)
])
self.norm = RMSNorm(config.n_dim, config.norm_eps)
self.freq_cis = get_rotary_emb(config.n_dim // config.n_head, config.m_len)
@ -85,7 +79,7 @@ class Transformer(nn.Module):
self,
input_ids: Tensor,
input_mask: Optional[Tensor]=None,
persistent_key_values: Optional[List[Tuple[Tensor, Tensor]]]=None,
persistent_key_values: Optional[Tuple[Tensor, Tensor]]=None,
start_pos: int = 0
) -> Tensor:
assert input_ids.ndim == 2
@ -105,9 +99,8 @@ class Transformer(nn.Module):
dtype=x.dtype
)
for i, layer in enumerate(self.layers):
kv_cache = persistent_key_values[i] if persistent_key_values else None
x = layer(x, freqs_cis, attn_mask, kv_cache, start_pos)
for layer in self.layers:
x = layer(x, freqs_cis, attn_mask, persistent_key_values, start_pos)
hidden_states = self.norm(x)
logits = F.linear(hidden_states, self.embedding)