feat: 增加 MLA 模块
This commit is contained in:
parent
abc3a06266
commit
abcedf892e
|
|
@ -135,15 +135,6 @@ class MLP(nn.Module):
|
||||||
out = self.down(gated)
|
out = self.down(gated)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
class Attention(nn.Module):
|
|
||||||
|
|
||||||
def forward(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, is_causal: bool= False):
|
|
||||||
# (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim)
|
|
||||||
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
|
|
||||||
# (bsz, n_heads, seq_len, head_dim) - > (bsz, seq_len, n_heads*head_dim)
|
|
||||||
sdqa_out = F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal).permute(0, 2, 1, 3).contiguous().flatten(2)
|
|
||||||
|
|
||||||
return sdqa_out
|
|
||||||
|
|
||||||
|
|
||||||
class GQA(nn.Module):
|
class GQA(nn.Module):
|
||||||
|
|
@ -170,8 +161,6 @@ class GQA(nn.Module):
|
||||||
self.use_qk_norm = use_qk_norm
|
self.use_qk_norm = use_qk_norm
|
||||||
self.use_gated_attention = use_gated_attention
|
self.use_gated_attention = use_gated_attention
|
||||||
|
|
||||||
self.attention = Attention()
|
|
||||||
|
|
||||||
self.q_proj = Linear(dim, n_heads * self.head_dim)
|
self.q_proj = Linear(dim, n_heads * self.head_dim)
|
||||||
self.k_proj = Linear(dim, n_kv_heads * self.head_dim)
|
self.k_proj = Linear(dim, n_kv_heads * self.head_dim)
|
||||||
self.v_proj = Linear(dim, n_kv_heads * self.head_dim)
|
self.v_proj = Linear(dim, n_kv_heads * self.head_dim)
|
||||||
|
|
@ -198,6 +187,8 @@ class GQA(nn.Module):
|
||||||
start_pos: int = 0
|
start_pos: int = 0
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
bsz, seq_len, _ = x.size()
|
bsz, seq_len, _ = x.size()
|
||||||
|
is_causal = mask is None
|
||||||
|
|
||||||
# x(bsz, seq_len, n_heads * head_dim) -> (bsz, seq_len, n_heads, head_dim)
|
# x(bsz, seq_len, n_heads * head_dim) -> (bsz, seq_len, n_heads, head_dim)
|
||||||
q = self._split_heads(self.q_proj(x), self.n_heads)
|
q = self._split_heads(self.q_proj(x), self.n_heads)
|
||||||
k = self._split_heads(self.k_proj(x), self.n_kv_heads)
|
k = self._split_heads(self.k_proj(x), self.n_kv_heads)
|
||||||
|
|
@ -219,7 +210,11 @@ class GQA(nn.Module):
|
||||||
v = v_cache[:bsz, :start_pos + seq_len, self.layer_id]
|
v = v_cache[:bsz, :start_pos + seq_len, self.layer_id]
|
||||||
|
|
||||||
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
|
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
|
||||||
sdqa_out = self.attention(q, k, v, mask, is_causal=(mask == None))
|
|
||||||
|
# (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim)
|
||||||
|
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
|
||||||
|
# (bsz, n_heads, seq_len, head_dim) - > (bsz, seq_len, n_heads*head_dim)
|
||||||
|
sdqa_out = F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal).permute(0, 2, 1, 3).contiguous().flatten(2)
|
||||||
|
|
||||||
if self.use_gated_attention:
|
if self.use_gated_attention:
|
||||||
sdqa_out = sdqa_out * F.sigmoid(self.gate(x))
|
sdqa_out = sdqa_out * F.sigmoid(self.gate(x))
|
||||||
|
|
@ -229,6 +224,100 @@ class GQA(nn.Module):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class MLA(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
n_heads: int,
|
||||||
|
n_kv_heads: int,
|
||||||
|
kv_lora_rank: int,
|
||||||
|
qk_nope_head_dim: int,
|
||||||
|
qk_rope_head_dim: int,
|
||||||
|
norm_eps: float,
|
||||||
|
use_gated_attention: bool,
|
||||||
|
layer_id: int
|
||||||
|
):
|
||||||
|
self.dim = dim
|
||||||
|
self.n_heads = n_heads
|
||||||
|
self.n_kv_heads = n_kv_heads
|
||||||
|
self.kv_lora_rank = kv_lora_rank
|
||||||
|
self.qk_nope_head_dim = qk_nope_head_dim
|
||||||
|
self.qk_rope_head_dim = qk_rope_head_dim
|
||||||
|
self.head_dim = qk_nope_head_dim + qk_rope_head_dim
|
||||||
|
self.layer_id = layer_id
|
||||||
|
self.n_rep = n_heads // n_kv_heads
|
||||||
|
self.use_gated_attention = use_gated_attention
|
||||||
|
|
||||||
|
self.q_proj = Linear(dim, n_heads * self.head_dim, bias=False)
|
||||||
|
self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False)
|
||||||
|
self.kv_norm = RMSNorm(kv_lora_rank, eps=norm_eps)
|
||||||
|
|
||||||
|
# KV (k_nope, k_rope, v)
|
||||||
|
self.kv_b_proj = Linear(
|
||||||
|
kv_lora_rank,
|
||||||
|
n_kv_heads * (self.head_dim + qk_rope_head_dim + self.head_dim),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.o_proj = Linear(dim, dim, bias=False)
|
||||||
|
|
||||||
|
if use_gated_attention:
|
||||||
|
self.gate = Linear(dim, dim, bias=False)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
rotary_emb: Tuple[Tensor, Tensor],
|
||||||
|
mask: Tensor = None,
|
||||||
|
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
|
||||||
|
start_pos: int = 0
|
||||||
|
) -> Tensor:
|
||||||
|
bsz, seq_len, _ = x.size()
|
||||||
|
is_causal = mask is None
|
||||||
|
|
||||||
|
q = self.q_proj(x)
|
||||||
|
q = q.view(bsz, seq_len, self.n_heads, self.head_dim)
|
||||||
|
|
||||||
|
kv_compressed = self.kv_a_proj(x)
|
||||||
|
kv_compressed = self.kv_norm(kv_compressed)
|
||||||
|
|
||||||
|
kv = self.kv_b_proj(kv_compressed)
|
||||||
|
kv = kv.view(bsz, seq_len, self.n_kv_heads, -1)
|
||||||
|
|
||||||
|
k_nope, k_rope, v = torch.split(
|
||||||
|
kv,
|
||||||
|
[self.qk_nope_head_dim, self.qk_rope_head_dim, self.head_dim],
|
||||||
|
dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
q_nope, q_rope = q[..., :self.qk_nope_head_dim], q[..., self.qk_rope_head_dim:]
|
||||||
|
q_rope = apply_rotary_emb(q_rope, rotary_emb)
|
||||||
|
k_rope = apply_rotary_emb(k_rope, rotary_emb)
|
||||||
|
|
||||||
|
q = torch.cat([q_nope, q_rope], dim=-1)
|
||||||
|
k = torch.cat([k_nope, k_rope], dim=-1)
|
||||||
|
|
||||||
|
if kv_cache is not None:
|
||||||
|
k_cache, v_cache = kv_cache
|
||||||
|
k_cache[:bsz, start_pos:start_pos + seq_len, self.layer_id] = k
|
||||||
|
v_cache[:bsz, start_pos:start_pos + seq_len, self.layer_id] = v
|
||||||
|
k = k_cache[:bsz, :start_pos + seq_len, self.layer_id]
|
||||||
|
v = v_cache[:bsz, :start_pos + seq_len, self.layer_id]
|
||||||
|
|
||||||
|
q = q.permute(0, 2, 1, 3)
|
||||||
|
k = k.permute(0, 2, 1, 3)
|
||||||
|
v = v.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
|
attn_out = F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal)
|
||||||
|
attn_out = attn_out.permute(0, 2, 1, 3).contiguous().flatten(2)
|
||||||
|
|
||||||
|
if self.use_gated_attention:
|
||||||
|
attn_out = attn_out * F.sigmoid(self.gate(x))
|
||||||
|
|
||||||
|
out = self.o_proj(attn_out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class DecoderBlock(nn.Module):
|
class DecoderBlock(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue