383 lines
12 KiB
Python
383 lines
12 KiB
Python
from typing import Optional, Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch import Tensor
|
|
|
|
|
|
def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
|
|
"""
|
|
Repeat k times along the dimension for attention heads.
|
|
Args:
|
|
x (Tensor): The input tensor.
|
|
n_rep (int): The number of repetitions.
|
|
Returns:
|
|
Tensor: The repeated tensor.
|
|
"""
|
|
|
|
bs, slen, n_heads, head_dim = x.shape
|
|
if n_rep == 1:
|
|
return x
|
|
return (
|
|
x[:, :, :, None, :]
|
|
.expand(bs, slen, n_heads, n_rep, head_dim)
|
|
.reshape(bs, slen, n_heads * n_rep, head_dim)
|
|
)
|
|
|
|
|
|
def get_rotary_emb(
|
|
dim: int,
|
|
max_len: int,
|
|
base: float = 10000,
|
|
device: Optional[torch.device] = None,
|
|
) -> Tuple[Tensor, Tensor]:
|
|
"""
|
|
Get the rotary embedding for the given dimension and maximum length.
|
|
Args:
|
|
dim (int): The dimension of the input.
|
|
max_len (int): The maximum length of the input.
|
|
base (float, optional): The base for the frequency. Defaults to 10000.
|
|
device (optional): The device to create tensors on. Defaults to None.
|
|
Returns:
|
|
Tensor: The rotary embedding tensor.
|
|
"""
|
|
|
|
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim)
|
|
t = torch.arange(0, max_len, dtype=torch.float64, device=device)
|
|
freqs = torch.outer(t, theta)
|
|
|
|
return torch.cos(freqs).float(), torch.sin(freqs).float()
|
|
|
|
|
|
def apply_rotary_emb(x: torch.Tensor, rotary_emb: Tuple[Tensor, Tensor]) -> Tensor:
|
|
"""
|
|
Apply rotary embedding to the input tensor using cos/sin form.
|
|
Args:
|
|
x (Tensor): The input tensor (shape [..., seq_len, dim]).
|
|
rotary_emb (Tuple[Tensor, Tensor]): The rotary embedding (shape [seq_len, dim//2]).
|
|
Returns:
|
|
Tensor: The output tensor (rotated, same shape as input).
|
|
"""
|
|
|
|
dtype = x.dtype
|
|
cos, sin = rotary_emb
|
|
|
|
cos = cos.unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, dim//2]
|
|
sin = sin.unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, dim//2]
|
|
|
|
x_real = x[..., 0::2] # [batch, seq_len, dim//2]
|
|
x_imag = x[..., 1::2] # [batch, seq_len, dim//2]
|
|
|
|
x_real_rot = x_real * cos - x_imag * sin
|
|
x_imag_rot = x_real * sin + x_imag * cos
|
|
|
|
x_out = torch.stack([x_real_rot, x_imag_rot], dim=-1) # [batch, seq_len, dim//2, 2]
|
|
x_out = x_out.view(*x_out.shape[:-2], -1) # [batch, seq_len, dim]
|
|
|
|
return x_out.to(dtype)
|
|
|
|
|
|
class RotaryEmbedding(nn.Module):
|
|
def __init__(self, dim: int, max_len: int, base: int = 10000):
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.max_len = max_len
|
|
self.base = base
|
|
self.max_len_cached = None
|
|
self._set_rotary_buffer(self.max_len, None)
|
|
|
|
def _set_rotary_buffer(self, max_len: int, device: Optional[torch.device] = None):
|
|
cos_cached, sin_cached = get_rotary_emb(self.dim, max_len, self.base, device)
|
|
self.register_buffer("cos_cached", cos_cached, persistent=False)
|
|
self.register_buffer("sin_cached", sin_cached, persistent=False)
|
|
self.max_len_cached = max_len
|
|
|
|
def forward(self, x: Tensor, start_pos: int = 0) -> Tuple[Tensor, Tensor]:
|
|
seq_len = x.size(1)
|
|
|
|
if self.max_len_cached < seq_len + start_pos:
|
|
self._set_rotary_buffer(self.max_len_cached * 2, x.device)
|
|
|
|
cos = self.cos_cached[start_pos : start_pos + seq_len]
|
|
sin = self.sin_cached[start_pos : start_pos + seq_len]
|
|
|
|
return (cos, sin)
|
|
|
|
|
|
class Linear(nn.Module):
|
|
def __init__(self, in_dim: int, out_dim: int, bias: bool = False):
|
|
super().__init__()
|
|
self.weight = nn.Parameter(torch.empty((out_dim, in_dim)))
|
|
self.bias = nn.Parameter(torch.zeros(out_dim)) if bias else None
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
return F.linear(x, self.weight, self.bias)
|
|
|
|
|
|
class RMSNorm(nn.Module):
|
|
def __init__(self, dim, norm_eps):
|
|
super().__init__()
|
|
self.weight = nn.Parameter(torch.ones(dim))
|
|
self.normalized_shape = (dim,)
|
|
self.norm_eps = norm_eps
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
return F.rms_norm(x, self.normalized_shape, self.weight, self.norm_eps)
|
|
|
|
|
|
class MLP(nn.Module):
|
|
def __init__(self, dim: int, dim_feed_forward: int):
|
|
super().__init__()
|
|
self.up = Linear(dim, dim_feed_forward)
|
|
self.gate = Linear(dim, dim_feed_forward)
|
|
self.down = Linear(dim_feed_forward, dim)
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
gated = self.up(x) * F.silu(self.gate(x))
|
|
out = self.down(gated)
|
|
return out
|
|
|
|
|
|
class GQA(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
n_heads: int,
|
|
n_kv_heads: int,
|
|
use_qk_norm: bool,
|
|
norm_eps: float,
|
|
use_gated_attention: bool,
|
|
layer_id: int,
|
|
):
|
|
super().__init__()
|
|
assert dim % n_heads == 0
|
|
assert n_heads % n_kv_heads == 0
|
|
|
|
self.head_dim = dim // n_heads
|
|
self.layer_id = layer_id
|
|
self.dim = dim
|
|
self.n_heads = n_heads
|
|
self.n_kv_heads = n_kv_heads
|
|
self.n_rep = n_heads // n_kv_heads
|
|
self.use_qk_norm = use_qk_norm
|
|
self.use_gated_attention = use_gated_attention
|
|
|
|
self.q_proj = Linear(dim, n_heads * self.head_dim)
|
|
self.k_proj = Linear(dim, n_kv_heads * self.head_dim)
|
|
self.v_proj = Linear(dim, n_kv_heads * self.head_dim)
|
|
self.o_proj = Linear(dim, dim)
|
|
|
|
if self.use_qk_norm:
|
|
self.q_norm = RMSNorm(self.head_dim, norm_eps)
|
|
self.k_norm = RMSNorm(self.head_dim, norm_eps)
|
|
|
|
if self.use_gated_attention:
|
|
self.gate = Linear(dim, dim)
|
|
|
|
def _split_heads(self, x: Tensor, n_heads) -> Tensor:
|
|
batch_size, seq_len, _ = x.shape
|
|
x = x.reshape(batch_size, seq_len, n_heads, self.head_dim)
|
|
return x
|
|
|
|
def forward(
|
|
self,
|
|
x: Tensor,
|
|
rotary_emb: Tuple[Tensor, Tensor],
|
|
mask: Tensor = None,
|
|
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
|
|
start_pos: int = 0,
|
|
) -> Tensor:
|
|
bsz, seq_len, _ = x.size()
|
|
is_causal = mask is None
|
|
|
|
# x(bsz, seq_len, n_heads * head_dim) -> (bsz, seq_len, n_heads, head_dim)
|
|
q = self._split_heads(self.q_proj(x), self.n_heads)
|
|
k = self._split_heads(self.k_proj(x), self.n_kv_heads)
|
|
v = self._split_heads(self.v_proj(x), self.n_kv_heads)
|
|
q, k = apply_rotary_emb(q, rotary_emb), apply_rotary_emb(k, rotary_emb)
|
|
|
|
if self.use_qk_norm:
|
|
q, k = self.q_norm(q), self.k_norm(k)
|
|
|
|
if kv_cache is not None:
|
|
k_cache, v_cache = kv_cache
|
|
|
|
# copy to cache
|
|
k_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = k
|
|
v_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = v
|
|
|
|
# get cache
|
|
k = k_cache[:bsz, : start_pos + seq_len, self.layer_id]
|
|
v = v_cache[:bsz, : start_pos + seq_len, self.layer_id]
|
|
|
|
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
|
|
|
|
# (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim)
|
|
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
|
|
# (bsz, n_heads, seq_len, head_dim) - > (bsz, seq_len, n_heads*head_dim)
|
|
sdqa_out = (
|
|
F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal)
|
|
.permute(0, 2, 1, 3)
|
|
.contiguous()
|
|
.flatten(2)
|
|
)
|
|
|
|
if self.use_gated_attention:
|
|
sdqa_out = sdqa_out * F.sigmoid(self.gate(x))
|
|
|
|
out = self.o_proj(sdqa_out)
|
|
|
|
return out
|
|
|
|
|
|
class MLA(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
n_heads: int,
|
|
n_kv_heads: int,
|
|
kv_lora_rank: int,
|
|
qk_nope_head_dim: int,
|
|
qk_rope_head_dim: int,
|
|
norm_eps: float,
|
|
use_gated_attention: bool,
|
|
layer_id: int,
|
|
):
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.n_heads = n_heads
|
|
self.n_kv_heads = n_kv_heads
|
|
self.kv_lora_rank = kv_lora_rank
|
|
self.qk_nope_head_dim = qk_nope_head_dim
|
|
self.qk_rope_head_dim = qk_rope_head_dim
|
|
self.head_dim = qk_nope_head_dim + qk_rope_head_dim
|
|
self.layer_id = layer_id
|
|
self.n_rep = n_heads // n_kv_heads
|
|
self.use_gated_attention = use_gated_attention
|
|
|
|
self.q_proj = Linear(dim, n_heads * self.head_dim, bias=False)
|
|
self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False)
|
|
self.kv_norm = RMSNorm(kv_lora_rank, norm_eps)
|
|
|
|
# KV (k_nope, k_rope, v)
|
|
self.kv_b_proj = Linear(
|
|
kv_lora_rank,
|
|
n_kv_heads * (self.head_dim + qk_rope_head_dim + self.head_dim),
|
|
)
|
|
|
|
self.o_proj = Linear(dim, dim, bias=False)
|
|
|
|
if use_gated_attention:
|
|
self.gate = Linear(dim, dim, bias=False)
|
|
|
|
def forward(
|
|
self,
|
|
x: Tensor,
|
|
rotary_emb: Tuple[Tensor, Tensor],
|
|
mask: Tensor = None,
|
|
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
|
|
start_pos: int = 0,
|
|
) -> Tensor:
|
|
bsz, seq_len, _ = x.size()
|
|
is_causal = mask is None
|
|
|
|
q = self.q_proj(x)
|
|
q = q.view(bsz, seq_len, self.n_heads, self.head_dim)
|
|
|
|
kv_compressed = self.kv_a_proj(x)
|
|
kv_compressed = self.kv_norm(kv_compressed)
|
|
|
|
kv = self.kv_b_proj(kv_compressed)
|
|
kv = kv.view(bsz, seq_len, self.n_kv_heads, -1)
|
|
|
|
k_nope, k_rope, v = torch.split(
|
|
kv, [self.qk_nope_head_dim, self.qk_rope_head_dim, self.head_dim], dim=-1
|
|
)
|
|
|
|
q_nope, q_rope = (
|
|
q[..., : self.qk_nope_head_dim],
|
|
q[..., self.qk_rope_head_dim :],
|
|
)
|
|
q_rope = apply_rotary_emb(q_rope, rotary_emb)
|
|
k_rope = apply_rotary_emb(k_rope, rotary_emb)
|
|
|
|
q = torch.cat([q_nope, q_rope], dim=-1)
|
|
k = torch.cat([k_nope, k_rope], dim=-1)
|
|
|
|
if kv_cache is not None:
|
|
k_cache, v_cache = kv_cache
|
|
k_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = k
|
|
v_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = v
|
|
k = k_cache[:bsz, : start_pos + seq_len, self.layer_id]
|
|
v = v_cache[:bsz, : start_pos + seq_len, self.layer_id]
|
|
|
|
q = q.permute(0, 2, 1, 3)
|
|
k = k.permute(0, 2, 1, 3)
|
|
v = v.permute(0, 2, 1, 3)
|
|
|
|
attn_out = F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal)
|
|
attn_out = attn_out.permute(0, 2, 1, 3).contiguous().flatten(2)
|
|
|
|
if self.use_gated_attention:
|
|
attn_out = attn_out * F.sigmoid(self.gate(x))
|
|
|
|
out = self.o_proj(attn_out)
|
|
|
|
return out
|
|
|
|
|
|
class DecoderBlock(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
n_heads: int,
|
|
dim_ffn: int,
|
|
n_kv_heads: int,
|
|
norm_eps: int,
|
|
use_qk_norm: bool,
|
|
use_gated_attention: bool,
|
|
layer_id: int,
|
|
):
|
|
super().__init__()
|
|
self.attention = GQA(
|
|
dim,
|
|
n_heads,
|
|
n_kv_heads,
|
|
use_qk_norm,
|
|
norm_eps,
|
|
use_gated_attention,
|
|
layer_id,
|
|
)
|
|
self.input_norm = RMSNorm(dim, norm_eps)
|
|
self.mlp = MLP(dim, dim_ffn)
|
|
self.post_attention_norm = RMSNorm(dim, norm_eps)
|
|
|
|
def forward(
|
|
self,
|
|
x: Tensor,
|
|
rotary_emb: Tuple[Tensor, Tensor],
|
|
attention_mask: Optional[Tensor] = None,
|
|
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
|
|
start_pos: int = 0,
|
|
) -> Tensor:
|
|
# attention
|
|
attn_output = self.attention(
|
|
self.input_norm(x), rotary_emb, attention_mask, kv_cache, start_pos
|
|
)
|
|
x = attn_output + x
|
|
|
|
# feed forward
|
|
x = self.mlp(self.post_attention_norm(x)) + x
|
|
|
|
return x
|
|
|
|
|
|
class Embedding(nn.Module):
|
|
def __init__(self, vocab_size: int, embedding_dim: int):
|
|
super().__init__()
|
|
self.weight = nn.Parameter(torch.empty((vocab_size, embedding_dim)))
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
return F.embedding(x, self.weight)
|