大型语言模型的推理分为两个阶段:Prefill(预填充)和 Decoding(解码)。FlashAttention 解决了 Prefill 阶段的效率问题,但 Decoding 阶段有着完全不同的性能瓶颈。
处理整个输入 prompt,Query 长度 = N(可能数千个 token)。Batch × Head 数量多,GPU SM 利用率高。计算密集型。
逐 token 生成,每次 Query 长度 = 1。Batch × Head 数量往往远小于 SM 数量。内存带宽受限型。
Decoding 中通过 KV Cache 缓存历史 token 的 K/V,避免重复计算。但这也意味着:每生成一个 token,都需要从 HBM(高带宽内存)加载整个 KV Cache。
核心瓶颈:以 CodeLlama-34B 为例,在 4 个 A100 GPU 上运行,使用 GQA(Grouped Query Attention,16个 Query head,2个 KV head)。当 batch size = 1、序列长度 = 64k 时,KV Cache 占用约 2GB,每生成一个 token 就需要从 HBM 读取一次,而 HBM 带宽是有限的(A100 约 2TB/s)。
FlashAttention 在 Prefill 时沿 Batch 和 Query 长度方向并行,充分利用了所有 SM。但在 Decoding 时:
论文原话:"During inference, the query length is typically 1: this means that if the batch size is smaller than the number of streaming multiprocessors (SMs) on the GPU (108 for an A100), the operation will only use a small part of the GPU. With a batch size of 1, FlashAttention will use less than 1% of the GPU!"
理解 Flash Decoding 必须先理解 FlashAttention 的核心思想,因为 Flash Decoding 直接在其基础上扩展。
# 标准实现:需要物化 N×N 的中间矩阵 S = Q @ K.T / sqrt(d) # HBM写入: N×N 的 float16 = 2N² 字节 P = softmax(S) # HBM读取 S,写入 P:再 2N² 字节 O = P @ V # HBM读取 P:再 2N² 字节 # 总 HBM 访问量:O(N²d) → 对长序列极其昂贵
FlashAttention 的关键洞察是:将 Q/K/V 分块(tile),在 SRAM 中计算完一块后立刻累加结果,永远不将 N×N 的中间矩阵写到 HBM。
HBM 访问量对比:标准 Attention O(N²d) → FlashAttention O(N²d²/M),其中 M 是 SRAM 大小。典型情况下减少 5-20× 的 HBM 访问。
训练时,FlashAttention 在 Batch × Head 和 Query 序列 方向并行,每个 Thread Block 负责一个 Q-block 与所有 K-block 的交互。这在序列长度大时非常高效,因为并行度充足。
但 Decoding 时 Query 长度 = 1,沿 Query 方向完全没有并行空间。Flash Decoding 的核心创新就是找到了一个新的并行维度。
Flash Decoding 能够将 KV 序列分块并行计算的数学基础,是 Online Softmax(在线 Softmax)的可合并性。
设已处理了前 t-1 个元素,状态为 (m, l, O),现在来了第 t 个块:
Flash Decoding 在不同 SM 上并行计算多个 KV 分块,每个分块输出 (O_split, m_split, l_split)。最终需要将这些局部结果合并成全局正确的输出:
可证明的正确性:由于 softmax 的 log-sum-exp 技巧,上述合并公式数学上等价于在完整序列上计算 softmax。这意味着任意分块顺序、任意并行得到的结果都完全正确,没有近似误差。
每个 KV 分块只需额外存储 1 个标量 L(每个 query head 1个),开销极小。
Flash Decoding 在 FlashAttention 基础上增加了一个新的并行维度——沿 KV 序列方向分块并行,通过三个步骤实现完全正确的结果合并。
分块数 S 的选择:S 的选择目标是让总并行度(Batch × Heads × S)充分填满所有 SM。对 A100(108个SM)来说,若 Batch=1, Heads=16,则 S=8 可使并行度达到 128,充分占用所有 SM。
@triton.jit def flash_decode_kernel1( Q_ptr, K_ptr, V_ptr, Out_ptr, LSE_ptr, # LSE = log-sum-exp seq_len, chunk_size, stride_qh, stride_kh, stride_vh, BLOCK_D: tl.constexpr, ): # 当前 thread block 负责的 split 索引 split_id = tl.program_id(0) head_id = tl.program_id(1) # 加载 Q (1 × d) q = tl.load(Q_ptr + head_id * stride_qh + tl.arange(0, BLOCK_D)) # 初始化局部状态 m_i = tl.full((1,), -1e9, dtype=tl.float32) # running max l_i = tl.zeros((1,), dtype=tl.float32) # running sum acc = tl.zeros((BLOCK_D,), dtype=tl.float32) # running output # 在当前 split 内部做 tiling(FlashAttention 风格) start = split_id * chunk_size for i in range(0, chunk_size, BLOCK_C): # 加载 K block, V block k = tl.load(K_ptr + (start + i) * stride_kh + tl.arange(0, BLOCK_D)) v = tl.load(V_ptr + (start + i) * stride_vh + tl.arange(0, BLOCK_D)) # 计算 QK score s = tl.dot(q[None,:], k[:,None]) * scale # (1,) # Online softmax 更新 m_new = tl.maximum(m_i, s) l_i = l_i * tl.exp(m_i - m_new) + tl.exp(s - m_new) acc = acc * tl.exp(m_i - m_new) + tl.exp(s - m_new) * v m_i = m_new # 归一化并写回 acc = acc / l_i lse = tl.log(l_i) + m_i # log-sum-exp tl.store(Out_ptr + ..., acc) tl.store(LSE_ptr + ..., lse)
@triton.jit def flash_decode_kernel2( Out_splits_ptr, LSE_splits_ptr, # Kernel 1 的输出 Final_out_ptr, num_splits, BLOCK_D: tl.constexpr, ): head_id = tl.program_id(0) d_ids = tl.arange(0, BLOCK_D) # Step 1: 读取所有 split 的 LSE 和 out lse_all = tl.load(LSE_splits_ptr + head_id * num_splits + tl.arange(0, num_splits)) # (S,) # Step 2: 计算全局 log-sum-exp lse_global = tl.reduce(lse_all, 0, tl.math.logsumexp) # Step 3: 加权合并 acc = tl.zeros((BLOCK_D,), dtype=tl.float32) for s in range(num_splits): lse_s = tl.load(LSE_splits_ptr + head_id * num_splits + s) out_s = tl.load(Out_splits_ptr + (head_id * num_splits + s) * BLOCK_D + d_ids) scale = tl.exp(lse_s - lse_global) acc += scale * out_s # Step 4: 写回 tl.store(Final_out_ptr + head_id * BLOCK_D + d_ids, acc)
关键数字:Kernel 2 的额外 HBM 访问量仅为 S × d × sizeof(float)。以 S=8, d=128, float16 为例,额外读写只有 8×128×2 = 2KB——相比 KV Cache 的 GB 级读写,几乎可以忽略不计。
Flash Decoding 的加速本质上是通过增加一个新的并行维度,让 GPU 的所有 SM 都"有活干"。
| 方法 | 并行维度 | 并行度(B=1, H=16, N=64k) | SM 利用率 |
|---|---|---|---|
| 标准 Attention | Batch × Head | 16 | ~15% |
| FlashAttention v2 | Batch × Head × Q-seq | 16(Q长度=1时无效) | ~15% |
| Flash Decoding | Batch × Head × KV-split | 16 × S(S可调) | ~100% |
在 Decoding 阶段,Query 序列长度 = 1(每次只生成一个 token),因此 FlashAttention 沿 Query 方向的并行完全失效。Flash Decoding 转而沿 KV 序列方向分块,这个维度随着上下文变长而增大,正好弥补了 Query 方向并行度的缺失。
性能上界:论文中将"读取整个模型权重 + KV Cache 的时间"作为 Attention 计算的理论上界。Flash Decoding 在长序列时接近这个上界,说明计算本身已经几乎被完全隐藏在内存读取中。
在 4×A100 上测试 CodeLlama-34B(batch=1,测量每秒生成的 token 数):
A100,f16,16个 Query head(dim=128),2个 KV head(GQA):
| 配置 | PyTorch Eager (μs) | FlashAttention v2 (μs) | Flash Decoding (μs) | 加速比 |
|---|---|---|---|---|
| B=256, seq=256 | 3058 | 391 | 63 | 6.2× vs FA2 |
| B=128, seq=512 | 3151 | 366 | 68 | 5.4× |
| B=16, seq=4096 | 3157 | 402 | 57 | 7.1× |
| B=8, seq=8192 | 3173 | 529 | 56 | 9.4× |
| B=2, seq=32768 | 3224 | 1156 | 60 | 19.3× |
| B=1, seq=65536 | 1336 | 2301 | 64 | 35.9× |
| B=1, seq=131072 | 2664 | 4592 | 107 | 43.0× |
规律:Flash Decoding 的延迟对序列长度几乎不敏感(56–107 μs),因为它将 KV 读取分散到多个 SM 并行执行,且 Kernel 2 的归约开销极小。而 FlashAttention 的延迟随序列长度线性增长,因为所有 KV 加载串行在少数几个 SM 上。
# flash_attn/flash_attn_interface.py def flash_attn_with_kvcache( q, # (batch, seqlen_q, nheads, headdim) k_cache, # (batch, seqlen_k, nheads_k, headdim) v_cache, # (batch, seqlen_k, nheads_k, headdim) cache_seqlens=None, # 每个 batch 的实际 KV 长度 softmax_scale=None, causal=False, num_splits=0, # 0 = 自动选择;>0 = 手动指定 split 数 ): """ Flash Decoding 接口。 当 seqlen_q == 1 时(典型 decoding 场景)自动启用 Flash Decoding。 支持 MQA/GQA(nheads_k < nheads)。 """ if num_splits == 0: # 启发式:根据 seq_len 和 SM 数量自动选择 split 数 num_splits = _get_num_splits(seqlen_k, nheads, batch) # Kernel 1:并行计算各 split 的局部 attention out_splits, lse_splits = _flash_decode_split( q, k_cache, v_cache, num_splits, softmax_scale) # Kernel 2:归约合并 out = _flash_decode_reduce(out_splits, lse_splits) return out
import triton import triton.language as tl import torch @triton.jit def _flash_decode_fwd_kernel( Q, K, V, # [B, H, 1, D], [B, H_k, N, D], [B, H_k, N, D] Out, LSE, # [B, H, S, D], [B, H, S] (S = num_splits) stride_qb, stride_qh, stride_qd, stride_kb, stride_kh, stride_kn, stride_kd, stride_ob, stride_oh, stride_os, stride_od, stride_lb, stride_lh, stride_ls, N, # 总序列长度 num_splits, scale, BLOCK_D: tl.constexpr, # head dim,必须是 power of 2 BLOCK_N: tl.constexpr, # 每次处理的 K 行数(tile size) GROUP_H: tl.constexpr, # GQA: Q heads per KV head ): # ── 索引计算 ───────────────────────────────────────── batch_id = tl.program_id(0) head_id = tl.program_id(1) split_id = tl.program_id(2) kv_head_id = head_id // GROUP_H # GQA: 映射到 KV head chunk_size = tl.cdiv(N, num_splits) start = split_id * chunk_size end = tl.minimum(start + chunk_size, N) # ── 加载 Q ────────────────────────────────────────── q = tl.load( Q + batch_id * stride_qb + head_id * stride_qh + tl.arange(0, BLOCK_D), mask=tl.arange(0, BLOCK_D) < BLOCK_D, other=0.0 ).to(tl.float32) # ── 初始化 Online Softmax 状态 ────────────────────── m_i = tl.full((1,), float("-inf"), dtype=tl.float32) l_i = tl.zeros((1,), dtype=tl.float32) acc = tl.zeros((BLOCK_D,), dtype=tl.float32) # ── 在当前 split 内部 tile 循环 ────────────────────── kv_base_k = (batch_id * stride_kb + kv_head_id * stride_kh + start * stride_kn) kv_base_v = (batch_id * stride_kb + kv_head_id * stride_kh + start * stride_kn) for block_start in range(start, end, BLOCK_N): block_end = tl.minimum(block_start + BLOCK_N, end) valid = block_end - block_start n_ids = tl.arange(0, BLOCK_N) mask_n = n_ids < valid # 加载 K block: [BLOCK_N, BLOCK_D] k = tl.load( K + kv_base_k + n_ids[:, None] * stride_kn + tl.arange(0, BLOCK_D)[None, :], mask=mask_n[:, None], other=0.0 ).to(tl.float32) # 加载 V block: [BLOCK_N, BLOCK_D] v = tl.load( V + kv_base_v + n_ids[:, None] * stride_kn + tl.arange(0, BLOCK_D)[None, :], mask=mask_n[:, None], other=0.0 ).to(tl.float32) # QK^T:scores [BLOCK_N] scores = tl.sum(q[None, :] * k, axis=1) * scale scores = tl.where(mask_n, scores, float("-inf")) # Online Softmax 更新 m_new = tl.maximum(m_i, tl.max(scores, axis=0)) p = tl.exp(scores - m_new) # [BLOCK_N] l_new = l_i * tl.exp(m_i - m_new) + tl.sum(p, axis=0) acc = acc * tl.exp(m_i - m_new) + tl.sum(p[:, None] * v, axis=0) m_i, l_i = m_new, l_new kv_base_k += BLOCK_N * stride_kn kv_base_v += BLOCK_N * stride_kn # ── 写回局部结果 ───────────────────────────────────── acc = acc / l_i # 归一化 lse = tl.log(l_i) + m_i # log-sum-exp out_off = (batch_id * stride_ob + head_id * stride_oh + split_id * stride_os + tl.arange(0, BLOCK_D)) tl.store(Out + out_off, acc.to(tl.float16)) lse_off = batch_id * stride_lb + head_id * stride_lh + split_id * stride_ls tl.store(LSE + lse_off, lse) @triton.jit def _flash_decode_reduce_kernel( Out_splits, LSE_splits, # [B, H, S, D], [B, H, S] Final_out, # [B, H, D] num_splits, BLOCK_D: tl.constexpr, ): batch_id = tl.program_id(0) head_id = tl.program_id(1) # 读取所有 split 的 LSE lse_off = batch_id * (num_splits * ...) + head_id * num_splits lse_vals = tl.load(LSE_splits + lse_off + tl.arange(0, num_splits)) # 全局 log-sum-exp:数值稳定实现 lse_max = tl.max(lse_vals, axis=0) lse_global = lse_max + tl.log(tl.sum(tl.exp(lse_vals - lse_max), axis=0)) # 加权合并所有 split 的 output acc = tl.zeros((BLOCK_D,), dtype=tl.float32) for s in range(num_splits): w = tl.exp(tl.load(LSE_splits + lse_off + s) - lse_global) out_s = tl.load(Out_splits + ... + s * BLOCK_D + tl.arange(0, BLOCK_D)) acc += w * out_s.to(tl.float32) tl.store(Final_out + batch_id * ... + head_id * BLOCK_D + tl.arange(0, BLOCK_D), acc.to(tl.float16))
def flash_decode(q, k_cache, v_cache, num_splits=None): """ q: [batch, 1, nheads, headdim] ← seqlen=1 for decoding k_cache: [batch, seqlen, nheads_k, headdim] ← KV Cache v_cache: [batch, seqlen, nheads_k, headdim] """ B, _, H, D = q.shape _, N, H_k, _ = k_cache.shape GROUP_H = H // H_k # GQA group size if num_splits is None: # 启发式:让总 Thread Block 数 ≈ SM 数的 4 倍(充分利用) num_sm = torch.cuda.get_device_properties(0).multi_processor_count num_splits = max(1, num_sm * 4 // (B * H)) num_splits = min(num_splits, N // 64) # 每个 split 至少 64 个 token scale = D ** -0.5 # 临时缓冲区 out_splits = torch.empty(B, H, num_splits, D, dtype=torch.float16, device=q.device) lse_splits = torch.empty(B, H, num_splits, dtype=torch.float32, device=q.device) # Kernel 1 启动 grid1 = (B, H, num_splits) _flash_decode_fwd_kernel[grid1]( q, k_cache, v_cache, out_splits, lse_splits, *q.stride(), *k_cache.stride(), *out_splits.stride(), *lse_splits.stride(), N, num_splits, scale, BLOCK_D=D, BLOCK_N=64, GROUP_H=GROUP_H, ) # Kernel 2 启动(归约) final_out = torch.empty(B, H, D, dtype=torch.float16, device=q.device) grid2 = (B, H) _flash_decode_reduce_kernel[grid2]( out_splits, lse_splits, final_out, num_splits, BLOCK_D=D) return final_out.unsqueeze(1) # [B, 1, H, D]
Flash Decoding++ (MLSys 2024) 进一步优化了原版 Flash Decoding,提出了三个新技术,在 NVIDIA 和 AMD GPU 上实现额外 4× 以上加速。
原版 Flash Decoding 中,合并各 split 时需要知道全局最大值 m_global,这需要两轮扫描(先找最大,再归约),引入跨 SM 同步等待。
1. 各 split 计算局部 (O_s, m_s, l_s)
2. 同步等待 → 找全局 m_global
3. 用 m_global 更新每个 split 的结果
→ 需要两轮 global reduction
用统计分析确定 QK score 的安全上界(例如基于 Q/K 的 L2 范数)
每个 split 直接使用这个统一的全局 max,无需同步
→ 异步执行,消除跨-split barrier
Decoding 阶段的矩阵乘法是 GEMV(矩阵×向量)形式(batch 小,Q 长度=1),而 GPU 的 Tensor Core 为 GEMM 优化,对小矩阵效率低。Flash Decoding++ 用 double buffering 隐藏 GEMM 延迟。
根据不同硬件(A100/H100/MI200)的内存层次特性,动态选择最优的数据复用策略(L1/L2/HBM 的不同层次)。
| 方法 | 与 HuggingFace 对比 | 与 Flash Decoding 对比 |
|---|---|---|
| Flash Decoding++(NVIDIA) | 最高 4.86× | 最高 1.37× |
| Flash Decoding++(AMD) | 最高 4.35× | 类似提升 |
from flash_attn import flash_attn_with_kvcache class LlamaAttention: def forward(self, x, kv_cache, position_ids): B, seq_q, _ = x.shape q = self.q_proj(x).view(B, seq_q, self.n_heads, self.head_dim) k = self.k_proj(x).view(B, seq_q, self.n_kv_heads, self.head_dim) v = self.v_proj(x).view(B, seq_q, self.n_kv_heads, self.head_dim) # 更新 KV Cache kv_cache.k[..., :seq_q, :] = k kv_cache.v[..., :seq_q, :] = v # Flash Decoding (当 seq_q == 1 时自动启用) out = flash_attn_with_kvcache( q, # [B, 1, H, D] kv_cache.k, # [B, N, H_k, D] kv_cache.v, # [B, N, H_k, D] cache_seqlens=kv_cache.seq_lens, # 每条序列的实际长度 softmax_scale=self.head_dim**-0.5, causal=True, ) # out: [B, 1, H, D] out = out.view(B, seq_q, self.n_heads * self.head_dim) return self.o_proj(out)
import xformers.ops as xops # xFormers 会自动 dispatch: # - seqlen_q > 1 → FlashAttention(训练/prefill) # - seqlen_q == 1 → Flash Decoding(decoding) out = xops.memory_efficient_attention( q, k, v, attn_bias=None, scale=head_dim**-0.5, )
| 场景 | 建议 | 原因 |
|---|---|---|
| 短序列(<512)+ 大 Batch | FlashAttention v2 | 并行度已充足,split 开销反而增加 |
| Prefill 阶段(seqlen_q > 1) | FlashAttention v2 | 已沿 Q 维度并行,不需要 split KV |
| 长序列 + 小 Batch | ✓ Flash Decoding | KV split 是唯一可用并行维度 |
| 极长序列(>32k) | ✓ Flash Decoding | 加速比最显著(可达 35×+) |
| 方法 | 主要优化 | 与 Flash Decoding 的关系 |
|---|---|---|
| FlashAttention v1/v2 | IO-aware tiling,减少 HBM 访问 | Flash Decoding 的基础;FA2 加入了 Q-parallel |
| FlashAttention v3 | Hopper WGMMA + TMA,ping-pong 流水 | 包含 Flash Decoding (Split-KV) 作为推理优化 |
| Flash Decoding++ | 异步 Softmax + Flat GEMM + 启发式数据流 | Flash Decoding 的改进版,多后端支持 |
| vLLM (PagedAttention) | KV Cache 分页管理,支持动态 batch | 正交技术,可组合使用 |
| GQA (Grouped Query Attention) | 减少 KV head 数量,降低 KV Cache 大小 | Flash Decoding 原生支持 GQA;两者结合效果更好 |
| Speculative Decoding | Draft model 预测多个 token,验证后批量接受 | DEFT 等工作将 Flash Decoding 扩展到树状 attention |
| DEFT (Tree Attention) | 树形解码结构的 Flash Decoding | 直接扩展,处理分支 KV Cache 的负载均衡 |
核心总结:Flash Decoding 的本质是发现了 LLM Decoding 中被忽视的并行维度——KV 序列方向。通过将 Online Softmax 的可分解性从"块内"扩展到"块间",实现了跨 SM 的并行 KV 加载,以接近硬件带宽极限的速度读取 KV Cache,从而将长序列 Attention 的延迟从 O(N) 降低到接近 O(1)(相对于序列长度)。
参考文献:Tri Dao, Daniel Haziza, Francisco Massa, Grigory Sizov. "Flash-Decoding for long-context inference." Together.ai / Meta, October 2023. 原始博客:crfm.stanford.edu/2023/10/12/flashdecoding.html / pytorch.org/blog/flash-decoding/