Tensor Core 和 MMA

Tensor Core 与 MMA

自 Volta 架构开始,nvidia 在显卡上装上了 Tensor Core 架构。该架构是为满足深度学习中所需的大量矩阵类运算需求而设计的硬件架构,专门提供高效小块的矩阵乘法:D=A×B+CD=A\times B+C

因为深度学习中高精度 float 运算不是必须的,所以 Tensor Core 还支持更低精度的计算,更低精度的计算意味着更高的计算效率,更少的能量消耗。

来自NVIDIA 的 Tensor Core 油管视频

Tensor Core 提供了两种使用方法。第一种是利用 nvidia 提供的矩阵计算库 cublas 和深度学习库 cudnn,它们封装了常用的矩阵类计算和深度学习计算所需要的函数,以 SDK 的形式提供封装的接口。这里可以参考 Tensor Core 的三种用法

第二种是通过 CUDA API 提供的特定的接口和 PTX 汇编实现。CUDA 编译器 nvcc 提供了 WMMA(Warp Matrix Multiply Accumulate)和 MMA(Matrix Multiply Accumulate)两种形式:

  • WMMA 是一种较高级别的编程接口,允许开发者以小块数据(fragment )为单位进行操作,如使用 load_matrix_sync()store_matrix_sync()mma_sync() 等 API 完成数据的搬运计算,使用 Tensor Core 完成数据计算。WMMA 形式对数据和API都进行了相应的抽象,编程相对简单,但对指令单控制也相对粗糙。
  • MMA 则是一种更加底层的实现,通过 PTX 汇编语言描述,其数据直接面向寄存器表示,计算则是通过 mma.sync 类的函数实现。MMA 形式的编程直接面向寄存器表示和汇编指令,难道较大,容易出错,但是可以实现精细的控制从而达到更高的计算效率。

那么今天这篇博客先介绍一下常见的 MMA 指令以及使用方法。

MMA 指令入门

先具体介绍一下 PTX 汇编实现的 MMA 指令。

ldmatrix

指令格式

1
2
3
4
5
6
7
8
ldmatrix.sync.aligned.shape.num{.trans}{.ss}.type r, [p];

.shape = {.m8n8, .m16n16};
.num = {.x1, .x2, .x4};
.ss = {.shared{::cta}};
.type = {.b16, .b8};
.dst_fmt = { .b8x16 };
.src_fmt = { .b6x16_p32, .b4x16_p64 };

ldmatrix 指令的使用格式例子: ldmatrix.sync.aligned.m8n8.x1.shared.b16 { %0 }, [ %1 ];
这条 PTX 指令掺杂了太多信息,让我们逐个分析:

  • ldmatrix:PTX 指令名字,说明该指令用于 load matrix from shared memory to register
  • sync 同步而非异步执行的,即 Warp 内所有的线程必须都完成后才能再往下执行
  • aligned 需要 Warp 内所有线程执行相同的 ldmatrix 指令
  • m8n8 意思是矩阵的数据维度,必须是 8 行 8 列
  • x1:加载的矩阵数量,表示一个;x2 和 x4 表示加载的矩阵数量为 2 和 4
  • shared:从 shared memory 中加载
  • b16:数据类型为 bfloat16,即 bf16
  • %0, %1 是占位符,分别对应输出和输入操作数。%0 是输出寄存器,%1 是输入的 shared memory 地址

注意:PTX 汇编本身不会直接使用 {},[] 用于表示内存地址访问间接寻址。它类似于 C/C++ 中的数组下标操作,用于从内存中加载数据或将数据存储到内存中

总结下来,可以这么说,该指令让一个 Warp(32个线程),从 Shared Memory 的 [p] 地址中加载 1 个 8*8(m8n8) 的矩阵(必须要求该矩阵的每行必须连续存放,但行间可以不连续存放),存放到了目标寄存器 %0 中。

行列数据排布

那我们来继续分析一下这条指令其他值得注意的地方,首先 ldmatrix 指令一次执行,可以将 m8n8 的 bf16 矩阵的 64 个数据元素从 shared memory 搬运到寄存器中,一共 128 个字节。

读取 shared memory 时,这 8 行 8 列中,仅要求行内的元素连续,而行间的位置可以不连续,所以需要用户确保 thread0-thread7 的 %1 寄存器填充的是8个行首地址,其他情况见下面表格。

.num Threads 0–7 Threads 8–15 Threads 16–23 Threads 24–31
.x1 addr0–addr7
.x2 addr0–addr7 addr8–addr15 -
.x4 addr0–addr7 addr8–addr15 addr16–addr23 addr24–addr31

也就是说,读取矩阵时,shared memory 的数据是从上面的 Addr0-Addr31 中读取的。以 x1 为例,线程 0-7 获得了 8 个行首地址的 shared memory 的值,执行后,线程 0-3 获得了线程 0 对应的行首地址的 shared memory 的值,线程 4-7 获得了线程 1 对应的行首地址的 shared memory 的值,以此类推。这里借用知乎大佬 reed 的图片,左边就是 shared memory 的数据,右边就是 ldmatrix 后读到寄存器的数据。

存到寄存器的数据排布其实就是上图所反映的那样。四个连续的线程会先加载(到寄存器)连续的一行,即一行的 8 个元素 16 个字节会在一次内读取完成。每个线程加载 2 个元素,共 4 个字节,线程 0 的 %0 会获得头两个元素,更具体的情况见下表:

而对于 ldmatrix.sync.aligned.m16n16.x1.shared.b8 {%0, %1, %2, %3}, [%4];,情况略有不同,因为这里会一次性读入 m16n16 的矩阵共 256 个元素,每个元素 8 bits,所以对应的每个 thread 需要读入 8 个元素供 8 个字节,需要 2 个寄存器来存放矩阵数据,%0 存放第一个 m8n8 的对应数值,%1 存放第二个,数据排布可参考下图:注意区别不一样的是,线程 0 的两个寄存对应的数据位置是分开的,分别在第 0 行和第 8 行。


初步认识好这条 PTX 指令后,又要如何在 CUDA C++ 中使用呢?我们可以使用 C++ 中 asm volatile 指令来将 PTX 指令嵌入到 CUDA C++ 中。

参考使用 PTX 指令模板:

1
2
3
4
__asm__ volatile ("汇编指令"
: "输出操作数"
: "输入操作数"
: "clobber 列表");

可以写出如下宏指令或 inline 函数来帮助我们使用 PTX 指令

1
2
3
4
5
6
7
8
9
#define LDMATRIX_M8N8_X1(R0, addr) asm volatile("ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];\n" : "=r"(R0) : "r"(addr))
#define LDMATRIX_M8N8_X4(R0, R1, R2, R3, addr) asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(R0), "=r"(R1), "=r"(R2), "=r"(R3) : "r"(addr))

static __device__ __forceinline__ void ldmatrix_m8n8_x1(uint32_t &r0, uint64_t addr) {
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0}, [%1];\n" : "=r"(r0) : "r"(addr));
}
static __device__ __forceinline__ void ldmatrix_m8n8_x4(uint32_t &r0, uint32_t &r1, uint32_t &r2, uint32_t &r3, uint64_t addr) {
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n" :"=r"(r0), "=r"(r1), "=r"(r2), "=r"(r3) : "r"(addr));
}

举个简单的小例子,比如我们要从 shared memory 中加载一个 16x16 的 bf16 矩阵到寄存器中:

已知使用 ldmatrix 指令一次可以加载 m8n8 大小的矩阵数据,故需要加载 4 个 m8n8 的矩阵,于是可直接使用 ldmatrix.sync.aligned.m8n8.x4.shared.b16 这个 PTX 指令。另外需要注意,我们必须给 Warp 内的每个线程都安排好对应的寄存器,否则就会出现错误。x4 意味着每个线程需要准备 4 个寄存器,加载 8 个 bf16 元素。对于线程 Ti,根据之前的分析,

于是,代码可以这么写:

1
2
3
4
5
6
const int lane_id = threadIdx.x % 32; // First, we get the current thread id in the warp
uint32_t r[4]; // Allocate 4 registers for each thread
uint32_t load_smem_ptr = __cvta_generic_to_shared( // convert generic address(pointer)
&src_matrix[lane_id % 16][(lane_id / 16) * 8] // to share space offset, often used
); // in PTX which requires shared space offset rather than pointer
LDMATRIX_M8N8_X4(r[0], r[1], r[2], r[3], load_smem_ptr);

mma

指令格式

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
mma.sync.aligned.m8n8k4.alayout.blayout.dtype.f16.f16.ctype  d, a, b, c;
mma.sync.aligned.m16n8k8.row.col.dtype.f16.f16.ctype d, a, b, c;
mma.sync.aligned.m16n8k16.row.col.dtype.f16.f16.ctype d, a, b, c;

.alayout = {.row, .col};
.blayout = {.row, .col};
.ctype = {.f16, .f32};
.dtype = {.f16, .f32};


mma.sync.aligned.m16n8k4.row.col.f32.tf32.tf32.f32 d, a, b, c;
mma.sync.aligned.m16n8k8.row.col.f32.atype.btype.f32 d, a, b, c;
mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 d, a, b, c;
mma.sync.aligned.shape.row.col.dtype.f8type.f8type.ctype d, a, b, c;
mma.sync.aligned.m16n8k32.row.col.kind.dtype.f8f6f4type.f8f6f4type.ctype d, a, b, c;

.atype = {.bf16, .tf32};
.btype = {.bf16, .tf32};
.f8type = {.e4m3, .e5m2};
.f8f6f4type = {.e4m3, .e5m2, .e3m2, .e2m3, .e2m1};
.ctype = {.f16, .f32};
.dtype = {.f16, .f32};
.shape = {.m16n8k16, .m16n8k32};
.kind = {.kind::f8f6f4};

mma 的指令格式要复杂得多,但功能都是一样的,实现 D=A×B+CD=A\times B+C,简单起见,我们主要看一下这条指令(下指 mma m16n8k16):

mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%8, %9};

mma m16n8k16 中,矩阵 A MxK,矩阵 B KxN,矩阵 C 和 D 都是 MxN,数据类型位 bf16。先来看矩阵 A 的 layout,它有 16 行 16 列共 256 个元素,分成了四个小的 fragments,让一个 warps 重复做相同的四份处理。warps 中,每个线程需要加载 8 个元素,因此需要 4 个寄存器,每个寄存器放 2 个元素:

其中 a0:a1 元素存放在 %2 寄存器内,元素编号 A[0][0] 和 A[0][1],a2:a3 元素存放在 %3 寄存器内,元素编号 A[8][0] 和 A[8][1],以此类推。

对于矩阵 B,它必须是列主序,16行8列的,矩阵 layout 是这样,同样,b0:b1 元素存放在 %6 寄存器内,元素编号 B[0][0] 和 B[1][0],b2:b3 元素存放在 %7 寄存器内,元素编号 B[8][0] 和 B[9][0]:

矩阵 C 和 D 的 layout 是一样的,16 行 8 列。矩阵 C 的 c0:c1 元素存放在 %8 寄存器内,元素编号 C[0][0] 和 C[0][1],c2:c3 元素存放在 %9 寄存器内,元素编号 C[8][0] 和 C[8][1]:


同样地,可以写出如下宏指令或 inline 函数来帮助我们使用 PTX 指令

1
2
3
4
5
6
7
8
9
#define HMMA16816(RD0, RD1, RA0, RA1, RA2, RA3, RB0, RB1, RC0, RC1) \
asm volatile("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 \
{%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%8, %9};\n" : \
"=r"(RD0), "=r"(RD1) : \
"r"(RA0), "r"(RA1), "r"(RA2), "r"(RA3), \
"r"(RB0), "r"(RB1), \
"r"(RC0), "r"(RC1))

HMMA16816(RC[0], RC[1], RA[0], RA[1], RA[2], RA[3], RB[0], RB[1], RC[0], RC[1]);

使用 MMA 完成矩阵乘法

学会了 mma 指令和 ldmatrix 指令后,我们就可以简单地用他们来练练手,写一个 m16n8k16 bf16 的矩阵乘法:

首先假设,我们已经从 global memory 中加载了矩阵 A 和 B,并放在了 shared memory 中。于是我们就只剩下将他们从 shared memory 中加载到寄存器中,再进行矩阵乘法了!

考虑到矩阵 A 的大小是 16x16 的 bf16 矩阵,于是和之前的 ldmatrix 例子一样,我们使用 ldmatrix.sync.aligned.m8n8.x4.shared.b16 将矩阵 A load 到寄存器中:

1
2
3
4
const int lane_id = threadIdx.x % 32; // First, we get the current thread id in the warp
uint32_t ra[4]; // Allocate 4 registers for each thread
uint32_t load_a_smem_ptr = __cvta_generic_to_shared( &asrc_matrix[lane_id % 16][(lane_id / 16) * 8] );
LDMATRIX_M8N8_X4(ra[0], ra[1], ra[2], ra[3], load_smem_ptr);

敏感的读者或许之前就注意到了,上面的例子中我们唯独没有解释为何 asrc_matrix 在转换为共享地址空间时要经过这么复杂索引变化:[lane_id % 16][(lane_id / 16) * 8]。其实,看懂了 MMA 的矩阵 A layout 后就不难明白,这是为了配合完成下面 MMA 计算。只有让线程 0-15 读取 shared memory 的前 16 列,线程 16-31 读取 shared memory 的后 16 列,才能让 ldmatrix 后的矩阵 A 中寄存器内的数值顺序正确:

那么对于矩阵 B 来说,情况也是类似的,但需要额外注意的是,矩阵 B 是 Trans 的,有 16 行 8 列,所以要用 ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16,而为了配合 MMA 的 B layout,也需要让线程 0-15 读取 shared memory。

1
2
3
4
5
6
7
#define LDMATRIX_X2_T(R0, R1, addr) asm volatile( \
"ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];\n": \
"=r"(R0), "=r"(R1) : "r"(addr) \
)
uint32_t rb[2];
uint32_t load_smem_b_ptr = __cvta_generic_to_shared(&s_b[lane_id % 16][0]); // convert generic address
LDMATRIX_X2_T(rb[0], rb[1], load_smem_b_ptr);

将所有的数据都加载到寄存器后,接着就是执行 MMA 乘法,并将计算出来的值从寄存器中取出:

1
2
3
4
HMMA16816(rc[0], rc[1], ra[0], ra[1], ra[2], ra[3], rb[0], rb[1], rc[0], rc[1]);

dst_c[lane_id / 4][(lane_id % 4) * 2] = rc[0];
dst_c[lane_id / 4 + 8][(lane_id % 4) * 2] = rc[1];

从上面的 C Layout 我们可以看到,线程 ti 的 rc[0] 寄存器存放的是前八行元素,rc[1] 寄存器是后半行,每行有四个线程,于是通过 lane_id / 4 定位行数,通过 (lane_id % 4) * 2 得到列数(×2 是因为一个寄存器有两个元素)。

如此一来,矩阵乘就完成了,下面给出完整的代码清单:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <mma.h>

#define LDST32BITS(value) (reinterpret_cast<half2*>(&(value))[0])
#define LDST64BITS(value) (reinterpret_cast<float2*>(&(value))[0])
#define LDST128BITS(value) (reinterpret_cast<float4*>(&(value))[0])
#define LDMATRIX_X1(R, addr) asm volatile("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n" : "=r"(R) : "r"(addr))
#define LDMATRIX_X2(R0, R1, addr) asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" : "=r"(R0), "=r"(R1) : "r"(addr))
#define LDMATRIX_X4(R0, R1, R2, R3, addr) asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(R0), "=r"(R1), "=r"(R2), "=r"(R3) : "r"(addr))
#define LDMATRIX_X1_T(R, addr) asm volatile("ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];\n" : "=r"(R) : "r"(addr))
#define LDMATRIX_X2_T(R0, R1, addr) asm volatile("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];\n" : "=r"(R0), "=r"(R1) : "r"(addr))
#define LDMATRIX_X4_T(R0, R1, R2, R3, addr) asm volatile("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(R0), "=r"(R1), "=r"(R2), "=r"(R3) : "r"(addr))
#define HMMA16816(RD0, RD1, RA0, RA1, RA2, RA3, RB0, RB1, RC0, RC1) asm volatile("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%8, %9};\n" : "=r"(RD0), "=r"(RD1) : "r"(RA0), "r"(RA1), "r"(RA2), "r"(RA3), "r"(RB0), "r"(RB1), "r"(RC0), "r"(RC1))

HOST_DEVICE_INLINE
int div_ceil(int a, int b) { return (a % b != 0) ? (a / b + 1) : (a / b); }

// only 1 warp per block(32 threads), m16n8k16. A, B, C: all row_major.
template<const int MMA_M=16, const int MMA_N=8, const int MMA_K=16>
__global__ void hgemm_mma_m16n8k16_naive_kernel(half* A, half* B, half* C,
int M, int N, int K) {
const int bx = blockIdx.x;
const int by = blockIdx.y;
const int NUM_K_TILES = div_ceil(K, MMA_K);
constexpr int BM = MMA_M; // 16
constexpr int BN = MMA_N; // 8
constexpr int BK = MMA_K; // 16

__shared__ half s_a[MMA_M][MMA_K]; // 16x16
__shared__ half s_b[MMA_K][MMA_N]; // 16x8
__shared__ half s_c[MMA_M][MMA_N]; // 16x8

const int tid = threadIdx.y * blockDim.x + threadIdx.x; // within block
const int lane_id = tid % WARP_SIZE; // 0~31

// s_a[16][16], 每行16,每线程load 8,需要2线程,共16行,需2x16=32线程
const int load_smem_a_m = tid / 2; // row 0~15
const int load_smem_a_k = (tid % 2) * 8; // col 0,8
// s_b[16][8], 每行8,每线程load 8,需要1线程,共16行,需16线程,只需一半线程加载
const int load_smem_b_k = tid; // row 0~31, but only use 0~15
const int load_smem_b_n = 0; // col 0
const int load_gmem_a_m = by * BM + load_smem_a_m; // global m
const int load_gmem_b_n = bx * BN + load_smem_b_n; // global n
if (load_gmem_a_m >= M && load_gmem_b_n >= N) return;

uint32_t RC[2] = {0, 0};

#pragma unroll
for (int k = 0; k < NUM_K_TILES; ++k) {
// gmem_a -> smem_a
int load_gmem_a_k = k * BK + load_smem_a_k; // global col of a
int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k;
LDST128BITS(s_a[load_smem_a_m][load_smem_a_k]) = (
LDST128BITS(A[load_gmem_a_addr]));

// gmem_b -> smem_b
if (lane_id < MMA_K) {
int load_gmem_b_k = k * MMA_K + load_smem_b_k; // global row of b
int load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n;
LDST128BITS(s_b[load_smem_b_k][load_smem_b_n]) = (
LDST128BITS(B[load_gmem_b_addr]));
}
__syncthreads();

uint32_t RA[4];
uint32_t RB[2];

// ldmatrix for s_a, ldmatrix.trans for s_b.
// s_a: (0,1)*8 -> 0,8 -> [(0~15),(0,8)]
uint32_t load_smem_a_ptr = __cvta_generic_to_shared(
&s_a[lane_id % 16][(lane_id / 16) * 8]);
LDMATRIX_X4(RA[0], RA[1], RA[2], RA[3], load_smem_a_ptr);
uint32_t load_smem_b_ptr = __cvta_generic_to_shared(
&s_b[lane_id % 16][0]);
LDMATRIX_X2_T(RB[0], RB[1], load_smem_b_ptr);

HMMA16816(RC[0], RC[1], RA[0], RA[1], RA[2], RA[3], RB[0], RB[1], RC[0], RC[1]);

__syncthreads();
}

// s_c[16][8], https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
// #matrix-fragments-for-mma-m16n8k16-with-floating-point-type
// [0~7][0~3 u32 -> 0~7 f16], [8~15][0~3 u32 -> 0~7 f16]
LDST32BITS(s_c[lane_id / 4 ][(lane_id % 4) * 2]) = LDST32BITS(RC[0]);
LDST32BITS(s_c[lane_id / 4 + 8][(lane_id % 4) * 2]) = LDST32BITS(RC[1]);

__syncthreads();

// store s_c[16][8]
if (lane_id < MMA_M) {
// store 128 bits per memory issue.
int store_gmem_c_m = by * BM + lane_id;
int store_gmem_c_n = bx * BN;
int store_gmem_c_addr = store_gmem_c_m * N + store_gmem_c_n;
LDST128BITS(C[store_gmem_c_addr]) = (LDST128BITS(s_c[lane_id][0]));
}
}

Tensor Core 和 MMA
https://dingfen.github.io/2025/03/09/2025-3-9-mma/
作者
Bill Ding
发布于
2025年3月9日
更新于
2026年3月8日
许可协议