feat(module): 重构旋转位置编码实现以提升性能和可读性
This commit is contained in:
parent
805773c7fe
commit
bdc3f4dc63
|
|
@ -29,45 +29,67 @@ def get_rotary_emb(
|
||||||
dim: int,
|
dim: int,
|
||||||
max_len: int,
|
max_len: int,
|
||||||
base: float = 10000,
|
base: float = 10000,
|
||||||
) -> torch.Tensor:
|
) -> Tuple[Tensor, Tensor]:
|
||||||
"""
|
"""
|
||||||
Get the rotary embedding for the given dimension and maximum length.
|
Get the rotary embedding for the given dimension and maximum length.
|
||||||
Args:
|
Args:
|
||||||
dim (int): The dimension of the input.
|
dim (int): The dimension of the input.
|
||||||
max_len (int): The maximum length of the input.
|
max_len (int): The maximum length of the input.
|
||||||
base (float, optional): The base for the frequency. Defaults to 10000.
|
base (float, optional): The base for the frequency. Defaults to 10000.
|
||||||
device (torch.device, optional): The device to use. Defaults to "cuda".
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor: The rotary embedding tensor.
|
Tensor: The rotary embedding tensor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
theta = base ** (-torch.arange(0, dim, 2).float() / dim)
|
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64) / dim)
|
||||||
t = torch.arange(0, max_len).float()
|
t = torch.arange(0, max_len, dtype=torch.float64)
|
||||||
freqs = torch.outer(t, theta)
|
freqs = torch.outer(t, theta)
|
||||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
|
||||||
|
|
||||||
return freqs_cis
|
return torch.cos(freqs).float(), torch.sin(freqs).float()
|
||||||
|
|
||||||
def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
|
def apply_rotary_emb(x: torch.Tensor, rotary_emb: Tuple[Tensor, Tensor]) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Apply rotary embedding to the input tensor.
|
Apply rotary embedding to the input tensor using cos/sin form.
|
||||||
Args:
|
Args:
|
||||||
x (Tensor): The input tensor.
|
x (Tensor): The input tensor (shape [..., seq_len, dim]).
|
||||||
freqs_cis (Tensor): The rotary embedding tensor.
|
rotary_emb (Tuple[Tensor, Tensor]): The rotary embedding (shape [seq_len, dim//2]).
|
||||||
Returns:
|
Returns:
|
||||||
Tensor: The output tensor.
|
Tensor: The output tensor (rotated, same shape as input).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
dtype = x.dtype
|
dtype = x.dtype
|
||||||
seq_len = x.size(1)
|
cos, sin = rotary_emb
|
||||||
|
|
||||||
x_complex = torch.view_as_complex(x.view(*x.shape[:-1], -1, 2).float())
|
cos = cos.unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, dim//2]
|
||||||
freqs_cis = freqs_cis.reshape(1, seq_len, 1, -1)
|
sin = sin.unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, dim//2]
|
||||||
x_out = torch.view_as_real(x_complex * freqs_cis).flatten(3)
|
|
||||||
|
x_real = x[..., 0::2] # [batch, seq_len, dim//2]
|
||||||
|
x_imag = x[..., 1::2] # [batch, seq_len, dim//2]
|
||||||
|
|
||||||
|
x_real_rot = x_real * cos - x_imag * sin
|
||||||
|
x_imag_rot = x_real * sin + x_imag * cos
|
||||||
|
|
||||||
|
x_out = torch.stack([x_real_rot, x_imag_rot], dim=-1) # [batch, seq_len, dim//2, 2]
|
||||||
|
x_out = x_out.view(*x_out.shape[:-2], -1) # [batch, seq_len, dim]
|
||||||
|
|
||||||
return x_out.to(dtype)
|
return x_out.to(dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class RotaryEmbedding(nn.Module):
|
||||||
|
def __init__(self, dim: int, max_len: int, base: int=10000):
|
||||||
|
super().__init__()
|
||||||
|
cos_emb, sin_emb = get_rotary_emb(dim, max_len, base)
|
||||||
|
self.register_buffer("cos_emb", cos_emb, persistent=False)
|
||||||
|
self.register_buffer("sin_emb", sin_emb, persistent=False)
|
||||||
|
self._rotary_buffers = {"cos_emb", "sin_emb"}
|
||||||
|
|
||||||
|
def forward(self, x: Tensor, start_pos: int=0) -> Tuple[Tensor, Tensor]:
|
||||||
|
seq_len = x.size(1)
|
||||||
|
cos = self.cos_emb[start_pos : start_pos + seq_len]
|
||||||
|
sin = self.sin_emb[start_pos : start_pos + seq_len]
|
||||||
|
|
||||||
|
return (cos, sin)
|
||||||
|
|
||||||
|
|
||||||
class Linear(nn.Module):
|
class Linear(nn.Module):
|
||||||
def __init__(self, in_dim: int, out_dim: int, bias: bool = False, weight_param=None, bias_param=None):
|
def __init__(self, in_dim: int, out_dim: int, bias: bool = False, weight_param=None, bias_param=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
@ -137,7 +159,7 @@ class GQA(nn.Module):
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
freqs_cis: Tensor,
|
rotary_emb: Tuple[Tensor, Tensor],
|
||||||
mask: Tensor = None,
|
mask: Tensor = None,
|
||||||
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
|
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
|
||||||
start_pos: int = 0
|
start_pos: int = 0
|
||||||
|
|
@ -147,7 +169,7 @@ class GQA(nn.Module):
|
||||||
q = self._split_heads(self.q_proj(x), self.n_heads)
|
q = self._split_heads(self.q_proj(x), self.n_heads)
|
||||||
k = self._split_heads(self.k_proj(x), self.n_kvheads)
|
k = self._split_heads(self.k_proj(x), self.n_kvheads)
|
||||||
v = self._split_heads(self.v_proj(x), self.n_kvheads)
|
v = self._split_heads(self.v_proj(x), self.n_kvheads)
|
||||||
q, k = apply_rotary_emb(q, freqs_cis), apply_rotary_emb(k, freqs_cis)
|
q, k = apply_rotary_emb(q, rotary_emb), apply_rotary_emb(k, rotary_emb)
|
||||||
|
|
||||||
if kv_cache is not None:
|
if kv_cache is not None:
|
||||||
k_cache, v_cache = kv_cache
|
k_cache, v_cache = kv_cache
|
||||||
|
|
@ -186,7 +208,7 @@ class DecoderBlock(nn.Module):
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
freqs_cis: Tensor,
|
rotary_emb: Tuple[Tensor, Tensor],
|
||||||
attention_mask: Optional[Tensor] = None,
|
attention_mask: Optional[Tensor] = None,
|
||||||
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
|
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
|
||||||
start_pos: int = 0
|
start_pos: int = 0
|
||||||
|
|
@ -194,7 +216,7 @@ class DecoderBlock(nn.Module):
|
||||||
# attention
|
# attention
|
||||||
attn_output = self.attention(
|
attn_output = self.attention(
|
||||||
self.norm_attn(x),
|
self.norm_attn(x),
|
||||||
freqs_cis,
|
rotary_emb,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
start_pos
|
start_pos
|
||||||
|
|
|
||||||
|
|
@ -4,14 +4,13 @@ import torch.nn as nn
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from typing import Any, Mapping, Optional, Tuple
|
from typing import Any, Mapping, Optional, Tuple
|
||||||
from khaosz.config.model_config import TransformerConfig
|
from khaosz.config.model_config import TransformerConfig
|
||||||
from khaosz.model.module import Embedding, DecoderBlock, Linear, RMSNorm, get_rotary_emb
|
from khaosz.model.module import Embedding, DecoderBlock, Linear, RMSNorm, RotaryEmbedding
|
||||||
|
|
||||||
|
|
||||||
def process_attention_mask(
|
def process_attention_mask(
|
||||||
seq_mask: Tensor,
|
seq_mask: Tensor,
|
||||||
input_tensor: Tensor,
|
input_tensor: Tensor,
|
||||||
start_pos: int = 0,
|
start_pos: int = 0,
|
||||||
seq_len: int = 0,
|
|
||||||
is_causal: bool = False,
|
is_causal: bool = False,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
"""
|
"""
|
||||||
|
|
@ -20,13 +19,13 @@ def process_attention_mask(
|
||||||
seq_mask (Tensor): A tensor indicating whether each position is valid or not.
|
seq_mask (Tensor): A tensor indicating whether each position is valid or not.
|
||||||
input_tensor (Tensor): The input tensor.
|
input_tensor (Tensor): The input tensor.
|
||||||
start_pos (int): The starting position of the sequence.
|
start_pos (int): The starting position of the sequence.
|
||||||
seq_len (int): The length of the sequence.
|
|
||||||
is_causal (bool): Whether the attention is causal or not.
|
is_causal (bool): Whether the attention is causal or not.
|
||||||
Returns:
|
Returns:
|
||||||
Tensor: The attention mask tensor.
|
Tensor: The attention mask tensor.
|
||||||
"""
|
"""
|
||||||
device = input_tensor.device
|
device = input_tensor.device
|
||||||
dtype = input_tensor.dtype
|
dtype = input_tensor.dtype
|
||||||
|
seq_len = input_tensor.size(1)
|
||||||
|
|
||||||
if seq_mask is None:
|
if seq_mask is None:
|
||||||
if start_pos != 0:
|
if start_pos != 0:
|
||||||
|
|
@ -65,16 +64,18 @@ class Transformer(nn.Module):
|
||||||
def __init__(self, config: TransformerConfig):
|
def __init__(self, config: TransformerConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.rotary_embeding = RotaryEmbedding(config.n_dim // config.n_head, config.m_len)
|
||||||
self.embed_tokens = Embedding(config.vocab_size, config.n_dim)
|
self.embed_tokens = Embedding(config.vocab_size, config.n_dim)
|
||||||
lm_head_init_weight = self.embed_tokens.weight if config.tie_weight == True else None
|
|
||||||
|
|
||||||
self.layers = nn.ModuleList([
|
self.layers = nn.ModuleList([
|
||||||
DecoderBlock(config.n_dim, config.n_head, config.d_ffn, config.n_kvhead, config.norm_eps, layer_id)
|
DecoderBlock(config.n_dim, config.n_head, config.d_ffn, config.n_kvhead, config.norm_eps, layer_id)
|
||||||
for layer_id in range(config.n_layer)
|
for layer_id in range(config.n_layer)
|
||||||
])
|
])
|
||||||
|
lm_head_init_weight = self.embed_tokens.weight if config.tie_weight == True else None
|
||||||
|
|
||||||
self.norm = RMSNorm(config.n_dim, config.norm_eps)
|
self.norm = RMSNorm(config.n_dim, config.norm_eps)
|
||||||
self.lm_head = Linear(config.n_dim, config.vocab_size, weight_param=lm_head_init_weight)
|
self.lm_head = Linear(config.n_dim, config.vocab_size, weight_param=lm_head_init_weight)
|
||||||
self.freq_cis = get_rotary_emb(config.n_dim // config.n_head, config.m_len)
|
|
||||||
self._init_parameters()
|
self._init_parameters()
|
||||||
|
|
||||||
def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False):
|
def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False):
|
||||||
|
|
@ -115,20 +116,15 @@ class Transformer(nn.Module):
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
assert input_ids.ndim == 2
|
assert input_ids.ndim == 2
|
||||||
|
|
||||||
seq_len = input_ids.size(-1)
|
|
||||||
x = self.embed_tokens(input_ids)
|
x = self.embed_tokens(input_ids)
|
||||||
freqs_cis = self.freq_cis[start_pos:start_pos+seq_len].to(x.device)
|
rotary_emb = self.rotary_embeding(x, start_pos)
|
||||||
|
|
||||||
attn_mask = process_attention_mask(
|
attn_mask = process_attention_mask(
|
||||||
input_mask,
|
input_mask, x, start_pos, is_causal=True
|
||||||
x,
|
|
||||||
start_pos=start_pos,
|
|
||||||
seq_len=seq_len,
|
|
||||||
is_causal=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
x = layer(x, freqs_cis, attn_mask, persistent_key_values, start_pos)
|
x = layer(x, rotary_emb, attn_mask, persistent_key_values, start_pos)
|
||||||
|
|
||||||
hidden_states = self.norm(x)
|
hidden_states = self.norm(x)
|
||||||
logits = self.lm_head(hidden_states)
|
logits = self.lm_head(hidden_states)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue