CuTe Swizzle 学习手册

从 Bank 冲突原理到 Sw<3,4,3> vs Sw<3,3,3> 的完整指南

什么是 Shared Memory Bank 冲突?

Shared Memory 的 Bank 结构

GPU 的 Shared Memory 被分成 32 个 Bank,每个 Bank 宽 4 字节,相邻地址交错分布:

Bank 0: 字节 0- 3, 128-131, 256-259 ...
Bank 1: 字节 4- 7, 132-135, 260-263 ...
Bank 2: 字节 8-11, 136-139, 264-267 ...
···
Bank 31: 字节 124-127, 252-255, 380-383 ...

Bank 编号公式:bank = (byte_addr / 4) % 32

什么情况下产生冲突?

冲突:一个 Warp(32个线程)中,多个线程访问同一个 Bank不同地址时,硬件被迫串行化这些访问。

✓ 无冲突 每个线程访问不同 Bank

✓ 广播 所有线程访问同一 Bank 的相同地址(广播,无冲突)

✗ N路冲突 N 个线程访问同一 Bank 的不同地址 → 性能降低 N 倍

冲突的典型场景:矩阵列访问

8行 × 32列 的 float32 矩阵存入 Shared Memory 为例(共1024字节):

元素 (row, col) 的 Bank = (row×32 + col) % 32 = col(因为 32%32=0,row 无贡献)

访问模式:

为什么列访问必然冲突?

读第 0 列:线程 t 读 (t, 0),Bank = t×32 % 32 = 0全部落在 Bank 0!

这是 stride = 32 × sizeof(float32) = 128字节 = 32个 Bank 的整数倍导致的。这种 stride 直接让每一行的起始 Bank 完全相同,所有行同一列的元素挤进同一个 Bank。

行访问为什么没问题?

读第 0 行:线程 t 读 (0, t),Bank = t — 线程 0→Bank 0,线程 1→Bank 1 ... 线程 31→Bank 31,完美无冲突。

但在 GEMM 中,我们通常需要同时在行和列方向高效访问矩阵,单靠行主序布局做不到两全其美。

Swizzle:用 XOR 打散 Bank 分配

Swizzle 的核心思想:将元素的行号的部分 bit 与列号的部分 bit 做 XOR,使每一行的起始 Bank 各不相同,从而消除冲突。

XOR 的几何意义

原始:bank(r, c) = c % 32
→ 每行 Bank 分配完全相同,列访问必然冲突

Swizzle 后:bank(r, c) = (c XOR r×k) % 32
→ 每行的 Bank 分配被整体偏移,所有行 Bank 各不同

交互演示:有无 Swizzle 的对比

矩阵:8×32 float32  |  访问:读第 0 列

Swizzle 数学公式

对线性索引 x,Swizzle<B, M, S> 的变换:

swizzle(x) = x  XOR  ( ((x >> (M+S)) & mask) << M )

其中 mask = (1 << B) - 1,提取 B 个 bit。

等价描述:把 x 的 高位段(从位置 M+S 起的 B 个 bit)XOR 进 x 的 低位段(从位置 M 起的 B 个 bit)。其余 bit 不变。

Swizzle<B, M, S> 三个参数的含义

B — Bits(位宽)

参与 XOR 的 bit 数,决定 Swizzle 的周期
周期 = 2B 行,之后模式重复。

常见值:B=3(8行为一个周期,足够覆盖32个bank的差异)

M — Min/Base(基底偏移)

被 XOR 改写的那段 bit 的起始位置,决定 Swizzle 的最小操作粒度 = 2M 字节。

M=4 → 16字节粒度(128-bit对齐,匹配 LDG.128 / TMA)
M=3 → 8字节粒度(64-bit对齐,匹配 LDG.64)

S — Shift(行间跨度)

从"被改写段"到"XOR源段"之间的 bit 距离,关联到矩阵的行宽

S 越大 → Swizzle 在行方向上的"步伐"越大,适合更宽的矩阵。
常见值:S=3,意味着 M+S 位置往上的3 bit 是行号。

交互式 Bit 图

不同矩阵配置的选择指南

数据类型每行列数 K每行字节数推荐 Swizzle原因
float1664128BSw<3,3,3>64元素对应64bit粒度
float16128256BSw<3,4,3>128bit向量化加载,TMA标配
float3232128BSw<3,4,3>128B = 16字节粒度
float321664BSw<3,3,3>64B = 8字节粒度
bfloat16128256BSw<3,4,3>同 float16 × 128

Swizzle 模式热力图

下图显示对应 Swizzle 参数下,每个元素实际映射到的 Bank 编号(颜色 = Bank,每行颜色各异表示无冲突):

Sw<3,4,3> vs Sw<3,3,3>:架构与场景差异

Sw<3,4,3> — 128-Byte Swizzle

H100 / A100 TMA WGMMA cp.async 128bit

Bit 操作:bit[4:7] XOR bit[7:10]

粒度:24 = 16 字节(128-bit)

适配场景:每次读取128bit(4个float32 或 8个float16),与 LDG.128 / cp.async.cg / TMA 对齐。

矩阵宽度要求:行宽 ≥ 2M+S × dtype = 27 × 2bytes = 256 bytes(float16 需 K≥128)

Sw<3,3,3> — 64-Byte Swizzle

A100 cp.async 64bit 较窄矩阵

Bit 操作:bit[3:6] XOR bit[6:9]

粒度:23 = 8 字节(64-bit)

适配场景:每次读取64bit(2个float32 或 4个float16)。

矩阵宽度要求:行宽 ≥ 26 × 2bytes = 128 bytes(float16 需 K≥64)

具体区别:在 K=64 float16 矩阵上的对比

矩阵:8行 × 64列 float16(每行128字节)。读取第0列时,各行访问的 Bank:

架构总结表

特性Sw<3,4,3>Sw<3,3,3>
XOR 的 bit 位置bit[4:7] ← bit[7:10]bit[3:6] ← bit[6:9]
操作粒度16 字节(128-bit)8 字节(64-bit)
Swizzle 生效的最小行宽128 元素(float16)/ 64 元素(float32)64 元素(float16)/ 32 元素(float32)
H100 TMA descriptor标配 ✓可用,但较少见
A100 cp.async 128bit推荐 ✓也可用
WGMMA smem layout匹配 ✓需验证
32bit 访问(LDG.32)不推荐更合适

一句话记忆法

M 决定粒度,S 决定行宽匹配,B 决定周期深度(通常固定为3)。

H100 + TMA + float16/bfloat16 + K=128 → 首选 Sw<3,4,3>
A100 + cp.async + float16 + K=64 → 首选 Sw<3,3,3>

CuTe 中如何使用 Swizzle

① 创建 Swizzled Layout

// 方式1:composition 直接组合 Swizzle 和 Layout
using SwizzleAtom = Swizzle<3, 4, 3>;

// 128列 float16 的 swizzled smem layout
auto smem_layout = composition(
    SwizzleAtom{},
    make_layout(
        make_shape (_128{}, _64{}),   // 128行, 64列(K方向)
        make_stride( _64{},  _1{})    // 行主序
    )
);

// 等价的简化写法(CuTe 内置的 swizzled layout helper)
using SmemLayoutA = decltype(composition(
    Swizzle<3, 4, 3>{},
    Layout<Shape<_128, _64>, Stride<_64, _1>>{}
));

② 在 Shared Memory Tensor 上使用

// 声明 smem 指针
extern __shared__ half_t smem_buf[];
auto sA = make_tensor(make_smem_ptr(smem_buf), smem_layout);

// 访问元素:CuTe 自动处理 swizzle 的地址计算
auto elem = sA(row, col);  // 逻辑坐标,物理地址已 swizzled

③ H100 TMA + Swizzle(最常见用法)

// TMA descriptor 携带 swizzle 信息
auto tma_load_A = make_tma_copy(
    SM90_TMA_LOAD{},
    tensor_A,                          // global memory tensor
    smem_layout_A.layout(),            // swizzled smem layout
    make_shape(bM, bK),               // tile shape
    Int<1>{}                          // multicast count
);

// Kernel 内部:TMA load 自动 swizzle
copy(tma_load_A, tma_coord, sA);       // 触发异步 TMA,结果直接写入 swizzled smem

④ 用 make_smem_layout_A 的便捷宏(CUTLASS风格)

// CUTLASS 中常见的 smem layout 生成方式
using SmemLayoutAtomA = decltype(
    composition(Swizzle<3, 4, 3>{},
                Layout<Shape <_8,_64>,
                       Stride<_64, _1>>{})
);

// 然后用 tiled 方式铺开整个 smem tile
using SmemLayoutA = decltype(
    tile_to_shape(SmemLayoutAtomA{},
                  Shape<Int<kM>, Int<kK>>{})
);

⑤ WGMMA 配合 Swizzle 的完整流程

// 1. 定义 MMA operation
using MmaOp = SM90_64x128x16_F32BF16BF16_SS<GMMA::Major::K>;
using TiledMma = TiledMMA<
    MMA_Atom<MmaOp>,
    Layout<Shape<_2,_2,_1>>,          // warpgroup tiling
    Tile<_128,_128,_64>               // MN K tile
>;

// 2. 定义 Swizzled smem layout(必须与 WGMMA 期望的 layout 匹配)
using SmemLayoutA = decltype(composition(
    Swizzle<3, 4, 3>{},
    Layout<Shape<_128, _64>, Stride<_64, _1>>{}
));

// 3. Kernel 内:TMA load → smem(自动 swizzle),smem → WGMMA
__shared__ bf16 smem_A[128 * 64];
auto sA = make_tensor(make_smem_ptr(smem_A), SmemLayoutA{});

copy(tma_load_A, tAgA, sA);           // async TMA load
cute::cp_async_fence();

gemm(tiled_mma, tCrA, tCrB, tCrC);   // WGMMA 读 smem,无 bank 冲突

⑥ 如何验证 Swizzle 是否正确消除冲突

// CuTe 提供的 layout print 工具
print_layout(smem_layout);   // 打印 layout 的坐标映射关系

// 检查任意一行的 bank 分配:
for (int c = 0; c < K; ++c) {
    int offset = smem_layout(0, c);   // swizzled linear offset
    int bank   = (offset * sizeof(T) / 4) % 32;
    printf("col=%d bank=%d\n", c, bank);
}
// 正确时:同一行中,每32个元素组内没有重复 bank

⚠ 常见踩坑点

问题原因解决
WGMMA 结果错误smem layout 的 swizzle 与 MMA atom 期望不匹配使用 CUTLASS 推荐的 SmemLayoutAtom
TMA store 地址错误store 的 swizzle descriptor 与 load 不同store/load 用同一个 smem_layout
Sm80 用了 Sm90 的 swizzleH100 的 Sw<3,4,3> 在 A100 cp.async 下 bank 分配错误根据架构和向量宽度选择 M 值
矩阵太窄,swizzle 没效果行宽 < 2M+S × dtype,高位 bit 始终为0换用更小 S 的 swizzle,或 padding