关于llama架构的分析
1. 整体架构概览
Llama 3采用了经典的纯解码器Transformer架构,整体设计围绕自回归语言生成任务进行深度优化。模型核心由词嵌入层、多个堆叠的Transformer块以及输出层构成,每个Transformer块内部包含多头注意力机制和前馈网络两部分。特别值得注意的是其位置编码方案——使用了改进版的旋转位置编码,基础频率参数设定为50万,这比原始RoPE的1万要高出许多,使得模型能够更好地处理长序列文本并具备更强的位置外推能力。
在注意力机制方面,Llama 3引入了分组查询注意力设计,允许键值头的数量少于查询头,通过重复利用键值头来匹配查询头的数量,这种设计在保持模型性能的同时显著降低了内存占用。前馈网络采用了SwiGLU激活函数,其独特的门控机制提供了比传统MLP更强的表达能力。整个模型还采用了RMSNorm进行层归一化,计算效率高于标准的LayerNorm,同时配合模型并行策略和KV缓存技术,在训练和推理过程中都实现了优异的速度表现和内存效率。
1.1 模型初始化
1 | class Transformer(nn.Module): |
关键点:
- 使用
VocabParallelEmbedding将词表分割到多个GPU - 每层都是独立的
TransformerBlock - 输出前进行 RMSNorm 归一化
- 预先计算旋转位置编码,减少运行时计算
2. 注意力机制详解
2.1 注意力头配置
1 | class Attention(nn.Module): |
关键点:
- 支持 GQA:当
n_kv_heads < n_heads时,KV头会被复用 - 每个GPU只计算部分注意力头(
n_local_heads) - KV缓存预分配内存,加速自回归生成
2.2 注意力前向传播
1 | def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]): |
3. RoPE(旋转位置编码)实现
3.1 预计算旋转矩阵
1 | def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): |
数学原理:
- 对于位置
m和维度i,旋转角度为:m * theta^{-2i/d} - 复数形式:
(cos(mθ), sin(mθ))应用于 Q 和 K 的相邻维度对
3.2 应用旋转位置编码
1 | def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor): |
4. 前馈网络(SwiGLU)
4.1 FeedForward 实现
1 | class FeedForward(nn.Module): |
SwiGLU 公式:
1 | FFN(x) = W₂ · (silu(W₁·x) ⊙ W₃·x) |
其中 silu 是 Swish 激活函数:x * sigmoid(x)
5. Transformer Block 结构
5.1 单层 Transformer
1 | class TransformerBlock(nn.Module): |
架构特点:
- Pre-LayerNorm:归一化在子层之前
- 残差连接:每个子层都有残差连接
- 并行计算:注意力头和FFN都支持模型并行