feat(model): 添加 tie_weight 配置选项并优化模型模块实现
This commit is contained in:
parent
b260f5581d
commit
69d9374f51
|
|
@ -13,6 +13,7 @@ class TransformerConfig:
|
||||||
m_len: Optional[int] = None
|
m_len: Optional[int] = None
|
||||||
norm_eps: Optional[float] = None
|
norm_eps: Optional[float] = None
|
||||||
d_ffn: Optional[int] = None
|
d_ffn: Optional[int] = None
|
||||||
|
tie_weight: Optional[bool] = None
|
||||||
|
|
||||||
# GQA
|
# GQA
|
||||||
n_kvhead: Optional[int] = None
|
n_kvhead: Optional[int] = None
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,6 @@ import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn import init
|
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -30,7 +29,6 @@ def get_rotary_emb(
|
||||||
dim: int,
|
dim: int,
|
||||||
max_len: int,
|
max_len: int,
|
||||||
base: float = 10000,
|
base: float = 10000,
|
||||||
device: torch.device = "cuda",
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Get the rotary embedding for the given dimension and maximum length.
|
Get the rotary embedding for the given dimension and maximum length.
|
||||||
|
|
@ -43,8 +41,8 @@ def get_rotary_emb(
|
||||||
Tensor: The rotary embedding tensor.
|
Tensor: The rotary embedding tensor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
theta = base ** (-torch.arange(0, dim, 2, device=device).float() / dim)
|
theta = base ** (-torch.arange(0, dim, 2).float() / dim)
|
||||||
t = torch.arange(0, max_len, device=device).float()
|
t = torch.arange(0, max_len).float()
|
||||||
freqs = torch.outer(t, theta)
|
freqs = torch.outer(t, theta)
|
||||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
||||||
|
|
||||||
|
|
@ -71,17 +69,17 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
|
||||||
|
|
||||||
|
|
||||||
class Linear(nn.Module):
|
class Linear(nn.Module):
|
||||||
def __init__(self, in_dim: int, out_dim: int, bias: bool=False, weight_param=None, bias_param=None):
|
def __init__(self, in_dim: int, out_dim: int, bias: bool = False, weight_param=None, bias_param=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.weight = nn.Parameter(weight_param or torch.empty((out_dim, in_dim)))
|
weight_param = torch.empty((out_dim, in_dim)) if weight_param is None else weight_param
|
||||||
self.bias = nn.Parameter(bias_param or torch.zeros(out_dim)) if bias else None
|
bias_param = torch.zeros(out_dim) if bias_param is None else bias_param
|
||||||
|
|
||||||
def _reset_parameter(self):
|
|
||||||
init.normal_(self.weight, mean=0, std=0.006)
|
|
||||||
|
|
||||||
|
self.weight = nn.Parameter(weight_param)
|
||||||
|
self.bias = nn.Parameter(bias_param) if bias else None
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
return F.linear(x, self.weight, self.bias)
|
return F.linear(x, self.weight, self.bias)
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
class RMSNorm(nn.Module):
|
||||||
def __init__(self, n_dim, norm_eps):
|
def __init__(self, n_dim, norm_eps):
|
||||||
|
|
@ -212,10 +210,8 @@ class DecoderBlock(nn.Module):
|
||||||
class Embedding(nn.Module):
|
class Embedding(nn.Module):
|
||||||
def __init__(self, vocab_size: int, embedding_dim: int, weight_param=None):
|
def __init__(self, vocab_size: int, embedding_dim: int, weight_param=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.weight = nn.Parameter(weight_param or torch.empty((vocab_size, embedding_dim)))
|
weight_param = torch.empty((vocab_size, embedding_dim)) if weight_param is None else weight_param
|
||||||
|
self.weight = nn.Parameter(weight_param)
|
||||||
def _reset_parameter(self):
|
|
||||||
init.normal_(self.weight, mean=0, std=0.02)
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
return F.embedding(x, self.weight)
|
return F.embedding(x, self.weight)
|
||||||
|
|
@ -1,22 +1,18 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn import init
|
from typing import Any, Mapping, Optional, Tuple
|
||||||
from typing import Optional, Tuple
|
|
||||||
|
|
||||||
from khaosz.config.model_config import TransformerConfig
|
from khaosz.config.model_config import TransformerConfig
|
||||||
from khaosz.model.module import DecoderBlock, RMSNorm, get_rotary_emb
|
from khaosz.model.module import Embedding, DecoderBlock, Linear, RMSNorm, get_rotary_emb
|
||||||
|
|
||||||
|
|
||||||
def process_attention_mask(
|
def process_attention_mask(
|
||||||
seq_mask: Tensor,
|
seq_mask: Tensor,
|
||||||
|
input_tensor: Tensor,
|
||||||
start_pos: int = 0,
|
start_pos: int = 0,
|
||||||
seq_len: int = 0,
|
seq_len: int = 0,
|
||||||
is_causal: bool = False,
|
is_causal: bool = False,
|
||||||
device: torch.device = "cuda",
|
|
||||||
dtype: torch.dtype = torch.float32
|
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Create attention mask for GQA
|
Create attention mask for GQA
|
||||||
|
|
@ -29,6 +25,8 @@ def process_attention_mask(
|
||||||
Returns:
|
Returns:
|
||||||
Tensor: The attention mask tensor.
|
Tensor: The attention mask tensor.
|
||||||
"""
|
"""
|
||||||
|
device = input_tensor.device
|
||||||
|
dtype = input_tensor.dtype
|
||||||
|
|
||||||
if seq_mask is None:
|
if seq_mask is None:
|
||||||
if start_pos != 0:
|
if start_pos != 0:
|
||||||
|
|
@ -66,14 +64,43 @@ def process_attention_mask(
|
||||||
class Transformer(nn.Module):
|
class Transformer(nn.Module):
|
||||||
def __init__(self, config: TransformerConfig):
|
def __init__(self, config: TransformerConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embedding = nn.Parameter(torch.empty(config.vocab_size, config.n_dim))
|
self.config = config
|
||||||
|
self.embed_tokens = Embedding(config.vocab_size, config.n_dim)
|
||||||
|
lm_head_init_weight = self.embed_tokens.weight if config.tie_weight == True else None
|
||||||
|
|
||||||
self.layers = nn.ModuleList([
|
self.layers = nn.ModuleList([
|
||||||
DecoderBlock(config.n_dim, config.n_head, config.d_ffn, config.n_kvhead, config.norm_eps, layer_id)
|
DecoderBlock(config.n_dim, config.n_head, config.d_ffn, config.n_kvhead, config.norm_eps, layer_id)
|
||||||
for layer_id in range(config.n_layer)
|
for layer_id in range(config.n_layer)
|
||||||
])
|
])
|
||||||
self.norm = RMSNorm(config.n_dim, config.norm_eps)
|
self.norm = RMSNorm(config.n_dim, config.norm_eps)
|
||||||
|
self.lm_head = Linear(config.n_dim, config.vocab_size, weight_param=lm_head_init_weight)
|
||||||
self.freq_cis = get_rotary_emb(config.n_dim // config.n_head, config.m_len)
|
self.freq_cis = get_rotary_emb(config.n_dim // config.n_head, config.m_len)
|
||||||
init.normal_(self.embedding, mean=0, std=0.02)
|
self._init_parameters()
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False):
|
||||||
|
if self.config.tie_weight == True:
|
||||||
|
lm_head_key = 'lm_head.weight'
|
||||||
|
embed_key = 'embed_tokens.weight'
|
||||||
|
|
||||||
|
if lm_head_key not in state_dict and embed_key in state_dict:
|
||||||
|
state_dict[lm_head_key] = state_dict[embed_key]
|
||||||
|
|
||||||
|
return super().load_state_dict(state_dict, strict, assign)
|
||||||
|
|
||||||
|
def state_dict(self, destination=None, prefix='', keep_vars=False):
|
||||||
|
state_dict = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
|
||||||
|
|
||||||
|
if self.config.tie_weight == True:
|
||||||
|
lm_head_key = prefix + 'lm_head.weight'
|
||||||
|
if lm_head_key in state_dict:
|
||||||
|
del state_dict[lm_head_key]
|
||||||
|
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
def _init_parameters(self):
|
||||||
|
for param in self.parameters():
|
||||||
|
if param.dim() > 1:
|
||||||
|
nn.init.normal_(param, mean=0.0, std=0.006)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|
@ -83,27 +110,24 @@ class Transformer(nn.Module):
|
||||||
start_pos: int = 0
|
start_pos: int = 0
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
assert input_ids.ndim == 2
|
assert input_ids.ndim == 2
|
||||||
seq_len = input_ids.size(-1)
|
|
||||||
x = F.embedding(input_ids, self.embedding)
|
|
||||||
|
|
||||||
self.freq_cis = self.freq_cis.to(x.device)
|
seq_len = input_ids.size(-1)
|
||||||
freqs_cis = self.freq_cis[start_pos:start_pos+seq_len]
|
x = self.embed_tokens(input_ids)
|
||||||
has_kvcache = persistent_key_values is not None
|
freqs_cis = self.freq_cis[start_pos:start_pos+seq_len].to(x.device)
|
||||||
|
|
||||||
attn_mask = process_attention_mask(
|
attn_mask = process_attention_mask(
|
||||||
input_mask,
|
input_mask,
|
||||||
|
x,
|
||||||
start_pos=start_pos,
|
start_pos=start_pos,
|
||||||
seq_len=seq_len,
|
seq_len=seq_len,
|
||||||
is_causal=has_kvcache,
|
is_causal=True
|
||||||
device=x.device,
|
|
||||||
dtype=x.dtype
|
|
||||||
)
|
)
|
||||||
|
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
x = layer(x, freqs_cis, attn_mask, persistent_key_values, start_pos)
|
x = layer(x, freqs_cis, attn_mask, persistent_key_values, start_pos)
|
||||||
|
|
||||||
hidden_states = self.norm(x)
|
hidden_states = self.norm(x)
|
||||||
logits = F.linear(hidden_states, self.embedding)
|
logits = self.lm_head(hidden_states)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"logits": logits,
|
"logits": logits,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue