feat(model): 添加QK归一化和门控注意力支持

This commit is contained in:
ViperEkura 2026-01-05 16:14:44 +08:00
parent fd7ee2895a
commit eba99e1f5e
10 changed files with 151 additions and 109 deletions

View File

@ -1,37 +1,43 @@
import json
from dataclasses import asdict, dataclass
from typing import Any, Dict, Optional, Self
from typing import Optional, Self
@dataclass
class ModelConfig:
# basic config
vocab_size: Optional[int] = None
n_dim: Optional[int] = None
n_head: Optional[int] = None
n_layer: Optional[int] = None
m_len: Optional[int] = None
dim: Optional[int] = None
n_layers: Optional[int] = None
norm_eps: Optional[float] = None
d_ffn: Optional[int] = None
dim_ffn: Optional[int] = None
tie_weight: Optional[bool] = None
# RoPE
max_len: Optional[int] = None
rope_theta: Optional[float] = None
# GQA
n_kvhead: Optional[int] = None
n_heads: Optional[int] = None
n_kv_heads: Optional[int] = None
use_qk_norm: Optional[bool] = None
use_gated_attention: Optional[bool] = None
def load(self, config_path: str) -> Self:
config = {}
with open(config_path, 'r') as f:
config: Dict[str, Any] = json.load(f)
config.update(json.load(f))
for key, value in config.items():
if hasattr(self, key):
setattr(self, key, value)
return self
def save(self, config_path: str) -> None:
config_dict = asdict(self)
config_dict = {k: v for k, v in config_dict.items() if v is not None}
def save(self, config_path: str):
config_dict = {k: v for k, v in asdict(self).items() if v is not None}
with open(config_path, 'w') as f:
json.dump(config_dict, f, indent=4)

View File

@ -100,7 +100,7 @@ class GeneratorCore:
) -> List[int]:
cur_cache_pos = start_pos
for _ in range(len(ids), self.config.m_len):
for _ in range(len(ids), self.config.max_len):
next_token_id, cache_increase = self.generate_iterator(
input_ids, temperature, top_k, top_p, attn_mask, kv_caches, cur_cache_pos)
@ -127,7 +127,7 @@ class EmbeddingEncoderCore:
with_batch = isinstance(sentence, list)
ids = self.tokenizer.encode(sentence)
batch_ids = ids if with_batch else [ids]
max_model_len = self.config.m_len
max_model_len = self.config.max_len
all_fragments = []
fragment_origin_idx = []
@ -195,10 +195,10 @@ class KVCacheManager:
self.batch_size = batch_size
self.device = device
self.dtype = dtype
self.num_layers = config.n_layer
self.max_len = config.m_len
self.num_heads = config.n_kvhead
self.head_dim = config.n_dim //config.n_head
self.num_layers = config.n_layers
self.max_len = config.max_len
self.num_heads = config.n_kv_heads
self.head_dim = config.dim //config.n_heads
self._kv_cache: Tuple[Tensor, Tensor] = None
self._seq_mask: Tensor = None

View File

@ -167,7 +167,7 @@ class StreamGenerator(GeneratorCore):
self.model.eval()
kv_caches = cache_manager.get_kvcache()
for _ in range(len(ids), self.config.m_len):
for _ in range(len(ids), self.config.max_len):
next_token_id, cache_increase = self.generate_iterator(
input_ids, temperature, top_k, top_p, kv_caches=kv_caches, start_pos=cur_cache_pos)
@ -219,7 +219,7 @@ class BatchGenerator(GeneratorCore):
start_cache_pos = max_ids_len
cur_cache_pos = 0
while max_ids_len < self.config.m_len and sum(activate_task_mask) != 0:
while max_ids_len < self.config.max_len and sum(activate_task_mask) != 0:
kv_caches = cache_manager.get_kvcache()
attn_mask =cache_manager.get_seq_mask()

View File

@ -102,23 +102,20 @@ class RotaryEmbedding(nn.Module):
class Linear(nn.Module):
def __init__(self, in_dim: int, out_dim: int, bias: bool = False, weight_param=None, bias_param=None):
def __init__(self, in_dim: int, out_dim: int, bias: bool = False):
super().__init__()
weight_param = torch.empty((out_dim, in_dim)) if weight_param is None else weight_param
bias_param = torch.zeros(out_dim) if bias_param is None else bias_param
self.weight = nn.Parameter(weight_param)
self.bias = nn.Parameter(bias_param) if bias else None
self.weight = nn.Parameter(torch.empty((out_dim, in_dim)))
self.bias = nn.Parameter(torch.zeros(out_dim)) if bias else None
def forward(self, x: Tensor) -> Tensor:
return F.linear(x, self.weight, self.bias)
class RMSNorm(nn.Module):
def __init__(self, n_dim, norm_eps):
def __init__(self, dim, norm_eps):
super().__init__()
self.weight = nn.Parameter(torch.ones(n_dim))
self.normalized_shape = (n_dim, )
self.weight = nn.Parameter(torch.ones(dim))
self.normalized_shape = (dim, )
self.norm_eps = norm_eps
def forward(self, x: Tensor) -> Tensor:
@ -127,41 +124,70 @@ class RMSNorm(nn.Module):
class MLP(nn.Module):
def __init__(self, n_dim: int, d_ffn: int):
def __init__(self, dim: int, dim_feed_forward: int):
super().__init__()
self.up = Linear(n_dim, d_ffn)
self.gate = Linear(n_dim, d_ffn)
self.down = Linear(d_ffn, n_dim)
self.up = Linear(dim, dim_feed_forward)
self.gate = Linear(dim, dim_feed_forward)
self.down = Linear(dim_feed_forward, dim)
def forward(self, x: Tensor) -> Tensor:
gated = self.up(x) * F.silu(self.gate(x))
out = self.down(gated)
return out
class Attention(nn.Module):
def forward(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, is_causal: bool= False):
# (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim)
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
# (bsz, n_heads, seq_len, head_dim) - > (bsz, seq_len, n_heads*head_dim)
sdqa_out = F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal).permute(0, 2, 1, 3).contiguous().flatten(2)
return sdqa_out
class GQA(nn.Module):
def __init__(
self,
n_dim: int,
n_head: int,
n_kvhead: int,
dim: int,
n_heads: int,
n_kv_heads: int,
use_qk_norm: bool,
norm_eps: float,
use_gated_attention: bool,
layer_id: int
):
super().__init__()
assert n_dim % n_head == 0
assert n_head % n_kvhead == 0
assert dim % n_heads == 0
assert n_heads % n_kv_heads == 0
self.head_dim = n_dim // n_head
self.head_dim = dim // n_heads
self.layer_id = layer_id
self.n_dim = n_dim
self.n_heads = n_head
self.n_kvheads = n_kvhead
self.n_rep = n_head // n_kvhead
self.dim = dim
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.n_rep = n_heads // n_kv_heads
self.use_qk_norm = use_qk_norm
self.use_gated_attention = use_gated_attention
self.q_proj = Linear(n_dim, n_head * self.head_dim)
self.k_proj = Linear(n_dim, n_kvhead * self.head_dim)
self.v_proj = Linear(n_dim, n_kvhead * self.head_dim)
self.o_proj = Linear(n_dim, n_dim)
self.attention = Attention()
self.q_proj = Linear(dim, n_heads * self.head_dim)
self.k_proj = Linear(dim, n_kv_heads * self.head_dim)
self.v_proj = Linear(dim, n_kv_heads * self.head_dim)
self.o_proj = Linear(dim, dim)
if self.use_qk_norm:
self.q_norm = RMSNorm(self.head_dim, norm_eps)
self.k_norm = RMSNorm(self.head_dim, norm_eps)
if self.use_gated_attention:
self.gate = Linear(dim, dim)
def _split_heads(self, x: Tensor, n_heads) -> Tensor:
batch_size, seq_len, _ = x.shape
x = x.reshape(batch_size, seq_len, n_heads, self.head_dim)
return x
def forward(
self,
@ -174,10 +200,13 @@ class GQA(nn.Module):
bsz, seq_len, _ = x.size()
# x(bsz, seq_len, n_heads * head_dim) -> (bsz, seq_len, n_heads, head_dim)
q = self._split_heads(self.q_proj(x), self.n_heads)
k = self._split_heads(self.k_proj(x), self.n_kvheads)
v = self._split_heads(self.v_proj(x), self.n_kvheads)
k = self._split_heads(self.k_proj(x), self.n_kv_heads)
v = self._split_heads(self.v_proj(x), self.n_kv_heads)
q, k = apply_rotary_emb(q, rotary_emb), apply_rotary_emb(k, rotary_emb)
if self.use_qk_norm:
q, k = self.q_norm(q), self.k_norm(k)
if kv_cache is not None:
k_cache, v_cache = kv_cache
@ -190,27 +219,34 @@ class GQA(nn.Module):
v = v_cache[:bsz, :start_pos + seq_len, self.layer_id]
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
sdqa_out = self.attention(q, k, v, mask, is_causal=(mask == None))
# (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim)
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
sdqa_out = F.scaled_dot_product_attention(q, k, v, mask, is_causal=(mask == None)).permute(0, 2, 1, 3)
out = self.o_proj(sdqa_out.contiguous().view(bsz, seq_len, -1))
if self.use_gated_attention:
sdqa_out = sdqa_out * F.sigmoid(self.gate(x))
out = self.o_proj(sdqa_out)
return out
def _split_heads(self, x: Tensor, n_heads) -> Tensor:
batch_size, seq_len, _ = x.shape
x = x.reshape(batch_size, seq_len, n_heads, self.head_dim)
return x
class DecoderBlock(nn.Module):
def __init__(self, n_dim, n_head, d_ffn, n_kvhead, norm_eps, layer_id):
def __init__(
self,
dim: int,
n_heads: int,
dim_ffn: int,
n_kv_heads: int,
norm_eps: int,
use_qk_norm: bool,
use_gated_attention: bool,
layer_id: int
):
super().__init__()
self.attention = GQA(n_dim, n_head, n_kvhead, layer_id)
self.norm_attn = RMSNorm(n_dim, norm_eps)
self.ffn = MLP(n_dim, d_ffn)
self.norm_ffn = RMSNorm(n_dim, norm_eps)
self.attention = GQA(dim, n_heads, n_kv_heads,
use_qk_norm, norm_eps, use_gated_attention, layer_id)
self.input_norm = RMSNorm(dim, norm_eps)
self.mlp = MLP(dim, dim_ffn)
self.post_attention_norm = RMSNorm(dim, norm_eps)
def forward(
self,
@ -222,7 +258,7 @@ class DecoderBlock(nn.Module):
) -> Tensor:
# attention
attn_output = self.attention(
self.norm_attn(x),
self.input_norm(x),
rotary_emb,
attention_mask,
kv_cache,
@ -231,16 +267,15 @@ class DecoderBlock(nn.Module):
x = attn_output + x
# feed forward
x = self.ffn(self.norm_ffn(x)) + x
x = self.mlp(self.post_attention_norm(x)) + x
return x
class Embedding(nn.Module):
def __init__(self, vocab_size: int, embedding_dim: int, weight_param=None):
def __init__(self, vocab_size: int, embedding_dim: int):
super().__init__()
weight_param = torch.empty((vocab_size, embedding_dim)) if weight_param is None else weight_param
self.weight = nn.Parameter(weight_param)
self.weight = nn.Parameter(torch.empty((vocab_size, embedding_dim)))
def forward(self, x: Tensor) -> Tensor:
return F.embedding(x, self.weight)

View File

@ -59,16 +59,17 @@ class Transformer(nn.Module):
def __init__(self, config: ModelConfig):
super().__init__()
self.config = config
self.rotary_embeding = RotaryEmbedding(config.n_dim // config.n_head, config.m_len)
self.embed_tokens = Embedding(config.vocab_size, config.n_dim)
self.rotary_embeding = RotaryEmbedding(config.dim // config.n_heads, config.max_len)
self.embed_tokens = Embedding(config.vocab_size, config.dim)
self.layers = nn.ModuleList([
DecoderBlock(config.n_dim, config.n_head, config.d_ffn, config.n_kvhead, config.norm_eps, layer_id)
for layer_id in range(config.n_layer)
DecoderBlock(config.dim, config.n_heads, config.dim_ffn, config.n_kv_heads,
config.norm_eps, config.use_qk_norm, config.use_gated_attention, layer_id)
for layer_id in range(config.n_layers)
])
self.norm = RMSNorm(config.n_dim, config.norm_eps)
self.lm_head = Linear(config.n_dim, config.vocab_size)
self.norm = RMSNorm(config.dim, config.norm_eps)
self.lm_head = Linear(config.dim, config.vocab_size)
if self.config.tie_weight == True:
self.lm_head.weight = self.embed_tokens.weight

View File

@ -83,19 +83,19 @@ def base_test_env(request: pytest.FixtureRequest):
n_dim_choices = [8, 16, 32]
n_head_choices = [2, 4]
n_dim = int(np.random.choice(n_dim_choices))
n_head = int(np.random.choice(n_head_choices))
n_kvhead = n_head // 2
d_ffn = n_dim * 2
dim = int(np.random.choice(n_dim_choices))
n_heads = int(np.random.choice(n_head_choices))
n_kv_heads = n_heads // 2
dim_ffn = dim * 2
config = {
"vocab_size": 1000,
"n_dim": n_dim,
"n_head": n_head,
"n_kvhead": n_kvhead,
"d_ffn": d_ffn,
"m_len": 1024,
"n_layer": 4,
"dim": dim,
"n_heads": n_heads,
"n_kv_heads": n_kv_heads,
"dim_ffn": dim_ffn,
"max_len": 1024,
"n_layers": 4,
"norm_eps": 1e-5
}

View File

@ -22,12 +22,12 @@ def test_env(request: pytest.FixtureRequest):
config = {
"vocab_size": 1000,
"n_dim": 128,
"n_head": 4,
"n_kvhead": 2,
"d_ffn": 256,
"m_len": 64,
"n_layer": 2,
"dim": 128,
"n_heads": 4,
"n_kv_heads": 2,
"dim_ffn": 256,
"max_len": 64,
"n_layers": 2,
"norm_eps": 1e-5
}
with open(config_path, 'w') as f:
@ -64,9 +64,9 @@ def test_model_parameter(test_env):
def test_transformer(test_env):
model = test_env["model"]
input_ids = torch.randint(0, test_env["transformer_config"].vocab_size,
(4, test_env["transformer_config"].m_len))
(4, test_env["transformer_config"].max_len))
output_logits = model(input_ids)["logits"]
target_shape = (4, test_env["transformer_config"].m_len, test_env["transformer_config"].vocab_size)
target_shape = (4, test_env["transformer_config"].max_len, test_env["transformer_config"].vocab_size)
assert output_logits.shape == target_shape
# generator
@ -80,7 +80,7 @@ def test_embedding_encoder_core(test_env):
single_emb = encoder.encode("测试文本")
assert isinstance(single_emb, torch.Tensor)
assert single_emb.shape[-1] == test_env["transformer_config"].n_dim
assert single_emb.shape[-1] == test_env["transformer_config"].dim
batch_emb = encoder.encode(["测试1", "测试2"])

View File

@ -16,12 +16,12 @@ def transformer_test_env():
config = {
"vocab_size": 1000,
"n_dim": 128,
"n_head": 4,
"n_kvhead": 2,
"d_ffn": 256,
"m_len": 64,
"n_layer": 2,
"dim": 128,
"n_heads": 4,
"n_kv_heads": 2,
"dim_ffn": 256,
"max_len": 64,
"n_layers": 2,
"norm_eps": 1e-5
}

View File

@ -28,7 +28,7 @@ class GenerationBenchmark:
def _initialize_kv_cache(self, batch_size: int) -> list:
"""初始化KV缓存"""
config = self.config
shape = (batch_size, config.m_len, config.n_layer, config.n_kvhead, config.n_dim // config.n_head)
shape = (batch_size, config.max_len, config.n_layers, config.n_kv_heads, config.dim // config.n_heads)
k_cache = torch.zeros(shape, device=self.device, dtype=self.dtype)
v_cache = torch.zeros(shape, device=self.device, dtype=self.dtype)
return (k_cache, v_cache)
@ -175,12 +175,12 @@ def print_benchmark_result(result: BenchmarkResult):
if __name__ == "__main__":
config = ModelConfig(
vocab_size=10000,
n_dim=1536,
n_head=24,
n_kvhead=4,
d_ffn=6912,
m_len=2048,
n_layer=24,
dim=1536,
n_heads=24,
n_kv_heads=4,
dim_ffn=6912,
max_len=2048,
n_layers=24,
norm_eps=1e-5,
)

View File

@ -111,7 +111,7 @@ def train(
parameter.load(param_path)
if window_size is None:
window_size = parameter.config.m_len
window_size = parameter.config.max_len
model = parameter.model