From 144b9598ad48c5137e355510b802446370c727aa Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Fri, 31 Oct 2025 22:43:12 +0800 Subject: [PATCH] =?UTF-8?q?feat(model):=20=E6=B7=BB=E5=8A=A0=20Linear=20?= =?UTF-8?q?=E5=92=8C=20Embedding=20=E6=A8=A1=E5=9D=97=E7=9A=84=E8=87=AA?= =?UTF-8?q?=E5=AE=9A=E4=B9=89=E5=8F=82=E6=95=B0=E5=88=9D=E5=A7=8B=E5=8C=96?= =?UTF-8?q?=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/model/module.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/khaosz/model/module.py b/khaosz/model/module.py index 643dd0d..825d564 100644 --- a/khaosz/model/module.py +++ b/khaosz/model/module.py @@ -1,4 +1,3 @@ - import torch import torch.nn as nn import torch.nn.functional as F @@ -72,10 +71,12 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: class Linear(nn.Module): - def __init__(self, in_dim: int, out_dim: int, bias: bool=False): + def __init__(self, in_dim: int, out_dim: int, bias: bool=False, weight_param=None, bias_param=None): super().__init__() - self.weight = nn.Parameter(torch.empty((out_dim, in_dim))) - self.bias = nn.Parameter(torch.zeros(out_dim)) if bias else None + self.weight = nn.Parameter(weight_param or torch.empty((out_dim, in_dim))) + self.bias = nn.Parameter(bias_param or torch.zeros(out_dim)) if bias else None + + def _reset_parameter(self): init.normal_(self.weight, mean=0, std=0.006) def forward(self, x: Tensor) -> Tensor: @@ -205,4 +206,16 @@ class DecoderBlock(nn.Module): # feed forward x = self.ffn(self.norm_ffn(x)) + x - return x \ No newline at end of file + return x + + +class Embedding(nn.Module): + def __init__(self, vocab_size: int, embedding_dim: int, weight_param=None): + super().__init__() + self.weight = nn.Parameter(weight_param or torch.empty((vocab_size, embedding_dim))) + + def _reset_parameter(self): + init.normal_(self.weight, mean=0, std=0.02) + + def forward(self, x: Tensor) -> Tensor: + return F.embedding(x, self.weight)