feat(KVCacheManager): 优化KV缓存结构为元组形式以提升性能
This commit is contained in:
parent
bc5ef72001
commit
38b2725cd1
|
|
@ -173,34 +173,25 @@ class KVCacheManager:
|
|||
self.device = device
|
||||
self.dtype = dtype
|
||||
|
||||
self._kv_cache: List[Tuple[Tensor, Tensor]] = None
|
||||
self._kv_cache: Tuple[Tensor, Tensor] = None
|
||||
self._seq_mask: Tensor = None
|
||||
self._initialize()
|
||||
|
||||
def _initialize(self):
|
||||
self._kv_cache = []
|
||||
for _ in range(self.num_layers):
|
||||
k_cache = torch.zeros(
|
||||
(self.batch_size, self.max_len, self.num_heads, self.head_dim),
|
||||
device=self.device, dtype=self.dtype
|
||||
)
|
||||
v_cache = torch.zeros(
|
||||
(self.batch_size, self.max_len, self.num_heads, self.head_dim),
|
||||
device=self.device, dtype=self.dtype
|
||||
)
|
||||
self._kv_cache.append((k_cache, v_cache))
|
||||
|
||||
self._seq_mask = torch.ones(
|
||||
(self.batch_size, self.max_len),
|
||||
device=self.device, dtype=torch.bool
|
||||
k_cache = torch.zeros(
|
||||
(self.batch_size, self.num_layers, self.max_len, self.num_heads, self.head_dim),
|
||||
device=self.device, dtype=self.dtype
|
||||
)
|
||||
v_cache = torch.zeros(
|
||||
(self.batch_size, self.num_layers, self.max_len, self.num_heads, self.head_dim),
|
||||
device=self.device, dtype=self.dtype
|
||||
)
|
||||
self._kv_cache = (k_cache, v_cache)
|
||||
self._seq_mask = torch.ones((self.batch_size, self.max_len), device=self.device, dtype=torch.bool)
|
||||
|
||||
def update(self, active_mask: Tensor):
|
||||
for i in range(self.num_layers):
|
||||
k_cache, v_cache = self._kv_cache[i]
|
||||
new_k_cache, new_v_cache = k_cache[active_mask], v_cache[active_mask]
|
||||
self._kv_cache[i] = (new_k_cache, new_v_cache)
|
||||
|
||||
def update(self, active_mask: Tensor):
|
||||
k_cache, v_cache = self._kv_cache
|
||||
self._kv_cache = (k_cache[active_mask], v_cache[active_mask])
|
||||
self._seq_mask = self._seq_mask[active_mask]
|
||||
|
||||
def reset(self, full_reset=False):
|
||||
|
|
@ -215,7 +206,7 @@ class KVCacheManager:
|
|||
bool_mask = (input_ids != pad_id)
|
||||
self._seq_mask[: batch_size, : seq_len] = bool_mask
|
||||
|
||||
def get_kvcache(self) -> List[Tuple[Tensor, Tensor]]:
|
||||
def get_kvcache(self) -> Tuple[Tensor, Tensor]:
|
||||
return self._kv_cache
|
||||
|
||||
def get_seq_mask(self) -> Tensor:
|
||||
|
|
|
|||
|
|
@ -117,12 +117,14 @@ class GQA(nn.Module):
|
|||
n_dim: int,
|
||||
n_head: int,
|
||||
n_kvhead: int,
|
||||
layer_id: int
|
||||
):
|
||||
super().__init__()
|
||||
assert n_dim % n_head == 0
|
||||
assert n_head % n_kvhead == 0
|
||||
|
||||
self.head_dim = n_dim // n_head
|
||||
self.layer_id = layer_id
|
||||
self.n_dim = n_dim
|
||||
self.n_heads = n_head
|
||||
self.n_kvheads = n_kvhead
|
||||
|
|
@ -150,14 +152,14 @@ class GQA(nn.Module):
|
|||
|
||||
if kv_cache is not None:
|
||||
k_cache, v_cache = kv_cache
|
||||
|
||||
|
||||
# copy to cache
|
||||
k_cache[:bsz, start_pos:start_pos + seq_len] = k
|
||||
v_cache[:bsz, start_pos:start_pos + seq_len] = v
|
||||
k_cache[:bsz, self.layer_id, start_pos:start_pos + seq_len] = k
|
||||
v_cache[:bsz, self.layer_id, start_pos:start_pos + seq_len] = v
|
||||
|
||||
# get cache
|
||||
k = k_cache[:bsz, :start_pos + seq_len]
|
||||
v = v_cache[:bsz, :start_pos + seq_len]
|
||||
k = k_cache[:bsz, self.layer_id, :start_pos + seq_len]
|
||||
v = v_cache[:bsz, self.layer_id, :start_pos + seq_len]
|
||||
|
||||
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
|
||||
|
||||
|
|
@ -175,9 +177,9 @@ class GQA(nn.Module):
|
|||
|
||||
|
||||
class DecoderBlock(nn.Module):
|
||||
def __init__(self, n_dim, n_head, d_ffn, n_kvhead, norm_eps):
|
||||
def __init__(self, n_dim, n_head, d_ffn, n_kvhead, norm_eps, layer_id):
|
||||
super().__init__()
|
||||
self.attention = GQA(n_dim, n_head, n_kvhead)
|
||||
self.attention = GQA(n_dim, n_head, n_kvhead, layer_id)
|
||||
self.norm_attn = RMSNorm(n_dim, norm_eps)
|
||||
self.ffn = MLP(n_dim, d_ffn)
|
||||
self.norm_ffn = RMSNorm(n_dim, norm_eps)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import torch.nn.functional as F
|
|||
|
||||
from torch import Tensor
|
||||
from torch.nn import init
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from khaosz.config.model_config import TransformerConfig
|
||||
from khaosz.model.module import DecoderBlock, RMSNorm, get_rotary_emb
|
||||
|
|
@ -68,14 +68,8 @@ class Transformer(nn.Module):
|
|||
super().__init__()
|
||||
self.embedding = nn.Parameter(torch.empty(config.vocab_size, config.n_dim))
|
||||
self.layers = nn.ModuleList([
|
||||
DecoderBlock(
|
||||
config.n_dim,
|
||||
config.n_head,
|
||||
config.d_ffn,
|
||||
config.n_kvhead,
|
||||
config.norm_eps
|
||||
)
|
||||
for _ in range(config.n_layer)
|
||||
DecoderBlock(config.n_dim, config.n_head, config.d_ffn, config.n_kvhead, config.norm_eps, layer_id)
|
||||
for layer_id in range(config.n_layer)
|
||||
])
|
||||
self.norm = RMSNorm(config.n_dim, config.norm_eps)
|
||||
self.freq_cis = get_rotary_emb(config.n_dim // config.n_head, config.m_len)
|
||||
|
|
@ -85,7 +79,7 @@ class Transformer(nn.Module):
|
|||
self,
|
||||
input_ids: Tensor,
|
||||
input_mask: Optional[Tensor]=None,
|
||||
persistent_key_values: Optional[List[Tuple[Tensor, Tensor]]]=None,
|
||||
persistent_key_values: Optional[Tuple[Tensor, Tensor]]=None,
|
||||
start_pos: int = 0
|
||||
) -> Tensor:
|
||||
assert input_ids.ndim == 2
|
||||
|
|
@ -105,9 +99,8 @@ class Transformer(nn.Module):
|
|||
dtype=x.dtype
|
||||
)
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
kv_cache = persistent_key_values[i] if persistent_key_values else None
|
||||
x = layer(x, freqs_cis, attn_mask, kv_cache, start_pos)
|
||||
for layer in self.layers:
|
||||
x = layer(x, freqs_cis, attn_mask, persistent_key_values, start_pos)
|
||||
|
||||
hidden_states = self.norm(x)
|
||||
logits = F.linear(hidden_states, self.embedding)
|
||||
|
|
|
|||
Loading…
Reference in New Issue