feat(KVCacheManager): 优化KV缓存结构为元组形式以提升性能

This commit is contained in:
ViperEkura 2025-10-29 12:01:28 +08:00
parent bc5ef72001
commit 38b2725cd1
3 changed files with 29 additions and 43 deletions

View File

@ -173,34 +173,25 @@ class KVCacheManager:
self.device = device self.device = device
self.dtype = dtype self.dtype = dtype
self._kv_cache: List[Tuple[Tensor, Tensor]] = None self._kv_cache: Tuple[Tensor, Tensor] = None
self._seq_mask: Tensor = None self._seq_mask: Tensor = None
self._initialize() self._initialize()
def _initialize(self): def _initialize(self):
self._kv_cache = []
for _ in range(self.num_layers):
k_cache = torch.zeros( k_cache = torch.zeros(
(self.batch_size, self.max_len, self.num_heads, self.head_dim), (self.batch_size, self.num_layers, self.max_len, self.num_heads, self.head_dim),
device=self.device, dtype=self.dtype device=self.device, dtype=self.dtype
) )
v_cache = torch.zeros( v_cache = torch.zeros(
(self.batch_size, self.max_len, self.num_heads, self.head_dim), (self.batch_size, self.num_layers, self.max_len, self.num_heads, self.head_dim),
device=self.device, dtype=self.dtype device=self.device, dtype=self.dtype
) )
self._kv_cache.append((k_cache, v_cache)) self._kv_cache = (k_cache, v_cache)
self._seq_mask = torch.ones((self.batch_size, self.max_len), device=self.device, dtype=torch.bool)
self._seq_mask = torch.ones(
(self.batch_size, self.max_len),
device=self.device, dtype=torch.bool
)
def update(self, active_mask: Tensor): def update(self, active_mask: Tensor):
for i in range(self.num_layers): k_cache, v_cache = self._kv_cache
k_cache, v_cache = self._kv_cache[i] self._kv_cache = (k_cache[active_mask], v_cache[active_mask])
new_k_cache, new_v_cache = k_cache[active_mask], v_cache[active_mask]
self._kv_cache[i] = (new_k_cache, new_v_cache)
self._seq_mask = self._seq_mask[active_mask] self._seq_mask = self._seq_mask[active_mask]
def reset(self, full_reset=False): def reset(self, full_reset=False):
@ -215,7 +206,7 @@ class KVCacheManager:
bool_mask = (input_ids != pad_id) bool_mask = (input_ids != pad_id)
self._seq_mask[: batch_size, : seq_len] = bool_mask self._seq_mask[: batch_size, : seq_len] = bool_mask
def get_kvcache(self) -> List[Tuple[Tensor, Tensor]]: def get_kvcache(self) -> Tuple[Tensor, Tensor]:
return self._kv_cache return self._kv_cache
def get_seq_mask(self) -> Tensor: def get_seq_mask(self) -> Tensor:

View File

@ -117,12 +117,14 @@ class GQA(nn.Module):
n_dim: int, n_dim: int,
n_head: int, n_head: int,
n_kvhead: int, n_kvhead: int,
layer_id: int
): ):
super().__init__() super().__init__()
assert n_dim % n_head == 0 assert n_dim % n_head == 0
assert n_head % n_kvhead == 0 assert n_head % n_kvhead == 0
self.head_dim = n_dim // n_head self.head_dim = n_dim // n_head
self.layer_id = layer_id
self.n_dim = n_dim self.n_dim = n_dim
self.n_heads = n_head self.n_heads = n_head
self.n_kvheads = n_kvhead self.n_kvheads = n_kvhead
@ -152,12 +154,12 @@ class GQA(nn.Module):
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
# copy to cache # copy to cache
k_cache[:bsz, start_pos:start_pos + seq_len] = k k_cache[:bsz, self.layer_id, start_pos:start_pos + seq_len] = k
v_cache[:bsz, start_pos:start_pos + seq_len] = v v_cache[:bsz, self.layer_id, start_pos:start_pos + seq_len] = v
# get cache # get cache
k = k_cache[:bsz, :start_pos + seq_len] k = k_cache[:bsz, self.layer_id, :start_pos + seq_len]
v = v_cache[:bsz, :start_pos + seq_len] v = v_cache[:bsz, self.layer_id, :start_pos + seq_len]
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep) k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
@ -175,9 +177,9 @@ class GQA(nn.Module):
class DecoderBlock(nn.Module): class DecoderBlock(nn.Module):
def __init__(self, n_dim, n_head, d_ffn, n_kvhead, norm_eps): def __init__(self, n_dim, n_head, d_ffn, n_kvhead, norm_eps, layer_id):
super().__init__() super().__init__()
self.attention = GQA(n_dim, n_head, n_kvhead) self.attention = GQA(n_dim, n_head, n_kvhead, layer_id)
self.norm_attn = RMSNorm(n_dim, norm_eps) self.norm_attn = RMSNorm(n_dim, norm_eps)
self.ffn = MLP(n_dim, d_ffn) self.ffn = MLP(n_dim, d_ffn)
self.norm_ffn = RMSNorm(n_dim, norm_eps) self.norm_ffn = RMSNorm(n_dim, norm_eps)

View File

@ -4,7 +4,7 @@ import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from torch.nn import init from torch.nn import init
from typing import List, Optional, Tuple from typing import Optional, Tuple
from khaosz.config.model_config import TransformerConfig from khaosz.config.model_config import TransformerConfig
from khaosz.model.module import DecoderBlock, RMSNorm, get_rotary_emb from khaosz.model.module import DecoderBlock, RMSNorm, get_rotary_emb
@ -68,14 +68,8 @@ class Transformer(nn.Module):
super().__init__() super().__init__()
self.embedding = nn.Parameter(torch.empty(config.vocab_size, config.n_dim)) self.embedding = nn.Parameter(torch.empty(config.vocab_size, config.n_dim))
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
DecoderBlock( DecoderBlock(config.n_dim, config.n_head, config.d_ffn, config.n_kvhead, config.norm_eps, layer_id)
config.n_dim, for layer_id in range(config.n_layer)
config.n_head,
config.d_ffn,
config.n_kvhead,
config.norm_eps
)
for _ in range(config.n_layer)
]) ])
self.norm = RMSNorm(config.n_dim, config.norm_eps) self.norm = RMSNorm(config.n_dim, config.norm_eps)
self.freq_cis = get_rotary_emb(config.n_dim // config.n_head, config.m_len) self.freq_cis = get_rotary_emb(config.n_dim // config.n_head, config.m_len)
@ -85,7 +79,7 @@ class Transformer(nn.Module):
self, self,
input_ids: Tensor, input_ids: Tensor,
input_mask: Optional[Tensor]=None, input_mask: Optional[Tensor]=None,
persistent_key_values: Optional[List[Tuple[Tensor, Tensor]]]=None, persistent_key_values: Optional[Tuple[Tensor, Tensor]]=None,
start_pos: int = 0 start_pos: int = 0
) -> Tensor: ) -> Tensor:
assert input_ids.ndim == 2 assert input_ids.ndim == 2
@ -105,9 +99,8 @@ class Transformer(nn.Module):
dtype=x.dtype dtype=x.dtype
) )
for i, layer in enumerate(self.layers): for layer in self.layers:
kv_cache = persistent_key_values[i] if persistent_key_values else None x = layer(x, freqs_cis, attn_mask, persistent_key_values, start_pos)
x = layer(x, freqs_cis, attn_mask, kv_cache, start_pos)
hidden_states = self.norm(x) hidden_states = self.norm(x)
logits = F.linear(hidden_states, self.embedding) logits = F.linear(hidden_states, self.embedding)