深度学习手册 / 2023–2024

Flash Decoding
长上下文推理加速

Tri Dao · Daniel Haziza · Francisco Massa · Grigory Sizov  |  Together.ai / Meta / Princeton · Oct 2023
LLM Inference Attention KV Cache GPU Kernel IO-Awareness Online Softmax
目录

问题背景:为什么 Decoding 慢?

大型语言模型的推理分为两个阶段:Prefill(预填充)Decoding(解码)。FlashAttention 解决了 Prefill 阶段的效率问题,但 Decoding 阶段有着完全不同的性能瓶颈。

❶ Prefill 阶段

处理整个输入 prompt,Query 长度 = N(可能数千个 token)。Batch × Head 数量多,GPU SM 利用率高。计算密集型。

❷ Decoding 阶段

逐 token 生成,每次 Query 长度 = 1。Batch × Head 数量往往远小于 SM 数量。内存带宽受限型。

KV Cache:省计算,但带来新瓶颈

Decoding 中通过 KV Cache 缓存历史 token 的 K/V,避免重复计算。但这也意味着:每生成一个 token,都需要从 HBM(高带宽内存)加载整个 KV Cache。

核心瓶颈:以 CodeLlama-34B 为例,在 4 个 A100 GPU 上运行,使用 GQA(Grouped Query Attention,16个 Query head,2个 KV head)。当 batch size = 1、序列长度 = 64k 时,KV Cache 占用约 2GB,每生成一个 token 就需要从 HBM 读取一次,而 HBM 带宽是有限的(A100 约 2TB/s)。

SM 利用率:FlashAttention 的盲区

FlashAttention 在 Prefill 时沿 Batch 和 Query 长度方向并行,充分利用了所有 SM。但在 Decoding 时:

可用并行度 = Batch_size × Head_num
A100 有 108 个 SM
若 Batch=1,Head=32:并行度=32,仅占用 32/108 ≈ 30% 的 SM
若使用 GQA(2个KV head):并行度更低,接近 1% 利用率

论文原话:"During inference, the query length is typically 1: this means that if the batch size is smaller than the number of streaming multiprocessors (SMs) on the GPU (108 for an A100), the operation will only use a small part of the GPU. With a batch size of 1, FlashAttention will use less than 1% of the GPU!"

A100 SM 总数
108
每个 SM 独立执行 Thread Block
FA2 Decoding 利用率
<1%
Batch=1, GQA-2 时的极端情况
Flash Decoding 加速
在 64k 序列长度下 vs FlashAttention

FlashAttention 回顾:训练阶段的优化

理解 Flash Decoding 必须先理解 FlashAttention 的核心思想,因为 Flash Decoding 直接在其基础上扩展。

标准 Attention 的内存问题

PYTHON (Standard Attention)
# 标准实现:需要物化 N×N 的中间矩阵
S = Q @ K.T / sqrt(d)          # HBM写入: N×N 的 float16 = 2N² 字节
P = softmax(S)                  # HBM读取 S,写入 P:再 2N² 字节
O = P @ V                       # HBM读取 P:再 2N² 字节
# 总 HBM 访问量:O(N²d) → 对长序列极其昂贵

FlashAttention 的解法:Tiling + Online Softmax

FlashAttention 的关键洞察是:将 Q/K/V 分块(tile),在 SRAM 中计算完一块后立刻累加结果,永远不将 N×N 的中间矩阵写到 HBM

FlashAttention 前向传播
1将 K, V 分成 T_c 个块,每块大小 B_c × d
2将 Q, O 分成 T_r 个块,每块大小 B_r × d
3外层循环 Q-blocks,内层循环 K-blocks:每次加载到 SRAM,计算局部 softmax,用 online softmax 公式更新 O
4只写回 O(N×d)和 logsumexp L(N×1),不写中间矩阵

HBM 访问量对比:标准 Attention O(N²d) → FlashAttention O(N²d²/M),其中 M 是 SRAM 大小。典型情况下减少 5-20× 的 HBM 访问。

FlashAttention 并行化策略(训练/Prefill)

训练时,FlashAttention 在 Batch × HeadQuery 序列 方向并行,每个 Thread Block 负责一个 Q-block 与所有 K-block 的交互。这在序列长度大时非常高效,因为并行度充足。

但 Decoding 时 Query 长度 = 1,沿 Query 方向完全没有并行空间。Flash Decoding 的核心创新就是找到了一个新的并行维度。

Online Softmax:可分解的核心数学

Flash Decoding 能够将 KV 序列分块并行计算的数学基础,是 Online Softmax(在线 Softmax)的可合并性。

标准 Softmax 的问题

softmax(x)_i = exp(x_i) / Σ exp(x_j)

# 数值不稳定版:x 很大时 exp(x) 溢出
# 安全版:减去 max 再算
softmax(x)_i = exp(x_i - m) / Σ exp(x_j - m),m = max(x)

→ 问题:需要知道全局 max,无法流式计算

Online Softmax:逐块更新公式

设已处理了前 t-1 个元素,状态为 (m, l, O),现在来了第 t 个块:

状态定义:
m_t = max(x_1, ..., x_t) → 当前见到的最大值
l_t = Σ_{i=1}^{t} exp(x_i - m_t) → 归一化因子(分母)
O_t = Σ_{i=1}^{t} exp(x_i - m_t) · v_i / l_t → 当前输出累计

当新的块 (x_{t+1}, v_{t+1}) 到来时:
m_new = max(m_t, x_{t+1})
l_new = l_t · exp(m_t - m_new) + exp(x_{t+1} - m_new)
O_new = [ O_t · l_t · exp(m_t - m_new) + exp(x_{t+1} - m_new) · v_{t+1} ] / l_new

跨分块的合并公式(Flash Decoding 的关键)

Flash Decoding 在不同 SM 上并行计算多个 KV 分块,每个分块输出 (O_split, m_split, l_split)。最终需要将这些局部结果合并成全局正确的输出:

设有两个分块 A 和 B 的局部结果:
分块 A:(O_A, m_A, l_A)
分块 B:(O_B, m_B, l_B)

全局合并:
m_global = max(m_A, m_B)
l_global = l_A · exp(m_A - m_global) + l_B · exp(m_B - m_global)
O_global = [ O_A · l_A · exp(m_A - m_global) + O_B · l_B · exp(m_B - m_global) ] / l_global

# 这等价于计算完整序列 [A, B] 的正确 softmax attention 输出

可证明的正确性:由于 softmax 的 log-sum-exp 技巧,上述合并公式数学上等价于在完整序列上计算 softmax。这意味着任意分块顺序、任意并行得到的结果都完全正确,没有近似误差。

存储的辅助量:log-sum-exp L

L_split = log(l_split) + m_split → log-sum-exp,数值更稳定

使用 L 进行合并:
L_global = log(exp(L_A) + exp(L_B)) → logsumexp(L_A, L_B)
O_global = O_A · exp(L_A - L_global) + O_B · exp(L_B - L_global)

每个 KV 分块只需额外存储 1 个标量 L(每个 query head 1个),开销极小。

Flash Decoding 算法:三步设计

Flash Decoding 在 FlashAttention 基础上增加了一个新的并行维度——沿 KV 序列方向分块并行,通过三个步骤实现完全正确的结果合并。

Step 1:分割 KV(零开销)

将 K ∈ ℝ^{N×d} 分为 S 个块:K = [K_0 | K_1 | ... | K_{S-1}]
将 V ∈ ℝ^{N×d} 分为 S 个块:V = [V_0 | V_1 | ... | V_{S-1}]

每块大小:chunk_size = N / S

注意:不涉及任何 GPU 操作!分块只是原始 KV tensor 的不同视图(view),零内存开销。

分块数 S 的选择:S 的选择目标是让总并行度(Batch × Heads × S)充分填满所有 SM。对 A100(108个SM)来说,若 Batch=1, Heads=16,则 S=8 可使并行度达到 128,充分占用所有 SM。

Step 2:并行计算各分块的局部 Attention(Kernel 1)

Kernel 1:局部 Attention 计算(每个 split 独立运行在不同 SM 上)
1加载 Q ∈ ℝ^{1×d}(或 B_r×d)到 SRAM
2加载当前分块 K_s ∈ ℝ^{chunk×d},V_s ∈ ℝ^{chunk×d} 到 SRAM
3使用 FlashAttention 算法(带 online softmax 的 tiling)计算局部输出:
O_s, m_s, l_s = flash_attn(Q, K_s, V_s)
4将 O_s(head_dim 维向量)和 L_s = log(l_s) + m_s(1个标量)写回 HBM
额外存储:每个 query × 每个分块只需 1 个标量!
TRITON (Kernel 1 简化版)
@triton.jit
def flash_decode_kernel1(
    Q_ptr, K_ptr, V_ptr,
    Out_ptr, LSE_ptr,    # LSE = log-sum-exp
    seq_len, chunk_size,
    stride_qh, stride_kh, stride_vh,
    BLOCK_D: tl.constexpr,
):
    # 当前 thread block 负责的 split 索引
    split_id = tl.program_id(0)
    head_id  = tl.program_id(1)
    
    # 加载 Q (1 × d)
    q = tl.load(Q_ptr + head_id * stride_qh + tl.arange(0, BLOCK_D))
    
    # 初始化局部状态
    m_i = tl.full((1,), -1e9, dtype=tl.float32)  # running max
    l_i = tl.zeros((1,), dtype=tl.float32)        # running sum
    acc = tl.zeros((BLOCK_D,), dtype=tl.float32)    # running output
    
    # 在当前 split 内部做 tiling(FlashAttention 风格)
    start = split_id * chunk_size
    for i in range(0, chunk_size, BLOCK_C):
        # 加载 K block, V block
        k = tl.load(K_ptr + (start + i) * stride_kh + tl.arange(0, BLOCK_D))
        v = tl.load(V_ptr + (start + i) * stride_vh + tl.arange(0, BLOCK_D))
        
        # 计算 QK score
        s = tl.dot(q[None,:], k[:,None]) * scale  # (1,)
        
        # Online softmax 更新
        m_new = tl.maximum(m_i, s)
        l_i = l_i * tl.exp(m_i - m_new) + tl.exp(s - m_new)
        acc  = acc * tl.exp(m_i - m_new) + tl.exp(s - m_new) * v
        m_i  = m_new
    
    # 归一化并写回
    acc = acc / l_i
    lse = tl.log(l_i) + m_i   # log-sum-exp
    tl.store(Out_ptr + ..., acc)
    tl.store(LSE_ptr + ..., lse)

Step 3:归约合并(Kernel 2)

Kernel 2:全局归约(合并所有 split 的结果)
1读取所有 S 个分块的 O_s(d 维)和 L_s(1 维标量)
2计算全局 L_global = logsumexp(L_0, L_1, ..., L_{S-1})
3加权合并:O_global = Σ_s O_s · exp(L_s - L_global)
4写回最终输出 O_global(形状 1×d,与标准 attention 完全相同)
TRITON (Kernel 2 简化版)
@triton.jit
def flash_decode_kernel2(
    Out_splits_ptr, LSE_splits_ptr,   # Kernel 1 的输出
    Final_out_ptr,
    num_splits,
    BLOCK_D: tl.constexpr,
):
    head_id = tl.program_id(0)
    d_ids   = tl.arange(0, BLOCK_D)
    
    # Step 1: 读取所有 split 的 LSE 和 out
    lse_all = tl.load(LSE_splits_ptr + head_id * num_splits 
                      + tl.arange(0, num_splits))  # (S,)
    
    # Step 2: 计算全局 log-sum-exp
    lse_global = tl.reduce(lse_all, 0, tl.math.logsumexp)
    
    # Step 3: 加权合并
    acc = tl.zeros((BLOCK_D,), dtype=tl.float32)
    for s in range(num_splits):
        lse_s = tl.load(LSE_splits_ptr + head_id * num_splits + s)
        out_s = tl.load(Out_splits_ptr + (head_id * num_splits + s) * BLOCK_D + d_ids)
        scale = tl.exp(lse_s - lse_global)
        acc  += scale * out_s
    
    # Step 4: 写回
    tl.store(Final_out_ptr + head_id * BLOCK_D + d_ids, acc)

关键数字:Kernel 2 的额外 HBM 访问量仅为 S × d × sizeof(float)。以 S=8, d=128, float16 为例,额外读写只有 8×128×2 = 2KB——相比 KV Cache 的 GB 级读写,几乎可以忽略不计。

GPU 并行性:从 SM 利用率理解加速原理

Flash Decoding 的加速本质上是通过增加一个新的并行维度,让 GPU 的所有 SM 都"有活干"。

并行度对比

方法并行维度并行度(B=1, H=16, N=64k)SM 利用率
标准 AttentionBatch × Head16~15%
FlashAttention v2Batch × Head × Q-seq16(Q长度=1时无效)~15%
Flash DecodingBatch × Head × KV-split16 × S(S可调)~100%

为什么沿 Q 方向并行不行?

在 Decoding 阶段,Query 序列长度 = 1(每次只生成一个 token),因此 FlashAttention 沿 Query 方向的并行完全失效。Flash Decoding 转而沿 KV 序列方向分块,这个维度随着上下文变长而增大,正好弥补了 Query 方向并行度的缺失。

内存带宽利用率

KV Cache 读取量(以 CodeLlama-34B, 4 GPU, seq=64k 为例):
= 2 × seq_len × d × num_kv_heads × sizeof(bf16) / num_gpu
= 2 × 64000 × 128 × 2 × 2 / 4
82 MB 每个 token

A100 HBM 带宽 = 2 TB/s
→ 理论最短时间 = 82MB / 2TB/s ≈ 41 μs
Flash Decoding 实测:~60 μs(接近硬件极限)

性能上界:论文中将"读取整个模型权重 + KV Cache 的时间"作为 Attention 计算的理论上界。Flash Decoding 在长序列时接近这个上界,说明计算本身已经几乎被完全隐藏在内存读取中。

实测性能:8× 加速的来源

CodeLlama-34B Decoding 吞吐量

在 4×A100 上测试 CodeLlama-34B(batch=1,测量每秒生成的 token 数):

Micro-benchmark:Attention 单算子延迟

A100,f16,16个 Query head(dim=128),2个 KV head(GQA):

配置PyTorch Eager (μs)FlashAttention v2 (μs)Flash Decoding (μs)加速比
B=256, seq=2563058391636.2× vs FA2
B=128, seq=5123151366685.4×
B=16, seq=40963157402577.1×
B=8, seq=81923173529569.4×
B=2, seq=32768322411566019.3×
B=1, seq=65536133623016435.9×
B=1, seq=1310722664459210743.0×

规律:Flash Decoding 的延迟对序列长度几乎不敏感(56–107 μs),因为它将 KV 读取分散到多个 SM 并行执行,且 Kernel 2 的归约开销极小。而 FlashAttention 的延迟随序列长度线性增长,因为所有 KV 加载串行在少数几个 SM 上。

性能不变性分析

Flash Decoding 延迟 ≈ (KV 总量 / 并行 SM 数) / HBM带宽 + 归约开销

当 split 数 S ∝ seq_len 时(保持每个 split 大小固定):
并行 SM 数 ∝ seq_len
KV 总量 ∝ seq_len
→ 延迟 ≈ 常数!这就是为什么长序列时延迟几乎不变

代码实现:完整 Triton Kernel

xFormers / FlashAttention repo 中的实现结构

PYTHON (接口层)
# flash_attn/flash_attn_interface.py
def flash_attn_with_kvcache(
    q,          # (batch, seqlen_q, nheads, headdim)
    k_cache,    # (batch, seqlen_k, nheads_k, headdim)
    v_cache,    # (batch, seqlen_k, nheads_k, headdim)
    cache_seqlens=None,   # 每个 batch 的实际 KV 长度
    softmax_scale=None,
    causal=False,
    num_splits=0,        # 0 = 自动选择;>0 = 手动指定 split 数
):
    """
    Flash Decoding 接口。
    当 seqlen_q == 1 时(典型 decoding 场景)自动启用 Flash Decoding。
    支持 MQA/GQA(nheads_k < nheads)。
    """
    if num_splits == 0:
        # 启发式:根据 seq_len 和 SM 数量自动选择 split 数
        num_splits = _get_num_splits(seqlen_k, nheads, batch)
    
    # Kernel 1:并行计算各 split 的局部 attention
    out_splits, lse_splits = _flash_decode_split(
        q, k_cache, v_cache, num_splits, softmax_scale)
    
    # Kernel 2:归约合并
    out = _flash_decode_reduce(out_splits, lse_splits)
    
    return out

完整 Triton 实现(简化教学版)

TRITON (完整可运行版本)
import triton
import triton.language as tl
import torch

@triton.jit
def _flash_decode_fwd_kernel(
    Q, K, V,           # [B, H, 1, D], [B, H_k, N, D], [B, H_k, N, D]
    Out, LSE,          # [B, H, S, D], [B, H, S]  (S = num_splits)
    stride_qb, stride_qh, stride_qd,
    stride_kb, stride_kh, stride_kn, stride_kd,
    stride_ob, stride_oh, stride_os, stride_od,
    stride_lb, stride_lh, stride_ls,
    N,                 # 总序列长度
    num_splits,
    scale,
    BLOCK_D: tl.constexpr,   # head dim,必须是 power of 2
    BLOCK_N: tl.constexpr,   # 每次处理的 K 行数(tile size)
    GROUP_H: tl.constexpr,   # GQA: Q heads per KV head
):
    # ── 索引计算 ─────────────────────────────────────────
    batch_id = tl.program_id(0)
    head_id  = tl.program_id(1)
    split_id = tl.program_id(2)
    
    kv_head_id = head_id // GROUP_H   # GQA: 映射到 KV head
    
    chunk_size = tl.cdiv(N, num_splits)
    start = split_id * chunk_size
    end   = tl.minimum(start + chunk_size, N)
    
    # ── 加载 Q ──────────────────────────────────────────
    q = tl.load(
        Q + batch_id * stride_qb + head_id * stride_qh 
          + tl.arange(0, BLOCK_D),
        mask=tl.arange(0, BLOCK_D) < BLOCK_D, other=0.0
    ).to(tl.float32)
    
    # ── 初始化 Online Softmax 状态 ──────────────────────
    m_i = tl.full((1,), float("-inf"), dtype=tl.float32)
    l_i = tl.zeros((1,), dtype=tl.float32)
    acc = tl.zeros((BLOCK_D,), dtype=tl.float32)
    
    # ── 在当前 split 内部 tile 循环 ──────────────────────
    kv_base_k = (batch_id * stride_kb + kv_head_id * stride_kh 
                 + start * stride_kn)
    kv_base_v = (batch_id * stride_kb + kv_head_id * stride_kh 
                 + start * stride_kn)
    
    for block_start in range(start, end, BLOCK_N):
        block_end = tl.minimum(block_start + BLOCK_N, end)
        valid = block_end - block_start
        
        n_ids = tl.arange(0, BLOCK_N)
        mask_n = n_ids < valid
        
        # 加载 K block: [BLOCK_N, BLOCK_D]
        k = tl.load(
            K + kv_base_k + n_ids[:, None] * stride_kn 
              + tl.arange(0, BLOCK_D)[None, :],
            mask=mask_n[:, None], other=0.0
        ).to(tl.float32)
        
        # 加载 V block: [BLOCK_N, BLOCK_D]
        v = tl.load(
            V + kv_base_v + n_ids[:, None] * stride_kn 
              + tl.arange(0, BLOCK_D)[None, :],
            mask=mask_n[:, None], other=0.0
        ).to(tl.float32)
        
        # QK^T:scores [BLOCK_N]
        scores = tl.sum(q[None, :] * k, axis=1) * scale
        scores = tl.where(mask_n, scores, float("-inf"))
        
        # Online Softmax 更新
        m_new = tl.maximum(m_i, tl.max(scores, axis=0))
        p = tl.exp(scores - m_new)           # [BLOCK_N]
        l_new = l_i * tl.exp(m_i - m_new) + tl.sum(p, axis=0)
        acc = acc * tl.exp(m_i - m_new) + tl.sum(p[:, None] * v, axis=0)
        m_i, l_i = m_new, l_new
        
        kv_base_k += BLOCK_N * stride_kn
        kv_base_v += BLOCK_N * stride_kn
    
    # ── 写回局部结果 ─────────────────────────────────────
    acc = acc / l_i   # 归一化
    lse = tl.log(l_i) + m_i   # log-sum-exp
    
    out_off = (batch_id * stride_ob + head_id * stride_oh 
               + split_id * stride_os + tl.arange(0, BLOCK_D))
    tl.store(Out + out_off, acc.to(tl.float16))
    
    lse_off = batch_id * stride_lb + head_id * stride_lh + split_id * stride_ls
    tl.store(LSE + lse_off, lse)


@triton.jit
def _flash_decode_reduce_kernel(
    Out_splits, LSE_splits,   # [B, H, S, D], [B, H, S]
    Final_out,                # [B, H, D]
    num_splits,
    BLOCK_D: tl.constexpr,
):
    batch_id = tl.program_id(0)
    head_id  = tl.program_id(1)
    
    # 读取所有 split 的 LSE
    lse_off = batch_id * (num_splits * ...) + head_id * num_splits
    lse_vals = tl.load(LSE_splits + lse_off + tl.arange(0, num_splits))
    
    # 全局 log-sum-exp:数值稳定实现
    lse_max = tl.max(lse_vals, axis=0)
    lse_global = lse_max + tl.log(tl.sum(tl.exp(lse_vals - lse_max), axis=0))
    
    # 加权合并所有 split 的 output
    acc = tl.zeros((BLOCK_D,), dtype=tl.float32)
    for s in range(num_splits):
        w = tl.exp(tl.load(LSE_splits + lse_off + s) - lse_global)
        out_s = tl.load(Out_splits + ... + s * BLOCK_D + tl.arange(0, BLOCK_D))
        acc += w * out_s.to(tl.float32)
    
    tl.store(Final_out + batch_id * ... + head_id * BLOCK_D 
             + tl.arange(0, BLOCK_D), acc.to(tl.float16))

Python 调用入口

PYTHON
def flash_decode(q, k_cache, v_cache, num_splits=None):
    """
    q:       [batch, 1, nheads, headdim]       ← seqlen=1 for decoding
    k_cache: [batch, seqlen, nheads_k, headdim] ← KV Cache
    v_cache: [batch, seqlen, nheads_k, headdim]
    """
    B, _, H, D = q.shape
    _, N, H_k, _ = k_cache.shape
    GROUP_H = H // H_k   # GQA group size
    
    if num_splits is None:
        # 启发式:让总 Thread Block 数 ≈ SM 数的 4 倍(充分利用)
        num_sm = torch.cuda.get_device_properties(0).multi_processor_count
        num_splits = max(1, num_sm * 4 // (B * H))
        num_splits = min(num_splits, N // 64)   # 每个 split 至少 64 个 token
    
    scale = D ** -0.5
    
    # 临时缓冲区
    out_splits = torch.empty(B, H, num_splits, D, dtype=torch.float16, 
                             device=q.device)
    lse_splits  = torch.empty(B, H, num_splits, dtype=torch.float32, 
                              device=q.device)
    
    # Kernel 1 启动
    grid1 = (B, H, num_splits)
    _flash_decode_fwd_kernel[grid1](
        q, k_cache, v_cache, out_splits, lse_splits,
        *q.stride(), *k_cache.stride(),
        *out_splits.stride(), *lse_splits.stride(),
        N, num_splits, scale,
        BLOCK_D=D, BLOCK_N=64, GROUP_H=GROUP_H,
    )
    
    # Kernel 2 启动(归约)
    final_out = torch.empty(B, H, D, dtype=torch.float16, device=q.device)
    grid2 = (B, H)
    _flash_decode_reduce_kernel[grid2](
        out_splits, lse_splits, final_out, num_splits, BLOCK_D=D)
    
    return final_out.unsqueeze(1)  # [B, 1, H, D]

扩展变体:Flash Decoding++

Flash Decoding++ (MLSys 2024) 进一步优化了原版 Flash Decoding,提出了三个新技术,在 NVIDIA 和 AMD GPU 上实现额外 4× 以上加速。

问题 1:Softmax 归约的同步开销

原版 Flash Decoding 中,合并各 split 时需要知道全局最大值 m_global,这需要两轮扫描(先找最大,再归约),引入跨 SM 同步等待。

原版:同步 Softmax

1. 各 split 计算局部 (O_s, m_s, l_s)
2. 同步等待 → 找全局 m_global
3. 用 m_global 更新每个 split 的结果
→ 需要两轮 global reduction

++:统一最大值(Unified Max)

用统计分析确定 QK score 的安全上界(例如基于 Q/K 的 L2 范数)
每个 split 直接使用这个统一的全局 max,无需同步
→ 异步执行,消除跨-split barrier

问题 2:Flat GEMM 效率

Decoding 阶段的矩阵乘法是 GEMV(矩阵×向量)形式(batch 小,Q 长度=1),而 GPU 的 Tensor Core 为 GEMM 优化,对小矩阵效率低。Flash Decoding++ 用 double buffering 隐藏 GEMM 延迟。

问题 3:数据流启发式优化

根据不同硬件(A100/H100/MI200)的内存层次特性,动态选择最优的数据复用策略(L1/L2/HBM 的不同层次)。

方法与 HuggingFace 对比与 Flash Decoding 对比
Flash Decoding++(NVIDIA)最高 4.86×最高 1.37×
Flash Decoding++(AMD)最高 4.35×类似提升

工程集成:在 LLaMA 推理中使用

通过 FlashAttention 库使用

PYTHON
from flash_attn import flash_attn_with_kvcache

class LlamaAttention:
    def forward(self, x, kv_cache, position_ids):
        B, seq_q, _ = x.shape
        
        q = self.q_proj(x).view(B, seq_q, self.n_heads, self.head_dim)
        k = self.k_proj(x).view(B, seq_q, self.n_kv_heads, self.head_dim)
        v = self.v_proj(x).view(B, seq_q, self.n_kv_heads, self.head_dim)
        
        # 更新 KV Cache
        kv_cache.k[..., :seq_q, :] = k
        kv_cache.v[..., :seq_q, :] = v
        
        # Flash Decoding (当 seq_q == 1 时自动启用)
        out = flash_attn_with_kvcache(
            q,              # [B, 1, H, D]
            kv_cache.k,     # [B, N, H_k, D]
            kv_cache.v,     # [B, N, H_k, D]
            cache_seqlens=kv_cache.seq_lens,  # 每条序列的实际长度
            softmax_scale=self.head_dim**-0.5,
            causal=True,
        )
        # out: [B, 1, H, D]
        
        out = out.view(B, seq_q, self.n_heads * self.head_dim)
        return self.o_proj(out)

通过 xFormers 使用

PYTHON
import xformers.ops as xops

# xFormers 会自动 dispatch:
# - seqlen_q > 1 → FlashAttention(训练/prefill)
# - seqlen_q == 1 → Flash Decoding(decoding)
out = xops.memory_efficient_attention(
    q, k, v,
    attn_bias=None,
    scale=head_dim**-0.5,
)

何时不适用 Flash Decoding?

场景建议原因
短序列(<512)+ 大 BatchFlashAttention v2并行度已充足,split 开销反而增加
Prefill 阶段(seqlen_q > 1)FlashAttention v2已沿 Q 维度并行,不需要 split KV
长序列 + 小 Batch✓ Flash DecodingKV split 是唯一可用并行维度
极长序列(>32k)✓ Flash Decoding加速比最显著(可达 35×+)

与其他方法的关系

方法主要优化与 Flash Decoding 的关系
FlashAttention v1/v2IO-aware tiling,减少 HBM 访问Flash Decoding 的基础;FA2 加入了 Q-parallel
FlashAttention v3Hopper WGMMA + TMA,ping-pong 流水包含 Flash Decoding (Split-KV) 作为推理优化
Flash Decoding++异步 Softmax + Flat GEMM + 启发式数据流Flash Decoding 的改进版,多后端支持
vLLM (PagedAttention)KV Cache 分页管理,支持动态 batch正交技术,可组合使用
GQA (Grouped Query Attention)减少 KV head 数量,降低 KV Cache 大小Flash Decoding 原生支持 GQA;两者结合效果更好
Speculative DecodingDraft model 预测多个 token,验证后批量接受DEFT 等工作将 Flash Decoding 扩展到树状 attention
DEFT (Tree Attention)树形解码结构的 Flash Decoding直接扩展,处理分支 KV Cache 的负载均衡

技术演进图谱

Attention 优化技术演进
IOnline Softmax(2018):Milakov & Gimelshein 提出流式 softmax,奠定可分解 softmax 的数学基础
IIFlashAttention v1(2022):IO-aware tiling,消除 N² 中间矩阵,2-4× 加速训练
IIIFlashAttention v2(2023):改进并行策略(Q-parallel),减少非 matmul FLOP,2-3× vs v1
IVFlash Decoding(2023.10):新增 KV-split 并行维度,专门解决 Decoding 的 SM 利用率问题,长序列 8× 加速
VFlash Decoding++(2024):异步 Softmax 消除跨 split 同步,Flat GEMM 优化,支持 AMD GPU
VIFlashAttention v3(2024):Hopper 专属,WGMMA + TMA + 流水线,集成 Flash Decoding 支持

核心总结:Flash Decoding 的本质是发现了 LLM Decoding 中被忽视的并行维度——KV 序列方向。通过将 Online Softmax 的可分解性从"块内"扩展到"块间",实现了跨 SM 的并行 KV 加载,以接近硬件带宽极限的速度读取 KV Cache,从而将长序列 Attention 的延迟从 O(N) 降低到接近 O(1)(相对于序列长度)。

参考文献:Tri Dao, Daniel Haziza, Francisco Massa, Grigory Sizov. "Flash-Decoding for long-context inference." Together.ai / Meta, October 2023. 原始博客:crfm.stanford.edu/2023/10/12/flashdecoding.html / pytorch.org/blog/flash-decoding/