From bdc3f4dc638228ca2edb6f8a4d581faedfd9fb47 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Thu, 6 Nov 2025 17:52:47 +0800 Subject: [PATCH] =?UTF-8?q?feat(module):=20=E9=87=8D=E6=9E=84=E6=97=8B?= =?UTF-8?q?=E8=BD=AC=E4=BD=8D=E7=BD=AE=E7=BC=96=E7=A0=81=E5=AE=9E=E7=8E=B0?= =?UTF-8?q?=E4=BB=A5=E6=8F=90=E5=8D=87=E6=80=A7=E8=83=BD=E5=92=8C=E5=8F=AF?= =?UTF-8?q?=E8=AF=BB=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/model/module.py | 64 +++++++++++++++++++++++++------------ khaosz/model/transformer.py | 22 ++++++------- 2 files changed, 52 insertions(+), 34 deletions(-) diff --git a/khaosz/model/module.py b/khaosz/model/module.py index 966cf3f..154547b 100644 --- a/khaosz/model/module.py +++ b/khaosz/model/module.py @@ -29,45 +29,67 @@ def get_rotary_emb( dim: int, max_len: int, base: float = 10000, - ) -> torch.Tensor: + ) -> Tuple[Tensor, Tensor]: """ Get the rotary embedding for the given dimension and maximum length. Args: dim (int): The dimension of the input. max_len (int): The maximum length of the input. base (float, optional): The base for the frequency. Defaults to 10000. - device (torch.device, optional): The device to use. Defaults to "cuda". Returns: Tensor: The rotary embedding tensor. """ - theta = base ** (-torch.arange(0, dim, 2).float() / dim) - t = torch.arange(0, max_len).float() + theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64) / dim) + t = torch.arange(0, max_len, dtype=torch.float64) freqs = torch.outer(t, theta) - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) - - return freqs_cis -def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: + return torch.cos(freqs).float(), torch.sin(freqs).float() + +def apply_rotary_emb(x: torch.Tensor, rotary_emb: Tuple[Tensor, Tensor]) -> Tensor: """ - Apply rotary embedding to the input tensor. + Apply rotary embedding to the input tensor using cos/sin form. Args: - x (Tensor): The input tensor. - freqs_cis (Tensor): The rotary embedding tensor. + x (Tensor): The input tensor (shape [..., seq_len, dim]). + rotary_emb (Tuple[Tensor, Tensor]): The rotary embedding (shape [seq_len, dim//2]). Returns: - Tensor: The output tensor. + Tensor: The output tensor (rotated, same shape as input). """ dtype = x.dtype - seq_len = x.size(1) - - x_complex = torch.view_as_complex(x.view(*x.shape[:-1], -1, 2).float()) - freqs_cis = freqs_cis.reshape(1, seq_len, 1, -1) - x_out = torch.view_as_real(x_complex * freqs_cis).flatten(3) + cos, sin = rotary_emb + + cos = cos.unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, dim//2] + sin = sin.unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, dim//2] + + x_real = x[..., 0::2] # [batch, seq_len, dim//2] + x_imag = x[..., 1::2] # [batch, seq_len, dim//2] + + x_real_rot = x_real * cos - x_imag * sin + x_imag_rot = x_real * sin + x_imag * cos + + x_out = torch.stack([x_real_rot, x_imag_rot], dim=-1) # [batch, seq_len, dim//2, 2] + x_out = x_out.view(*x_out.shape[:-2], -1) # [batch, seq_len, dim] return x_out.to(dtype) +class RotaryEmbedding(nn.Module): + def __init__(self, dim: int, max_len: int, base: int=10000): + super().__init__() + cos_emb, sin_emb = get_rotary_emb(dim, max_len, base) + self.register_buffer("cos_emb", cos_emb, persistent=False) + self.register_buffer("sin_emb", sin_emb, persistent=False) + self._rotary_buffers = {"cos_emb", "sin_emb"} + + def forward(self, x: Tensor, start_pos: int=0) -> Tuple[Tensor, Tensor]: + seq_len = x.size(1) + cos = self.cos_emb[start_pos : start_pos + seq_len] + sin = self.sin_emb[start_pos : start_pos + seq_len] + + return (cos, sin) + + class Linear(nn.Module): def __init__(self, in_dim: int, out_dim: int, bias: bool = False, weight_param=None, bias_param=None): super().__init__() @@ -137,7 +159,7 @@ class GQA(nn.Module): def forward( self, x: Tensor, - freqs_cis: Tensor, + rotary_emb: Tuple[Tensor, Tensor], mask: Tensor = None, kv_cache: Optional[Tuple[Tensor, Tensor]] = None, start_pos: int = 0 @@ -147,7 +169,7 @@ class GQA(nn.Module): q = self._split_heads(self.q_proj(x), self.n_heads) k = self._split_heads(self.k_proj(x), self.n_kvheads) v = self._split_heads(self.v_proj(x), self.n_kvheads) - q, k = apply_rotary_emb(q, freqs_cis), apply_rotary_emb(k, freqs_cis) + q, k = apply_rotary_emb(q, rotary_emb), apply_rotary_emb(k, rotary_emb) if kv_cache is not None: k_cache, v_cache = kv_cache @@ -186,7 +208,7 @@ class DecoderBlock(nn.Module): def forward( self, x: Tensor, - freqs_cis: Tensor, + rotary_emb: Tuple[Tensor, Tensor], attention_mask: Optional[Tensor] = None, kv_cache: Optional[Tuple[Tensor, Tensor]] = None, start_pos: int = 0 @@ -194,7 +216,7 @@ class DecoderBlock(nn.Module): # attention attn_output = self.attention( self.norm_attn(x), - freqs_cis, + rotary_emb, attention_mask, kv_cache, start_pos diff --git a/khaosz/model/transformer.py b/khaosz/model/transformer.py index e6af718..b0fcd9d 100644 --- a/khaosz/model/transformer.py +++ b/khaosz/model/transformer.py @@ -4,14 +4,13 @@ import torch.nn as nn from torch import Tensor from typing import Any, Mapping, Optional, Tuple from khaosz.config.model_config import TransformerConfig -from khaosz.model.module import Embedding, DecoderBlock, Linear, RMSNorm, get_rotary_emb +from khaosz.model.module import Embedding, DecoderBlock, Linear, RMSNorm, RotaryEmbedding def process_attention_mask( seq_mask: Tensor, input_tensor: Tensor, start_pos: int = 0, - seq_len: int = 0, is_causal: bool = False, ) -> Tensor: """ @@ -20,13 +19,13 @@ def process_attention_mask( seq_mask (Tensor): A tensor indicating whether each position is valid or not. input_tensor (Tensor): The input tensor. start_pos (int): The starting position of the sequence. - seq_len (int): The length of the sequence. is_causal (bool): Whether the attention is causal or not. Returns: Tensor: The attention mask tensor. """ device = input_tensor.device dtype = input_tensor.dtype + seq_len = input_tensor.size(1) if seq_mask is None: if start_pos != 0: @@ -65,16 +64,18 @@ class Transformer(nn.Module): def __init__(self, config: TransformerConfig): super().__init__() self.config = config + self.rotary_embeding = RotaryEmbedding(config.n_dim // config.n_head, config.m_len) self.embed_tokens = Embedding(config.vocab_size, config.n_dim) - lm_head_init_weight = self.embed_tokens.weight if config.tie_weight == True else None self.layers = nn.ModuleList([ DecoderBlock(config.n_dim, config.n_head, config.d_ffn, config.n_kvhead, config.norm_eps, layer_id) for layer_id in range(config.n_layer) ]) + lm_head_init_weight = self.embed_tokens.weight if config.tie_weight == True else None + self.norm = RMSNorm(config.n_dim, config.norm_eps) self.lm_head = Linear(config.n_dim, config.vocab_size, weight_param=lm_head_init_weight) - self.freq_cis = get_rotary_emb(config.n_dim // config.n_head, config.m_len) + self._init_parameters() def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False): @@ -115,20 +116,15 @@ class Transformer(nn.Module): ) -> Tensor: assert input_ids.ndim == 2 - seq_len = input_ids.size(-1) x = self.embed_tokens(input_ids) - freqs_cis = self.freq_cis[start_pos:start_pos+seq_len].to(x.device) + rotary_emb = self.rotary_embeding(x, start_pos) attn_mask = process_attention_mask( - input_mask, - x, - start_pos=start_pos, - seq_len=seq_len, - is_causal=True + input_mask, x, start_pos, is_causal=True ) for layer in self.layers: - x = layer(x, freqs_cis, attn_mask, persistent_key_values, start_pos) + x = layer(x, rotary_emb, attn_mask, persistent_key_values, start_pos) hidden_states = self.norm(x) logits = self.lm_head(hidden_states)