diff --git a/khaosz/config/model_config.py b/khaosz/config/model_config.py index b62884a..c91b44f 100644 --- a/khaosz/config/model_config.py +++ b/khaosz/config/model_config.py @@ -13,6 +13,7 @@ class TransformerConfig: m_len: Optional[int] = None norm_eps: Optional[float] = None d_ffn: Optional[int] = None + tie_weight: Optional[bool] = None # GQA n_kvhead: Optional[int] = None diff --git a/khaosz/model/module.py b/khaosz/model/module.py index 825d564..966cf3f 100644 --- a/khaosz/model/module.py +++ b/khaosz/model/module.py @@ -3,7 +3,6 @@ import torch.nn as nn import torch.nn.functional as F from torch import Tensor -from torch.nn import init from typing import Optional, Tuple @@ -30,7 +29,6 @@ def get_rotary_emb( dim: int, max_len: int, base: float = 10000, - device: torch.device = "cuda", ) -> torch.Tensor: """ Get the rotary embedding for the given dimension and maximum length. @@ -43,8 +41,8 @@ def get_rotary_emb( Tensor: The rotary embedding tensor. """ - theta = base ** (-torch.arange(0, dim, 2, device=device).float() / dim) - t = torch.arange(0, max_len, device=device).float() + theta = base ** (-torch.arange(0, dim, 2).float() / dim) + t = torch.arange(0, max_len).float() freqs = torch.outer(t, theta) freqs_cis = torch.polar(torch.ones_like(freqs), freqs) @@ -71,17 +69,17 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: class Linear(nn.Module): - def __init__(self, in_dim: int, out_dim: int, bias: bool=False, weight_param=None, bias_param=None): + def __init__(self, in_dim: int, out_dim: int, bias: bool = False, weight_param=None, bias_param=None): super().__init__() - self.weight = nn.Parameter(weight_param or torch.empty((out_dim, in_dim))) - self.bias = nn.Parameter(bias_param or torch.zeros(out_dim)) if bias else None - - def _reset_parameter(self): - init.normal_(self.weight, mean=0, std=0.006) + weight_param = torch.empty((out_dim, in_dim)) if weight_param is None else weight_param + bias_param = torch.zeros(out_dim) if bias_param is None else bias_param + self.weight = nn.Parameter(weight_param) + self.bias = nn.Parameter(bias_param) if bias else None + def forward(self, x: Tensor) -> Tensor: return F.linear(x, self.weight, self.bias) - + class RMSNorm(nn.Module): def __init__(self, n_dim, norm_eps): @@ -212,10 +210,8 @@ class DecoderBlock(nn.Module): class Embedding(nn.Module): def __init__(self, vocab_size: int, embedding_dim: int, weight_param=None): super().__init__() - self.weight = nn.Parameter(weight_param or torch.empty((vocab_size, embedding_dim))) - - def _reset_parameter(self): - init.normal_(self.weight, mean=0, std=0.02) + weight_param = torch.empty((vocab_size, embedding_dim)) if weight_param is None else weight_param + self.weight = nn.Parameter(weight_param) def forward(self, x: Tensor) -> Tensor: - return F.embedding(x, self.weight) + return F.embedding(x, self.weight) \ No newline at end of file diff --git a/khaosz/model/transformer.py b/khaosz/model/transformer.py index 2c68d6b..06dacda 100644 --- a/khaosz/model/transformer.py +++ b/khaosz/model/transformer.py @@ -1,22 +1,18 @@ import torch import torch.nn as nn -import torch.nn.functional as F from torch import Tensor -from torch.nn import init -from typing import Optional, Tuple - +from typing import Any, Mapping, Optional, Tuple from khaosz.config.model_config import TransformerConfig -from khaosz.model.module import DecoderBlock, RMSNorm, get_rotary_emb +from khaosz.model.module import Embedding, DecoderBlock, Linear, RMSNorm, get_rotary_emb def process_attention_mask( seq_mask: Tensor, + input_tensor: Tensor, start_pos: int = 0, seq_len: int = 0, is_causal: bool = False, - device: torch.device = "cuda", - dtype: torch.dtype = torch.float32 ) -> Tensor: """ Create attention mask for GQA @@ -29,6 +25,8 @@ def process_attention_mask( Returns: Tensor: The attention mask tensor. """ + device = input_tensor.device + dtype = input_tensor.dtype if seq_mask is None: if start_pos != 0: @@ -66,14 +64,43 @@ def process_attention_mask( class Transformer(nn.Module): def __init__(self, config: TransformerConfig): super().__init__() - self.embedding = nn.Parameter(torch.empty(config.vocab_size, config.n_dim)) + self.config = config + self.embed_tokens = Embedding(config.vocab_size, config.n_dim) + lm_head_init_weight = self.embed_tokens.weight if config.tie_weight == True else None + self.layers = nn.ModuleList([ DecoderBlock(config.n_dim, config.n_head, config.d_ffn, config.n_kvhead, config.norm_eps, layer_id) for layer_id in range(config.n_layer) ]) self.norm = RMSNorm(config.n_dim, config.norm_eps) + self.lm_head = Linear(config.n_dim, config.vocab_size, weight_param=lm_head_init_weight) self.freq_cis = get_rotary_emb(config.n_dim // config.n_head, config.m_len) - init.normal_(self.embedding, mean=0, std=0.02) + self._init_parameters() + + def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False): + if self.config.tie_weight == True: + lm_head_key = 'lm_head.weight' + embed_key = 'embed_tokens.weight' + + if lm_head_key not in state_dict and embed_key in state_dict: + state_dict[lm_head_key] = state_dict[embed_key] + + return super().load_state_dict(state_dict, strict, assign) + + def state_dict(self, destination=None, prefix='', keep_vars=False): + state_dict = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) + + if self.config.tie_weight == True: + lm_head_key = prefix + 'lm_head.weight' + if lm_head_key in state_dict: + del state_dict[lm_head_key] + + return state_dict + + def _init_parameters(self): + for param in self.parameters(): + if param.dim() > 1: + nn.init.normal_(param, mean=0.0, std=0.006) def forward( self, @@ -83,27 +110,24 @@ class Transformer(nn.Module): start_pos: int = 0 ) -> Tensor: assert input_ids.ndim == 2 - seq_len = input_ids.size(-1) - x = F.embedding(input_ids, self.embedding) - self.freq_cis = self.freq_cis.to(x.device) - freqs_cis = self.freq_cis[start_pos:start_pos+seq_len] - has_kvcache = persistent_key_values is not None + seq_len = input_ids.size(-1) + x = self.embed_tokens(input_ids) + freqs_cis = self.freq_cis[start_pos:start_pos+seq_len].to(x.device) attn_mask = process_attention_mask( input_mask, + x, start_pos=start_pos, seq_len=seq_len, - is_causal=has_kvcache, - device=x.device, - dtype=x.dtype + is_causal=True ) for layer in self.layers: x = layer(x, freqs_cis, attn_mask, persistent_key_values, start_pos) hidden_states = self.norm(x) - logits = F.linear(hidden_states, self.embedding) + logits = self.lm_head(hidden_states) return { "logits": logits,