1. 技术背景
注意力机制是现代深度学习的核心组件之一,特别是Transformer架构的成功,使注意力机制成为自然语言处理、计算机视觉和多模态领域的基础算子。然而,标准的注意力计算存在显著的计算瓶颈和内存访问问题,限制了模型规模和训练效率。
常见的缩放点积实现方式如下:
在做算子优化的时候, 需要分清楚这个算子是compute bound 还是 memory bound, 另外还需要考虑到GPU 的分级内存结构, 我们通常使用算数强度进行分析,其中算数强度(Arithmetic Intensity)是指在一个计算过程中,每从内存中搬运1字节数据,所执行的浮点运算次数,对于原始实现的数学计算强度分析如下:
对于计算强度:
: FLOPS : FLOPS : FLOPS
总计算强度为
对于内存访问数量, 假设 FP32 精度:
- 读取
: - 写回
: - 读取
: - 写回
: - 读取
: - 写回
:
总访问量为
求得算数强度的表达式为:
我们选用一个比较常见的训练参数
而对比之下, 对于矩阵乘法而言, 当三个个维度相同时并使用FP32精度, 其计算强度为
所以就数学上分析而言,attention 算子是属于memory bound 的一类, 需要对访存进行优化并且高效地存储中间值, 通过引入 OnlineSoftmax 机制, 我们可以高效地解决这一问题。传统的Softmax 计算分为三步,分别是求的指数最大值,计算指数和, 归一化, 其中不管如何都会有两次读取输入参数,并且从HBM中读取而不是从SRAM中读取。 Online softmax 是分块计算指数和并且动态更新输出,最后除以迭代后的指数和, 从而只用读取一次HBM, 节省访问。
传统Softmax 实现伪代码: 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25# 输入: S [M, N] - 注意力分数矩阵
# 输出: P [M, N] - softmax概率矩阵
m = full(M, -inf) # 最大值
l = zeros(M) # 指数和
# 第一步: 逐行计算最大值 (HBM 访问)
m = zeros(M) # 每行的最大值
for i in range(M):
m[i] = max(S[i, :])
# 第二步: 逐行计算指数和 (HBM 访问)
l = zeros(M) # 每行的指数和
for i in range(M):
sum_exp = 0
for j in range(N):
sum_exp += exp(S[i, j] - m[i]) # 再次访问同一行
l[i] = sum_exp
# 第三步: 计算归一化结果并写回HBM
P = zeros(M, N)
for i in range(M):
for j in range(N):
P[i, j] = exp(S[i, j] - m[i]) / l[i]
OnlineSoftmax 实现代码: 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51# 更高效的实现,避免重复计算指数
# 输入: S [M, N] - 注意力分数矩阵
# 输出: P [M, N] - softmax概率矩阵
Bc = 128 # 列块大小
Tc = ceil(N / Bc)
m_global = full(M, -inf) # 最大值
l_global = zeros(M) # 指数和
for col_block in range(Tc):
col_start = col_block * Bc
col_end = col_start + Bc
# 加载当前块
S_block = S[:, col_start:col_end] # [M, Bc]
# 1. 计算当前块的最大值 (HBM 访问)
m_local = zeros(M)
for i in range(M):
m_local[i] = max(S_block[i, :])
# 2. 更新全局最大值
m_new = zeros(M)
for i in range(M):
m_new[i] = max(m_global[i], m_local[i])
# 3. 更新指数和
for i in range(M):
if col_block > 0:
scale_old = exp(m_global[i] - m_new[i])
l_global[i] = l_global[i] * scale_old
local_sum = 0
for j in range(Bc):
exp_val = exp(S_block[i, j] - m_new[i])
P[i, col_start + j] = exp_val
local_sum += exp_val
scale_new = exp(m_local[i] - m_new[i])
l_global[i] = l_global[i] + scale_new * local_sum
# 注意:此时P中的值还没有除以l_global
# 更新全局最大值
m_global = m_new
# 4. 最终归一化
# 所有块处理完后,一次性归一化
for i in range(M):
for j in range(N):
P[i, j] = P[i, j] / l_global[i]
由此分析得到 flash attention 融合版本:
2. 伪代码实现
1 | # Q: [M, d] # Query |
3. 算数强度分析
之前我们已经求得传统attention 的计算访比
对于计算强度, 由于flash attention 采用的是分块策略, 并没有降低计算量: 1. 计算
对于分块内部更新
对于访问强度, 只用计算HBM访问而忽略SRAM访问(SRAM远快于DRAM, 并且SRAM是瓶颈), 并且采用FP32 访问: 1. 读取
对于分块的访问量
但是实际上,由于L2 缓存的存在,KV 部分是不会重复加载的, 以长度为2048为例子举例,KV 块的总大小
所以实际上的访问量: 1. 读取
总访问量
则算数强度的计算
在常见的计算中, 我们可以忽略低次方项
带入