flash attention 进化史

flash-attention V1

论文链接:FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.

从 HBM 到分块

提升 Transformer 的性能和能力,使之能够处理更长的序列一直是一个非常重要课题。而 Transformer 架构中的 self-attention 计算量和访存量会随着输入序列长度呈二次增长。为了提升 Flash Attention 在 GPU 上的运行性能,特别是优化 attention 计算过程中的访存性能,Tri-Dao 提出了 Flash Attention,旨在优化 self-attention 计算时数据访存行为,因为现代 GPU 上,计算本身已经高度并行化,十分高效,而访存则因为带宽和延时进步不够快而落后。但若将小块数据量放在更小更近的高速缓存中,那么计算和访存的速度就相匹配,能更快地完成 attention 计算(如下左图)。

FlashAttention V1 示意图

Tips:标准 Attention 的计算公式如下:

Attn=softmax(QKTd)V \text{Attn} = \text{softmax}(\frac{QK^T}{\sqrt{d}})V

前面提到,Flash Attention 特别注重算法的访存行为,因此将标准 attention 算法抽象为如下图的具体流程:

标准 attention 算法流程

不难看出,上面的算法中一共包含 8 次 HBM 的读写操作,分别为:

  • 第一行对 Q,KQ,K 的读取共两次,对 SS 的写入一次,读写总共三次,访存复杂度为 Θ(Nd+N2)\Theta(Nd+N^2)
  • 第二行对 SS 读取一次,对 PP 写入一次,读写总共两次;访存复杂度为 Θ(N2)\Theta(N^2)
  • 第三行对 P,VP,V 的读取共两次,对 OO 的写入一次,读写总共三次,访存复杂度为 Θ(Nd+N2)\Theta(Nd+N^2)

标准 attention 算法的访存复杂度为 Θ(N2)\Theta(N^2),对于较长的 N(即输入序列)来说,GPU 的内存访问压力太大了。为了减少对 HBM 的读写,Flash Attention 将参与计算的矩阵分块送进 SRAM,来提高整体读写速度(减少了 HBM 读写)。

softmax 的处理

确定了通过分块来降低对 HBM 的操作次数的思想后,一个棘手的问题挡在了我们前面,那就是 softmax。

数学上,标准 softmax 公式如下:

softmax(xi)=exij=1dexj \text{softmax}(x_i) = \frac{e^{x_i}}{\sum^d_{j=1}e^{x_j}}

而在程序实现中,为了防止 exp 值过大导致的数值溢出,通常需要对 softmax 做如下修改:将每一项的幂次都减去这里的最大值 m(x)=max(x)m(x)=\max(x),因为计算是相除,所以结果不变,但避免了溢出。

于是得到了 safe softmax:

safe softmax(xi)=exim(x)j=1dexjm(x) \text{safe softmax}(x_i) = \frac{e^{x_{i}-m(x)}}{\sum^d_{j=1}{e^{x_{j}-m(x)}}}

但规避了数值溢出问题,另一个问题也紧接而来,如果要求出 m(x),就意味着必须先获取 SS 这一块数据的所有值,才能完成 softmax 计算,这与我们要做分块并行的初衷相违背。

好在 Tri-Dao 等人通过巧妙地公式变换,将这一限制规避开了,接下来我们来看具体实现过程:

假设我们通过 QKTQK^T 得到了矩阵 SS,该矩阵的某一行的向量为 x=[x1,x2,,xd]x=[x_1,x_2,…,x_d],因为分块的原因, 它被我们切成了两部分 x=[x(1),x(2)]x=[x^{(1)},x^{(2)}]

定义:

  • m(x)m(x):标准场景下,该行的全局最大值
  • m(x(1))m(x^{(1)}):分块1的全局最大值
  • m(x(2))m(x^{(2)}):分块2的全局最大值

那么易知:m(x)=max(x(1),x(2))=max(m(x(1)),m(x(2)))m(x)=\max(x^{(1)},x^{(2)})=\max(m(x^{(1)}),m(x^{(2)}))

再定义:

  • f(x)f(x):标准场景下,exm(x)e^{x−m(x)} 的结果
  • f(x(1))f(x^{(1)}):分块场景下,ex(1)m(x(1))e^{x^{(1)}−m(x^{(1)})} 的结果
  • f(x(2))f(x^{(2)}):分块场景下,ex(2)m(x(2))e^{x^{(2)}−m(x^{(2)})} 的结果

那么可以把幂次中的公共项提取出来,换成 m(x(1))m(x)m(x^{(1)})-m(x) 就可以得到 f(x)f(x)

f(x)=[em(x(1))m(x)f(x(1)),em(x(2))m(x)f(x(2))]f(x)=[e^{m(x^{(1)})-m(x)}f(x^{(1)}),e^{m(x^{(2)})-m(x)}f(x^{(2)})]

定义:

  • l(x)l(x):标准场景下,if(x)i\sum_i{f(x)_i} 的结果
  • l(x(1))l(x^{(1)}):标准场景下,if(x(1))i\sum_i{f(x^{(1)})_i} 的结果
  • l(x(2))l(x^{(2)}):标准场景下,if(x(2))i\sum_i{f(x^{(2)})_i} 的结果

类似地,有

l(x)=em(x(1))m(x)l(x(1))+em(x(2))m(x)l(x(2))l(x)=e^{m(x^{(1)})−m(x)}l(x^{(1)})+e^{m(x^{(2)})−m(x)}l(x^{(2)})

有了 f(x)f(x)l(x)l(x) 的值后,softmax 的计算结果就可以完全用 f(x)f(x) 替代:

safe softmax(x)=f(x)l(x)\text{safe softmax}(x)=\frac{f(x)}{l(x)}

看着比较复杂,其实就是一件事:用分块的最大值来暂时替代全局的最大值完成分块并行计算,如果分块的最大值不准确,那么就乘上全局的最大值与其的差来做弥补,同样可以获得与 safe softmax 一样的计算结果,而且计算变得可并行了。

结合下面的伪代码,我们来理解一下代码10-13行的计算步骤:

flash attention V1 算法流程

  • 第10行:mij~\tilde{m_{ij}} 就是当前分块 SijS_{ij} 的局部最大值,相当于前面定义的 m(x)m(x)。而 Pij~=exp(Sijmij~)\tilde{P_{ij}}=\text{exp}(S_{ij}-\tilde{m_{ij}}),相当于前面定义的 f(x)f(x)lij~\tilde{l_{ij}} 就是前面定义的 l(x)l(x)
  • 第11行:在循环中需要不断地更新计算 m(x)m(x)l(x)l(x) 全局最大值
  • 第12行:分块并行计算完成后,需要对计算得到的 Pij~\tilde{P_{ij}} 做修正,还需要对之前累加到 OiO_i 的数值做修正。这其实是一个增量计算的过程,每次需要根据新的“全局最大值”来更新旧的 Oij,Pij~O_{ij},\tilde{P_{ij}} 值。接着再处理下一个分块,然后再更新……

仔细观察内循环(伪代码 5-13 行),在整个计算过程中,只有 mim_i,lil_i,OiO_i 被写回到显存(HBM)中,相比于标准场景下,我们要写回到的是 S,P,OS,P,O ,读写量少了很多,缓解了 HBM 带宽限制。

因此,分块计算 safe softmax 的意义,就是抹去对 S,PS,P 的写回。

分块计算

softmax 分块计算的棘手问题解决后,我们来看一下 flash attention 分块计算的其他流程。为了最大限度地使用片上 SRAM,flash attention 对 QQK,VK,V 做了如下分块:

Bc=M4dB_c=\lceil\frac{M}{4d}\rceilK,VK,V 的列方向的分块大小,MM 是片上 SRAM 的大小,除以 4d4d 是因为需要同时存放 Q,K,V,OQ,K,V,O 四个矩阵,该计算确保了 SRAM 能同时存下所需的数据。Br=min(M4d,d)B_r=\min(\lceil\frac{M}{4d}\rceil, d)QQ 的行方向的分块大小,且不能超过 dd,这是为了保证分块后计算得到的 SijS_{ij} 可以被存放在 SRAM 上。

然后就是执行分块计算,如上面给出的 FlashAttention V1 示意图,中图就展示了分块计算的基本流程。Q 方向的分块是内循环,而 K 方向的分块是外循环。外循环开始时,会从 HBM 中 load 一块 Kj,VjK_j,V_j 到 SRAM 中,然后开始遍历 QiQ_i(这里假设是 prefill 阶段,Q sequence length >> 1),从 HBM load 一块 Qi,OiQ_i,O_i 后完成 Sij=QiKjTS_{ij}=Q_iK^T_j 的计算,随后 SijS_{ij} 根据我们之前的介绍流程处理后,得到了 PijP_{ij}OiO_i,内循环结束后,OiO_i 会被写回,再开始下一轮外循环。

让我们来计算一下 HBM 访存复杂度,每块 Kj,VjK_j,V_j 大小为 (Bc,d)(B_c, d)Qi,OiQ_i,O_i 大小为 (Br,d)(B_r, d),于是一个外循环加载 Kj,VjK_j,V_j 只需要 Θ(M)\Theta(M) 的数据量,而内循环全走一遍需要加载 Θ(M×NdM)\Theta(M\times \frac{Nd}{M}),一共有 NdM\frac{Nd}{M} 个外循环,因此整个 flash attention V1 的访存复杂度为 Θ(N2d2M)=Θ(Nd)\Theta(\frac{N^2d^2}{M})=\Theta(Nd),在 LLM 中, 通常有 d<<Nd << N,因此 flash attention V1 实现了更好的访存性能。

flash-attention V2

flash-attention V3

flash-attention V4

参考链接

[1] FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. https://arxiv.org/abs/2205.14135
[2] https://zhuanlan.zhihu.com/p/642962397
[3] 大模型优化-FlashAttention-v1 https://baoblei.github.io/2024/12/07/da-mo-xing-you-hua-flashattention-v1/

flash-attn 实现 prefill 阶段

初步认识好 MMA 指令后,我们来看一下 flash-attn 要怎么用 MMA 指令来实现。

问题分析

首先,需要明确 flash-attn 的输入输出,那么在很多大模型的 prefill 阶段,输入输出 Q K V O 张量的输入张量维度都是:[batch_size, num_heads, seq_len, head_dim],我们为方便起见,简写成 [b, h, sq, d]

另外,大多模型的 head_dim 维度都是固定为 128,因此我们优先考虑 d=128 的情况。

再来看计算过程。attention 的计算公式是:

attn=softmax(QKTd)V\text{attn} = \text{softmax}(\frac{QK^T}{\sqrt{d}})V

但这样的 QKV 张量实在是太大,不可能直接用一个 mma 就完成。对于四维张量,前面两维度是可以完全并行起来的,后面两维度才涉及到矩阵乘法。因此,可将前两个维度放到 GPU block 间做并行(即 grid size),反正他们完全可并行,不需要考虑这么多。

于是,我们得到了这样一张矩阵乘法图,首先计算 QK,得到矩阵 S,然后 S 与 V 相乘获得最后的结果 O(softmax 和 scale 等先略去)。

然而,这样的矩阵乘还是太大了,一个线程块都无法处理。我们需要进一步将矩阵切分,这里涉及到矩阵乘法时要如何切分的问题。对 M 和 N 方向切分还是对 K 方向切分?让我们回到实际问题,M 和 N 方向相当于是 seq_len 的长度,显然它是不确定的,它可能会非常长,也可能会非常短,而 K 方向的长度是一定的,我们只考虑 K=128 的情况,因此,我们选择在 M 和 N 方向切分,将不确定的长度切分成确定的大小,才有利于我们后续实现矩阵乘,否则 seq_len 要适配所有情况,难度太大。

那么我们现在一拍脑袋,先把 Seqlenq 和 Seqlenk 切分成等长 64 的吧(这些值的最有配置需要后续实现完成后 tune 一把,现在实现不考虑性能优劣):

1
const int nrow = 64;

切分完成后,我们终于可以让一个 CUDA 线程块来处理深色部分的区域了,得到的计算结果会填充到深绿色地方。

于是,让我们考虑一个线程块内的事情,由于 MMA 指令都是指挥一个 Warp 做事情,因此我们还需要考虑一个线程块内所有 Warp 要如何切分。

这里的关键在于 m16n8k16 的矩阵方块对应一个 Warp,需要用这个16行16列的颗粒度铺满整个矩阵 A,用16行8列颗粒度铺满矩阵 B,进而完成矩阵乘法计算。

在纸上画一下,使用 mma 分割方式只有一种,但用 Warp 分割的方式可以有很多种。比如说,我们假定一个线程块有 8 个 Warps,即 256 个线程,那么可以按下面任意方式将矩阵分成 8 份,每份由一个 Warp 来负责计算,每种 Warps 的切分实现的代码就会不一样。若再考虑到 block size 可以改变的话,实现的方案会更多,至于哪种性能最优,最简单的办法就是 tune 一遍看。

那么,我们就取图中第一种分割方法来实现 mma 吧。

线程与数据的映射

写 CUDA 马虎不得,尤其是处理哪个线程对应处理哪块数据时更加马虎不得。我们仔细地按照上图第一种方式做切分。调用函数情况如下:

1
2
3
dim3 gird(b*h, (sq+63)/64);
dim3 block(256);
flash_attn_fwd<<<grid, block>>>(Q, K, V, ...);

使用该图作参考

假设我们已经拿到了 QKV 三个 global tensor 指针:

1
2
3
half* Q;
half* K;
half* V;

那么首先,要确定哪个线程块对应哪块数据,然后把他们的部分装到 shared memory,否则就在一开始就 load 错数据了。

考虑矩阵 Q,注意它是四维的,[b, h, sq, d],所以寻址会非常复杂。我们之前说过,前两个维度是用 grid 切的 grid(b*h, (sq+nrow)/nrow)

所以图中,矩阵 A 的一页大矩阵就是一个 blockIdx.x ,它将矩阵 A 切分成了 b*h 份,因而一个 blockIdx.x 对应的大小是 sq*d,所以这块的偏移量是:blockIdx.x * sq * d,然后再考虑 blockIdx.y 它用于寻址页内的block id 数,因此它对应的大小是 64 * d,所以这块偏移量是 blockIdx.y * nrow * d,那么总偏移量就是:

1
2
// 对应一个深色块起始地址 Q
int Q_blk_gmem_offset = blockIdx.x * sq * d + blockIdx.y * nrow * d;

对于 K 而言,也是一样的道理:

1
2
// 对应一个深色块起始地址 K
int K_blk_gmem_offset = blockIdx.x * sq * d + blockIdx.y * nrow * d;

然后要确认线程块内每个 Warp 负责区域的偏移量,要明确每个线程属于哪个 warp:

1
2
3
4
5
int warpid = threadIdx.x % 32;
const int mmaDimx = 2;
const int mmaDimy = 2;
int Q_warp_smem_offset = warpid / 4 * mmaDimx * 16 * d;
int K_warp_smem_offset = warpid % 4 * mmaDimy * 8 * d;

flash attention 进化史
https://dingfen.github.io/2026/03/08/2026-3-7-flashattn/
作者
Bill Ding
发布于
2026年3月8日
更新于
2026年3月8日
许可协议