feat(model): 添加QK归一化和门控注意力支持
This commit is contained in:
parent
fd7ee2895a
commit
eba99e1f5e
|
|
@ -1,37 +1,43 @@
|
|||
import json
|
||||
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Any, Dict, Optional, Self
|
||||
from typing import Optional, Self
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
# basic config
|
||||
vocab_size: Optional[int] = None
|
||||
n_dim: Optional[int] = None
|
||||
n_head: Optional[int] = None
|
||||
n_layer: Optional[int] = None
|
||||
m_len: Optional[int] = None
|
||||
dim: Optional[int] = None
|
||||
|
||||
n_layers: Optional[int] = None
|
||||
norm_eps: Optional[float] = None
|
||||
d_ffn: Optional[int] = None
|
||||
dim_ffn: Optional[int] = None
|
||||
tie_weight: Optional[bool] = None
|
||||
|
||||
# RoPE
|
||||
max_len: Optional[int] = None
|
||||
rope_theta: Optional[float] = None
|
||||
|
||||
# GQA
|
||||
n_kvhead: Optional[int] = None
|
||||
n_heads: Optional[int] = None
|
||||
n_kv_heads: Optional[int] = None
|
||||
use_qk_norm: Optional[bool] = None
|
||||
use_gated_attention: Optional[bool] = None
|
||||
|
||||
|
||||
def load(self, config_path: str) -> Self:
|
||||
config = {}
|
||||
with open(config_path, 'r') as f:
|
||||
config: Dict[str, Any] = json.load(f)
|
||||
config.update(json.load(f))
|
||||
|
||||
for key, value in config.items():
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
|
||||
return self
|
||||
|
||||
def save(self, config_path: str) -> None:
|
||||
config_dict = asdict(self)
|
||||
config_dict = {k: v for k, v in config_dict.items() if v is not None}
|
||||
def save(self, config_path: str):
|
||||
config_dict = {k: v for k, v in asdict(self).items() if v is not None}
|
||||
with open(config_path, 'w') as f:
|
||||
json.dump(config_dict, f, indent=4)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -100,7 +100,7 @@ class GeneratorCore:
|
|||
) -> List[int]:
|
||||
cur_cache_pos = start_pos
|
||||
|
||||
for _ in range(len(ids), self.config.m_len):
|
||||
for _ in range(len(ids), self.config.max_len):
|
||||
next_token_id, cache_increase = self.generate_iterator(
|
||||
input_ids, temperature, top_k, top_p, attn_mask, kv_caches, cur_cache_pos)
|
||||
|
||||
|
|
@ -127,7 +127,7 @@ class EmbeddingEncoderCore:
|
|||
with_batch = isinstance(sentence, list)
|
||||
ids = self.tokenizer.encode(sentence)
|
||||
batch_ids = ids if with_batch else [ids]
|
||||
max_model_len = self.config.m_len
|
||||
max_model_len = self.config.max_len
|
||||
|
||||
all_fragments = []
|
||||
fragment_origin_idx = []
|
||||
|
|
@ -195,10 +195,10 @@ class KVCacheManager:
|
|||
self.batch_size = batch_size
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
self.num_layers = config.n_layer
|
||||
self.max_len = config.m_len
|
||||
self.num_heads = config.n_kvhead
|
||||
self.head_dim = config.n_dim //config.n_head
|
||||
self.num_layers = config.n_layers
|
||||
self.max_len = config.max_len
|
||||
self.num_heads = config.n_kv_heads
|
||||
self.head_dim = config.dim //config.n_heads
|
||||
|
||||
self._kv_cache: Tuple[Tensor, Tensor] = None
|
||||
self._seq_mask: Tensor = None
|
||||
|
|
|
|||
|
|
@ -167,7 +167,7 @@ class StreamGenerator(GeneratorCore):
|
|||
self.model.eval()
|
||||
kv_caches = cache_manager.get_kvcache()
|
||||
|
||||
for _ in range(len(ids), self.config.m_len):
|
||||
for _ in range(len(ids), self.config.max_len):
|
||||
next_token_id, cache_increase = self.generate_iterator(
|
||||
input_ids, temperature, top_k, top_p, kv_caches=kv_caches, start_pos=cur_cache_pos)
|
||||
|
||||
|
|
@ -219,7 +219,7 @@ class BatchGenerator(GeneratorCore):
|
|||
start_cache_pos = max_ids_len
|
||||
cur_cache_pos = 0
|
||||
|
||||
while max_ids_len < self.config.m_len and sum(activate_task_mask) != 0:
|
||||
while max_ids_len < self.config.max_len and sum(activate_task_mask) != 0:
|
||||
kv_caches = cache_manager.get_kvcache()
|
||||
attn_mask =cache_manager.get_seq_mask()
|
||||
|
||||
|
|
|
|||
|
|
@ -102,23 +102,20 @@ class RotaryEmbedding(nn.Module):
|
|||
|
||||
|
||||
class Linear(nn.Module):
|
||||
def __init__(self, in_dim: int, out_dim: int, bias: bool = False, weight_param=None, bias_param=None):
|
||||
def __init__(self, in_dim: int, out_dim: int, bias: bool = False):
|
||||
super().__init__()
|
||||
weight_param = torch.empty((out_dim, in_dim)) if weight_param is None else weight_param
|
||||
bias_param = torch.zeros(out_dim) if bias_param is None else bias_param
|
||||
|
||||
self.weight = nn.Parameter(weight_param)
|
||||
self.bias = nn.Parameter(bias_param) if bias else None
|
||||
self.weight = nn.Parameter(torch.empty((out_dim, in_dim)))
|
||||
self.bias = nn.Parameter(torch.zeros(out_dim)) if bias else None
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return F.linear(x, self.weight, self.bias)
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, n_dim, norm_eps):
|
||||
def __init__(self, dim, norm_eps):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(n_dim))
|
||||
self.normalized_shape = (n_dim, )
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
self.normalized_shape = (dim, )
|
||||
self.norm_eps = norm_eps
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
|
|
@ -127,41 +124,70 @@ class RMSNorm(nn.Module):
|
|||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, n_dim: int, d_ffn: int):
|
||||
def __init__(self, dim: int, dim_feed_forward: int):
|
||||
super().__init__()
|
||||
self.up = Linear(n_dim, d_ffn)
|
||||
self.gate = Linear(n_dim, d_ffn)
|
||||
self.down = Linear(d_ffn, n_dim)
|
||||
self.up = Linear(dim, dim_feed_forward)
|
||||
self.gate = Linear(dim, dim_feed_forward)
|
||||
self.down = Linear(dim_feed_forward, dim)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
gated = self.up(x) * F.silu(self.gate(x))
|
||||
out = self.down(gated)
|
||||
return out
|
||||
|
||||
class Attention(nn.Module):
|
||||
|
||||
def forward(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, is_causal: bool= False):
|
||||
# (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim)
|
||||
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
|
||||
# (bsz, n_heads, seq_len, head_dim) - > (bsz, seq_len, n_heads*head_dim)
|
||||
sdqa_out = F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal).permute(0, 2, 1, 3).contiguous().flatten(2)
|
||||
|
||||
return sdqa_out
|
||||
|
||||
|
||||
class GQA(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_dim: int,
|
||||
n_head: int,
|
||||
n_kvhead: int,
|
||||
dim: int,
|
||||
n_heads: int,
|
||||
n_kv_heads: int,
|
||||
use_qk_norm: bool,
|
||||
norm_eps: float,
|
||||
use_gated_attention: bool,
|
||||
layer_id: int
|
||||
):
|
||||
super().__init__()
|
||||
assert n_dim % n_head == 0
|
||||
assert n_head % n_kvhead == 0
|
||||
assert dim % n_heads == 0
|
||||
assert n_heads % n_kv_heads == 0
|
||||
|
||||
self.head_dim = n_dim // n_head
|
||||
self.head_dim = dim // n_heads
|
||||
self.layer_id = layer_id
|
||||
self.n_dim = n_dim
|
||||
self.n_heads = n_head
|
||||
self.n_kvheads = n_kvhead
|
||||
self.n_rep = n_head // n_kvhead
|
||||
self.dim = dim
|
||||
self.n_heads = n_heads
|
||||
self.n_kv_heads = n_kv_heads
|
||||
self.n_rep = n_heads // n_kv_heads
|
||||
self.use_qk_norm = use_qk_norm
|
||||
self.use_gated_attention = use_gated_attention
|
||||
|
||||
self.q_proj = Linear(n_dim, n_head * self.head_dim)
|
||||
self.k_proj = Linear(n_dim, n_kvhead * self.head_dim)
|
||||
self.v_proj = Linear(n_dim, n_kvhead * self.head_dim)
|
||||
self.o_proj = Linear(n_dim, n_dim)
|
||||
self.attention = Attention()
|
||||
|
||||
self.q_proj = Linear(dim, n_heads * self.head_dim)
|
||||
self.k_proj = Linear(dim, n_kv_heads * self.head_dim)
|
||||
self.v_proj = Linear(dim, n_kv_heads * self.head_dim)
|
||||
self.o_proj = Linear(dim, dim)
|
||||
|
||||
if self.use_qk_norm:
|
||||
self.q_norm = RMSNorm(self.head_dim, norm_eps)
|
||||
self.k_norm = RMSNorm(self.head_dim, norm_eps)
|
||||
|
||||
if self.use_gated_attention:
|
||||
self.gate = Linear(dim, dim)
|
||||
|
||||
def _split_heads(self, x: Tensor, n_heads) -> Tensor:
|
||||
batch_size, seq_len, _ = x.shape
|
||||
x = x.reshape(batch_size, seq_len, n_heads, self.head_dim)
|
||||
return x
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
@ -174,10 +200,13 @@ class GQA(nn.Module):
|
|||
bsz, seq_len, _ = x.size()
|
||||
# x(bsz, seq_len, n_heads * head_dim) -> (bsz, seq_len, n_heads, head_dim)
|
||||
q = self._split_heads(self.q_proj(x), self.n_heads)
|
||||
k = self._split_heads(self.k_proj(x), self.n_kvheads)
|
||||
v = self._split_heads(self.v_proj(x), self.n_kvheads)
|
||||
k = self._split_heads(self.k_proj(x), self.n_kv_heads)
|
||||
v = self._split_heads(self.v_proj(x), self.n_kv_heads)
|
||||
q, k = apply_rotary_emb(q, rotary_emb), apply_rotary_emb(k, rotary_emb)
|
||||
|
||||
if self.use_qk_norm:
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
|
||||
if kv_cache is not None:
|
||||
k_cache, v_cache = kv_cache
|
||||
|
||||
|
|
@ -190,27 +219,34 @@ class GQA(nn.Module):
|
|||
v = v_cache[:bsz, :start_pos + seq_len, self.layer_id]
|
||||
|
||||
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
|
||||
sdqa_out = self.attention(q, k, v, mask, is_causal=(mask == None))
|
||||
|
||||
# (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim)
|
||||
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
|
||||
sdqa_out = F.scaled_dot_product_attention(q, k, v, mask, is_causal=(mask == None)).permute(0, 2, 1, 3)
|
||||
out = self.o_proj(sdqa_out.contiguous().view(bsz, seq_len, -1))
|
||||
if self.use_gated_attention:
|
||||
sdqa_out = sdqa_out * F.sigmoid(self.gate(x))
|
||||
|
||||
out = self.o_proj(sdqa_out)
|
||||
|
||||
return out
|
||||
|
||||
def _split_heads(self, x: Tensor, n_heads) -> Tensor:
|
||||
batch_size, seq_len, _ = x.shape
|
||||
x = x.reshape(batch_size, seq_len, n_heads, self.head_dim)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class DecoderBlock(nn.Module):
|
||||
def __init__(self, n_dim, n_head, d_ffn, n_kvhead, norm_eps, layer_id):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
n_heads: int,
|
||||
dim_ffn: int,
|
||||
n_kv_heads: int,
|
||||
norm_eps: int,
|
||||
use_qk_norm: bool,
|
||||
use_gated_attention: bool,
|
||||
layer_id: int
|
||||
):
|
||||
super().__init__()
|
||||
self.attention = GQA(n_dim, n_head, n_kvhead, layer_id)
|
||||
self.norm_attn = RMSNorm(n_dim, norm_eps)
|
||||
self.ffn = MLP(n_dim, d_ffn)
|
||||
self.norm_ffn = RMSNorm(n_dim, norm_eps)
|
||||
self.attention = GQA(dim, n_heads, n_kv_heads,
|
||||
use_qk_norm, norm_eps, use_gated_attention, layer_id)
|
||||
self.input_norm = RMSNorm(dim, norm_eps)
|
||||
self.mlp = MLP(dim, dim_ffn)
|
||||
self.post_attention_norm = RMSNorm(dim, norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
@ -222,7 +258,7 @@ class DecoderBlock(nn.Module):
|
|||
) -> Tensor:
|
||||
# attention
|
||||
attn_output = self.attention(
|
||||
self.norm_attn(x),
|
||||
self.input_norm(x),
|
||||
rotary_emb,
|
||||
attention_mask,
|
||||
kv_cache,
|
||||
|
|
@ -231,16 +267,15 @@ class DecoderBlock(nn.Module):
|
|||
x = attn_output + x
|
||||
|
||||
# feed forward
|
||||
x = self.ffn(self.norm_ffn(x)) + x
|
||||
x = self.mlp(self.post_attention_norm(x)) + x
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Embedding(nn.Module):
|
||||
def __init__(self, vocab_size: int, embedding_dim: int, weight_param=None):
|
||||
def __init__(self, vocab_size: int, embedding_dim: int):
|
||||
super().__init__()
|
||||
weight_param = torch.empty((vocab_size, embedding_dim)) if weight_param is None else weight_param
|
||||
self.weight = nn.Parameter(weight_param)
|
||||
self.weight = nn.Parameter(torch.empty((vocab_size, embedding_dim)))
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return F.embedding(x, self.weight)
|
||||
|
|
@ -59,16 +59,17 @@ class Transformer(nn.Module):
|
|||
def __init__(self, config: ModelConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.rotary_embeding = RotaryEmbedding(config.n_dim // config.n_head, config.m_len)
|
||||
self.embed_tokens = Embedding(config.vocab_size, config.n_dim)
|
||||
self.rotary_embeding = RotaryEmbedding(config.dim // config.n_heads, config.max_len)
|
||||
self.embed_tokens = Embedding(config.vocab_size, config.dim)
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
DecoderBlock(config.n_dim, config.n_head, config.d_ffn, config.n_kvhead, config.norm_eps, layer_id)
|
||||
for layer_id in range(config.n_layer)
|
||||
DecoderBlock(config.dim, config.n_heads, config.dim_ffn, config.n_kv_heads,
|
||||
config.norm_eps, config.use_qk_norm, config.use_gated_attention, layer_id)
|
||||
for layer_id in range(config.n_layers)
|
||||
])
|
||||
|
||||
self.norm = RMSNorm(config.n_dim, config.norm_eps)
|
||||
self.lm_head = Linear(config.n_dim, config.vocab_size)
|
||||
self.norm = RMSNorm(config.dim, config.norm_eps)
|
||||
self.lm_head = Linear(config.dim, config.vocab_size)
|
||||
|
||||
if self.config.tie_weight == True:
|
||||
self.lm_head.weight = self.embed_tokens.weight
|
||||
|
|
|
|||
|
|
@ -83,19 +83,19 @@ def base_test_env(request: pytest.FixtureRequest):
|
|||
n_dim_choices = [8, 16, 32]
|
||||
n_head_choices = [2, 4]
|
||||
|
||||
n_dim = int(np.random.choice(n_dim_choices))
|
||||
n_head = int(np.random.choice(n_head_choices))
|
||||
n_kvhead = n_head // 2
|
||||
d_ffn = n_dim * 2
|
||||
dim = int(np.random.choice(n_dim_choices))
|
||||
n_heads = int(np.random.choice(n_head_choices))
|
||||
n_kv_heads = n_heads // 2
|
||||
dim_ffn = dim * 2
|
||||
|
||||
config = {
|
||||
"vocab_size": 1000,
|
||||
"n_dim": n_dim,
|
||||
"n_head": n_head,
|
||||
"n_kvhead": n_kvhead,
|
||||
"d_ffn": d_ffn,
|
||||
"m_len": 1024,
|
||||
"n_layer": 4,
|
||||
"dim": dim,
|
||||
"n_heads": n_heads,
|
||||
"n_kv_heads": n_kv_heads,
|
||||
"dim_ffn": dim_ffn,
|
||||
"max_len": 1024,
|
||||
"n_layers": 4,
|
||||
"norm_eps": 1e-5
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -22,12 +22,12 @@ def test_env(request: pytest.FixtureRequest):
|
|||
|
||||
config = {
|
||||
"vocab_size": 1000,
|
||||
"n_dim": 128,
|
||||
"n_head": 4,
|
||||
"n_kvhead": 2,
|
||||
"d_ffn": 256,
|
||||
"m_len": 64,
|
||||
"n_layer": 2,
|
||||
"dim": 128,
|
||||
"n_heads": 4,
|
||||
"n_kv_heads": 2,
|
||||
"dim_ffn": 256,
|
||||
"max_len": 64,
|
||||
"n_layers": 2,
|
||||
"norm_eps": 1e-5
|
||||
}
|
||||
with open(config_path, 'w') as f:
|
||||
|
|
@ -64,9 +64,9 @@ def test_model_parameter(test_env):
|
|||
def test_transformer(test_env):
|
||||
model = test_env["model"]
|
||||
input_ids = torch.randint(0, test_env["transformer_config"].vocab_size,
|
||||
(4, test_env["transformer_config"].m_len))
|
||||
(4, test_env["transformer_config"].max_len))
|
||||
output_logits = model(input_ids)["logits"]
|
||||
target_shape = (4, test_env["transformer_config"].m_len, test_env["transformer_config"].vocab_size)
|
||||
target_shape = (4, test_env["transformer_config"].max_len, test_env["transformer_config"].vocab_size)
|
||||
assert output_logits.shape == target_shape
|
||||
|
||||
# generator
|
||||
|
|
@ -80,7 +80,7 @@ def test_embedding_encoder_core(test_env):
|
|||
|
||||
single_emb = encoder.encode("测试文本")
|
||||
assert isinstance(single_emb, torch.Tensor)
|
||||
assert single_emb.shape[-1] == test_env["transformer_config"].n_dim
|
||||
assert single_emb.shape[-1] == test_env["transformer_config"].dim
|
||||
|
||||
|
||||
batch_emb = encoder.encode(["测试1", "测试2"])
|
||||
|
|
|
|||
|
|
@ -16,12 +16,12 @@ def transformer_test_env():
|
|||
|
||||
config = {
|
||||
"vocab_size": 1000,
|
||||
"n_dim": 128,
|
||||
"n_head": 4,
|
||||
"n_kvhead": 2,
|
||||
"d_ffn": 256,
|
||||
"m_len": 64,
|
||||
"n_layer": 2,
|
||||
"dim": 128,
|
||||
"n_heads": 4,
|
||||
"n_kv_heads": 2,
|
||||
"dim_ffn": 256,
|
||||
"max_len": 64,
|
||||
"n_layers": 2,
|
||||
"norm_eps": 1e-5
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ class GenerationBenchmark:
|
|||
def _initialize_kv_cache(self, batch_size: int) -> list:
|
||||
"""初始化KV缓存"""
|
||||
config = self.config
|
||||
shape = (batch_size, config.m_len, config.n_layer, config.n_kvhead, config.n_dim // config.n_head)
|
||||
shape = (batch_size, config.max_len, config.n_layers, config.n_kv_heads, config.dim // config.n_heads)
|
||||
k_cache = torch.zeros(shape, device=self.device, dtype=self.dtype)
|
||||
v_cache = torch.zeros(shape, device=self.device, dtype=self.dtype)
|
||||
return (k_cache, v_cache)
|
||||
|
|
@ -175,12 +175,12 @@ def print_benchmark_result(result: BenchmarkResult):
|
|||
if __name__ == "__main__":
|
||||
config = ModelConfig(
|
||||
vocab_size=10000,
|
||||
n_dim=1536,
|
||||
n_head=24,
|
||||
n_kvhead=4,
|
||||
d_ffn=6912,
|
||||
m_len=2048,
|
||||
n_layer=24,
|
||||
dim=1536,
|
||||
n_heads=24,
|
||||
n_kv_heads=4,
|
||||
dim_ffn=6912,
|
||||
max_len=2048,
|
||||
n_layers=24,
|
||||
norm_eps=1e-5,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -111,7 +111,7 @@ def train(
|
|||
parameter.load(param_path)
|
||||
|
||||
if window_size is None:
|
||||
window_size = parameter.config.m_len
|
||||
window_size = parameter.config.max_len
|
||||
|
||||
model = parameter.model
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue