feat(model): 添加 Linear 和 Embedding 模块的自定义参数初始化支持
This commit is contained in:
parent
877669b799
commit
144b9598ad
|
|
@ -1,4 +1,3 @@
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
@ -72,10 +71,12 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
|
||||||
|
|
||||||
|
|
||||||
class Linear(nn.Module):
|
class Linear(nn.Module):
|
||||||
def __init__(self, in_dim: int, out_dim: int, bias: bool=False):
|
def __init__(self, in_dim: int, out_dim: int, bias: bool=False, weight_param=None, bias_param=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.weight = nn.Parameter(torch.empty((out_dim, in_dim)))
|
self.weight = nn.Parameter(weight_param or torch.empty((out_dim, in_dim)))
|
||||||
self.bias = nn.Parameter(torch.zeros(out_dim)) if bias else None
|
self.bias = nn.Parameter(bias_param or torch.zeros(out_dim)) if bias else None
|
||||||
|
|
||||||
|
def _reset_parameter(self):
|
||||||
init.normal_(self.weight, mean=0, std=0.006)
|
init.normal_(self.weight, mean=0, std=0.006)
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
|
@ -205,4 +206,16 @@ class DecoderBlock(nn.Module):
|
||||||
# feed forward
|
# feed forward
|
||||||
x = self.ffn(self.norm_ffn(x)) + x
|
x = self.ffn(self.norm_ffn(x)) + x
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Embedding(nn.Module):
|
||||||
|
def __init__(self, vocab_size: int, embedding_dim: int, weight_param=None):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(weight_param or torch.empty((vocab_size, embedding_dim)))
|
||||||
|
|
||||||
|
def _reset_parameter(self):
|
||||||
|
init.normal_(self.weight, mean=0, std=0.02)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return F.embedding(x, self.weight)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue