CuTe Layout 代数
CuTe(CUDA Templates)的核心是一套完整的 Layout 代数系统,用于在编译期描述任意维度的张量结构,消除手写偏移计算,并让编译器生成最优代码。
Layout 将多维坐标映射到线性内存偏移。Shape 描述各维度大小,Stride 描述各维度的步长。
静态(Int<N> / _N):值在类型中,零运行时开销,循环可完全展开。
动态(int):运行时值,灵活但有少量开销。可混用。
Shape/Stride 可以是任意嵌套的元组,天然表达 GPU 的层级结构:Thread → Warp → Block → Cluster。
核心 API 与基础示例
#include <cute/layout.hpp> #include <cute/tensor.hpp> using namespace cute; // ─── 1. 创建 Layout ─────────────────────────────────────────── // 行主序 4×8 矩阵:stride = (8, 1) auto layout_rm = make_layout(make_shape(_4{}, _8{}), make_stride(_8{}, _1{})); // 列主序 4×8 矩阵:stride = (1, 4) auto layout_cm = make_layout(make_shape(_4{}, _8{}), make_stride(_1{}, _4{})); // 自动推断 stride(行主序) auto layout_auto = make_layout(make_shape(_4{}, _8{})); // stride = (8, 1) // ─── 2. 坐标访问 ────────────────────────────────────────────── int offset = layout_rm(2, 3); // = 2×8 + 3×1 = 19 int offset2 = layout_rm(make_coord(2, 3)); // 等价 // ─── 3. 嵌套 Layout(表达 Warp/Thread 层级)───────────────── // Shape = (4, (2, 4)):第二维是 rank-2 子 layout auto layout_nested = make_layout( make_shape(_4{}, make_shape(_2{}, _4{})), // (4, (2, 4)) make_stride(_8{}, make_stride(_4{}, _1{})) // (8, (4, 1)) ); // 坐标 (i, (j0, j1)) → offset = i×8 + j0×4 + j1×1 int off = layout_nested(1, make_coord(0, 3)); // = 8 + 3 = 11 // ─── 4. Layout 变换 ─────────────────────────────────────────── auto result = tiled_divide(layout_rm, make_shape(_2{}, _4{})); // 原 (4,8) → 结果 ((2,4), (2,2)):tile内坐标, tile编号 // ─── 5. 打印调试 ────────────────────────────────────────────── print_layout(layout_rm); // 打印完整坐标映射表 print(layout_nested); // 打印结构描述
Layout 变换族
坐标系与 idx2crd / crd2idx
using namespace cute; // idx2crd:把线性索引按 Shape 展开为多维坐标(列主序) auto shape = make_shape(_3{}, make_shape(_2{}, _3{})); auto coord = idx2crd(5, shape); // (1, (1, 2)) — 见下方解析 // crd2idx:多维坐标 → 线性索引 int idx = crd2idx(make_coord(1, make_coord(1, 2)), shape); // = 5 // 解析:shape(3, (2,3)) 总 size=18 // slot0: idx%3 = 5%3 = 2 → 但 slot0 size=3, idx=5: coord0 = 5%3 = 2 // 实际 cute 对每个 slot 独立展开: // coord0 = 5 % 3 = 2;剩余 = 5 / 3 = 1 // 内层 (2,3): coord1_0 = 1 % 2 = 1,coord1_1 = (1/2) % 3 = 0 // → (2, (1, 0))
静态 vs 动态整数性能对比
// stride 在编译期已知 Layout<Shape<_4,_8>,Stride<_8,_1>> layout; // PTX: shl.b32 %r1, %r0, 3 (无乘法) // 循环可完全 unroll // 内存占用:0 字节(空结构体)
// stride 运行时才知道 auto layout = make_layout( make_shape(4, 8), // 动态 make_stride(8, 1) // 动态 ); // PTX: mul.lo.s32 %r1, %r0, %stride // 内存占用:每个动态值 4 字节
CuTe Tensor 抽象
Tensor = Engine(数据指针/存储后端)+ Layout(形状和步长)。它是 CuTe 所有计算的载体,能透明地表示 Global Memory、Shared Memory、寄存器等不同存储层。
指向 HBM/DRAM,通过 TMA 或 ldg 访问。Layout 描述全局矩阵的行列结构。
指向 SMEM,Layout 通常带有 Swizzle 以消除 Bank 冲突。用于 MMA 的 operand staging。
存储在寄存器文件中,通过 make_fragment_like 或 MMA 的 partition 创建。零地址开销。
Tensor 的创建与使用
#include <cute/tensor.hpp> using namespace cute; // ─── Global Memory Tensor ────────────────────────────────────── float* ptr_A = /*...*/; int M = 4096, K = 2048; // 动态 shape,行主序 auto gA = make_tensor(ptr_A, make_layout(make_shape(M, K), make_stride(K, 1))); // 静态 shape(编译期已知) auto gA_static = make_tensor(ptr_A, make_layout(make_shape(Int<4096>{}, Int<2048>{}))); // ─── Shared Memory Tensor ────────────────────────────────────── extern __shared__ half_t smem_buf[]; // 带 Swizzle 的 smem layout using SmemLayout = decltype(composition( Swizzle<3,4,3>{}, Layout<Shape<_128,_64>, Stride<_64,_1>>{} )); auto sA = make_tensor(make_smem_ptr(smem_buf), SmemLayout{}); // ─── Register Tensor ─────────────────────────────────────────── auto rC = make_tensor<float>(make_shape(_16{}, _8{})); // 寄存器分配 // 等价:auto rC = make_fragment_like<float>(shape); // ─── Tensor 切片与 local_tile ────────────────────────────────── // 从全局矩阵切出 block 负责的 tile auto gA_blk = local_tile(gA, make_shape(Int<128>{}, Int<64>{}), // tile shape make_coord(blockIdx.x, _0{})); // block 坐标,K 方向从 0 开始 // 返回 shape = (128, 64, K/64),第三维是 K-tile 迭代轴 // ─── 按线程/Warp 分区 ───────────────────────────────────────── // 用 TiledMMA 分区(决定每个 thread 负责哪些元素) TiledMma tiled_mma; auto [tCsA, tCsB, tCrC] = tiled_mma.partition_fragment_C(sA, sB);
数据流:Global → Shared → Register
make_smem_ptr 返回一个带有 shared memory 语义标记的包装指针,让 CuTe 知道这块内存需要通过 smem 地址空间访问,在生成 PTX 时使用 .shared 地址空间修饰符。
TiledMMA — 矩阵乘法抽象
TiledMMA 将单个 MMA atom(如 SM80_16x8x8_F32F16F16F32_TN)在 Warp/Thread 层级铺展,描述整个 Warpgroup 如何协作完成一个大 tile 的矩阵乘法。
硬件 MMA 指令的最小执行单元,例如:
SM80_16x8x8_F32F16F16F32_TN— A100SM90_64x128x16_F32BF16BF16_SS— H100 WGMMASM100_64x128x32_S32S8S8_SS— B200 UMMA
控制 Atom 在 M/N/K 方向如何重复铺展,决定每个 Warp 或 Warpgroup 覆盖的总矩阵大小。
控制每次 MMA 操作中每个线程拥有多少个 A/B/C 的分量(values)。
TiledMMA 的完整使用
#include <cute/atom/mma_atom.hpp> using namespace cute; // ─── 定义 TiledMMA(Ampere SM80 为例)────────────────────────── using MmaAtom = MMA_Atom<SM80_16x8x8_F32F16F16F32_TN>; using TiledMma = TiledMMA< MmaAtom, Layout<Shape<_2,_2,_1>>, // AtomLayoutMNK:Atom 在 MNK 方向的重复次数 Layout<Shape<_1,_2,_1>> // ValLayoutMNK:每个线程的 value 铺展 >; // 总 shape:M=32, N=16, K=8(每 warp 处理) // ─── partition:把 smem tensor 分配给当前 thread ─────────────── TiledMma tiled_mma; auto thr_mma = tiled_mma.get_slice(threadIdx.x); // 把 smem tile 切成当前 thread 负责的部分 auto tCsA = thr_mma.partition_A(sA); // shape: (MMA_M, MMA_K, num_k_tiles) auto tCsB = thr_mma.partition_B(sB); // shape: (MMA_N, MMA_K, num_k_tiles) // 分配累加寄存器 auto tCrC = thr_mma.partition_fragment_C(sA(_, _, 0)); // shape: (MMA_M, MMA_N) clear(tCrC); // 清零 // ─── 分配 A/B 寄存器缓存(从 smem 拷贝后存放)────────────────── auto tCrA = make_fragment_like(tCsA(_, _, 0)); // shape: (MMA_M, MMA_K) auto tCrB = make_fragment_like(tCsB(_, _, 0)); // shape: (MMA_N, MMA_K) // ─── 主循环:依次累加 K tiles ────────────────────────────────── for (int k = 0; k < size<2>(tCsA); ++k) { copy(tCsA(_, _, k), tCrA); // smem → reg(带 lds 指令) copy(tCsB(_, _, k), tCrB); gemm(tiled_mma, tCrA, tCrB, tCrC); // 执行 MMA 指令 }
常用 MMA Atom 速查
| Atom 名称 | 架构 | 形状 (M×N×K) | 数据类型 | 备注 |
|---|---|---|---|---|
SM80_16x8x8_F32F16F16F32_TN | A100 | 16×8×8 | F16→F32 | 标准 wmma |
SM80_16x8x16_F32F16F16F32_TN | A100 | 16×8×16 | F16→F32 | 更宽 K |
SM80_16x8x16_F16F16F16F16_TN | A100 | 16×8×16 | F16→F16 | 半精度累加 |
SM90_64x64x16_F32BF16BF16_SS | H100 | 64×64×16 | BF16→F32 | WGMMA smem |
SM90_64x128x16_F32BF16BF16_SS | H100 | 64×128×16 | BF16→F32 | WGMMA 宽N |
SM90_64x256x16_F32BF16BF16_SS | H100 | 64×256×16 | BF16→F32 | 超宽N |
SM90_64x128x32_S32S8S8_SS | H100 | 64×128×32 | INT8→INT32 | 量化推理 |
SM100_64x128x32_S32S8S8_SS | B200 | 64×128×32 | INT8→INT32 | UMMA |
SM100_64x256x16_F32BF16BF16_SS | B200 | 64×256×16 | BF16→F32 | UMMA 更大N |
TiledCopy — 数据搬运抽象
TiledCopy 将单个 Copy Atom(如 SM80_CP_ASYNC_CACHEGLOBAL)在线程层级铺展,描述一个 Warp/Block 如何高效地将数据从 Global Memory 搬到 Shared Memory(或反向)。
单个线程的最小复制操作。可以是普通 ld/st、cp.async、TMA 等。
通过 SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t> 实现 128-bit 向量化异步拷贝,填满内存带宽。
H100 的 TMA(Tensor Memory Accelerator):单线程触发,硬件自动完成整个 tile 的搬运,支持 Swizzle descriptor。
cp.async 异步拷贝(A100/H100 均支持)
using namespace cute; // ─── 定义 Copy Atom:128-bit 异步 Global→Shared ──────────────── using CopyAtom = Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, half_t>; using TiledCopyA = decltype(make_tiled_copy( CopyAtom{}, Layout<Shape<_32,_4>, Stride<_4,_1>>{}, // ThreadLayout: 32×4 线程 Layout<Shape<_1,_8>>{} // ValLayout: 每线程 8 个 half_t = 128 bit )); // 每次 copy:32×4=128 个线程,各搬 8 个 half = 32×32 个 half = 2KB/cycle TiledCopyA tiled_copy_a; auto thr_copy_a = tiled_copy_a.get_slice(threadIdx.x); // partition:把 global/smem tensor 分配给当前 thread auto tAgA = thr_copy_a.partition_S(gA); // Global source,shape: (CPY, CPY_M, CPY_K, k_tiles) auto tAsA = thr_copy_a.partition_D(sA); // Smem dest,shape: (CPY, CPY_M, CPY_K) // 执行异步拷贝(非阻塞) copy(tiled_copy_a, tAgA(_, _, _, k), tAsA); // 等待所有 cp.async 完成 cp_async_fence(); // 插入 fence cp_async_wait<0>(); // 等待所有未完成的 cp.async __syncthreads(); // 确保 smem 可见
TMA Copy(Hopper H100 专属,推荐)
// ─── Host 端:创建 TMA Descriptor ───────────────────────────── auto tma_load_A = make_tma_copy( SM90_TMA_LOAD{}, tensor_A, // global tensor SmemLayoutA{}, // smem layout(含 swizzle) make_shape(Int<128>{}, Int<64>{}), // tile shape Int<1>{} // multicast(默认 1) ); // tma_load_A 作为 kernel 参数传入 // ─── Device 端:触发 TMA Load ────────────────────────────────── __global__ void kernel(auto tma_load_A, auto tma_load_B) { // 只需一个线程(elected thread)触发 TMA uint64_t* mbar_ptr = /*...*/; // mbarrier 地址 if (elect_one_sync()) { // 计算 global 坐标(block tile 位置) auto blk_coord = make_coord(blockIdx.x, blockIdx.y, _0{}); // 触发 TMA:自动搬运整个 tile,支持 swizzle copy(tma_load_A, tma_coord_A, sA, mbar_ptr); } // 等待 mbarrier 完成 mbarrier_wait(mbar_ptr, phase); // 此后 sA 中数据已就绪 }
Swizzle — 消除 Bank 冲突
Swizzle 通过对内存地址的部分 bit 做 XOR,使不同行的相同列元素落在不同的 SRAM Bank,彻底消除 Shared Memory 的 Bank 冲突,是 GEMM 内核达到峰值带宽的必要条件。
Swizzle<B, M, S> 参数速查
| 参数 | 含义 | 影响 |
|---|---|---|
B (Bits) | 参与 XOR 的 bit 数 | Swizzle 周期 = 2B 行(通常 B=3,周期=8) |
M (Min) | XOR 目标 bit 的起始位 | 操作粒度 = 2M 字节(M=4→16B,M=3→8B) |
S (Shift) | 源到目标的 bit 距离 | 关联矩阵行宽,行宽须 ≥ 2M+S 字节 |
常用 Swizzle 配置
using namespace cute; // ─── Sw<3,4,3>:128-Byte 粒度,H100 TMA 标配 ──────────────────── // 适合:float16/bfloat16 矩阵,K≥128 列,LDG.128 / TMA using Sw343 = Swizzle<3,4,3>; using SmemLayoutA_343 = decltype(composition( Sw343{}, Layout<Shape<_128,_64>, Stride<_64,_1>>{} )); // ─── Sw<3,3,3>:64-Byte 粒度,A100 / 较窄矩阵 ────────────────── // 适合:float16,K=64 列,LDG.64 using Sw333 = Swizzle<3,3,3>; using SmemLayoutA_333 = decltype(composition( Sw333{}, Layout<Shape<_64,_32>, Stride<_32,_1>>{} )); // ─── tile_to_shape:将 SwizzleAtom 铺展到完整 tile ────────────── using SmemLayoutAtomA = SmemLayoutA_343; using SmemLayoutA = decltype(tile_to_shape( SmemLayoutAtomA{}, Shape<Int<128>, Int<128>>{} // 完整 smem tile 大小 )); // ─── 验证:检查第0列的 bank 分配 ───────────────────────────────── auto layout = SmemLayoutA{}; for (int r = 0; r < 8; ++r) { int offset = layout(r, 0); int bank = (offset * sizeof(half_t) / 4) % 32; printf("row=%d bank=%d\n", r, bank); // 各行应不同 }
Hopper SM90 — 架构总览
H100 (SM90) 是 NVIDIA 2022 年推出的数据中心 GPU,引入了 TMA、WGMMA、Thread Block Cluster 三大硬件创新,使 A100 到 H100 实现了 ~3× 的 GEMM 性能提升。
专用硬件单元,负责在 Global↔Shared Memory 间异步搬运整个 tensor tile,支持多维地址生成、Swizzle 和 Multicast,将地址计算完全卸载到硬件。
4个 Warp(128线程)作为 Warpgroup 协同执行 MMA。直接从 Shared Memory 读取 operand,无需先 load 到寄存器,大幅减少寄存器压力。
最多 8 个 Thread Block 组成 Cluster,共享同一 GPC 内的资源,可通过 Distributed Shared Memory (DSM) 访问相邻 Block 的 smem,实现 Block 间通信。
H100 关键规格
| 参数 | H100 SXM5 | 对比 A100 |
|---|---|---|
| SM 数量 | 132 | 108 |
| BF16 TFLOPS (稀疏) | 3958 | 624 |
| HBM3 带宽 | 3.35 TB/s | 2 TB/s |
| L2 Cache | 50 MB | 40 MB |
| Shared Memory / SM | 228 KB | 164 KB |
| Register File / SM | 65536 × 32bit | 65536 × 32bit |
| WGMMA | ✓ 支持 | ✗ 不支持 |
| TMA | ✓ 支持 | ✗ 不支持 |
| Thread Block Cluster | ✓ 支持 | ✗ 不支持 |
SM90 内部结构
每个 SM 包含 4 个 Warp Scheduler,128 KB L1 Cache(可配置为 Shared Memory),以及 4 个 Tensor Core 单元。WGMMA 将 4 个 Warp 绑定为一个 Warpgroup 协同使用 Tensor Core。
TMA — 张量内存加速器
TMA (Tensor Memory Accelerator) 是 H100 中最重要的硬件创新之一。它将 Global→Shared Memory 的数据搬运完全硬件化,支持最多 5 维张量的地址自动生成。
每个线程需要计算自己负责的地址,消耗大量 ALU 周期和寄存器。128线程每人算一次地址 → 大量重复计算。
一个线程触发 TMA descriptor,硬件自动完成整个 tile 的地址生成和搬运,其余线程可以继续计算(或等待 mbarrier)。
TMA Descriptor 的组成
完整 TMA 使用流程(Host + Device)
#include <cute/arch/copy_sm90_tma.hpp> #include <cute/tensor.hpp> using namespace cute; // ════════════ HOST SIDE ════════════════════════════════════════ void launch_kernel(half_t* A, half_t* B, float* C, int M, int N, int K) { // 1. 创建 global tensor auto tensor_A = make_tensor(A, make_layout(make_shape(M, K), make_stride(K, 1))); // 2. 定义 smem layout(带 Swizzle) using SmemLayoutA = decltype(composition( Swizzle<3,4,3>{}, Layout<Shape<_128,_64>,Stride<_64,_1>>{})); // 3. 创建 TMA descriptor(host 端) auto tma_load_A = make_tma_copy( SM90_TMA_LOAD{}, tensor_A, // global tensor SmemLayoutA{}, // smem layout(决定 descriptor 中的 swizzle) make_shape(Int<128>{}, Int<64>{}), // tile shape Int<1>{} // multicast factor ); // 4. 启动 kernel(TMA descriptor 作为参数) dim3 grid(N/128, M/128); dim3 block(128); // 1 Warpgroup gemm_kernel<<<grid, block>>>(tma_load_A, /*...*/); } // ════════════ DEVICE SIDE ══════════════════════════════════════ __global__ void gemm_kernel(auto tma_load_A, auto tma_load_B, float* C) { // 分配 smem(使用 CuTe 的 smem descriptor) extern __shared__ half_t smem[]; auto sA = make_tensor(make_smem_ptr(smem), SmemLayoutA{}); // ── mbarrier 初始化 ───────────────────────────────────── __shared__ uint64_t mbar; if (threadIdx.x == 0) { cute::mbarrier_init(&mbar, 1); // expect 1 TMA transaction } __syncthreads(); // ── 触发 TMA Load ─────────────────────────────────────── if (threadIdx.x == 0) { auto blk_coord = make_coord(blockIdx.x * 128, _0{}); cute::copy(tma_load_A.with(mbar), blk_coord, sA); } // ── 等待 mbarrier ─────────────────────────────────────── cute::mbarrier_wait(&mbar, 0 /*phase*/); // ── 此后 sA 数据就绪,开始 WGMMA ─────────────────────── // ... wgmma code ... }
TMA Multicast(Cluster 场景)
// Cluster 内多个 Block 共享同一份数据:TMA Multicast // 只有 1 个 Block 发起 TMA,其余 Block 直接通过 DSM 访问 auto tma_load_A_multicast = make_tma_copy( SM90_TMA_LOAD_MULTICAST{}, tensor_A, SmemLayoutA{}, make_shape(Int<128>{}, Int<64>{}), Int<2>{} // multicast to 2 blocks in cluster ); // Device 端:只有 Block 0 触发,两个 Block 都收到数据 if (threadIdx.x == 0 && block_rank_in_cluster() == 0) { cute::copy(tma_load_A_multicast.with(mbar, cluster_mask), blk_coord, sA); }
WGMMA — Warpgroup MMA
WGMMA (Warpgroup Matrix Multiply-Accumulate) 是 H100 的核心计算指令。与 A100 的 HMMA 相比,WGMMA 让 128 个线程(4个Warp)作为整体执行 MMA,A/B operand 直接从 Shared Memory 读取,无需先 load 到寄存器,大幅减少寄存器压力和 smem→reg 的数据移动开销。
寄存器压力大,smem→reg 带宽是瓶颈
A/B 直读 smem,C 在寄存器累加,效率最高
WGMMA 指令与同步语义
#include <cute/arch/mma_sm90.hpp> using namespace cute; // ─── 定义 WGMMA TiledMMA ─────────────────────────────────────── using TiledMma = TiledMMA< MMA_Atom<SM90_64x128x16_F32BF16BF16_SS<GMMA::Major::K>>>, Layout<Shape<_2,_2,_1>>, // 2×2 Atoms → 128×256×16 per WG Tile<_128, _256, _16> >; // ─── Device 端使用 ───────────────────────────────────────────── TiledMma tiled_mma; auto thr_mma = tiled_mma.get_slice(threadIdx.x); auto tCsA = thr_mma.partition_A(sA); // smem partition auto tCsB = thr_mma.partition_B(sB); auto tCrC = thr_mma.partition_fragment_C(rC_layout); // reg accumulator clear(tCrC); // ─── WGMMA 执行流程 ──────────────────────────────────────────── for (int k_tile = 0; k_tile < num_k_tiles; ++k_tile) { // 步骤1:fence(确保 smem 写入对 MMA 可见) cute::warpgroup_fence_operand(tCrC); // 步骤2:WGMMA 指令组(异步执行) cute::warpgroup_arrive(); gemm(tiled_mma, tCsA(_, _, k_tile), tCsB(_, _, k_tile), tCrC); cute::warpgroup_commit_batch(); // 步骤3:等待当前 batch 完成 cute::warpgroup_wait<0>(); // 0 = 等待所有未完成的 WGMMA cute::warpgroup_fence_operand(tCrC); } // ─── 写回 C 到 Global Memory ─────────────────────────────────── auto tCgC = thr_mma.partition_C(gC); copy(tCrC, tCgC); // reg → global(自动向量化)
WGMMA 寄存器布局(D matrix fragment)
wait<0> 等待全部,wait<1> 允许 1 个 batch 在飞(用于 Ping-Pong)。这是流水线优化的关键旋钮。
Ping-Pong 流水线
H100 GEMM 的最高性能来自 TMA 搬运与 WGMMA 计算的完全重叠。Ping-Pong 方案使用双缓冲(或多缓冲),使得"搬下一个 K-tile"和"计算当前 K-tile"同时进行。
Persistent Kernel + Ping-Pong
// 核心思想:2个 Stage 的 smem 双缓冲 // Stage 0:TMA load → sA[0], sB[0];WGMMA 计算 sA[1], sB[1] // Stage 1:TMA load → sA[1], sB[1];WGMMA 计算 sA[0], sB[0] // 两组操作完全重叠 constexpr int Stages = 2; // smem 双缓冲:每个 stage 各一份 __shared__ half_t smem_A[Stages][BLK_M * BLK_K]; __shared__ half_t smem_B[Stages][BLK_N * BLK_K]; __shared__ uint64_t mbar[Stages]; // 初始化 mbarrier for (int s = 0; s < Stages; ++s) { if (threadIdx.x == 0) cute::mbarrier_init(&mbar[s], 1); } __syncthreads(); // ── Prologue:预取第0个 K-tile ───────────────────────────── if (threadIdx.x == 0) { copy(tma_A.with(mbar[0]), coord_A_k0, sA[0]); copy(tma_B.with(mbar[0]), coord_B_k0, sB[0]); } // ── Main Loop ────────────────────────────────────────────── int phase = 0; for (int k = 0; k < num_k_tiles; ++k) { int cur = k % Stages; int nxt = (k + 1) % Stages; // A:等待当前 stage 的 TMA 完成 cute::mbarrier_wait(&mbar[cur], phase); // B:触发下一个 stage 的 TMA(与 WGMMA 重叠) if (k + 1 < num_k_tiles && threadIdx.x == 0) { cute::mbarrier_arrive_expect_tx(&mbar[nxt], 2 * tx_bytes); copy(tma_A.with(mbar[nxt]), coord_A(k+1), sA[nxt]); copy(tma_B.with(mbar[nxt]), coord_B(k+1), sB[nxt]); } // C:执行 WGMMA(计算当前 stage,与下一个 TMA 重叠) warpgroup_fence_operand(tCrC); warpgroup_arrive(); gemm(tiled_mma, sA[cur], sB[cur], tCrC); warpgroup_commit_batch(); warpgroup_wait<0>(); warpgroup_fence_operand(tCrC); if (cur == Stages - 1) phase ^= 1; // mbarrier phase flip } // ── Epilogue:写回结果 ───────────────────────────────────── copy(tCrC, tCgC);
三级流水线(Stage=3,最高利用率)
// Stage=3:允许 warpgroup_wait<1>,WGMMA pipeline depth=2 // TMA 可以领先 WGMMA 2个 tile,完全隐藏 smem load latency constexpr int Stages = 3; // Prologue:预取前 Stages-1 个 K-tile for (int s = 0; s < min(Stages-1, num_k_tiles); ++s) { if (threadIdx.x == 0) { cute::mbarrier_arrive_expect_tx(&mbar[s], tx_bytes); copy(tma_A.with(mbar[s]), coord_A(s), sA[s]); copy(tma_B.with(mbar[s]), coord_B(s), sB[s]); } } for (int k = 0; k < num_k_tiles; ++k) { int cur = k % Stages; int nxt = (k + Stages - 1) % Stages; cute::mbarrier_wait(&mbar[cur], (k / Stages) & 1); if (k + Stages - 1 < num_k_tiles && threadIdx.x == 0) { int prefetch_k = k + Stages - 1; cute::mbarrier_arrive_expect_tx(&mbar[nxt], tx_bytes); copy(tma_A.with(mbar[nxt]), coord_A(prefetch_k), sA[nxt]); copy(tma_B.with(mbar[nxt]), coord_B(prefetch_k), sB[nxt]); } warpgroup_fence_operand(tCrC); warpgroup_arrive(); gemm(tiled_mma, sA[cur], sB[cur], tCrC); warpgroup_commit_batch(); warpgroup_wait<1>(); // 允许1个在飞,进一步重叠 warpgroup_fence_operand(tCrC); } warpgroup_wait<0>(); // Epilogue:等待最后的 WGMMA
完整 GEMM 实现 (SM90)
将 TMA、WGMMA、Swizzle、Ping-Pong 流水线组合成一个完整的、可编译的 BF16 GEMM kernel。这是 CUTLASS 3.x 的简化版本。
#include <cute/tensor.hpp> #include <cute/arch/mma_sm90.hpp> #include <cute/arch/copy_sm90_tma.hpp> #include <cutlass/arch/barrier.h> using namespace cute; // ════════════ 编译期配置 ═══════════════════════════════════════ using ElementA = cutlass::bfloat16_t; using ElementB = cutlass::bfloat16_t; using ElementC = float; constexpr int BLK_M = 128, BLK_N = 128, BLK_K = 64; constexpr int Stages = 3; // ── Smem Layout(Swizzle 消除 bank 冲突)──────────────────────── using SmemLayoutAtomA = decltype(composition( Swizzle<3,4,3>{}, Layout<Shape<_8, _64>, Stride<_64,_1>>{})); using SmemLayoutA = decltype(tile_to_shape( SmemLayoutAtomA{}, Shape<Int<BLK_M>, Int<BLK_K>, Int<Stages>>{})); using SmemLayoutAtomB = SmemLayoutAtomA; using SmemLayoutB = decltype(tile_to_shape( SmemLayoutAtomB{}, Shape<Int<BLK_N>, Int<BLK_K>, Int<Stages>>{})); // ── TiledMMA (WGMMA 64x128x16 BF16) ──────────────────────────── using TiledMma = TiledMMA< MMA_Atom<SM90_64x128x16_F32BF16BF16_SS<GMMA::Major::K>>>, Layout<Shape<_2,_2,_1>>>; // 每个 WG 处理:128×128×16(2×2 WGMMA atoms) // ════════════ KERNEL ════════════════════════════════════════════ __global__ void __launch_bounds__(128, 1) hopper_gemm_kernel( auto const tma_load_A, // TMA descriptor for A auto const tma_load_B, // TMA descriptor for B ElementC* C_ptr, int M, int N, int K) { // ── Smem allocation ───────────────────────────────────── extern __shared__ char smem_buf[]; ElementA* smem_A = reinterpret_cast<ElementA*>(smem_buf); ElementB* smem_B = smem_A + BLK_M * BLK_K * Stages; auto sA = make_tensor(make_smem_ptr(smem_A), SmemLayoutA{}); auto sB = make_tensor(make_smem_ptr(smem_B), SmemLayoutB{}); // ── mbarrier 初始化 ───────────────────────────────────── __shared__ uint64_t mbar[Stages]; if (threadIdx.x == 0) for (int s = 0; s < Stages; ++s) cute::mbarrier_init(&mbar[s], 1); __syncthreads(); // ── TiledMMA partition ────────────────────────────────── TiledMma tiled_mma; auto thr_mma = tiled_mma.get_slice(threadIdx.x); auto tCsA = thr_mma.partition_A(sA(_, _, _0{})); auto tCsB = thr_mma.partition_B(sB(_, _, _0{})); auto tCrC = thr_mma.partition_fragment_C( make_layout(Shape<Int<BLK_M>, Int<BLK_N>>{})); clear(tCrC); // ── 计算本 block 的 M/N tile 起始坐标 ────────────────── int m0 = blockIdx.x * BLK_M; int n0 = blockIdx.y * BLK_N; int num_k = K / BLK_K; // ── Prologue:预取前 Stages-1 个 K-tile ──────────────── for (int s = 0; s < Stages-1 && s < num_k; ++s) { if (threadIdx.x == 0) { cute::mbarrier_arrive_expect_tx(&mbar[s], BLK_M*BLK_K*sizeof(ElementA) + BLK_N*BLK_K*sizeof(ElementB)); copy(tma_load_A.with(mbar[s]), make_coord(m0, s*BLK_K), sA(_, _, s)); copy(tma_load_B.with(mbar[s]), make_coord(n0, s*BLK_K), sB(_, _, s)); } } // ── Main Loop ─────────────────────────────────────────── int phase = 0; for (int k = 0; k < num_k; ++k) { int cur = k % Stages; int nxt = (k + Stages - 1) % Stages; // 等待当前 stage TMA 完成 cute::mbarrier_wait(&mbar[cur], (k / Stages) & 1); // 触发下一个 TMA(与 WGMMA 重叠) int prefetch = k + Stages - 1; if (prefetch < num_k && threadIdx.x == 0) { cute::mbarrier_arrive_expect_tx(&mbar[nxt], BLK_M*BLK_K*sizeof(ElementA) + BLK_N*BLK_K*sizeof(ElementB)); copy(tma_load_A.with(mbar[nxt]), make_coord(m0, prefetch*BLK_K), sA(_, _, nxt)); copy(tma_load_B.with(mbar[nxt]), make_coord(n0, prefetch*BLK_K), sB(_, _, nxt)); } // WGMMA 执行 warpgroup_fence_operand(tCrC); warpgroup_arrive(); gemm(tiled_mma, tCsA.with_smem(sA(_, _, cur)), tCsB.with_smem(sB(_, _, cur)), tCrC); warpgroup_commit_batch(); warpgroup_wait<1>(); warpgroup_fence_operand(tCrC); } warpgroup_wait<0>(); // ── Epilogue:写回 C ───────────────────────────────────── auto gC = make_tensor(C_ptr + m0*N + n0, make_layout(Shape<Int<BLK_M>, Int<BLK_N>>{}, Stride<Int<N>, _1>{})); auto tCgC = thr_mma.partition_C(gC); copy(tCrC, tCgC); }
Blackwell SM100 — 架构总览
B200 (SM100) 是 NVIDIA 2024 年发布的最新架构,在 Hopper 基础上引入 UMMA、UTCCP、第五代 NVLink,将 FP8 GEMM 性能推向新高度。
取代 WGMMA,支持更大的 MMA tile(最大 256×256),引入新的 Tensor descriptor 直接描述 smem 访问模式,同时支持 FP4/FP8/FP16/BF16/FP32。
Blackwell 的 TMA 升级版,支持更高带宽、更大的 multicast 范围,以及新的 TMA store 模式用于 epilogue 写回。
原生 FP4/FP6 支持,动态量化无需 Python 包装,Microscaling (MX) 格式支持更细粒度的量化粒度。
B200 vs H100 性能对比
| 参数 | B200 SXM | H100 SXM5 | 提升 |
|---|---|---|---|
| SM 数量 | 170 | 132 | ~1.3× |
| FP8 TFLOPS(稠密) | 9000 | 3958 | ~2.3× |
| FP4 TFLOPS(稠密) | 18000 | 不支持 | ∞ |
| HBM3e 带宽 | 8 TB/s | 3.35 TB/s | ~2.4× |
| L2 Cache | 100 MB | 50 MB | 2× |
| Shared Memory / SM | 256 KB | 228 KB | ~1.1× |
| NVLink | 第五代 1.8TB/s | 第四代 900GB/s | 2× |
| UMMA | ✓ | ✗(用WGMMA) | — |
| FP4 数据类型 | ✓ | ✗ | — |
Blackwell 计算层级
Blackwell 引入了新的 Block Scaling 概念——在 MMA 之前对 A/B 块进行缩放,让 FP4 量化精度更高。每个 MMA tile 有独立的 scaling factor,存储在额外的 smem 缓冲区中。
UMMA — Unified MMA
UMMA (Unified Matrix Multiply-Accumulate) 是 Blackwell 对 WGMMA 的全面升级,支持更大 tile、FP4 原生计算和 Block Scaling,是 B200 达到 18 PFLOPS FP4 性能的核心。
| 特性 | WGMMA (H100) | UMMA (B200) |
|---|---|---|
| 最大 N | 256 | 256+ |
| FP4 | ✗ | ✓ |
| Block scaling | ✗ | ✓ |
| A 来源 | smem/reg | smem(descriptor) |
E2M1 格式:2位指数 + 1位尾数 + 1位符号。每个元素只有 4bit,需要 Block Scaling factor 来还原精度。适合 LLM 推理的 weight-only 量化。
UMMA 使用(CuTe SM100)
#include <cute/arch/mma_sm100.hpp> using namespace cute; // ─── SM100 UMMA:BF16→FP32 ───────────────────────────────────── using MmaAtom_BF16 = MMA_Atom<SM100_64x128x16_F32BF16BF16_SS>; using TiledMma_BF16 = TiledMMA<MmaAtom_BF16, Layout<Shape<_2,_2,_1>>>; // ─── SM100 UMMA:FP8→FP32(更高性能)────────────────────────── using MmaAtom_FP8 = MMA_Atom<SM100_64x128x32_F32E4M3E4M3_SS>; using TiledMma_FP8 = TiledMMA<MmaAtom_FP8, Layout<Shape<_2,_2,_1>>>; // ─── SM100 UMMA:FP4→FP32(峰值性能)────────────────────────── using MmaAtom_FP4 = MMA_Atom<SM100_64x128x64_F32E2M1E2M1_SS>; // ─── Block Scaling(FP4 场景)────────────────────────────────── // 每 16 个 FP4 元素共享一个 FP32 scaling factor // Scale factor 存储在额外的 smem buffer 中 __shared__ float smem_scale_A[BLK_M / 16]; // M 方向每16行一个 scale __shared__ float smem_scale_B[BLK_N / 16]; // N 方向每16列一个 scale // ─── UTCCP:加载 FP4 数据 + scale ────────────────────────────── auto tma_load_A_fp4 = make_tma_copy( SM100_TMA_LOAD{}, tensor_A_fp4, // FP4 global tensor(4bit/元素) SmemLayoutA_fp4{}, make_shape(Int<128>{}, Int<128>{}), Int<1>{}); auto tma_load_ScaleA = make_tma_copy( SM100_TMA_LOAD{}, tensor_ScaleA, // FP32 scale tensor SmemLayoutScale{}, make_shape(Int<8>{}, Int<4>{}), Int<1>{}); // ─── UMMA Kernel 核心循环 ────────────────────────────────────── for (int k = 0; k < num_k; ++k) { // 等待 FP4 数据 + scale 都就绪 mbarrier_wait(&mbar[k % Stages], phase); // UMMA with block scaling warpgroup_fence_operand(tCrC); warpgroup_arrive(); // SM100 UMMA 自动读取 scale 并应用 gemm_with_scale(tiled_mma_fp4, sA_fp4(_, _, k%Stages), smem_scale_A, sB_fp4(_, _, k%Stages), smem_scale_B, tCrC); warpgroup_commit_batch(); warpgroup_wait<1>(); warpgroup_fence_operand(tCrC); }
Cluster & NVLink 通信
Blackwell 将 Thread Block Cluster 规模扩展,并通过第五代 NVLink 实现 GPU 间的高带宽低延迟通信,是超大 LLM 推理和训练的基础。
H100: 最大 8 Block/Cluster
B200: 最大 16 Block/Cluster(扩展模式)
更大 Cluster 意味着 TMA Multicast 可以一次广播给更多 Block,减少重复的 HBM 读取。
Cluster 内的 Block 可以直接读写其他 Block 的 Shared Memory,实现 Block 间高速数据交换,无需经过 HBM,带宽远高于 Global Memory。
单 GPU 双向带宽 1.8 TB/s,支持最多 576 GPU 互联(NVLink Switch 5)。是 DGX B200 实现超线性 LLM 扩展的基础。
Cluster 编程示例
// ─── 启动 Kernel with Cluster ────────────────────────────────── cudaLaunchConfig_t config; config.gridDim = {M/BLK_M, N/BLK_N, 1}; config.blockDim = {128, 1, 1}; cudaLaunchAttribute attrs[1]; attrs[0].id = cudaLaunchAttributeClusterDimension; attrs[0].val.clusterDim = {2, 1, 1}; // 2 blocks per cluster config.attrs = attrs; config.numAttrs = 1; cudaLaunchKernelEx(&config, cluster_gemm_kernel, /*args*/); // ─── Kernel 内:Cluster 基本操作 ─────────────────────────────── __cluster_dims__(2, 1, 1) __global__ void cluster_gemm_kernel(/*...*/) { // 获取当前 Block 在 Cluster 内的 rank int block_rank = cute::block_rank_in_cluster(); int cluster_size = cute::cluster_size(); // TMA Multicast:Block 0 发起,Cluster 内所有 Block 收到数据 uint16_t cluster_mask = (1 << cluster_size) - 1; // 所有 Block if (threadIdx.x == 0 && block_rank == 0) { cute::copy(tma_load_A_multicast.with(mbar, cluster_mask), blk_coord, sA); } // DSM:访问 Cluster 内其他 Block 的 smem half_t* remote_smem = cute::cluster_local_smem_ptr( smem_ptr, 1 - block_rank); // 访问另一个 Block 的 smem auto remote_tensor = make_tensor( make_smem_ptr(remote_smem), SmemLayoutB{}); // barrier synchronize across cluster cute::cluster_arrive(); cute::cluster_wait(); }
Blackwell GEMM 实现要点
B200 的最优 GEMM 策略与 H100 有所不同,主要差异在于 UMMA 的使用、FP4 block scaling 的处理,以及更大 smem 的利用。
#include <cute/arch/mma_sm100.hpp> #include <cute/arch/copy_sm100.hpp> using namespace cute; // ── SM100 FP8 GEMM 配置 ───────────────────────────────────────── using ElementA = cutlass::float_e4m3_t; // FP8 E4M3 using ElementB = cutlass::float_e4m3_t; using ElementC = float; // BLK_K=128(FP8 的 K 比 BF16 大一倍,吞吐翻倍) constexpr int BLK_M=128, BLK_N=128, BLK_K=128, Stages=4; using TiledMma = TiledMMA< MMA_Atom<SM100_64x128x32_F32E4M3E4M3_SS>, Layout<Shape<_2,_2,_1>>>; // Smem layout:FP8 用 Sw<3,4,3>(每个FP8=1B,128元素/行=128B) using SmemLayoutA = decltype(tile_to_shape( composition(Swizzle<3,4,3>{}, Layout<Shape<_8,_128>,Stride<_128,_1>>{}), Shape<Int<BLK_M>, Int<BLK_K>, Int<Stages>>{})); __global__ void __launch_bounds__(128) sm100_fp8_gemm(auto tma_A, auto tma_B, ElementC* C, int M, int N, int K) { extern __shared__ char smem[]; auto sA = make_tensor(make_smem_ptr((ElementA*)smem), SmemLayoutA{}); auto sB = make_tensor(make_smem_ptr( (ElementB*)(smem + BLK_M*BLK_K*Stages)), SmemLayoutB{}); TiledMma tiled_mma; auto thr_mma = tiled_mma.get_slice(threadIdx.x); auto tCrC = thr_mma.partition_fragment_C( Layout<Shape<Int<BLK_M>,Int<BLK_N>>>{}); clear(tCrC); __shared__ uint64_t mbar[Stages]; if (threadIdx.x == 0) for (int s=0; s<Stages; ++s) mbarrier_init(&mbar[s], 1); __syncthreads(); int m0=blockIdx.x*BLK_M, n0=blockIdx.y*BLK_N, num_k=K/BLK_K; constexpr int tx = BLK_M*BLK_K + BLK_N*BLK_K; // bytes per stage // Prologue for (int s=0; s<min(Stages-1,num_k); ++s) { if (!threadIdx.x) { mbarrier_arrive_expect_tx(&mbar[s], tx); copy(tma_A.with(mbar[s]), make_coord(m0,s*BLK_K), sA(_,_,s)); copy(tma_B.with(mbar[s]), make_coord(n0,s*BLK_K), sB(_,_,s)); } } for (int k=0; k<num_k; ++k) { int cur=k%Stages, nxt=(k+Stages-1)%Stages; mbarrier_wait(&mbar[cur], (k/Stages)&1); if (k+Stages-1<num_k && !threadIdx.x) { int pk=k+Stages-1; mbarrier_arrive_expect_tx(&mbar[nxt], tx); copy(tma_A.with(mbar[nxt]), make_coord(m0,pk*BLK_K), sA(_,_,nxt)); copy(tma_B.with(mbar[nxt]), make_coord(n0,pk*BLK_K), sB(_,_,nxt)); } warpgroup_fence_operand(tCrC); warpgroup_arrive(); gemm(tiled_mma, sA(_,_,cur), sB(_,_,cur), tCrC); warpgroup_commit_batch(); warpgroup_wait<1>(); warpgroup_fence_operand(tCrC); } warpgroup_wait<0>(); auto gC = make_tensor(C+m0*N+n0, Layout<Shape<Int<BLK_M>,Int<BLK_N>>,Stride<Int<N>,_1>>{}); copy(tCrC, thr_mma.partition_C(gC)); }
Occupancy & 寄存器优化
GPU 的计算效率取决于 SM 上同时驻留的 Warp 数量(Occupancy)。过多的寄存器使用会降低 Occupancy,而 WGMMA 的大 accumulator fragment 是主要压力来源。
寄存器预算计算
每 SM:65536 个 32-bit 寄存器
每 Warpgroup(128线程):理论最多 512 个/线程
实际可用:受 block 数量限制
64×128×F32 的 C fragment:
每线程需要 64 个 F32 寄存器
加上 A/B 缓存 → 轻松超过 128 个寄存器/线程
// ─── 限制寄存器数量(强制提高 Occupancy)────────────────────── __global__ void __launch_bounds__( 128, // maxThreadsPerBlock 1 // minBlocksPerMultiprocessor(SM 上至少 1 个 block) ) gemm_kernel(/*...*/) { /*...*/ } // ─── 用 nvcc 标志限制寄存器 ─────────────────────────────────── // nvcc --maxrregcount=128 gemm.cu // 或 per-kernel: // #pragma nv_diag_suppress 177 // 忽略警告 // ─── 检查寄存器使用 ─────────────────────────────────────────── // nvcc -Xptxas -v gemm.cu // 输出:ptxas info: Used N registers, M bytes smem // ─── WGMMA 寄存器估算工具 ───────────────────────────────────── constexpr int MMA_M = 64, MMA_N = 128; // WGMMA tile constexpr int num_atoms_m = 2, num_atoms_n = 2; // tiling // C fragment:每线程负责 (MMA_M * num_atoms_m) / 128 行 × (MMA_N * num_atoms_n) 列 // = (64×2)/128 × (128×2) = 1 × 256 个 F32 // = 256 个寄存器(太多了!需要减小 tile)
Occupancy 计算表
| 寄存器/线程 | 线程数/Block | Blocks/SM | Warps/SM | Occupancy |
|---|---|---|---|---|
| 64 | 128 | 8 | 32 | 100% |
| 96 | 128 | 5 | 20 | 62.5% |
| 128 | 128 | 4 | 16 | 50% |
| 192 | 128 | 2 | 8 | 25% |
| 256 | 128 | 2 | 8 | 25% |
内存层次结构
GPU 内存层次从 HBM 到 L2 Cache、L1/Shared Memory、寄存器,各层的带宽和延迟差异达到数量级,理解并充分利用这一层次是高性能 kernel 的基础。
关键带宽与延迟数据
| 存储层次 | 带宽 (H100) | 延迟 | 容量 | 访问单位 |
|---|---|---|---|---|
| Register File | ~80 TB/s/SM | 0 cycle | 256 KB/SM | 32-bit |
| Shared Memory | ~33 TB/s/SM | ~23 cycle | 228 KB/SM | 128 byte |
| L1 Cache | ~33 TB/s/SM | ~33 cycle | 128 KB/SM | 128 byte |
| L2 Cache | ~12 TB/s | ~200 cycle | 50 MB | 128 byte |
| HBM3 (DRAM) | 3.35 TB/s | ~700 cycle | 80 GB | 32 byte burst |
内存访问优化清单
- 使用 128-bit 向量化加载(float4 / uint4)
- 确保全局内存访问合并(coalesced)
- smem 使用 Swizzle 消除 bank 冲突
- TMA 替代手写 cp.async
- 预取(prefetch)隐藏内存延迟
- L2 持久化:
cudaStreamAttrValue
- 非对齐的全局内存访问
- Shared Memory bank 冲突
- stride = 32 × sizeof(T) 的 smem 访问
- 在热路径上使用原子操作
- 小 tile 导致 HBM 访问量增大
- L1 thrashing(工作集超过 L1)
Nsight 分析方法
使用 NVIDIA Nsight Compute (ncu) 和 Nsight Systems (nsys) 定量分析 kernel 瓶颈,指导优化方向。
Nsight Compute 关键指标
# ── 基础 profiling ────────────────────────────────────────────── ncu --set full -o profile ./gemm_kernel # ── 只收集关键指标(更快)────────────────────────────────────── ncu --metrics \ sm__throughput.avg.pct_of_peak_sustained_elapsed, # SM 利用率 l1tex__throughput.avg.pct_of_peak_sustained_elapsed, # L1 带宽利用率 lts__throughput.avg.pct_of_peak_sustained_elapsed, # L2 带宽利用率 gpu__dram_throughput.avg.pct_of_peak_sustained_elapsed, # HBM 利用率 sm__pipe_tensor_op_hmma_cycles_active.avg.pct_of_peak_sustained_elapsed, # Tensor Core 利用率 sm__warps_active.avg.pct_of_peak_sustained_elapsed # Warp 占用率 -o profile_fast ./gemm_kernel # ── GEMM 效率分析 ─────────────────────────────────────────────── ncu --metrics \ sm__pipe_tensor_op_hmma_cycles_active, # HMMA 活跃周期 smsp__inst_executed_op_hmma.sum, # HMMA 指令数 l1tex__data_bank_conflicts_pipe_lsu.sum # Smem bank 冲突数 ./gemm_kernel # ── 系统级时间线 ──────────────────────────────────────────────── nsys profile --trace=cuda,nvtx -o timeline ./gemm_benchmark
关键指标解读
| 指标 | 理想值 | 低说明什么 | 优化方向 |
|---|---|---|---|
| Tensor Core 利用率 | >80% | WGMMA 没有饱和 | 增大 tile,减少 epilogue 开销 |
| HBM 带宽利用率 | <60%(Compute bound) | 如果 >80% 则内存受限 | 增大 tile,减少 HBM 访问 |
| L2 带宽利用率 | 适中 | — | — |
| Bank 冲突次数 | 0 | >0 则有 bank 冲突 | 检查 Swizzle 配置 |
| SM 利用率 | >95% | 有 idle SM | 检查 grid size,增加 wave |
性能模型:Roofline
# GEMM Roofline 分析 M, N, K = 4096, 4096, 4096 # FLOP:GEMM = 2MNK flops = 2 * M * N * K # = 137.4 GFLOP # Bytes:读 A (MK) + B (KN) + 写 C (MN) # BF16 GEMM:A/B = 2B/elem, C = 4B/elem bytes_io = (M*K + K*N)*2 + M*N*4 # = 101.2 GB # Arithmetic Intensity (AI) AI = flops / bytes_io # = 1.36 FLOP/Byte # H100 Ridge Point(BF16 TFLOPS / HBM 带宽) ridge_point = 989e12 / 3.35e12 # ≈ 295 FLOP/Byte # AI(1.36) << Ridge Point(295) → GEMM 是 Compute Bound ✓ # 理论峰值性能受 Tensor Core 限制,而非带宽 print(f"AI = {AI:.2f} FLOP/Byte") print(f"Ridge Point = {ridge_point:.0f} FLOP/Byte") print(f"Bound: {'Compute' if AI > ridge_point else 'Memory'}")
① 先确认是 Compute Bound 还是 Memory Bound(Roofline 分析)
② Compute Bound → 最大化 Tensor Core 利用率,优化 smem layout(Swizzle),增大 tile
③ Memory Bound → 减少 HBM 访问(融合算子、增大 tile),提高向量化宽度
④ 最后才考虑 Occupancy(WGMMA 场景下 25-50% 是正常的)
快速调试清单
□ 用小 M/N/K 做数值正确性验证
□ ncu 确认 bank 冲突为 0
□ Tensor Core 利用率 > 75%
□ TMA descriptor swizzle 与 smem layout 匹配
□ warpgroup_wait 语义正确(无 race)
□ mbarrier phase 翻转逻辑正确
□ mbarrier expect_tx 字节数算错
□ TMA coord 坐标系搞混(行/列顺序)
□ warpgroup_fence 位置不对导致乱序
□ Swizzle 与 MMA atom 的 K-major/M-major 不匹配
□ Epilogue 时 tCrC 的 layout 理解错误
□ Cluster barrier 使用不当导致死锁