从 Bank 冲突原理到 Sw<3,4,3> vs Sw<3,3,3> 的完整指南
GPU 的 Shared Memory 被分成 32 个 Bank,每个 Bank 宽 4 字节,相邻地址交错分布:
Bank 编号公式:bank = (byte_addr / 4) % 32
冲突:一个 Warp(32个线程)中,多个线程访问同一个 Bank 的不同地址时,硬件被迫串行化这些访问。
✓ 无冲突 每个线程访问不同 Bank
✓ 广播 所有线程访问同一 Bank 的相同地址(广播,无冲突)
✗ N路冲突 N 个线程访问同一 Bank 的不同地址 → 性能降低 N 倍
以 8行 × 32列 的 float32 矩阵存入 Shared Memory 为例(共1024字节):
元素 (row, col) 的 Bank = (row×32 + col) % 32 = col(因为 32%32=0,row 无贡献)
读第 0 列:线程 t 读 (t, 0),Bank = t×32 % 32 = 0 — 全部落在 Bank 0!
这是 stride = 32 × sizeof(float32) = 128字节 = 32个 Bank 的整数倍导致的。这种 stride 直接让每一行的起始 Bank 完全相同,所有行同一列的元素挤进同一个 Bank。
读第 0 行:线程 t 读 (0, t),Bank = t — 线程 0→Bank 0,线程 1→Bank 1 ... 线程 31→Bank 31,完美无冲突。
但在 GEMM 中,我们通常需要同时在行和列方向高效访问矩阵,单靠行主序布局做不到两全其美。
Swizzle 的核心思想:将元素的行号的部分 bit 与列号的部分 bit 做 XOR,使每一行的起始 Bank 各不相同,从而消除冲突。
原始:bank(r, c) = c % 32
→ 每行 Bank 分配完全相同,列访问必然冲突
Swizzle 后:bank(r, c) = (c XOR r×k) % 32
→ 每行的 Bank 分配被整体偏移,所有行 Bank 各不同
对线性索引 x,Swizzle<B, M, S> 的变换:
其中 mask = (1 << B) - 1,提取 B 个 bit。
等价描述:把 x 的 高位段(从位置 M+S 起的 B 个 bit)XOR 进 x 的 低位段(从位置 M 起的 B 个 bit)。其余 bit 不变。
参与 XOR 的 bit 数,决定 Swizzle 的周期。
周期 = 2B 行,之后模式重复。
常见值:B=3(8行为一个周期,足够覆盖32个bank的差异)
被 XOR 改写的那段 bit 的起始位置,决定 Swizzle 的最小操作粒度 = 2M 字节。
M=4 → 16字节粒度(128-bit对齐,匹配 LDG.128 / TMA)
M=3 → 8字节粒度(64-bit对齐,匹配 LDG.64)
从"被改写段"到"XOR源段"之间的 bit 距离,关联到矩阵的行宽。
S 越大 → Swizzle 在行方向上的"步伐"越大,适合更宽的矩阵。
常见值:S=3,意味着 M+S 位置往上的3 bit 是行号。
| 数据类型 | 每行列数 K | 每行字节数 | 推荐 Swizzle | 原因 |
|---|---|---|---|---|
| float16 | 64 | 128B | Sw<3,3,3> | 64元素对应64bit粒度 |
| float16 | 128 | 256B | Sw<3,4,3> | 128bit向量化加载,TMA标配 |
| float32 | 32 | 128B | Sw<3,4,3> | 128B = 16字节粒度 |
| float32 | 16 | 64B | Sw<3,3,3> | 64B = 8字节粒度 |
| bfloat16 | 128 | 256B | Sw<3,4,3> | 同 float16 × 128 |
下图显示对应 Swizzle 参数下,每个元素实际映射到的 Bank 编号(颜色 = Bank,每行颜色各异表示无冲突):
Bit 操作:bit[4:7] XOR bit[7:10]
粒度:24 = 16 字节(128-bit)
适配场景:每次读取128bit(4个float32 或 8个float16),与 LDG.128 / cp.async.cg / TMA 对齐。
矩阵宽度要求:行宽 ≥ 2M+S × dtype = 27 × 2bytes = 256 bytes(float16 需 K≥128)
Bit 操作:bit[3:6] XOR bit[6:9]
粒度:23 = 8 字节(64-bit)
适配场景:每次读取64bit(2个float32 或 4个float16)。
矩阵宽度要求:行宽 ≥ 26 × 2bytes = 128 bytes(float16 需 K≥64)
矩阵:8行 × 64列 float16(每行128字节)。读取第0列时,各行访问的 Bank:
| 特性 | Sw<3,4,3> | Sw<3,3,3> |
|---|---|---|
| XOR 的 bit 位置 | bit[4:7] ← bit[7:10] | bit[3:6] ← bit[6:9] |
| 操作粒度 | 16 字节(128-bit) | 8 字节(64-bit) |
| Swizzle 生效的最小行宽 | 128 元素(float16)/ 64 元素(float32) | 64 元素(float16)/ 32 元素(float32) |
| H100 TMA descriptor | 标配 ✓ | 可用,但较少见 |
| A100 cp.async 128bit | 推荐 ✓ | 也可用 |
| WGMMA smem layout | 匹配 ✓ | 需验证 |
| 32bit 访问(LDG.32) | 不推荐 | 更合适 |
M 决定粒度,S 决定行宽匹配,B 决定周期深度(通常固定为3)。
H100 + TMA + float16/bfloat16 + K=128 → 首选 Sw<3,4,3>。
A100 + cp.async + float16 + K=64 → 首选 Sw<3,3,3>。
// 方式1:composition 直接组合 Swizzle 和 Layout using SwizzleAtom = Swizzle<3, 4, 3>; // 128列 float16 的 swizzled smem layout auto smem_layout = composition( SwizzleAtom{}, make_layout( make_shape (_128{}, _64{}), // 128行, 64列(K方向) make_stride( _64{}, _1{}) // 行主序 ) ); // 等价的简化写法(CuTe 内置的 swizzled layout helper) using SmemLayoutA = decltype(composition( Swizzle<3, 4, 3>{}, Layout<Shape<_128, _64>, Stride<_64, _1>>{} ));
// 声明 smem 指针 extern __shared__ half_t smem_buf[]; auto sA = make_tensor(make_smem_ptr(smem_buf), smem_layout); // 访问元素:CuTe 自动处理 swizzle 的地址计算 auto elem = sA(row, col); // 逻辑坐标,物理地址已 swizzled
// TMA descriptor 携带 swizzle 信息 auto tma_load_A = make_tma_copy( SM90_TMA_LOAD{}, tensor_A, // global memory tensor smem_layout_A.layout(), // swizzled smem layout make_shape(bM, bK), // tile shape Int<1>{} // multicast count ); // Kernel 内部:TMA load 自动 swizzle copy(tma_load_A, tma_coord, sA); // 触发异步 TMA,结果直接写入 swizzled smem
// CUTLASS 中常见的 smem layout 生成方式 using SmemLayoutAtomA = decltype( composition(Swizzle<3, 4, 3>{}, Layout<Shape <_8,_64>, Stride<_64, _1>>{}) ); // 然后用 tiled 方式铺开整个 smem tile using SmemLayoutA = decltype( tile_to_shape(SmemLayoutAtomA{}, Shape<Int<kM>, Int<kK>>{}) );
// 1. 定义 MMA operation using MmaOp = SM90_64x128x16_F32BF16BF16_SS<GMMA::Major::K>; using TiledMma = TiledMMA< MMA_Atom<MmaOp>, Layout<Shape<_2,_2,_1>>, // warpgroup tiling Tile<_128,_128,_64> // MN K tile >; // 2. 定义 Swizzled smem layout(必须与 WGMMA 期望的 layout 匹配) using SmemLayoutA = decltype(composition( Swizzle<3, 4, 3>{}, Layout<Shape<_128, _64>, Stride<_64, _1>>{} )); // 3. Kernel 内:TMA load → smem(自动 swizzle),smem → WGMMA __shared__ bf16 smem_A[128 * 64]; auto sA = make_tensor(make_smem_ptr(smem_A), SmemLayoutA{}); copy(tma_load_A, tAgA, sA); // async TMA load cute::cp_async_fence(); gemm(tiled_mma, tCrA, tCrB, tCrC); // WGMMA 读 smem,无 bank 冲突
// CuTe 提供的 layout print 工具 print_layout(smem_layout); // 打印 layout 的坐标映射关系 // 检查任意一行的 bank 分配: for (int c = 0; c < K; ++c) { int offset = smem_layout(0, c); // swizzled linear offset int bank = (offset * sizeof(T) / 4) % 32; printf("col=%d bank=%d\n", c, bank); } // 正确时:同一行中,每32个元素组内没有重复 bank
| 问题 | 原因 | 解决 |
|---|---|---|
| WGMMA 结果错误 | smem layout 的 swizzle 与 MMA atom 期望不匹配 | 使用 CUTLASS 推荐的 SmemLayoutAtom |
| TMA store 地址错误 | store 的 swizzle descriptor 与 load 不同 | store/load 用同一个 smem_layout |
| Sm80 用了 Sm90 的 swizzle | H100 的 Sw<3,4,3> 在 A100 cp.async 下 bank 分配错误 | 根据架构和向量宽度选择 M 值 |
| 矩阵太窄,swizzle 没效果 | 行宽 < 2M+S × dtype,高位 bit 始终为0 | 换用更小 S 的 swizzle,或 padding |