feat(model): 添加QK归一化和门控注意力支持

This commit is contained in:
ViperEkura 2026-01-05 16:14:44 +08:00
parent fd7ee2895a
commit eba99e1f5e
10 changed files with 151 additions and 109 deletions

View File

@ -1,37 +1,43 @@
import json import json
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from typing import Any, Dict, Optional, Self from typing import Optional, Self
@dataclass @dataclass
class ModelConfig: class ModelConfig:
# basic config # basic config
vocab_size: Optional[int] = None vocab_size: Optional[int] = None
n_dim: Optional[int] = None dim: Optional[int] = None
n_head: Optional[int] = None
n_layer: Optional[int] = None n_layers: Optional[int] = None
m_len: Optional[int] = None
norm_eps: Optional[float] = None norm_eps: Optional[float] = None
d_ffn: Optional[int] = None dim_ffn: Optional[int] = None
tie_weight: Optional[bool] = None tie_weight: Optional[bool] = None
# RoPE
max_len: Optional[int] = None
rope_theta: Optional[float] = None
# GQA # GQA
n_kvhead: Optional[int] = None n_heads: Optional[int] = None
n_kv_heads: Optional[int] = None
use_qk_norm: Optional[bool] = None
use_gated_attention: Optional[bool] = None
def load(self, config_path: str) -> Self: def load(self, config_path: str) -> Self:
config = {}
with open(config_path, 'r') as f: with open(config_path, 'r') as f:
config: Dict[str, Any] = json.load(f) config.update(json.load(f))
for key, value in config.items(): for key, value in config.items():
if hasattr(self, key): if hasattr(self, key):
setattr(self, key, value) setattr(self, key, value)
return self return self
def save(self, config_path: str) -> None: def save(self, config_path: str):
config_dict = asdict(self) config_dict = {k: v for k, v in asdict(self).items() if v is not None}
config_dict = {k: v for k, v in config_dict.items() if v is not None}
with open(config_path, 'w') as f: with open(config_path, 'w') as f:
json.dump(config_dict, f, indent=4) json.dump(config_dict, f, indent=4)

View File

@ -100,7 +100,7 @@ class GeneratorCore:
) -> List[int]: ) -> List[int]:
cur_cache_pos = start_pos cur_cache_pos = start_pos
for _ in range(len(ids), self.config.m_len): for _ in range(len(ids), self.config.max_len):
next_token_id, cache_increase = self.generate_iterator( next_token_id, cache_increase = self.generate_iterator(
input_ids, temperature, top_k, top_p, attn_mask, kv_caches, cur_cache_pos) input_ids, temperature, top_k, top_p, attn_mask, kv_caches, cur_cache_pos)
@ -127,7 +127,7 @@ class EmbeddingEncoderCore:
with_batch = isinstance(sentence, list) with_batch = isinstance(sentence, list)
ids = self.tokenizer.encode(sentence) ids = self.tokenizer.encode(sentence)
batch_ids = ids if with_batch else [ids] batch_ids = ids if with_batch else [ids]
max_model_len = self.config.m_len max_model_len = self.config.max_len
all_fragments = [] all_fragments = []
fragment_origin_idx = [] fragment_origin_idx = []
@ -195,10 +195,10 @@ class KVCacheManager:
self.batch_size = batch_size self.batch_size = batch_size
self.device = device self.device = device
self.dtype = dtype self.dtype = dtype
self.num_layers = config.n_layer self.num_layers = config.n_layers
self.max_len = config.m_len self.max_len = config.max_len
self.num_heads = config.n_kvhead self.num_heads = config.n_kv_heads
self.head_dim = config.n_dim //config.n_head self.head_dim = config.dim //config.n_heads
self._kv_cache: Tuple[Tensor, Tensor] = None self._kv_cache: Tuple[Tensor, Tensor] = None
self._seq_mask: Tensor = None self._seq_mask: Tensor = None

View File

@ -167,7 +167,7 @@ class StreamGenerator(GeneratorCore):
self.model.eval() self.model.eval()
kv_caches = cache_manager.get_kvcache() kv_caches = cache_manager.get_kvcache()
for _ in range(len(ids), self.config.m_len): for _ in range(len(ids), self.config.max_len):
next_token_id, cache_increase = self.generate_iterator( next_token_id, cache_increase = self.generate_iterator(
input_ids, temperature, top_k, top_p, kv_caches=kv_caches, start_pos=cur_cache_pos) input_ids, temperature, top_k, top_p, kv_caches=kv_caches, start_pos=cur_cache_pos)
@ -219,7 +219,7 @@ class BatchGenerator(GeneratorCore):
start_cache_pos = max_ids_len start_cache_pos = max_ids_len
cur_cache_pos = 0 cur_cache_pos = 0
while max_ids_len < self.config.m_len and sum(activate_task_mask) != 0: while max_ids_len < self.config.max_len and sum(activate_task_mask) != 0:
kv_caches = cache_manager.get_kvcache() kv_caches = cache_manager.get_kvcache()
attn_mask =cache_manager.get_seq_mask() attn_mask =cache_manager.get_seq_mask()

View File

@ -102,23 +102,20 @@ class RotaryEmbedding(nn.Module):
class Linear(nn.Module): class Linear(nn.Module):
def __init__(self, in_dim: int, out_dim: int, bias: bool = False, weight_param=None, bias_param=None): def __init__(self, in_dim: int, out_dim: int, bias: bool = False):
super().__init__() super().__init__()
weight_param = torch.empty((out_dim, in_dim)) if weight_param is None else weight_param self.weight = nn.Parameter(torch.empty((out_dim, in_dim)))
bias_param = torch.zeros(out_dim) if bias_param is None else bias_param self.bias = nn.Parameter(torch.zeros(out_dim)) if bias else None
self.weight = nn.Parameter(weight_param)
self.bias = nn.Parameter(bias_param) if bias else None
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
return F.linear(x, self.weight, self.bias) return F.linear(x, self.weight, self.bias)
class RMSNorm(nn.Module): class RMSNorm(nn.Module):
def __init__(self, n_dim, norm_eps): def __init__(self, dim, norm_eps):
super().__init__() super().__init__()
self.weight = nn.Parameter(torch.ones(n_dim)) self.weight = nn.Parameter(torch.ones(dim))
self.normalized_shape = (n_dim, ) self.normalized_shape = (dim, )
self.norm_eps = norm_eps self.norm_eps = norm_eps
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
@ -127,41 +124,70 @@ class RMSNorm(nn.Module):
class MLP(nn.Module): class MLP(nn.Module):
def __init__(self, n_dim: int, d_ffn: int): def __init__(self, dim: int, dim_feed_forward: int):
super().__init__() super().__init__()
self.up = Linear(n_dim, d_ffn) self.up = Linear(dim, dim_feed_forward)
self.gate = Linear(n_dim, d_ffn) self.gate = Linear(dim, dim_feed_forward)
self.down = Linear(d_ffn, n_dim) self.down = Linear(dim_feed_forward, dim)
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
gated = self.up(x) * F.silu(self.gate(x)) gated = self.up(x) * F.silu(self.gate(x))
out = self.down(gated) out = self.down(gated)
return out return out
class Attention(nn.Module):
def forward(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, is_causal: bool= False):
# (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim)
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
# (bsz, n_heads, seq_len, head_dim) - > (bsz, seq_len, n_heads*head_dim)
sdqa_out = F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal).permute(0, 2, 1, 3).contiguous().flatten(2)
return sdqa_out
class GQA(nn.Module): class GQA(nn.Module):
def __init__( def __init__(
self, self,
n_dim: int, dim: int,
n_head: int, n_heads: int,
n_kvhead: int, n_kv_heads: int,
use_qk_norm: bool,
norm_eps: float,
use_gated_attention: bool,
layer_id: int layer_id: int
): ):
super().__init__() super().__init__()
assert n_dim % n_head == 0 assert dim % n_heads == 0
assert n_head % n_kvhead == 0 assert n_heads % n_kv_heads == 0
self.head_dim = n_dim // n_head self.head_dim = dim // n_heads
self.layer_id = layer_id self.layer_id = layer_id
self.n_dim = n_dim self.dim = dim
self.n_heads = n_head self.n_heads = n_heads
self.n_kvheads = n_kvhead self.n_kv_heads = n_kv_heads
self.n_rep = n_head // n_kvhead self.n_rep = n_heads // n_kv_heads
self.use_qk_norm = use_qk_norm
self.use_gated_attention = use_gated_attention
self.q_proj = Linear(n_dim, n_head * self.head_dim) self.attention = Attention()
self.k_proj = Linear(n_dim, n_kvhead * self.head_dim)
self.v_proj = Linear(n_dim, n_kvhead * self.head_dim) self.q_proj = Linear(dim, n_heads * self.head_dim)
self.o_proj = Linear(n_dim, n_dim) self.k_proj = Linear(dim, n_kv_heads * self.head_dim)
self.v_proj = Linear(dim, n_kv_heads * self.head_dim)
self.o_proj = Linear(dim, dim)
if self.use_qk_norm:
self.q_norm = RMSNorm(self.head_dim, norm_eps)
self.k_norm = RMSNorm(self.head_dim, norm_eps)
if self.use_gated_attention:
self.gate = Linear(dim, dim)
def _split_heads(self, x: Tensor, n_heads) -> Tensor:
batch_size, seq_len, _ = x.shape
x = x.reshape(batch_size, seq_len, n_heads, self.head_dim)
return x
def forward( def forward(
self, self,
@ -174,10 +200,13 @@ class GQA(nn.Module):
bsz, seq_len, _ = x.size() bsz, seq_len, _ = x.size()
# x(bsz, seq_len, n_heads * head_dim) -> (bsz, seq_len, n_heads, head_dim) # x(bsz, seq_len, n_heads * head_dim) -> (bsz, seq_len, n_heads, head_dim)
q = self._split_heads(self.q_proj(x), self.n_heads) q = self._split_heads(self.q_proj(x), self.n_heads)
k = self._split_heads(self.k_proj(x), self.n_kvheads) k = self._split_heads(self.k_proj(x), self.n_kv_heads)
v = self._split_heads(self.v_proj(x), self.n_kvheads) v = self._split_heads(self.v_proj(x), self.n_kv_heads)
q, k = apply_rotary_emb(q, rotary_emb), apply_rotary_emb(k, rotary_emb) q, k = apply_rotary_emb(q, rotary_emb), apply_rotary_emb(k, rotary_emb)
if self.use_qk_norm:
q, k = self.q_norm(q), self.k_norm(k)
if kv_cache is not None: if kv_cache is not None:
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
@ -190,27 +219,34 @@ class GQA(nn.Module):
v = v_cache[:bsz, :start_pos + seq_len, self.layer_id] v = v_cache[:bsz, :start_pos + seq_len, self.layer_id]
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep) k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
sdqa_out = self.attention(q, k, v, mask, is_causal=(mask == None))
# (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim) if self.use_gated_attention:
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3) sdqa_out = sdqa_out * F.sigmoid(self.gate(x))
sdqa_out = F.scaled_dot_product_attention(q, k, v, mask, is_causal=(mask == None)).permute(0, 2, 1, 3)
out = self.o_proj(sdqa_out.contiguous().view(bsz, seq_len, -1)) out = self.o_proj(sdqa_out)
return out return out
def _split_heads(self, x: Tensor, n_heads) -> Tensor:
batch_size, seq_len, _ = x.shape
x = x.reshape(batch_size, seq_len, n_heads, self.head_dim)
return x
class DecoderBlock(nn.Module): class DecoderBlock(nn.Module):
def __init__(self, n_dim, n_head, d_ffn, n_kvhead, norm_eps, layer_id): def __init__(
self,
dim: int,
n_heads: int,
dim_ffn: int,
n_kv_heads: int,
norm_eps: int,
use_qk_norm: bool,
use_gated_attention: bool,
layer_id: int
):
super().__init__() super().__init__()
self.attention = GQA(n_dim, n_head, n_kvhead, layer_id) self.attention = GQA(dim, n_heads, n_kv_heads,
self.norm_attn = RMSNorm(n_dim, norm_eps) use_qk_norm, norm_eps, use_gated_attention, layer_id)
self.ffn = MLP(n_dim, d_ffn) self.input_norm = RMSNorm(dim, norm_eps)
self.norm_ffn = RMSNorm(n_dim, norm_eps) self.mlp = MLP(dim, dim_ffn)
self.post_attention_norm = RMSNorm(dim, norm_eps)
def forward( def forward(
self, self,
@ -222,7 +258,7 @@ class DecoderBlock(nn.Module):
) -> Tensor: ) -> Tensor:
# attention # attention
attn_output = self.attention( attn_output = self.attention(
self.norm_attn(x), self.input_norm(x),
rotary_emb, rotary_emb,
attention_mask, attention_mask,
kv_cache, kv_cache,
@ -231,16 +267,15 @@ class DecoderBlock(nn.Module):
x = attn_output + x x = attn_output + x
# feed forward # feed forward
x = self.ffn(self.norm_ffn(x)) + x x = self.mlp(self.post_attention_norm(x)) + x
return x return x
class Embedding(nn.Module): class Embedding(nn.Module):
def __init__(self, vocab_size: int, embedding_dim: int, weight_param=None): def __init__(self, vocab_size: int, embedding_dim: int):
super().__init__() super().__init__()
weight_param = torch.empty((vocab_size, embedding_dim)) if weight_param is None else weight_param self.weight = nn.Parameter(torch.empty((vocab_size, embedding_dim)))
self.weight = nn.Parameter(weight_param)
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
return F.embedding(x, self.weight) return F.embedding(x, self.weight)

View File

@ -59,16 +59,17 @@ class Transformer(nn.Module):
def __init__(self, config: ModelConfig): def __init__(self, config: ModelConfig):
super().__init__() super().__init__()
self.config = config self.config = config
self.rotary_embeding = RotaryEmbedding(config.n_dim // config.n_head, config.m_len) self.rotary_embeding = RotaryEmbedding(config.dim // config.n_heads, config.max_len)
self.embed_tokens = Embedding(config.vocab_size, config.n_dim) self.embed_tokens = Embedding(config.vocab_size, config.dim)
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
DecoderBlock(config.n_dim, config.n_head, config.d_ffn, config.n_kvhead, config.norm_eps, layer_id) DecoderBlock(config.dim, config.n_heads, config.dim_ffn, config.n_kv_heads,
for layer_id in range(config.n_layer) config.norm_eps, config.use_qk_norm, config.use_gated_attention, layer_id)
for layer_id in range(config.n_layers)
]) ])
self.norm = RMSNorm(config.n_dim, config.norm_eps) self.norm = RMSNorm(config.dim, config.norm_eps)
self.lm_head = Linear(config.n_dim, config.vocab_size) self.lm_head = Linear(config.dim, config.vocab_size)
if self.config.tie_weight == True: if self.config.tie_weight == True:
self.lm_head.weight = self.embed_tokens.weight self.lm_head.weight = self.embed_tokens.weight

View File

@ -83,19 +83,19 @@ def base_test_env(request: pytest.FixtureRequest):
n_dim_choices = [8, 16, 32] n_dim_choices = [8, 16, 32]
n_head_choices = [2, 4] n_head_choices = [2, 4]
n_dim = int(np.random.choice(n_dim_choices)) dim = int(np.random.choice(n_dim_choices))
n_head = int(np.random.choice(n_head_choices)) n_heads = int(np.random.choice(n_head_choices))
n_kvhead = n_head // 2 n_kv_heads = n_heads // 2
d_ffn = n_dim * 2 dim_ffn = dim * 2
config = { config = {
"vocab_size": 1000, "vocab_size": 1000,
"n_dim": n_dim, "dim": dim,
"n_head": n_head, "n_heads": n_heads,
"n_kvhead": n_kvhead, "n_kv_heads": n_kv_heads,
"d_ffn": d_ffn, "dim_ffn": dim_ffn,
"m_len": 1024, "max_len": 1024,
"n_layer": 4, "n_layers": 4,
"norm_eps": 1e-5 "norm_eps": 1e-5
} }

View File

@ -22,12 +22,12 @@ def test_env(request: pytest.FixtureRequest):
config = { config = {
"vocab_size": 1000, "vocab_size": 1000,
"n_dim": 128, "dim": 128,
"n_head": 4, "n_heads": 4,
"n_kvhead": 2, "n_kv_heads": 2,
"d_ffn": 256, "dim_ffn": 256,
"m_len": 64, "max_len": 64,
"n_layer": 2, "n_layers": 2,
"norm_eps": 1e-5 "norm_eps": 1e-5
} }
with open(config_path, 'w') as f: with open(config_path, 'w') as f:
@ -64,9 +64,9 @@ def test_model_parameter(test_env):
def test_transformer(test_env): def test_transformer(test_env):
model = test_env["model"] model = test_env["model"]
input_ids = torch.randint(0, test_env["transformer_config"].vocab_size, input_ids = torch.randint(0, test_env["transformer_config"].vocab_size,
(4, test_env["transformer_config"].m_len)) (4, test_env["transformer_config"].max_len))
output_logits = model(input_ids)["logits"] output_logits = model(input_ids)["logits"]
target_shape = (4, test_env["transformer_config"].m_len, test_env["transformer_config"].vocab_size) target_shape = (4, test_env["transformer_config"].max_len, test_env["transformer_config"].vocab_size)
assert output_logits.shape == target_shape assert output_logits.shape == target_shape
# generator # generator
@ -80,7 +80,7 @@ def test_embedding_encoder_core(test_env):
single_emb = encoder.encode("测试文本") single_emb = encoder.encode("测试文本")
assert isinstance(single_emb, torch.Tensor) assert isinstance(single_emb, torch.Tensor)
assert single_emb.shape[-1] == test_env["transformer_config"].n_dim assert single_emb.shape[-1] == test_env["transformer_config"].dim
batch_emb = encoder.encode(["测试1", "测试2"]) batch_emb = encoder.encode(["测试1", "测试2"])

View File

@ -16,12 +16,12 @@ def transformer_test_env():
config = { config = {
"vocab_size": 1000, "vocab_size": 1000,
"n_dim": 128, "dim": 128,
"n_head": 4, "n_heads": 4,
"n_kvhead": 2, "n_kv_heads": 2,
"d_ffn": 256, "dim_ffn": 256,
"m_len": 64, "max_len": 64,
"n_layer": 2, "n_layers": 2,
"norm_eps": 1e-5 "norm_eps": 1e-5
} }

View File

@ -28,7 +28,7 @@ class GenerationBenchmark:
def _initialize_kv_cache(self, batch_size: int) -> list: def _initialize_kv_cache(self, batch_size: int) -> list:
"""初始化KV缓存""" """初始化KV缓存"""
config = self.config config = self.config
shape = (batch_size, config.m_len, config.n_layer, config.n_kvhead, config.n_dim // config.n_head) shape = (batch_size, config.max_len, config.n_layers, config.n_kv_heads, config.dim // config.n_heads)
k_cache = torch.zeros(shape, device=self.device, dtype=self.dtype) k_cache = torch.zeros(shape, device=self.device, dtype=self.dtype)
v_cache = torch.zeros(shape, device=self.device, dtype=self.dtype) v_cache = torch.zeros(shape, device=self.device, dtype=self.dtype)
return (k_cache, v_cache) return (k_cache, v_cache)
@ -175,12 +175,12 @@ def print_benchmark_result(result: BenchmarkResult):
if __name__ == "__main__": if __name__ == "__main__":
config = ModelConfig( config = ModelConfig(
vocab_size=10000, vocab_size=10000,
n_dim=1536, dim=1536,
n_head=24, n_heads=24,
n_kvhead=4, n_kv_heads=4,
d_ffn=6912, dim_ffn=6912,
m_len=2048, max_len=2048,
n_layer=24, n_layers=24,
norm_eps=1e-5, norm_eps=1e-5,
) )

View File

@ -111,7 +111,7 @@ def train(
parameter.load(param_path) parameter.load(param_path)
if window_size is None: if window_size is None:
window_size = parameter.config.m_len window_size = parameter.config.max_len
model = parameter.model model = parameter.model