diff --git a/khaosz/model/module.py b/khaosz/model/module.py index 643dd0d..825d564 100644 --- a/khaosz/model/module.py +++ b/khaosz/model/module.py @@ -1,4 +1,3 @@ - import torch import torch.nn as nn import torch.nn.functional as F @@ -72,10 +71,12 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: class Linear(nn.Module): - def __init__(self, in_dim: int, out_dim: int, bias: bool=False): + def __init__(self, in_dim: int, out_dim: int, bias: bool=False, weight_param=None, bias_param=None): super().__init__() - self.weight = nn.Parameter(torch.empty((out_dim, in_dim))) - self.bias = nn.Parameter(torch.zeros(out_dim)) if bias else None + self.weight = nn.Parameter(weight_param or torch.empty((out_dim, in_dim))) + self.bias = nn.Parameter(bias_param or torch.zeros(out_dim)) if bias else None + + def _reset_parameter(self): init.normal_(self.weight, mean=0, std=0.006) def forward(self, x: Tensor) -> Tensor: @@ -205,4 +206,16 @@ class DecoderBlock(nn.Module): # feed forward x = self.ffn(self.norm_ffn(x)) + x - return x \ No newline at end of file + return x + + +class Embedding(nn.Module): + def __init__(self, vocab_size: int, embedding_dim: int, weight_param=None): + super().__init__() + self.weight = nn.Parameter(weight_param or torch.empty((vocab_size, embedding_dim))) + + def _reset_parameter(self): + init.normal_(self.weight, mean=0, std=0.02) + + def forward(self, x: Tensor) -> Tensor: + return F.embedding(x, self.weight)