feat(model): 添加 Linear 和 Embedding 模块的自定义参数初始化支持
This commit is contained in:
parent
877669b799
commit
144b9598ad
|
|
@ -1,4 +1,3 @@
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
|
@ -72,10 +71,12 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
|
|||
|
||||
|
||||
class Linear(nn.Module):
|
||||
def __init__(self, in_dim: int, out_dim: int, bias: bool=False):
|
||||
def __init__(self, in_dim: int, out_dim: int, bias: bool=False, weight_param=None, bias_param=None):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.empty((out_dim, in_dim)))
|
||||
self.bias = nn.Parameter(torch.zeros(out_dim)) if bias else None
|
||||
self.weight = nn.Parameter(weight_param or torch.empty((out_dim, in_dim)))
|
||||
self.bias = nn.Parameter(bias_param or torch.zeros(out_dim)) if bias else None
|
||||
|
||||
def _reset_parameter(self):
|
||||
init.normal_(self.weight, mean=0, std=0.006)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
|
|
@ -206,3 +207,15 @@ class DecoderBlock(nn.Module):
|
|||
x = self.ffn(self.norm_ffn(x)) + x
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Embedding(nn.Module):
|
||||
def __init__(self, vocab_size: int, embedding_dim: int, weight_param=None):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(weight_param or torch.empty((vocab_size, embedding_dim)))
|
||||
|
||||
def _reset_parameter(self):
|
||||
init.normal_(self.weight, mean=0, std=0.02)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return F.embedding(x, self.weight)
|
||||
|
|
|
|||
Loading…
Reference in New Issue