1. 技术背景
注意力机制是现代深度学习的核心组件之一,特别是Transformer架构的成功,使注意力机制成为自然语言处理、计算机视觉和多模态领域的基础算子。然而,标准的注意力计算存在显著的计算瓶颈和内存访问问题,限制了模型规模和训练效率。
常见的缩放点积实现方式如下:
在做算子优化的时候, 需要分清楚这个算子是compute bound 还是 memory bound, 另外还需要考虑到GPU 的分级内存结构, 我们通常使用算数强度进行分析,其中算数强度(Arithmetic Intensity)是指在一个计算过程中,每从内存中搬运1字节数据,所执行的浮点运算次数,对于原始实现的数学计算强度分析如下:
对于计算强度:
- : FLOPS
- : FLOPS
- : FLOPS
总计算强度为 FLOPs
对于内存访问数量, 假设 FP32 精度:
- 读取:
- 写回:
- 读取:
- 写回:
- 读取:
- 写回:
总访问量为
求得算数强度的表达式为:
,
我们选用一个比较常见的训练参数
计算得到
而对比之下, 对于矩阵乘法而言, 当三个个维度相同时并使用FP32精度, 其计算强度为 , 以 为例
计算得到 , 相对而言计算强度更大。
所以就数学上分析而言,attention 算子是属于memory bound 的一类, 需要对访存进行优化并且高效地存储中间值, 通过引入 OnlineSoftmax 机制, 我们可以高效地解决这一问题。传统的Softmax 计算分为三步,分别是求的指数最大值,计算指数和, 归一化, 其中不管如何都会有两次读取输入参数,并且从HBM中读取而不是从SRAM中读取。
Online softmax 是分块计算指数和并且动态更新输出,最后除以迭代后的指数和, 从而只用读取一次HBM, 节省访问。
传统Softmax 实现伪代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
|
m = full(M, -inf) l = zeros(M)
m = zeros(M) for i in range(M): m[i] = max(S[i, :])
l = zeros(M) for i in range(M): sum_exp = 0 for j in range(N): sum_exp += exp(S[i, j] - m[i]) l[i] = sum_exp
P = zeros(M, N) for i in range(M): for j in range(N): P[i, j] = exp(S[i, j] - m[i]) / l[i]
|
OnlineSoftmax 实现代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
|
Bc = 128 Tc = ceil(N / Bc)
m_global = full(M, -inf) l_global = zeros(M)
for col_block in range(Tc): col_start = col_block * Bc col_end = col_start + Bc S_block = S[:, col_start:col_end] m_local = zeros(M) for i in range(M): m_local[i] = max(S_block[i, :]) m_new = zeros(M) for i in range(M): m_new[i] = max(m_global[i], m_local[i]) for i in range(M): if col_block > 0: scale_old = exp(m_global[i] - m_new[i]) l_global[i] = l_global[i] * scale_old local_sum = 0 for j in range(Bc): exp_val = exp(S_block[i, j] - m_new[i]) P[i, col_start + j] = exp_val local_sum += exp_val scale_new = exp(m_local[i] - m_new[i]) l_global[i] = l_global[i] + scale_new * local_sum m_global = m_new
for i in range(M): for j in range(N): P[i, j] = P[i, j] / l_global[i]
|
由此分析得到 flash attention 融合版本:
2. 伪代码实现
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
|
Br = block_size_q Bc = block_size_kv Tr = ceil(M / Br) Tc = ceil(N / Bc)
for q_block_idx in range(Tr): m_prev = tensor((Br,), fill=-inf) l_prev = tensor((Br,), fill=0.0) O_acc = tensor((Br, d), fill=0.0)
for kv_block_idx in range(Tc): q_start = q_block_idx * Br k_start = kv_block_idx * Bc
Q_tile = Q[batch_idx, head_idx, q_start:q_start+Br, :] K_tile = K[batch_idx, head_idx, k_start:k_start+Bc, :] V_tile = V[batch_idx, head_idx, k_start:k_start+Bc, :]
S_tile = scale * (Q_tile @ K_tile.T)
m_j = max(S_tile, dim=-1) P_j = exp(S_tile - m_j[:, None]) l_j = sum(P_j, dim=1)
m_new = maximum(m_prev, m_j) l_new = exp(m_prev - m_new) * l_prev + exp(m_j - m_new) * l_j
scale_old = exp(m_prev - m_new)[:, None] scale_new = exp(m_j - m_new)[:, None] O_acc = scale_old * O_acc + scale_new * (P_j @ V_tile)
m_prev = m_new l_prev = l_new
O_block = O_acc / l_prev[:, None] O[batch_idx, head_idx, q_start:q_start+Br, :] = O_block
|
3. 算数强度分析
之前我们已经求得传统attention 的计算访比, 现在需要计算FlashAttention 版本的计算访存比。
对于计算强度, 由于flash attention 采用的是分块策略, 并没有降低计算量:
- 计算 : , FLOPs =
- OnlineSoftmax 计算: 求max , 减去max并 exp , 求和, 归一化因子更新, 总共
- 计算 : , FLOPs =
- 更新:
对于分块内部更新 , 则对于全局更新计算 FLOPs = (低次项忽略)
对于访问强度, 只用计算HBM访问而忽略SRAM访问(SRAM远快于DRAM, 并且SRAM是瓶颈), 并且采用FP32 访问:
- 读取:
- 写回:
对于分块的访问量 , 对于全局HBM 访问量, 你可能会这样想
但是实际上,由于L2 缓存的存在,KV 部分是不会重复加载的, 以长度为2048为例子举例,KV 块的总大小
, 以A100 为例, 其L2缓存大小约 40MB, 远远大于KV 块的大小, 所以实际上执行attention 任务时, 只用从全局加载一次KV
所以实际上的访问量:
- 读取,:
- 写回:
总访问量
则算数强度的计算
在常见的计算中, 我们可以忽略低次方项
带入 计算可以得到
所以从访存上说, 对于长序列, flash 版本有相当大的优势,大约在 500 左右, 极大地增大了计算强度。