FlashAttention算子优化
ViperEkura Lv1

1. 技术背景

注意力机制是现代深度学习的核心组件之一,特别是Transformer架构的成功,使注意力机制成为自然语言处理、计算机视觉和多模态领域的基础算子。然而,标准的注意力计算存在显著的计算瓶颈和内存访问问题,限制了模型规模和训练效率。

常见的缩放点积实现方式如下:

在做算子优化的时候, 需要分清楚这个算子是compute bound 还是 memory bound, 另外还需要考虑到GPU 的分级内存结构, 我们通常使用算数强度进行分析,其中算数强度(Arithmetic Intensity)是指在一个计算过程中,每从内存中搬运1字节数据,所执行的浮点运算次数,对于原始实现的数学计算强度分析如下:

对于计算强度:

  1. : FLOPS
  2. : FLOPS
  3. : FLOPS

总计算强度为 FLOPs

对于内存访问数量, 假设 FP32 精度:

  1. 读取:
  2. 写回:
  3. 读取:
  4. 写回:
  5. 读取:
  6. 写回:

总访问量为

求得算数强度的表达式为:

,

我们选用一个比较常见的训练参数
计算得到

而对比之下, 对于矩阵乘法而言, 当三个个维度相同时并使用FP32精度, 其计算强度为 , 以 为例
计算得到 , 相对而言计算强度更大。

所以就数学上分析而言,attention 算子是属于memory bound 的一类, 需要对访存进行优化并且高效地存储中间值, 通过引入 OnlineSoftmax 机制, 我们可以高效地解决这一问题。传统的Softmax 计算分为三步,分别是求的指数最大值,计算指数和, 归一化, 其中不管如何都会有两次读取输入参数,并且从HBM中读取而不是从SRAM中读取。
Online softmax 是分块计算指数和并且动态更新输出,最后除以迭代后的指数和, 从而只用读取一次HBM, 节省访问。

传统Softmax 实现伪代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# 输入: S [M, N] - 注意力分数矩阵
# 输出: P [M, N] - softmax概率矩阵

m = full(M, -inf) # 最大值
l = zeros(M) # 指数和

# 第一步: 逐行计算最大值 (HBM 访问)
m = zeros(M) # 每行的最大值
for i in range(M):
m[i] = max(S[i, :])

# 第二步: 逐行计算指数和 (HBM 访问)
l = zeros(M) # 每行的指数和
for i in range(M):
sum_exp = 0
for j in range(N):
sum_exp += exp(S[i, j] - m[i]) # 再次访问同一行
l[i] = sum_exp

# 第三步: 计算归一化结果并写回HBM
P = zeros(M, N)
for i in range(M):
for j in range(N):
P[i, j] = exp(S[i, j] - m[i]) / l[i]

OnlineSoftmax 实现代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
# 更高效的实现,避免重复计算指数
# 输入: S [M, N] - 注意力分数矩阵
# 输出: P [M, N] - softmax概率矩阵

Bc = 128 # 列块大小
Tc = ceil(N / Bc)

m_global = full(M, -inf) # 最大值
l_global = zeros(M) # 指数和

for col_block in range(Tc):
col_start = col_block * Bc
col_end = col_start + Bc

# 加载当前块
S_block = S[:, col_start:col_end] # [M, Bc]

# 1. 计算当前块的最大值 (HBM 访问)
m_local = zeros(M)
for i in range(M):
m_local[i] = max(S_block[i, :])

# 2. 更新全局最大值
m_new = zeros(M)
for i in range(M):
m_new[i] = max(m_global[i], m_local[i])

# 3. 更新指数和
for i in range(M):
if col_block > 0:
scale_old = exp(m_global[i] - m_new[i])
l_global[i] = l_global[i] * scale_old

local_sum = 0
for j in range(Bc):
exp_val = exp(S_block[i, j] - m_new[i])
P[i, col_start + j] = exp_val
local_sum += exp_val

scale_new = exp(m_local[i] - m_new[i])
l_global[i] = l_global[i] + scale_new * local_sum
# 注意:此时P中的值还没有除以l_global

# 更新全局最大值
m_global = m_new

# 4. 最终归一化
# 所有块处理完后,一次性归一化
for i in range(M):
for j in range(N):
P[i, j] = P[i, j] / l_global[i]

由此分析得到 flash attention 融合版本:

2. 伪代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
# Q: [M, d]               # Query
# K: [N, d] # Key
# V: [N, d] # Value
# scale = 1.0 / sqrt(d) # scale factor

Br = block_size_q
Bc = block_size_kv
Tr = ceil(M / Br)
Tc = ceil(N / Bc)

for q_block_idx in range(Tr):
m_prev = tensor((Br,), fill=-inf)
l_prev = tensor((Br,), fill=0.0)
O_acc = tensor((Br, d), fill=0.0)

for kv_block_idx in range(Tc):
q_start = q_block_idx * Br
k_start = kv_block_idx * Bc

Q_tile = Q[batch_idx, head_idx, q_start:q_start+Br, :] # (Br, d)
K_tile = K[batch_idx, head_idx, k_start:k_start+Bc, :] # (Bc, d)
V_tile = V[batch_idx, head_idx, k_start:k_start+Bc, :] # (Bc, d)

S_tile = scale * (Q_tile @ K_tile.T) # (Br, Bc)

m_j = max(S_tile, dim=-1) # (Br,)
P_j = exp(S_tile - m_j[:, None]) # (Br, Bc)
l_j = sum(P_j, dim=1) # (Br,)

m_new = maximum(m_prev, m_j) # (Br,)
l_new = exp(m_prev - m_new) * l_prev + exp(m_j - m_new) * l_j # (Br,)

scale_old = exp(m_prev - m_new)[:, None] # (Br, 1)
scale_new = exp(m_j - m_new)[:, None] # (Br, 1)
O_acc = scale_old * O_acc + scale_new * (P_j @ V_tile) # (Br, d)

m_prev = m_new
l_prev = l_new


O_block = O_acc / l_prev[:, None] # (Br, d)
O[batch_idx, head_idx, q_start:q_start+Br, :] = O_block

3. 算数强度分析

之前我们已经求得传统attention 的计算访比, 现在需要计算FlashAttention 版本的计算访存比。

对于计算强度, 由于flash attention 采用的是分块策略, 并没有降低计算量:

  1. 计算 : , FLOPs =
  2. OnlineSoftmax 计算: 求max , 减去max并 exp , 求和, 归一化因子更新, 总共
  3. 计算 : , FLOPs =
  4. 更新:

对于分块内部更新 , 则对于全局更新计算 FLOPs = (低次项忽略)

对于访问强度, 只用计算HBM访问而忽略SRAM访问(SRAM远快于DRAM, 并且SRAM是瓶颈), 并且采用FP32 访问:

  1. 读取:
  2. 写回:

对于分块的访问量 , 对于全局HBM 访问量, 你可能会这样想

但是实际上,由于L2 缓存的存在,KV 部分是不会重复加载的, 以长度为2048为例子举例,KV 块的总大小
, 以A100 为例, 其L2缓存大小约 40MB, 远远大于KV 块的大小, 所以实际上执行attention 任务时, 只用从全局加载一次KV

所以实际上的访问量:

  1. 读取:
  2. 写回:

总访问量

则算数强度的计算

在常见的计算中, 我们可以忽略低次方项

带入 计算可以得到
所以从访存上说, 对于长序列, flash 版本有相当大的优势,大约在 500 左右, 极大地增大了计算强度。

 REWARD AUTHOR