flash attention 进化史
flash-attention V1
论文链接:FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.
从 HBM 到分块
提升 Transformer 的性能和能力,使之能够处理更长的序列一直是一个非常重要课题。而 Transformer 架构中的 self-attention 计算量和访存量会随着输入序列长度呈二次增长。为了提升 Flash Attention 在 GPU 上的运行性能,特别是优化 attention 计算过程中的访存性能,Tri-Dao 提出了 Flash Attention,旨在优化 self-attention 计算时数据访存行为,因为现代 GPU 上,计算本身已经高度并行化,十分高效,而访存则因为带宽和延时进步不够快而落后。但若将小块数据量放在更小更近的高速缓存中,那么计算和访存的速度就相匹配,能更快地完成 attention 计算(如下左图)。

Tips:标准 Attention 的计算公式如下:
前面提到,Flash Attention 特别注重算法的访存行为,因此将标准 attention 算法抽象为如下图的具体流程:

不难看出,上面的算法中一共包含 8 次 HBM 的读写操作,分别为:
- 第一行对 的读取共两次,对 的写入一次,读写总共三次,访存复杂度为
- 第二行对 读取一次,对 写入一次,读写总共两次;访存复杂度为
- 第三行对 的读取共两次,对 的写入一次,读写总共三次,访存复杂度为
标准 attention 算法的访存复杂度为 ,对于较长的 N(即输入序列)来说,GPU 的内存访问压力太大了。为了减少对 HBM 的读写,Flash Attention 将参与计算的矩阵分块送进 SRAM,来提高整体读写速度(减少了 HBM 读写)。
softmax 的处理
确定了通过分块来降低对 HBM 的操作次数的思想后,一个棘手的问题挡在了我们前面,那就是 softmax。
数学上,标准 softmax 公式如下:
而在程序实现中,为了防止 exp 值过大导致的数值溢出,通常需要对 softmax 做如下修改:将每一项的幂次都减去这里的最大值 ,因为计算是相除,所以结果不变,但避免了溢出。
于是得到了 safe softmax:
但规避了数值溢出问题,另一个问题也紧接而来,如果要求出 m(x),就意味着必须先获取 这一块数据的所有值,才能完成 softmax 计算,这与我们要做分块并行的初衷相违背。
好在 Tri-Dao 等人通过巧妙地公式变换,将这一限制规避开了,接下来我们来看具体实现过程:
假设我们通过 得到了矩阵 ,该矩阵的某一行的向量为 ,因为分块的原因, 它被我们切成了两部分
定义:
- :标准场景下,该行的全局最大值
- :分块1的全局最大值
- :分块2的全局最大值
那么易知:
再定义:
- :标准场景下, 的结果
- :分块场景下, 的结果
- :分块场景下, 的结果
那么可以把幂次中的公共项提取出来,换成 就可以得到
定义:
- :标准场景下, 的结果
- :标准场景下, 的结果
- :标准场景下, 的结果
类似地,有
有了 和 的值后,softmax 的计算结果就可以完全用 替代:
看着比较复杂,其实就是一件事:用分块的最大值来暂时替代全局的最大值完成分块并行计算,如果分块的最大值不准确,那么就乘上全局的最大值与其的差来做弥补,同样可以获得与 safe softmax 一样的计算结果,而且计算变得可并行了。
结合下面的伪代码,我们来理解一下代码10-13行的计算步骤:

- 第10行: 就是当前分块 的局部最大值,相当于前面定义的 。而 ,相当于前面定义的 , 就是前面定义的
- 第11行:在循环中需要不断地更新计算 和 全局最大值
- 第12行:分块并行计算完成后,需要对计算得到的 做修正,还需要对之前累加到 的数值做修正。这其实是一个增量计算的过程,每次需要根据新的“全局最大值”来更新旧的 值。接着再处理下一个分块,然后再更新……
仔细观察内循环(伪代码 5-13 行),在整个计算过程中,只有 ,, 被写回到显存(HBM)中,相比于标准场景下,我们要写回到的是 ,读写量少了很多,缓解了 HBM 带宽限制。
因此,分块计算 safe softmax 的意义,就是抹去对 的写回。
分块计算
softmax 分块计算的棘手问题解决后,我们来看一下 flash attention 分块计算的其他流程。为了最大限度地使用片上 SRAM,flash attention 对 和 做了如下分块:
是 的列方向的分块大小, 是片上 SRAM 的大小,除以 是因为需要同时存放 四个矩阵,该计算确保了 SRAM 能同时存下所需的数据。 是 的行方向的分块大小,且不能超过 ,这是为了保证分块后计算得到的 可以被存放在 SRAM 上。
然后就是执行分块计算,如上面给出的 FlashAttention V1 示意图,中图就展示了分块计算的基本流程。Q 方向的分块是内循环,而 K 方向的分块是外循环。外循环开始时,会从 HBM 中 load 一块 到 SRAM 中,然后开始遍历 (这里假设是 prefill 阶段,Q sequence length >> 1),从 HBM load 一块 后完成 的计算,随后 根据我们之前的介绍流程处理后,得到了 和 ,内循环结束后, 会被写回,再开始下一轮外循环。
让我们来计算一下 HBM 访存复杂度,每块 大小为 , 大小为 ,于是一个外循环加载 只需要 的数据量,而内循环全走一遍需要加载 ,一共有 个外循环,因此整个 flash attention V1 的访存复杂度为 ,在 LLM 中, 通常有 ,因此 flash attention V1 实现了更好的访存性能。
flash-attention V2
flash-attention V3
flash-attention V4
参考链接
[1] FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. https://arxiv.org/abs/2205.14135
[2] https://zhuanlan.zhihu.com/p/642962397
[3] 大模型优化-FlashAttention-v1 https://baoblei.github.io/2024/12/07/da-mo-xing-you-hua-flashattention-v1/
flash-attn 实现 prefill 阶段
初步认识好 MMA 指令后,我们来看一下 flash-attn 要怎么用 MMA 指令来实现。
问题分析
首先,需要明确 flash-attn 的输入输出,那么在很多大模型的 prefill 阶段,输入输出 Q K V O 张量的输入张量维度都是:[batch_size, num_heads, seq_len, head_dim],我们为方便起见,简写成 [b, h, sq, d]。
另外,大多模型的 head_dim 维度都是固定为 128,因此我们优先考虑 d=128 的情况。
再来看计算过程。attention 的计算公式是:
但这样的 QKV 张量实在是太大,不可能直接用一个 mma 就完成。对于四维张量,前面两维度是可以完全并行起来的,后面两维度才涉及到矩阵乘法。因此,可将前两个维度放到 GPU block 间做并行(即 grid size),反正他们完全可并行,不需要考虑这么多。
于是,我们得到了这样一张矩阵乘法图,首先计算 QK,得到矩阵 S,然后 S 与 V 相乘获得最后的结果 O(softmax 和 scale 等先略去)。

然而,这样的矩阵乘还是太大了,一个线程块都无法处理。我们需要进一步将矩阵切分,这里涉及到矩阵乘法时要如何切分的问题。对 M 和 N 方向切分还是对 K 方向切分?让我们回到实际问题,M 和 N 方向相当于是 seq_len 的长度,显然它是不确定的,它可能会非常长,也可能会非常短,而 K 方向的长度是一定的,我们只考虑 K=128 的情况,因此,我们选择在 M 和 N 方向切分,将不确定的长度切分成确定的大小,才有利于我们后续实现矩阵乘,否则 seq_len 要适配所有情况,难度太大。
那么我们现在一拍脑袋,先把 Seqlenq 和 Seqlenk 切分成等长 64 的吧(这些值的最有配置需要后续实现完成后 tune 一把,现在实现不考虑性能优劣):
1 | |
切分完成后,我们终于可以让一个 CUDA 线程块来处理深色部分的区域了,得到的计算结果会填充到深绿色地方。

于是,让我们考虑一个线程块内的事情,由于 MMA 指令都是指挥一个 Warp 做事情,因此我们还需要考虑一个线程块内所有 Warp 要如何切分。
这里的关键在于 m16n8k16 的矩阵方块对应一个 Warp,需要用这个16行16列的颗粒度铺满整个矩阵 A,用16行8列颗粒度铺满矩阵 B,进而完成矩阵乘法计算。
在纸上画一下,使用 mma 分割方式只有一种,但用 Warp 分割的方式可以有很多种。比如说,我们假定一个线程块有 8 个 Warps,即 256 个线程,那么可以按下面任意方式将矩阵分成 8 份,每份由一个 Warp 来负责计算,每种 Warps 的切分实现的代码就会不一样。若再考虑到 block size 可以改变的话,实现的方案会更多,至于哪种性能最优,最简单的办法就是 tune 一遍看。

那么,我们就取图中第一种分割方法来实现 mma 吧。
线程与数据的映射
写 CUDA 马虎不得,尤其是处理哪个线程对应处理哪块数据时更加马虎不得。我们仔细地按照上图第一种方式做切分。调用函数情况如下:
1 | |
假设我们已经拿到了 QKV 三个 global tensor 指针:
1 | |
那么首先,要确定哪个线程块对应哪块数据,然后把他们的部分装到 shared memory,否则就在一开始就 load 错数据了。
考虑矩阵 Q,注意它是四维的,[b, h, sq, d],所以寻址会非常复杂。我们之前说过,前两个维度是用 grid 切的 grid(b*h, (sq+nrow)/nrow)。
所以图中,矩阵 A 的一页大矩阵就是一个 blockIdx.x ,它将矩阵 A 切分成了 b*h 份,因而一个 blockIdx.x 对应的大小是 sq*d,所以这块的偏移量是:blockIdx.x * sq * d,然后再考虑 blockIdx.y 它用于寻址页内的block id 数,因此它对应的大小是 64 * d,所以这块偏移量是 blockIdx.y * nrow * d,那么总偏移量就是:
1 | |
对于 K 而言,也是一样的道理:
1 | |
然后要确认线程块内每个 Warp 负责区域的偏移量,要明确每个线程属于哪个 warp:
1 | |
