feat(model): 添加 Linear 和 Embedding 模块的自定义参数初始化支持

This commit is contained in:
ViperEkura 2025-10-31 22:43:12 +08:00
parent 877669b799
commit 144b9598ad
1 changed files with 18 additions and 5 deletions

View File

@ -1,4 +1,3 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
@ -72,10 +71,12 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
class Linear(nn.Module): class Linear(nn.Module):
def __init__(self, in_dim: int, out_dim: int, bias: bool=False): def __init__(self, in_dim: int, out_dim: int, bias: bool=False, weight_param=None, bias_param=None):
super().__init__() super().__init__()
self.weight = nn.Parameter(torch.empty((out_dim, in_dim))) self.weight = nn.Parameter(weight_param or torch.empty((out_dim, in_dim)))
self.bias = nn.Parameter(torch.zeros(out_dim)) if bias else None self.bias = nn.Parameter(bias_param or torch.zeros(out_dim)) if bias else None
def _reset_parameter(self):
init.normal_(self.weight, mean=0, std=0.006) init.normal_(self.weight, mean=0, std=0.006)
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
@ -205,4 +206,16 @@ class DecoderBlock(nn.Module):
# feed forward # feed forward
x = self.ffn(self.norm_ffn(x)) + x x = self.ffn(self.norm_ffn(x)) + x
return x return x
class Embedding(nn.Module):
def __init__(self, vocab_size: int, embedding_dim: int, weight_param=None):
super().__init__()
self.weight = nn.Parameter(weight_param or torch.empty((vocab_size, embedding_dim)))
def _reset_parameter(self):
init.normal_(self.weight, mean=0, std=0.02)
def forward(self, x: Tensor) -> Tensor:
return F.embedding(x, self.weight)