1. 技术背景
注意力机制是现代深度学习的核心组件之一,特别是Transformer架构的成功,使注意力机制成为自然语言处理、计算机视觉和多模态领域的基础算子。然而,标准的注意力计算存在显著的计算瓶颈和内存访问问题,限制了模型规模和训练效率。
常见的缩放点积实现方式如下:
在做算子优化的时候, 需要分清楚这个算子是compute bound 还是 memory bound, 另外还需要考虑到GPU 的分级内存结构, 我们通常使用算数强度进行分析,其中算数强度(Arithmetic Intensity)是指在一个计算过程中,每从内存中搬运1字节数据,所执行的浮点运算次数,对于原始实现的数学计算强度分析如下:
对于计算强度:
: 2MNd FLOPS- P = softmax(S) : 5MN FLOPS
- O = PV : 2MNd FLOPS
总计算强度为 4MNd + 5MN FLOPs
对于内存访问数量, 假设 FP32 精度:
- 读取Q, K: 4 × (Md + Nd)
- 写回S: 4 × MN
- 读取S: 4 × MN
- 写回P: 4 × MN
- 读取P, V: 4 × (MN + Nd)
- 写回O: 4 × Md
总访问量为 4 × (2Md + 2Nd + 4MN)
求得算数强度的表达式为:
我们选用一个比较常见的训练参数 M = N = 2048, d = 64 计算得到
而对比之下, 对于矩阵乘法而言, 当三个个维度相同时并使用FP32精度, 其计算强度为
所以就数学上分析而言,attention 算子是属于memory bound 的一类, 需要对访存进行优化并且高效地存储中间值, 通过引入 OnlineSoftmax 机制, 我们可以高效地解决这一问题。传统的Softmax 计算分为三步,分别是求的指数最大值,计算指数和, 归一化, 其中不管如何都会有两次读取输入参数,并且从HBM中读取而不是从SRAM中读取。 Online softmax 是分块计算指数和并且动态更新输出,最后除以迭代后的指数和, 从而只用读取一次HBM, 节省访问。
传统Softmax 实现伪代码: 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25# 输入: S [M, N] - 注意力分数矩阵
# 输出: P [M, N] - softmax概率矩阵
m = full(M, -inf) # 最大值
l = zeros(M) # 指数和
# 第一步: 逐行计算最大值 (HBM 访问)
m = zeros(M) # 每行的最大值
for i in range(M):
m[i] = max(S[i, :])
# 第二步: 逐行计算指数和 (HBM 访问)
l = zeros(M) # 每行的指数和
for i in range(M):
sum_exp = 0
for j in range(N):
sum_exp += exp(S[i, j] - m[i]) # 再次访问同一行
l[i] = sum_exp
# 第三步: 计算归一化结果并写回HBM
P = zeros(M, N)
for i in range(M):
for j in range(N):
P[i, j] = exp(S[i, j] - m[i]) / l[i]
OnlineSoftmax 实现代码: 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51# 更高效的实现,避免重复计算指数
# 输入: S [M, N] - 注意力分数矩阵
# 输出: P [M, N] - softmax概率矩阵
Bc = 128 # 列块大小
Tc = ceil(N / Bc)
m_global = full(M, -inf) # 最大值
l_global = zeros(M) # 指数和
for col_block in range(Tc):
col_start = col_block * Bc
col_end = col_start + Bc
# 加载当前块
S_block = S[:, col_start:col_end] # [M, Bc]
# 1. 计算当前块的最大值 (HBM 访问)
m_local = zeros(M)
for i in range(M):
m_local[i] = max(S_block[i, :])
# 2. 更新全局最大值
m_new = zeros(M)
for i in range(M):
m_new[i] = max(m_global[i], m_local[i])
# 3. 更新指数和
for i in range(M):
if col_block > 0:
scale_old = exp(m_global[i] - m_new[i])
l_global[i] = l_global[i] * scale_old
local_sum = 0
for j in range(Bc):
exp_val = exp(S_block[i, j] - m_new[i])
P[i, col_start + j] = exp_val
local_sum += exp_val
scale_new = exp(m_local[i] - m_new[i])
l_global[i] = l_global[i] + scale_new * local_sum
# 注意:此时P中的值还没有除以l_global
# 更新全局最大值
m_global = m_new
# 4. 最终归一化
# 所有块处理完后,一次性归一化
for i in range(M):
for j in range(N):
P[i, j] = P[i, j] / l_global[i]
由此分析得到 flash attention 融合版本:
2. 伪代码实现
1 | # Q: [M, d] # Query |
3. 算数强度分析
之前我们已经求得传统attention 的计算访比
对于计算强度, 由于flash attention 采用的是分块策略, 并没有降低计算量: 1. 计算 QiKjT : (Br, d) × (d, Bc) − > (Br, Bc), FLOPs = 2BrBcd 2. OnlineSoftmax 计算: 求max BrBc, 减去max并 exp 2BrBc, 求和BrBc, 归一化因子更新BrBc + 7BrBc, 总共 12BrBc 3. 计算 PijVj: (Br, Bc) × (Bc, d) − > (Br, d), FLOPs = 2BrBcd 4. Oi 更新: 2Brd
对于分块内部更新 4BrBcd + 12BrBc + 2Brd, 则对于全局更新计算 FLOPs = TrTc(4BrBcd + 12BrBc + 2Brd) ≈ 4MNd + 12MN (低次项忽略)
对于访问强度, 只用计算HBM访问而忽略SRAM访问(SRAM远快于DRAM, 并且SRAM是瓶颈), 并且采用FP32 访问: 1. 读取Qi, Kj, Vj: 4 × (Brd + 2Bcd) 2. 写回Oi: 4 × Brd
对于分块的访问量 4 × (2Brd + 2Bcd), 对于全局HBM 访问量, 你可能会这样想
但是实际上,由于L2 缓存的存在,KV 部分是不会重复加载的, 以长度为2048为例子举例,KV 块的总大小 2 × 2048 × 64 × 4 = 1048576Bytes ≈ 1MB, 以A100 为例, 其L2缓存大小约 40MB, 远远大于KV 块的大小, 所以实际上执行attention 任务时, 只用从全局加载一次KV
所以实际上的访问量: 1. 读取Q, K,V: 4 × (Md + 2Nd) 2. 写回O: 4 × Md
总访问量 8(Md + Nd)
则算数强度的计算
在常见的计算中, 我们可以忽略低次方项
带入 M = N = 2048, d = 64 计算可以得到 R = 33 所以从访存上说, 对于长序列, flash 版本有相当大的优势,大约在 500 FLOPs/Byte 左右, 极大地增大了计算强度。