使用 MMA 的 flash-attn 实现

Tensor Core 与 MMA

自 Volta 架构开始,nvidia 在显卡上装上了 Tensor Core 架构。该架构是为满足深度学习中所需的大量矩阵类运算需求而设计的硬件架构,专门提供高效小块的矩阵乘法:D=A×B+CD=A\times B+C

因为深度学习中高精度 float 运算不是必须的,所以 Tensor Core 还支持更低精度的计算,更低精度的计算意味着更高的计算效率,更少的能量消耗。

来自NVIDIA 的 Tensor Core 油管视频

Tensor Core 提供了两种使用方法。第一种是利用 nvidia 提供的矩阵计算库 cublas 和深度学习库 cudnn,它们封装了常用的矩阵类计算和深度学习计算所需要的函数,以 SDK 的形式提供封装的接口。这里可以参考 Tensor Core 的三种用法

第二种是通过CUDA编程提供的特定的接口和PTX汇编实现。对于第二种形式,CUDA编译器 nvcc 提供了 WMMA(Warp Matrix Multiply Accumulate)和 MMA(Matrix Multiply Accumulate)两种形式,第一种形式是通过提供fragment数据表示和特定的 load_matrix_sync()store_matrix_sync()mma_sync() 数据加载存储和计算API来触发对 Tensor Core的编程。另一种是通过 PTX 汇编实现,其数据直接面向寄存器表示,计算则是通过 mma.sync 类的函数实现。WMMA 形式对数据和API都进行了相应的抽象,编程相对简单,但对指令单控制也相对粗糙。MMA 形式的编程直接面向寄存器表示和汇编指令,难道较大,容易出错,但是可以实现精细的控制从而达到更高的计算效率。

MMA 指令入门

先具体介绍一下 PTX 汇编实现的 MMA 指令。毕竟目前 LLM 大火导致大家对 attention 计算速度和实现灵活性有了更加极致追求,用其他的方式实现的 attention 已经不能令全球 geek 们满意了。

ldmatrix

指令格式:

1
2
3
4
5
6
7
8
ldmatrix.sync.aligned.shape.num{.trans}{.ss}.type r, [p];

.shape = {.m8n8, .m16n16};
.num = {.x1, .x2, .x4};
.ss = {.shared{::cta}};
.type = {.b16, .b8};
.dst_fmt = { .b8x16 };
.src_fmt = { .b6x16_p32, .b4x16_p64 };

ldmatrix指令的使用格式例子: ldmatrix.sync.aligned.m8n8.x1.shared.b16 { %0 }, [ %1 ];
这条 PTX 指令掺杂了太多信息,让我们逐个分析:

  • ldmatrix:PTX 指令名字,说明该指令用于 load matrix
  • sync 同步而非异步执行的,即 Warp 内所有的线程必须都完成后才能再往下执行
  • aligned 需要 Warp 内所有线程执行相同的 ldmatrix 指令
  • m8n8 意思是矩阵的数据维度,必须是 8 行 8 列
  • x1:加载的矩阵数量,表示一个
  • shared:从 shared memory 中加载
  • b16:数据类型为 bfloat16,即bf16

总结下来,可以这么说,该指令让一个 Warp(32个线程),从 Shared Memory 的 [p] 地址中加载 1 个 8*8(m8n8) 的矩阵(必须要求该矩阵的每行必须连续存放,但行间可以不连续存放),存放到了目标寄存器 %0 中。

那我们来继续分析一下这条指令的其他细节,首先 m8n8 的 bf16 矩阵有 64 个数据元素,每个元素占两个字节,一共 128 个字节,每个线程获得 2 个元素。

因为行间的位置可以不连续,所以需要用户确保 thread0-thread7 的 %1 寄存器填充的是8个行首地址,其他情况见下面表格。

.num Threads 0–7 Threads 8–15 Threads 16–23 Threads 24–31
.x1 addr0–addr7
.x2 addr0–addr7 addr8–addr15 -
.x4 addr0–addr7 addr8–addr15 addr16–addr23 addr24–addr31

读取矩阵时,四个连续的线程会先加载连续的一行,即 8 个元素 16 个字节。线程 0 的 %0 会获得头两个元素,其他线程情况见下表:

Row\Col 0 1 2 3 4 5 6 7
0 T0:r T0:r T1:r T1:r T2:r T2:r T3:r T3:r
1 T4:r T4:r T5:r T5:r T6:r T6:r T7:r T7:r
2 T8:r T8:r T9:r T9:r T10:r T10:r T11:r T11:r
3 T12:r T12:r T13:r T13:r T14:r T14:r T15:r T15:r
4 T16:r T16:r T17:r T17:r T18:r T18:r T19:r T19:r
5 T20:r T20:r T21:r T21:r T22:r T22:r T23:r T23:r
6 T24:r T24:r T25:r T25:r T26:r T26:r T27:r T27:r
7 T28:r T28:r T29:r T29:r T30:r T30:r T31:r T31:r

而对于 ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];,情况略有不同,因为这里会一次性读入 4 个 m8n8 的矩阵,所以对应的每个 thread 也需要 4 个寄存器来存放矩阵数据,%0 存放第一个 m8n8 的对应数值,%1 存放第二个,以此类推。

mma

指令格式:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
mma.sync.aligned.m8n8k4.alayout.blayout.dtype.f16.f16.ctype  d, a, b, c;
mma.sync.aligned.m16n8k8.row.col.dtype.f16.f16.ctype d, a, b, c;
mma.sync.aligned.m16n8k16.row.col.dtype.f16.f16.ctype d, a, b, c;

.alayout = {.row, .col};
.blayout = {.row, .col};
.ctype = {.f16, .f32};
.dtype = {.f16, .f32};


mma.sync.aligned.m16n8k4.row.col.f32.tf32.tf32.f32 d, a, b, c;
mma.sync.aligned.m16n8k8.row.col.f32.atype.btype.f32 d, a, b, c;
mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 d, a, b, c;
mma.sync.aligned.shape.row.col.dtype.f8type.f8type.ctype d, a, b, c;
mma.sync.aligned.m16n8k32.row.col.kind.dtype.f8f6f4type.f8f6f4type.ctype d, a, b, c;

.atype = {.bf16, .tf32};
.btype = {.bf16, .tf32};
.f8type = {.e4m3, .e5m2};
.f8f6f4type = {.e4m3, .e5m2, .e3m2, .e2m3, .e2m1};
.ctype = {.f16, .f32};
.dtype = {.f16, .f32};
.shape = {.m16n8k16, .m16n8k32};
.kind = {.kind::f8f6f4};

mma 的指令格式要复杂得多,但功能都是一样的,实现 D=A×B+CD=A\times B+C,结合 flash-attn 的要求,我们主要看一下这条指令:

mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%8, %9};

mma m16n8k16 中,矩阵 A MxK,矩阵 B KxN,矩阵 C 和 D 都是 MxN,数据类型位 bf16。先来看矩阵 A 的 layout,它有16行16列,每个寄存器放2个元素:

其中 a0:a1 元素存放在 %2 寄存器内,a2:a3 元素存放在 %3 寄存器内,以此类推。

对于矩阵 B,它必须是列主序,16行8列的,矩阵 layout 是这样:

官方的图看上去有点抽象,我用 excel 画了一张图,一个方块代表一个数据元素。深色部分是一个 Warps 读取一次的部分,用白色加粗的字体写明了每个线程加载的方块元素。

Warps 会分四次把矩阵 A load 完成,与此同时还会把矩阵 B 分两次 load 完成,最后完成矩阵计算。

flash-attn 实现 prefill 阶段

初步认识好 MMA 指令后,我们来看一下 flash-attn 要怎么用 MMA 指令来实现。

问题分析

首先,需要明确 flash-attn 的输入输出,那么在很多大模型的 prefill 阶段,输入输出 Q K V O 张量的输入张量维度都是:[batch_size, num_heads, seq_len, head_dim],我们为方便起见,简写成 [b, h, sq, d]

另外,大多模型的 head_dim 维度都是固定为 128,因此我们优先考虑 d=128 的情况。

再来看计算过程。attention 的计算公式是:

attn=softmax(QKTd)V\text{attn} = \text{softmax}(\frac{QK^T}{\sqrt{d}})V

但这样的 QKV 张量实在是太大,不可能直接用一个 mma 就完成。对于四维张量,前面两维度是可以完全并行起来的,后面两维度才涉及到矩阵乘法。因此,可将前两个维度放到 GPU block 间做并行(即 grid size),反正他们完全可并行,不需要考虑这么多。

于是,我们得到了这样一张矩阵乘法图,首先计算 QK,得到矩阵 S,然后 S 与 V 相乘获得最后的结果 O(softmax 和 scale 等先略去)。

然而,这样的矩阵乘还是太大了,一个线程块都无法处理。我们需要进一步将矩阵切分,这里涉及到矩阵乘法时要如何切分的问题。对 M 和 N 方向切分还是对 K 方向切分?让我们回到实际问题,M 和 N 方向相当于是 seq_len 的长度,显然它是不确定的,它可能会非常长,也可能会非常短,而 K 方向的长度是一定的,我们只考虑 K=128 的情况,因此,我们选择在 M 和 N 方向切分,将不确定的长度切分成确定的大小,才有利于我们后续实现矩阵乘,否则 seq_len 要适配所有情况,难度太大。

那么我们现在一拍脑袋,先把 Seqlenq 和 Seqlenk 切分成等长 64 的吧(这些值的最有配置需要后续实现完成后 tune 一把,现在实现不考虑性能优劣):

1
const int nrow = 64;

切分完成后,我们终于可以让一个 CUDA 线程块来处理深色部分的区域了,得到的计算结果会填充到深绿色地方。

于是,让我们考虑一个线程块内的事情,由于 MMA 指令都是指挥一个 Warp 做事情,因此我们还需要考虑一个线程块内所有 Warp 要如何切分。

这里的关键在于 m16n8k16 的矩阵方块对应一个 Warp,需要用这个16行16列的颗粒度铺满整个矩阵 A,用16行8列颗粒度铺满矩阵 B,进而完成矩阵乘法计算。

在纸上画一下,使用 mma 分割方式只有一种,但用 Warp 分割的方式可以有很多种。比如说,我们假定一个线程块有 8 个 Warps,即 256 个线程,那么可以按下面任意方式将矩阵分成 8 份,每份由一个 Warp 来负责计算,每种 Warps 的切分实现的代码就会不一样。若再考虑到 block size 可以改变的话,实现的方案会更多,至于哪种性能最优,最简单的办法就是 tune 一遍看。

那么,我们就取图中第一种分割方法来实现 mma 吧。

线程与数据的映射

写 CUDA 马虎不得,尤其是处理哪个线程对应处理哪块数据时更加马虎不得。我们仔细地按照上图第一种方式做切分。调用函数情况如下:

1
2
3
dim3 gird(b*h, (sq+63)/64);
dim3 block(256);
flash_attn_fwd<<<grid, block>>>(Q, K, V, ...);

使用该图作参考

假设我们已经拿到了 QKV 三个 global tensor 指针:

1
2
3
half* Q;
half* K;
half* V;

那么首先,要确定哪个线程块对应哪块数据,然后把他们的部分装到 shared memory,否则就在一开始就 load 错数据了。

考虑矩阵 Q,注意它是四维的,[b, h, sq, d],所以寻址会非常复杂。我们之前说过,前两个维度是用 grid 切的 grid(b*h, (sq+nrow)/nrow)

所以图中,矩阵 A 的一页大矩阵就是一个 blockIdx.x ,它将矩阵 A 切分成了 b*h 份,因而一个 blockIdx.x 对应的大小是 sq*d,所以这块的偏移量是:blockIdx.x * sq * d,然后再考虑 blockIdx.y 它用于寻址页内的block id 数,因此它对应的大小是 64 * d,所以这块偏移量是 blockIdx.y * nrow * d,那么总偏移量就是:

1
2
// 对应一个深色块起始地址 Q
int Q_blk_gmem_offset = blockIdx.x * sq * d + blockIdx.y * nrow * d;

对于 K 而言,也是一样的道理:

1
2
// 对应一个深色块起始地址 K
int K_blk_gmem_offset = blockIdx.x * sq * d + blockIdx.y * nrow * d;

然后要确认线程块内每个 Warp 负责区域的偏移量,要明确每个线程属于哪个 warp:

1
2
3
4
5
int warpid = threadIdx.x % 32;
const int mmaDimx = 2;
const int mmaDimy = 2;
int Q_warp_smem_offset = warpid / 4 * mmaDimx * 16 * d;
int K_warp_smem_offset = warpid % 4 * mmaDimy * 8 * d;

使用 MMA 的 flash-attn 实现
https://dingfen.github.io/2025/03/09/2025-3-9-mma/
作者
Bill Ding
发布于
2025年3月9日
更新于
2025年3月9日
许可协议