diff --git a/khaosz/config/model_config.py b/khaosz/config/model_config.py index 84977a8..62947ec 100644 --- a/khaosz/config/model_config.py +++ b/khaosz/config/model_config.py @@ -1,37 +1,43 @@ import json from dataclasses import asdict, dataclass -from typing import Any, Dict, Optional, Self +from typing import Optional, Self + @dataclass class ModelConfig: # basic config vocab_size: Optional[int] = None - n_dim: Optional[int] = None - n_head: Optional[int] = None - n_layer: Optional[int] = None - m_len: Optional[int] = None + dim: Optional[int] = None + + n_layers: Optional[int] = None norm_eps: Optional[float] = None - d_ffn: Optional[int] = None + dim_ffn: Optional[int] = None tie_weight: Optional[bool] = None + # RoPE + max_len: Optional[int] = None + rope_theta: Optional[float] = None + # GQA - n_kvhead: Optional[int] = None + n_heads: Optional[int] = None + n_kv_heads: Optional[int] = None + use_qk_norm: Optional[bool] = None + use_gated_attention: Optional[bool] = None def load(self, config_path: str) -> Self: + config = {} with open(config_path, 'r') as f: - config: Dict[str, Any] = json.load(f) + config.update(json.load(f)) + for key, value in config.items(): if hasattr(self, key): setattr(self, key, value) return self - def save(self, config_path: str) -> None: - config_dict = asdict(self) - config_dict = {k: v for k, v in config_dict.items() if v is not None} + def save(self, config_path: str): + config_dict = {k: v for k, v in asdict(self).items() if v is not None} with open(config_path, 'w') as f: json.dump(config_dict, f, indent=4) - - diff --git a/khaosz/inference/core.py b/khaosz/inference/core.py index 2bd6e51..ddbb24a 100644 --- a/khaosz/inference/core.py +++ b/khaosz/inference/core.py @@ -100,7 +100,7 @@ class GeneratorCore: ) -> List[int]: cur_cache_pos = start_pos - for _ in range(len(ids), self.config.m_len): + for _ in range(len(ids), self.config.max_len): next_token_id, cache_increase = self.generate_iterator( input_ids, temperature, top_k, top_p, attn_mask, kv_caches, cur_cache_pos) @@ -127,7 +127,7 @@ class EmbeddingEncoderCore: with_batch = isinstance(sentence, list) ids = self.tokenizer.encode(sentence) batch_ids = ids if with_batch else [ids] - max_model_len = self.config.m_len + max_model_len = self.config.max_len all_fragments = [] fragment_origin_idx = [] @@ -195,10 +195,10 @@ class KVCacheManager: self.batch_size = batch_size self.device = device self.dtype = dtype - self.num_layers = config.n_layer - self.max_len = config.m_len - self.num_heads = config.n_kvhead - self.head_dim = config.n_dim //config.n_head + self.num_layers = config.n_layers + self.max_len = config.max_len + self.num_heads = config.n_kv_heads + self.head_dim = config.dim //config.n_heads self._kv_cache: Tuple[Tensor, Tensor] = None self._seq_mask: Tensor = None diff --git a/khaosz/inference/generator.py b/khaosz/inference/generator.py index 1402a46..034a1b5 100644 --- a/khaosz/inference/generator.py +++ b/khaosz/inference/generator.py @@ -167,7 +167,7 @@ class StreamGenerator(GeneratorCore): self.model.eval() kv_caches = cache_manager.get_kvcache() - for _ in range(len(ids), self.config.m_len): + for _ in range(len(ids), self.config.max_len): next_token_id, cache_increase = self.generate_iterator( input_ids, temperature, top_k, top_p, kv_caches=kv_caches, start_pos=cur_cache_pos) @@ -219,7 +219,7 @@ class BatchGenerator(GeneratorCore): start_cache_pos = max_ids_len cur_cache_pos = 0 - while max_ids_len < self.config.m_len and sum(activate_task_mask) != 0: + while max_ids_len < self.config.max_len and sum(activate_task_mask) != 0: kv_caches = cache_manager.get_kvcache() attn_mask =cache_manager.get_seq_mask() diff --git a/khaosz/model/module.py b/khaosz/model/module.py index 31c26ab..a6c7fef 100644 --- a/khaosz/model/module.py +++ b/khaosz/model/module.py @@ -102,23 +102,20 @@ class RotaryEmbedding(nn.Module): class Linear(nn.Module): - def __init__(self, in_dim: int, out_dim: int, bias: bool = False, weight_param=None, bias_param=None): + def __init__(self, in_dim: int, out_dim: int, bias: bool = False): super().__init__() - weight_param = torch.empty((out_dim, in_dim)) if weight_param is None else weight_param - bias_param = torch.zeros(out_dim) if bias_param is None else bias_param - - self.weight = nn.Parameter(weight_param) - self.bias = nn.Parameter(bias_param) if bias else None + self.weight = nn.Parameter(torch.empty((out_dim, in_dim))) + self.bias = nn.Parameter(torch.zeros(out_dim)) if bias else None def forward(self, x: Tensor) -> Tensor: return F.linear(x, self.weight, self.bias) class RMSNorm(nn.Module): - def __init__(self, n_dim, norm_eps): + def __init__(self, dim, norm_eps): super().__init__() - self.weight = nn.Parameter(torch.ones(n_dim)) - self.normalized_shape = (n_dim, ) + self.weight = nn.Parameter(torch.ones(dim)) + self.normalized_shape = (dim, ) self.norm_eps = norm_eps def forward(self, x: Tensor) -> Tensor: @@ -127,41 +124,70 @@ class RMSNorm(nn.Module): class MLP(nn.Module): - def __init__(self, n_dim: int, d_ffn: int): + def __init__(self, dim: int, dim_feed_forward: int): super().__init__() - self.up = Linear(n_dim, d_ffn) - self.gate = Linear(n_dim, d_ffn) - self.down = Linear(d_ffn, n_dim) + self.up = Linear(dim, dim_feed_forward) + self.gate = Linear(dim, dim_feed_forward) + self.down = Linear(dim_feed_forward, dim) def forward(self, x: Tensor) -> Tensor: gated = self.up(x) * F.silu(self.gate(x)) out = self.down(gated) return out +class Attention(nn.Module): + + def forward(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, is_causal: bool= False): + # (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim) + q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3) + # (bsz, n_heads, seq_len, head_dim) - > (bsz, seq_len, n_heads*head_dim) + sdqa_out = F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal).permute(0, 2, 1, 3).contiguous().flatten(2) + + return sdqa_out + class GQA(nn.Module): def __init__( self, - n_dim: int, - n_head: int, - n_kvhead: int, + dim: int, + n_heads: int, + n_kv_heads: int, + use_qk_norm: bool, + norm_eps: float, + use_gated_attention: bool, layer_id: int ): super().__init__() - assert n_dim % n_head == 0 - assert n_head % n_kvhead == 0 + assert dim % n_heads == 0 + assert n_heads % n_kv_heads == 0 - self.head_dim = n_dim // n_head + self.head_dim = dim // n_heads self.layer_id = layer_id - self.n_dim = n_dim - self.n_heads = n_head - self.n_kvheads = n_kvhead - self.n_rep = n_head // n_kvhead + self.dim = dim + self.n_heads = n_heads + self.n_kv_heads = n_kv_heads + self.n_rep = n_heads // n_kv_heads + self.use_qk_norm = use_qk_norm + self.use_gated_attention = use_gated_attention - self.q_proj = Linear(n_dim, n_head * self.head_dim) - self.k_proj = Linear(n_dim, n_kvhead * self.head_dim) - self.v_proj = Linear(n_dim, n_kvhead * self.head_dim) - self.o_proj = Linear(n_dim, n_dim) + self.attention = Attention() + + self.q_proj = Linear(dim, n_heads * self.head_dim) + self.k_proj = Linear(dim, n_kv_heads * self.head_dim) + self.v_proj = Linear(dim, n_kv_heads * self.head_dim) + self.o_proj = Linear(dim, dim) + + if self.use_qk_norm: + self.q_norm = RMSNorm(self.head_dim, norm_eps) + self.k_norm = RMSNorm(self.head_dim, norm_eps) + + if self.use_gated_attention: + self.gate = Linear(dim, dim) + + def _split_heads(self, x: Tensor, n_heads) -> Tensor: + batch_size, seq_len, _ = x.shape + x = x.reshape(batch_size, seq_len, n_heads, self.head_dim) + return x def forward( self, @@ -174,10 +200,13 @@ class GQA(nn.Module): bsz, seq_len, _ = x.size() # x(bsz, seq_len, n_heads * head_dim) -> (bsz, seq_len, n_heads, head_dim) q = self._split_heads(self.q_proj(x), self.n_heads) - k = self._split_heads(self.k_proj(x), self.n_kvheads) - v = self._split_heads(self.v_proj(x), self.n_kvheads) + k = self._split_heads(self.k_proj(x), self.n_kv_heads) + v = self._split_heads(self.v_proj(x), self.n_kv_heads) q, k = apply_rotary_emb(q, rotary_emb), apply_rotary_emb(k, rotary_emb) + if self.use_qk_norm: + q, k = self.q_norm(q), self.k_norm(k) + if kv_cache is not None: k_cache, v_cache = kv_cache @@ -190,27 +219,34 @@ class GQA(nn.Module): v = v_cache[:bsz, :start_pos + seq_len, self.layer_id] k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep) + sdqa_out = self.attention(q, k, v, mask, is_causal=(mask == None)) - # (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim) - q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3) - sdqa_out = F.scaled_dot_product_attention(q, k, v, mask, is_causal=(mask == None)).permute(0, 2, 1, 3) - out = self.o_proj(sdqa_out.contiguous().view(bsz, seq_len, -1)) + if self.use_gated_attention: + sdqa_out = sdqa_out * F.sigmoid(self.gate(x)) + + out = self.o_proj(sdqa_out) return out - - def _split_heads(self, x: Tensor, n_heads) -> Tensor: - batch_size, seq_len, _ = x.shape - x = x.reshape(batch_size, seq_len, n_heads, self.head_dim) - return x - + class DecoderBlock(nn.Module): - def __init__(self, n_dim, n_head, d_ffn, n_kvhead, norm_eps, layer_id): + def __init__( + self, + dim: int, + n_heads: int, + dim_ffn: int, + n_kv_heads: int, + norm_eps: int, + use_qk_norm: bool, + use_gated_attention: bool, + layer_id: int + ): super().__init__() - self.attention = GQA(n_dim, n_head, n_kvhead, layer_id) - self.norm_attn = RMSNorm(n_dim, norm_eps) - self.ffn = MLP(n_dim, d_ffn) - self.norm_ffn = RMSNorm(n_dim, norm_eps) + self.attention = GQA(dim, n_heads, n_kv_heads, + use_qk_norm, norm_eps, use_gated_attention, layer_id) + self.input_norm = RMSNorm(dim, norm_eps) + self.mlp = MLP(dim, dim_ffn) + self.post_attention_norm = RMSNorm(dim, norm_eps) def forward( self, @@ -222,7 +258,7 @@ class DecoderBlock(nn.Module): ) -> Tensor: # attention attn_output = self.attention( - self.norm_attn(x), + self.input_norm(x), rotary_emb, attention_mask, kv_cache, @@ -231,16 +267,15 @@ class DecoderBlock(nn.Module): x = attn_output + x # feed forward - x = self.ffn(self.norm_ffn(x)) + x + x = self.mlp(self.post_attention_norm(x)) + x return x class Embedding(nn.Module): - def __init__(self, vocab_size: int, embedding_dim: int, weight_param=None): + def __init__(self, vocab_size: int, embedding_dim: int): super().__init__() - weight_param = torch.empty((vocab_size, embedding_dim)) if weight_param is None else weight_param - self.weight = nn.Parameter(weight_param) + self.weight = nn.Parameter(torch.empty((vocab_size, embedding_dim))) def forward(self, x: Tensor) -> Tensor: return F.embedding(x, self.weight) \ No newline at end of file diff --git a/khaosz/model/transformer.py b/khaosz/model/transformer.py index 9ff30a5..840bf6b 100644 --- a/khaosz/model/transformer.py +++ b/khaosz/model/transformer.py @@ -59,16 +59,17 @@ class Transformer(nn.Module): def __init__(self, config: ModelConfig): super().__init__() self.config = config - self.rotary_embeding = RotaryEmbedding(config.n_dim // config.n_head, config.m_len) - self.embed_tokens = Embedding(config.vocab_size, config.n_dim) + self.rotary_embeding = RotaryEmbedding(config.dim // config.n_heads, config.max_len) + self.embed_tokens = Embedding(config.vocab_size, config.dim) self.layers = nn.ModuleList([ - DecoderBlock(config.n_dim, config.n_head, config.d_ffn, config.n_kvhead, config.norm_eps, layer_id) - for layer_id in range(config.n_layer) + DecoderBlock(config.dim, config.n_heads, config.dim_ffn, config.n_kv_heads, + config.norm_eps, config.use_qk_norm, config.use_gated_attention, layer_id) + for layer_id in range(config.n_layers) ]) - self.norm = RMSNorm(config.n_dim, config.norm_eps) - self.lm_head = Linear(config.n_dim, config.vocab_size) + self.norm = RMSNorm(config.dim, config.norm_eps) + self.lm_head = Linear(config.dim, config.vocab_size) if self.config.tie_weight == True: self.lm_head.weight = self.embed_tokens.weight diff --git a/tests/conftest.py b/tests/conftest.py index 093fb37..cb81fa6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -83,19 +83,19 @@ def base_test_env(request: pytest.FixtureRequest): n_dim_choices = [8, 16, 32] n_head_choices = [2, 4] - n_dim = int(np.random.choice(n_dim_choices)) - n_head = int(np.random.choice(n_head_choices)) - n_kvhead = n_head // 2 - d_ffn = n_dim * 2 + dim = int(np.random.choice(n_dim_choices)) + n_heads = int(np.random.choice(n_head_choices)) + n_kv_heads = n_heads // 2 + dim_ffn = dim * 2 config = { "vocab_size": 1000, - "n_dim": n_dim, - "n_head": n_head, - "n_kvhead": n_kvhead, - "d_ffn": d_ffn, - "m_len": 1024, - "n_layer": 4, + "dim": dim, + "n_heads": n_heads, + "n_kv_heads": n_kv_heads, + "dim_ffn": dim_ffn, + "max_len": 1024, + "n_layers": 4, "norm_eps": 1e-5 } diff --git a/tests/test_module.py b/tests/test_module.py index 6ecb9ab..9d5857e 100644 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -22,12 +22,12 @@ def test_env(request: pytest.FixtureRequest): config = { "vocab_size": 1000, - "n_dim": 128, - "n_head": 4, - "n_kvhead": 2, - "d_ffn": 256, - "m_len": 64, - "n_layer": 2, + "dim": 128, + "n_heads": 4, + "n_kv_heads": 2, + "dim_ffn": 256, + "max_len": 64, + "n_layers": 2, "norm_eps": 1e-5 } with open(config_path, 'w') as f: @@ -64,9 +64,9 @@ def test_model_parameter(test_env): def test_transformer(test_env): model = test_env["model"] input_ids = torch.randint(0, test_env["transformer_config"].vocab_size, - (4, test_env["transformer_config"].m_len)) + (4, test_env["transformer_config"].max_len)) output_logits = model(input_ids)["logits"] - target_shape = (4, test_env["transformer_config"].m_len, test_env["transformer_config"].vocab_size) + target_shape = (4, test_env["transformer_config"].max_len, test_env["transformer_config"].vocab_size) assert output_logits.shape == target_shape # generator @@ -80,7 +80,7 @@ def test_embedding_encoder_core(test_env): single_emb = encoder.encode("测试文本") assert isinstance(single_emb, torch.Tensor) - assert single_emb.shape[-1] == test_env["transformer_config"].n_dim + assert single_emb.shape[-1] == test_env["transformer_config"].dim batch_emb = encoder.encode(["测试1", "测试2"]) diff --git a/tests/test_tie_weight.py b/tests/test_tie_weight.py index c2d8911..2542f0a 100644 --- a/tests/test_tie_weight.py +++ b/tests/test_tie_weight.py @@ -16,12 +16,12 @@ def transformer_test_env(): config = { "vocab_size": 1000, - "n_dim": 128, - "n_head": 4, - "n_kvhead": 2, - "d_ffn": 256, - "m_len": 64, - "n_layer": 2, + "dim": 128, + "n_heads": 4, + "n_kv_heads": 2, + "dim_ffn": 256, + "max_len": 64, + "n_layers": 2, "norm_eps": 1e-5 } diff --git a/tools/benchmark.py b/tools/benchmark.py index d588620..5f73667 100644 --- a/tools/benchmark.py +++ b/tools/benchmark.py @@ -28,7 +28,7 @@ class GenerationBenchmark: def _initialize_kv_cache(self, batch_size: int) -> list: """初始化KV缓存""" config = self.config - shape = (batch_size, config.m_len, config.n_layer, config.n_kvhead, config.n_dim // config.n_head) + shape = (batch_size, config.max_len, config.n_layers, config.n_kv_heads, config.dim // config.n_heads) k_cache = torch.zeros(shape, device=self.device, dtype=self.dtype) v_cache = torch.zeros(shape, device=self.device, dtype=self.dtype) return (k_cache, v_cache) @@ -175,12 +175,12 @@ def print_benchmark_result(result: BenchmarkResult): if __name__ == "__main__": config = ModelConfig( vocab_size=10000, - n_dim=1536, - n_head=24, - n_kvhead=4, - d_ffn=6912, - m_len=2048, - n_layer=24, + dim=1536, + n_heads=24, + n_kv_heads=4, + dim_ffn=6912, + max_len=2048, + n_layers=24, norm_eps=1e-5, ) diff --git a/tools/train.py b/tools/train.py index a352b3c..c3b600f 100644 --- a/tools/train.py +++ b/tools/train.py @@ -111,7 +111,7 @@ def train( parameter.load(param_path) if window_size is None: - window_size = parameter.config.m_len + window_size = parameter.config.max_len model = parameter.model