feat(model): 添加QK归一化和门控注意力支持
This commit is contained in:
parent
fd7ee2895a
commit
eba99e1f5e
|
|
@ -1,37 +1,43 @@
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import asdict, dataclass
|
||||||
from typing import Any, Dict, Optional, Self
|
from typing import Optional, Self
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelConfig:
|
class ModelConfig:
|
||||||
# basic config
|
# basic config
|
||||||
vocab_size: Optional[int] = None
|
vocab_size: Optional[int] = None
|
||||||
n_dim: Optional[int] = None
|
dim: Optional[int] = None
|
||||||
n_head: Optional[int] = None
|
|
||||||
n_layer: Optional[int] = None
|
n_layers: Optional[int] = None
|
||||||
m_len: Optional[int] = None
|
|
||||||
norm_eps: Optional[float] = None
|
norm_eps: Optional[float] = None
|
||||||
d_ffn: Optional[int] = None
|
dim_ffn: Optional[int] = None
|
||||||
tie_weight: Optional[bool] = None
|
tie_weight: Optional[bool] = None
|
||||||
|
|
||||||
|
# RoPE
|
||||||
|
max_len: Optional[int] = None
|
||||||
|
rope_theta: Optional[float] = None
|
||||||
|
|
||||||
# GQA
|
# GQA
|
||||||
n_kvhead: Optional[int] = None
|
n_heads: Optional[int] = None
|
||||||
|
n_kv_heads: Optional[int] = None
|
||||||
|
use_qk_norm: Optional[bool] = None
|
||||||
|
use_gated_attention: Optional[bool] = None
|
||||||
|
|
||||||
|
|
||||||
def load(self, config_path: str) -> Self:
|
def load(self, config_path: str) -> Self:
|
||||||
|
config = {}
|
||||||
with open(config_path, 'r') as f:
|
with open(config_path, 'r') as f:
|
||||||
config: Dict[str, Any] = json.load(f)
|
config.update(json.load(f))
|
||||||
|
|
||||||
for key, value in config.items():
|
for key, value in config.items():
|
||||||
if hasattr(self, key):
|
if hasattr(self, key):
|
||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def save(self, config_path: str) -> None:
|
def save(self, config_path: str):
|
||||||
config_dict = asdict(self)
|
config_dict = {k: v for k, v in asdict(self).items() if v is not None}
|
||||||
config_dict = {k: v for k, v in config_dict.items() if v is not None}
|
|
||||||
with open(config_path, 'w') as f:
|
with open(config_path, 'w') as f:
|
||||||
json.dump(config_dict, f, indent=4)
|
json.dump(config_dict, f, indent=4)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -100,7 +100,7 @@ class GeneratorCore:
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
cur_cache_pos = start_pos
|
cur_cache_pos = start_pos
|
||||||
|
|
||||||
for _ in range(len(ids), self.config.m_len):
|
for _ in range(len(ids), self.config.max_len):
|
||||||
next_token_id, cache_increase = self.generate_iterator(
|
next_token_id, cache_increase = self.generate_iterator(
|
||||||
input_ids, temperature, top_k, top_p, attn_mask, kv_caches, cur_cache_pos)
|
input_ids, temperature, top_k, top_p, attn_mask, kv_caches, cur_cache_pos)
|
||||||
|
|
||||||
|
|
@ -127,7 +127,7 @@ class EmbeddingEncoderCore:
|
||||||
with_batch = isinstance(sentence, list)
|
with_batch = isinstance(sentence, list)
|
||||||
ids = self.tokenizer.encode(sentence)
|
ids = self.tokenizer.encode(sentence)
|
||||||
batch_ids = ids if with_batch else [ids]
|
batch_ids = ids if with_batch else [ids]
|
||||||
max_model_len = self.config.m_len
|
max_model_len = self.config.max_len
|
||||||
|
|
||||||
all_fragments = []
|
all_fragments = []
|
||||||
fragment_origin_idx = []
|
fragment_origin_idx = []
|
||||||
|
|
@ -195,10 +195,10 @@ class KVCacheManager:
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.device = device
|
self.device = device
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.num_layers = config.n_layer
|
self.num_layers = config.n_layers
|
||||||
self.max_len = config.m_len
|
self.max_len = config.max_len
|
||||||
self.num_heads = config.n_kvhead
|
self.num_heads = config.n_kv_heads
|
||||||
self.head_dim = config.n_dim //config.n_head
|
self.head_dim = config.dim //config.n_heads
|
||||||
|
|
||||||
self._kv_cache: Tuple[Tensor, Tensor] = None
|
self._kv_cache: Tuple[Tensor, Tensor] = None
|
||||||
self._seq_mask: Tensor = None
|
self._seq_mask: Tensor = None
|
||||||
|
|
|
||||||
|
|
@ -167,7 +167,7 @@ class StreamGenerator(GeneratorCore):
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
kv_caches = cache_manager.get_kvcache()
|
kv_caches = cache_manager.get_kvcache()
|
||||||
|
|
||||||
for _ in range(len(ids), self.config.m_len):
|
for _ in range(len(ids), self.config.max_len):
|
||||||
next_token_id, cache_increase = self.generate_iterator(
|
next_token_id, cache_increase = self.generate_iterator(
|
||||||
input_ids, temperature, top_k, top_p, kv_caches=kv_caches, start_pos=cur_cache_pos)
|
input_ids, temperature, top_k, top_p, kv_caches=kv_caches, start_pos=cur_cache_pos)
|
||||||
|
|
||||||
|
|
@ -219,7 +219,7 @@ class BatchGenerator(GeneratorCore):
|
||||||
start_cache_pos = max_ids_len
|
start_cache_pos = max_ids_len
|
||||||
cur_cache_pos = 0
|
cur_cache_pos = 0
|
||||||
|
|
||||||
while max_ids_len < self.config.m_len and sum(activate_task_mask) != 0:
|
while max_ids_len < self.config.max_len and sum(activate_task_mask) != 0:
|
||||||
kv_caches = cache_manager.get_kvcache()
|
kv_caches = cache_manager.get_kvcache()
|
||||||
attn_mask =cache_manager.get_seq_mask()
|
attn_mask =cache_manager.get_seq_mask()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -102,23 +102,20 @@ class RotaryEmbedding(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class Linear(nn.Module):
|
class Linear(nn.Module):
|
||||||
def __init__(self, in_dim: int, out_dim: int, bias: bool = False, weight_param=None, bias_param=None):
|
def __init__(self, in_dim: int, out_dim: int, bias: bool = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
weight_param = torch.empty((out_dim, in_dim)) if weight_param is None else weight_param
|
self.weight = nn.Parameter(torch.empty((out_dim, in_dim)))
|
||||||
bias_param = torch.zeros(out_dim) if bias_param is None else bias_param
|
self.bias = nn.Parameter(torch.zeros(out_dim)) if bias else None
|
||||||
|
|
||||||
self.weight = nn.Parameter(weight_param)
|
|
||||||
self.bias = nn.Parameter(bias_param) if bias else None
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
return F.linear(x, self.weight, self.bias)
|
return F.linear(x, self.weight, self.bias)
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
class RMSNorm(nn.Module):
|
||||||
def __init__(self, n_dim, norm_eps):
|
def __init__(self, dim, norm_eps):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.weight = nn.Parameter(torch.ones(n_dim))
|
self.weight = nn.Parameter(torch.ones(dim))
|
||||||
self.normalized_shape = (n_dim, )
|
self.normalized_shape = (dim, )
|
||||||
self.norm_eps = norm_eps
|
self.norm_eps = norm_eps
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
|
@ -127,41 +124,70 @@ class RMSNorm(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class MLP(nn.Module):
|
class MLP(nn.Module):
|
||||||
def __init__(self, n_dim: int, d_ffn: int):
|
def __init__(self, dim: int, dim_feed_forward: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.up = Linear(n_dim, d_ffn)
|
self.up = Linear(dim, dim_feed_forward)
|
||||||
self.gate = Linear(n_dim, d_ffn)
|
self.gate = Linear(dim, dim_feed_forward)
|
||||||
self.down = Linear(d_ffn, n_dim)
|
self.down = Linear(dim_feed_forward, dim)
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
gated = self.up(x) * F.silu(self.gate(x))
|
gated = self.up(x) * F.silu(self.gate(x))
|
||||||
out = self.down(gated)
|
out = self.down(gated)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
|
||||||
|
def forward(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, is_causal: bool= False):
|
||||||
|
# (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim)
|
||||||
|
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
|
||||||
|
# (bsz, n_heads, seq_len, head_dim) - > (bsz, seq_len, n_heads*head_dim)
|
||||||
|
sdqa_out = F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal).permute(0, 2, 1, 3).contiguous().flatten(2)
|
||||||
|
|
||||||
|
return sdqa_out
|
||||||
|
|
||||||
|
|
||||||
class GQA(nn.Module):
|
class GQA(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
n_dim: int,
|
dim: int,
|
||||||
n_head: int,
|
n_heads: int,
|
||||||
n_kvhead: int,
|
n_kv_heads: int,
|
||||||
|
use_qk_norm: bool,
|
||||||
|
norm_eps: float,
|
||||||
|
use_gated_attention: bool,
|
||||||
layer_id: int
|
layer_id: int
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert n_dim % n_head == 0
|
assert dim % n_heads == 0
|
||||||
assert n_head % n_kvhead == 0
|
assert n_heads % n_kv_heads == 0
|
||||||
|
|
||||||
self.head_dim = n_dim // n_head
|
self.head_dim = dim // n_heads
|
||||||
self.layer_id = layer_id
|
self.layer_id = layer_id
|
||||||
self.n_dim = n_dim
|
self.dim = dim
|
||||||
self.n_heads = n_head
|
self.n_heads = n_heads
|
||||||
self.n_kvheads = n_kvhead
|
self.n_kv_heads = n_kv_heads
|
||||||
self.n_rep = n_head // n_kvhead
|
self.n_rep = n_heads // n_kv_heads
|
||||||
|
self.use_qk_norm = use_qk_norm
|
||||||
|
self.use_gated_attention = use_gated_attention
|
||||||
|
|
||||||
self.q_proj = Linear(n_dim, n_head * self.head_dim)
|
self.attention = Attention()
|
||||||
self.k_proj = Linear(n_dim, n_kvhead * self.head_dim)
|
|
||||||
self.v_proj = Linear(n_dim, n_kvhead * self.head_dim)
|
self.q_proj = Linear(dim, n_heads * self.head_dim)
|
||||||
self.o_proj = Linear(n_dim, n_dim)
|
self.k_proj = Linear(dim, n_kv_heads * self.head_dim)
|
||||||
|
self.v_proj = Linear(dim, n_kv_heads * self.head_dim)
|
||||||
|
self.o_proj = Linear(dim, dim)
|
||||||
|
|
||||||
|
if self.use_qk_norm:
|
||||||
|
self.q_norm = RMSNorm(self.head_dim, norm_eps)
|
||||||
|
self.k_norm = RMSNorm(self.head_dim, norm_eps)
|
||||||
|
|
||||||
|
if self.use_gated_attention:
|
||||||
|
self.gate = Linear(dim, dim)
|
||||||
|
|
||||||
|
def _split_heads(self, x: Tensor, n_heads) -> Tensor:
|
||||||
|
batch_size, seq_len, _ = x.shape
|
||||||
|
x = x.reshape(batch_size, seq_len, n_heads, self.head_dim)
|
||||||
|
return x
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|
@ -174,10 +200,13 @@ class GQA(nn.Module):
|
||||||
bsz, seq_len, _ = x.size()
|
bsz, seq_len, _ = x.size()
|
||||||
# x(bsz, seq_len, n_heads * head_dim) -> (bsz, seq_len, n_heads, head_dim)
|
# x(bsz, seq_len, n_heads * head_dim) -> (bsz, seq_len, n_heads, head_dim)
|
||||||
q = self._split_heads(self.q_proj(x), self.n_heads)
|
q = self._split_heads(self.q_proj(x), self.n_heads)
|
||||||
k = self._split_heads(self.k_proj(x), self.n_kvheads)
|
k = self._split_heads(self.k_proj(x), self.n_kv_heads)
|
||||||
v = self._split_heads(self.v_proj(x), self.n_kvheads)
|
v = self._split_heads(self.v_proj(x), self.n_kv_heads)
|
||||||
q, k = apply_rotary_emb(q, rotary_emb), apply_rotary_emb(k, rotary_emb)
|
q, k = apply_rotary_emb(q, rotary_emb), apply_rotary_emb(k, rotary_emb)
|
||||||
|
|
||||||
|
if self.use_qk_norm:
|
||||||
|
q, k = self.q_norm(q), self.k_norm(k)
|
||||||
|
|
||||||
if kv_cache is not None:
|
if kv_cache is not None:
|
||||||
k_cache, v_cache = kv_cache
|
k_cache, v_cache = kv_cache
|
||||||
|
|
||||||
|
|
@ -190,27 +219,34 @@ class GQA(nn.Module):
|
||||||
v = v_cache[:bsz, :start_pos + seq_len, self.layer_id]
|
v = v_cache[:bsz, :start_pos + seq_len, self.layer_id]
|
||||||
|
|
||||||
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
|
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
|
||||||
|
sdqa_out = self.attention(q, k, v, mask, is_causal=(mask == None))
|
||||||
|
|
||||||
# (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim)
|
if self.use_gated_attention:
|
||||||
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
|
sdqa_out = sdqa_out * F.sigmoid(self.gate(x))
|
||||||
sdqa_out = F.scaled_dot_product_attention(q, k, v, mask, is_causal=(mask == None)).permute(0, 2, 1, 3)
|
|
||||||
out = self.o_proj(sdqa_out.contiguous().view(bsz, seq_len, -1))
|
out = self.o_proj(sdqa_out)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def _split_heads(self, x: Tensor, n_heads) -> Tensor:
|
|
||||||
batch_size, seq_len, _ = x.shape
|
|
||||||
x = x.reshape(batch_size, seq_len, n_heads, self.head_dim)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class DecoderBlock(nn.Module):
|
class DecoderBlock(nn.Module):
|
||||||
def __init__(self, n_dim, n_head, d_ffn, n_kvhead, norm_eps, layer_id):
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
n_heads: int,
|
||||||
|
dim_ffn: int,
|
||||||
|
n_kv_heads: int,
|
||||||
|
norm_eps: int,
|
||||||
|
use_qk_norm: bool,
|
||||||
|
use_gated_attention: bool,
|
||||||
|
layer_id: int
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.attention = GQA(n_dim, n_head, n_kvhead, layer_id)
|
self.attention = GQA(dim, n_heads, n_kv_heads,
|
||||||
self.norm_attn = RMSNorm(n_dim, norm_eps)
|
use_qk_norm, norm_eps, use_gated_attention, layer_id)
|
||||||
self.ffn = MLP(n_dim, d_ffn)
|
self.input_norm = RMSNorm(dim, norm_eps)
|
||||||
self.norm_ffn = RMSNorm(n_dim, norm_eps)
|
self.mlp = MLP(dim, dim_ffn)
|
||||||
|
self.post_attention_norm = RMSNorm(dim, norm_eps)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|
@ -222,7 +258,7 @@ class DecoderBlock(nn.Module):
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
# attention
|
# attention
|
||||||
attn_output = self.attention(
|
attn_output = self.attention(
|
||||||
self.norm_attn(x),
|
self.input_norm(x),
|
||||||
rotary_emb,
|
rotary_emb,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
|
|
@ -231,16 +267,15 @@ class DecoderBlock(nn.Module):
|
||||||
x = attn_output + x
|
x = attn_output + x
|
||||||
|
|
||||||
# feed forward
|
# feed forward
|
||||||
x = self.ffn(self.norm_ffn(x)) + x
|
x = self.mlp(self.post_attention_norm(x)) + x
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class Embedding(nn.Module):
|
class Embedding(nn.Module):
|
||||||
def __init__(self, vocab_size: int, embedding_dim: int, weight_param=None):
|
def __init__(self, vocab_size: int, embedding_dim: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
weight_param = torch.empty((vocab_size, embedding_dim)) if weight_param is None else weight_param
|
self.weight = nn.Parameter(torch.empty((vocab_size, embedding_dim)))
|
||||||
self.weight = nn.Parameter(weight_param)
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
return F.embedding(x, self.weight)
|
return F.embedding(x, self.weight)
|
||||||
|
|
@ -59,16 +59,17 @@ class Transformer(nn.Module):
|
||||||
def __init__(self, config: ModelConfig):
|
def __init__(self, config: ModelConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.rotary_embeding = RotaryEmbedding(config.n_dim // config.n_head, config.m_len)
|
self.rotary_embeding = RotaryEmbedding(config.dim // config.n_heads, config.max_len)
|
||||||
self.embed_tokens = Embedding(config.vocab_size, config.n_dim)
|
self.embed_tokens = Embedding(config.vocab_size, config.dim)
|
||||||
|
|
||||||
self.layers = nn.ModuleList([
|
self.layers = nn.ModuleList([
|
||||||
DecoderBlock(config.n_dim, config.n_head, config.d_ffn, config.n_kvhead, config.norm_eps, layer_id)
|
DecoderBlock(config.dim, config.n_heads, config.dim_ffn, config.n_kv_heads,
|
||||||
for layer_id in range(config.n_layer)
|
config.norm_eps, config.use_qk_norm, config.use_gated_attention, layer_id)
|
||||||
|
for layer_id in range(config.n_layers)
|
||||||
])
|
])
|
||||||
|
|
||||||
self.norm = RMSNorm(config.n_dim, config.norm_eps)
|
self.norm = RMSNorm(config.dim, config.norm_eps)
|
||||||
self.lm_head = Linear(config.n_dim, config.vocab_size)
|
self.lm_head = Linear(config.dim, config.vocab_size)
|
||||||
|
|
||||||
if self.config.tie_weight == True:
|
if self.config.tie_weight == True:
|
||||||
self.lm_head.weight = self.embed_tokens.weight
|
self.lm_head.weight = self.embed_tokens.weight
|
||||||
|
|
|
||||||
|
|
@ -83,19 +83,19 @@ def base_test_env(request: pytest.FixtureRequest):
|
||||||
n_dim_choices = [8, 16, 32]
|
n_dim_choices = [8, 16, 32]
|
||||||
n_head_choices = [2, 4]
|
n_head_choices = [2, 4]
|
||||||
|
|
||||||
n_dim = int(np.random.choice(n_dim_choices))
|
dim = int(np.random.choice(n_dim_choices))
|
||||||
n_head = int(np.random.choice(n_head_choices))
|
n_heads = int(np.random.choice(n_head_choices))
|
||||||
n_kvhead = n_head // 2
|
n_kv_heads = n_heads // 2
|
||||||
d_ffn = n_dim * 2
|
dim_ffn = dim * 2
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
"vocab_size": 1000,
|
"vocab_size": 1000,
|
||||||
"n_dim": n_dim,
|
"dim": dim,
|
||||||
"n_head": n_head,
|
"n_heads": n_heads,
|
||||||
"n_kvhead": n_kvhead,
|
"n_kv_heads": n_kv_heads,
|
||||||
"d_ffn": d_ffn,
|
"dim_ffn": dim_ffn,
|
||||||
"m_len": 1024,
|
"max_len": 1024,
|
||||||
"n_layer": 4,
|
"n_layers": 4,
|
||||||
"norm_eps": 1e-5
|
"norm_eps": 1e-5
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -22,12 +22,12 @@ def test_env(request: pytest.FixtureRequest):
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
"vocab_size": 1000,
|
"vocab_size": 1000,
|
||||||
"n_dim": 128,
|
"dim": 128,
|
||||||
"n_head": 4,
|
"n_heads": 4,
|
||||||
"n_kvhead": 2,
|
"n_kv_heads": 2,
|
||||||
"d_ffn": 256,
|
"dim_ffn": 256,
|
||||||
"m_len": 64,
|
"max_len": 64,
|
||||||
"n_layer": 2,
|
"n_layers": 2,
|
||||||
"norm_eps": 1e-5
|
"norm_eps": 1e-5
|
||||||
}
|
}
|
||||||
with open(config_path, 'w') as f:
|
with open(config_path, 'w') as f:
|
||||||
|
|
@ -64,9 +64,9 @@ def test_model_parameter(test_env):
|
||||||
def test_transformer(test_env):
|
def test_transformer(test_env):
|
||||||
model = test_env["model"]
|
model = test_env["model"]
|
||||||
input_ids = torch.randint(0, test_env["transformer_config"].vocab_size,
|
input_ids = torch.randint(0, test_env["transformer_config"].vocab_size,
|
||||||
(4, test_env["transformer_config"].m_len))
|
(4, test_env["transformer_config"].max_len))
|
||||||
output_logits = model(input_ids)["logits"]
|
output_logits = model(input_ids)["logits"]
|
||||||
target_shape = (4, test_env["transformer_config"].m_len, test_env["transformer_config"].vocab_size)
|
target_shape = (4, test_env["transformer_config"].max_len, test_env["transformer_config"].vocab_size)
|
||||||
assert output_logits.shape == target_shape
|
assert output_logits.shape == target_shape
|
||||||
|
|
||||||
# generator
|
# generator
|
||||||
|
|
@ -80,7 +80,7 @@ def test_embedding_encoder_core(test_env):
|
||||||
|
|
||||||
single_emb = encoder.encode("测试文本")
|
single_emb = encoder.encode("测试文本")
|
||||||
assert isinstance(single_emb, torch.Tensor)
|
assert isinstance(single_emb, torch.Tensor)
|
||||||
assert single_emb.shape[-1] == test_env["transformer_config"].n_dim
|
assert single_emb.shape[-1] == test_env["transformer_config"].dim
|
||||||
|
|
||||||
|
|
||||||
batch_emb = encoder.encode(["测试1", "测试2"])
|
batch_emb = encoder.encode(["测试1", "测试2"])
|
||||||
|
|
|
||||||
|
|
@ -16,12 +16,12 @@ def transformer_test_env():
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
"vocab_size": 1000,
|
"vocab_size": 1000,
|
||||||
"n_dim": 128,
|
"dim": 128,
|
||||||
"n_head": 4,
|
"n_heads": 4,
|
||||||
"n_kvhead": 2,
|
"n_kv_heads": 2,
|
||||||
"d_ffn": 256,
|
"dim_ffn": 256,
|
||||||
"m_len": 64,
|
"max_len": 64,
|
||||||
"n_layer": 2,
|
"n_layers": 2,
|
||||||
"norm_eps": 1e-5
|
"norm_eps": 1e-5
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,7 @@ class GenerationBenchmark:
|
||||||
def _initialize_kv_cache(self, batch_size: int) -> list:
|
def _initialize_kv_cache(self, batch_size: int) -> list:
|
||||||
"""初始化KV缓存"""
|
"""初始化KV缓存"""
|
||||||
config = self.config
|
config = self.config
|
||||||
shape = (batch_size, config.m_len, config.n_layer, config.n_kvhead, config.n_dim // config.n_head)
|
shape = (batch_size, config.max_len, config.n_layers, config.n_kv_heads, config.dim // config.n_heads)
|
||||||
k_cache = torch.zeros(shape, device=self.device, dtype=self.dtype)
|
k_cache = torch.zeros(shape, device=self.device, dtype=self.dtype)
|
||||||
v_cache = torch.zeros(shape, device=self.device, dtype=self.dtype)
|
v_cache = torch.zeros(shape, device=self.device, dtype=self.dtype)
|
||||||
return (k_cache, v_cache)
|
return (k_cache, v_cache)
|
||||||
|
|
@ -175,12 +175,12 @@ def print_benchmark_result(result: BenchmarkResult):
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
config = ModelConfig(
|
config = ModelConfig(
|
||||||
vocab_size=10000,
|
vocab_size=10000,
|
||||||
n_dim=1536,
|
dim=1536,
|
||||||
n_head=24,
|
n_heads=24,
|
||||||
n_kvhead=4,
|
n_kv_heads=4,
|
||||||
d_ffn=6912,
|
dim_ffn=6912,
|
||||||
m_len=2048,
|
max_len=2048,
|
||||||
n_layer=24,
|
n_layers=24,
|
||||||
norm_eps=1e-5,
|
norm_eps=1e-5,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -111,7 +111,7 @@ def train(
|
||||||
parameter.load(param_path)
|
parameter.load(param_path)
|
||||||
|
|
||||||
if window_size is None:
|
if window_size is None:
|
||||||
window_size = parameter.config.m_len
|
window_size = parameter.config.max_len
|
||||||
|
|
||||||
model = parameter.model
|
model = parameter.model
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue