From abcedf892e3aacf5ca7eb8a69542d8a1b3f2dab0 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Wed, 18 Mar 2026 16:41:46 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=A2=9E=E5=8A=A0=20MLA=20=E6=A8=A1?= =?UTF-8?q?=E5=9D=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/model/module.py | 113 ++++++++++++++++++++++++++++++++++++----- 1 file changed, 101 insertions(+), 12 deletions(-) diff --git a/khaosz/model/module.py b/khaosz/model/module.py index a6c7fef..14d5dc6 100644 --- a/khaosz/model/module.py +++ b/khaosz/model/module.py @@ -135,15 +135,6 @@ class MLP(nn.Module): out = self.down(gated) return out -class Attention(nn.Module): - - def forward(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, is_causal: bool= False): - # (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim) - q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3) - # (bsz, n_heads, seq_len, head_dim) - > (bsz, seq_len, n_heads*head_dim) - sdqa_out = F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal).permute(0, 2, 1, 3).contiguous().flatten(2) - - return sdqa_out class GQA(nn.Module): @@ -169,8 +160,6 @@ class GQA(nn.Module): self.n_rep = n_heads // n_kv_heads self.use_qk_norm = use_qk_norm self.use_gated_attention = use_gated_attention - - self.attention = Attention() self.q_proj = Linear(dim, n_heads * self.head_dim) self.k_proj = Linear(dim, n_kv_heads * self.head_dim) @@ -198,6 +187,8 @@ class GQA(nn.Module): start_pos: int = 0 ) -> Tensor: bsz, seq_len, _ = x.size() + is_causal = mask is None + # x(bsz, seq_len, n_heads * head_dim) -> (bsz, seq_len, n_heads, head_dim) q = self._split_heads(self.q_proj(x), self.n_heads) k = self._split_heads(self.k_proj(x), self.n_kv_heads) @@ -219,7 +210,11 @@ class GQA(nn.Module): v = v_cache[:bsz, :start_pos + seq_len, self.layer_id] k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep) - sdqa_out = self.attention(q, k, v, mask, is_causal=(mask == None)) + + # (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim) + q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3) + # (bsz, n_heads, seq_len, head_dim) - > (bsz, seq_len, n_heads*head_dim) + sdqa_out = F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal).permute(0, 2, 1, 3).contiguous().flatten(2) if self.use_gated_attention: sdqa_out = sdqa_out * F.sigmoid(self.gate(x)) @@ -229,6 +224,100 @@ class GQA(nn.Module): return out +class MLA(nn.Module): + def __init__( + self, + dim: int, + n_heads: int, + n_kv_heads: int, + kv_lora_rank: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + norm_eps: float, + use_gated_attention: bool, + layer_id: int + ): + self.dim = dim + self.n_heads = n_heads + self.n_kv_heads = n_kv_heads + self.kv_lora_rank = kv_lora_rank + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.head_dim = qk_nope_head_dim + qk_rope_head_dim + self.layer_id = layer_id + self.n_rep = n_heads // n_kv_heads + self.use_gated_attention = use_gated_attention + + self.q_proj = Linear(dim, n_heads * self.head_dim, bias=False) + self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False) + self.kv_norm = RMSNorm(kv_lora_rank, eps=norm_eps) + + # KV (k_nope, k_rope, v) + self.kv_b_proj = Linear( + kv_lora_rank, + n_kv_heads * (self.head_dim + qk_rope_head_dim + self.head_dim), + ) + + self.o_proj = Linear(dim, dim, bias=False) + + if use_gated_attention: + self.gate = Linear(dim, dim, bias=False) + + def forward( + self, + x: Tensor, + rotary_emb: Tuple[Tensor, Tensor], + mask: Tensor = None, + kv_cache: Optional[Tuple[Tensor, Tensor]] = None, + start_pos: int = 0 + ) -> Tensor: + bsz, seq_len, _ = x.size() + is_causal = mask is None + + q = self.q_proj(x) + q = q.view(bsz, seq_len, self.n_heads, self.head_dim) + + kv_compressed = self.kv_a_proj(x) + kv_compressed = self.kv_norm(kv_compressed) + + kv = self.kv_b_proj(kv_compressed) + kv = kv.view(bsz, seq_len, self.n_kv_heads, -1) + + k_nope, k_rope, v = torch.split( + kv, + [self.qk_nope_head_dim, self.qk_rope_head_dim, self.head_dim], + dim=-1 + ) + + q_nope, q_rope = q[..., :self.qk_nope_head_dim], q[..., self.qk_rope_head_dim:] + q_rope = apply_rotary_emb(q_rope, rotary_emb) + k_rope = apply_rotary_emb(k_rope, rotary_emb) + + q = torch.cat([q_nope, q_rope], dim=-1) + k = torch.cat([k_nope, k_rope], dim=-1) + + if kv_cache is not None: + k_cache, v_cache = kv_cache + k_cache[:bsz, start_pos:start_pos + seq_len, self.layer_id] = k + v_cache[:bsz, start_pos:start_pos + seq_len, self.layer_id] = v + k = k_cache[:bsz, :start_pos + seq_len, self.layer_id] + v = v_cache[:bsz, :start_pos + seq_len, self.layer_id] + + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + + attn_out = F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal) + attn_out = attn_out.permute(0, 2, 1, 3).contiguous().flatten(2) + + if self.use_gated_attention: + attn_out = attn_out * F.sigmoid(self.gate(x)) + + out = self.o_proj(attn_out) + + return out + + class DecoderBlock(nn.Module): def __init__( self,