DeepSeek MLA — Prefix Cache 与 Chunked Prefill 深度解析

Cache 存压缩 latent,Prefill 要全量 KV —— vLLM / SGLang 如何桥接这一鸿沟?

目录

1. 核心矛盾:Cache 格式 vs Prefill 格式 2. MLA KV Cache 的存储结构 3. 两条注意力路径:Non-Absorbed (Prefill) vs Absorbed (Decode) 4. vLLM 的实现方案 5. SGLang 的实现方案 6. 核心数据流全景图:Prefix Cache → Chunked Prefill → Attention Output 7. 性能对比与工程取舍 8. 总结

1. 核心矛盾:Cache 格式 vs Prefill 格式

DeepSeek MLA 的独特之处在于:存储在 KV Cache 中的数据格式Prefill 阶段注意力计算需要的数据格式是不一样的。这是 MLA 区别于 MHA/GQA/MQA 的根本工程挑战。

KV Cache 存储格式 (每 token 仅存压缩表示) c_t (压缩 latent) [512 dims] k_rope (RoPE Key) [64 dims] 合计: 576 dims/token (所有 128 heads 共享) Prefill 注意力需要的格式 (Non-Absorbed / MHA 模式) K = [K_nope; K_rope] 128 heads × (128+64) = 128 × 192 dims V = c_t × W_UV 128 heads × 128 dims 合计: 128 × (192+128) = 40,960 dims/token
核心矛盾:Cache 存 576 维 (共享),Prefill 需要 40,960 维 (per-head)。当使用 Prefix Cache 命中已有缓存时,这些 576 维压缩数据需要在线解压缩 (on-demand decompression) 为全量 KV 才能参与 Prefill 的 FlashAttention 计算。

2. MLA KV Cache 的存储结构

2.1 每 token 的缓存内容

MLA 的 KV Cache 不像传统 MHA 那样存储每个 head 的 K 和 V,而是只存储联合压缩的 latent 向量

MLA KV Cache 每 token 存储布局 c_t = h_t × W_DKV 压缩 latent 向量 [d_c = 512] offset 0 k_rope RoPE Key [d_rope = 64] offset 512 512 维 64 维 总计 576 维/token — 所有 128 个 attention head 共享这一份缓存

2.2 FP8 量化存储 (vLLM)

vLLM 中对 MLA KV Cache 进一步应用了 FP8 量化:

字段内容字节数精度
c_t512 个 float8_e4m3 值512 BFP8
scale factors4 个 float32 (per-group scale)16 BFP32
k_rope64 个 bfloat16 值 (不量化)128 BBF16
总计656 B/token混合
对比 MHA:标准 128-head MHA,每 token 存储 128 × (128+128) × 2 = 65,536 B (BF16)。MLA 仅需 656 B,约 100 倍压缩,这正是 MLA 的巨大优势。

3. 两条注意力路径:Non-Absorbed (Prefill) vs Absorbed (Decode)

两条路径对比 Non-Absorbed 路径 (Prefill/Extend) compute-bound, 适合长序列 1. K_nope = c_t × W_UK [512 → 128 heads × 128] 2. K = [K_nope ; K_rope] [128 heads × 192] 3. V = c_t × W_UV [512 → 128 heads × 128] 4. Q = q_c × W_UQ + RoPE(q_rope) [128 × 192] 5. FlashAttention(Q, K, V) head_dim_k = 192, head_dim_v = 128 n_heads = 128 (真正的 MHA) 需要解压缩 c_t → 全量 K/V Absorbed 路径 (Decode) memory-bound, 适合单 token 生成 1. Q' = q × (W_UQ^T × W_UK) [→ 512] 2. Score_nope = Q' × c_t^T (直接用 latent) 3. Score_rope = Q_rope × K_rope^T 4. Score = Score_nope + Score_rope 5. FlashMLA / PagedAttention(Score, c_t) head_dim_k = 576, head_dim_v = 512 等效 MQA (1 shared KV head) 无需解压缩,直接用 c_t + k_rope
关键洞察:Decode 走 Absorbed 路径时,直接以 c_t 参与计算(矩阵吸收让 W_UK 被吸进 Q 侧),无需解压缩。但 Prefill 走 Non-Absorbed 路径时,必须将 c_t 通过 W_UK/W_UV 解压回全量 K/V。Prefix Cache 命中的那些 token,其缓存里只有 c_t + k_rope,因此 prefill 阶段必须做在线解压缩

4. vLLM 的实现方案 vLLM

4.1 核心代码架构

vLLM 的 MLA 注意力后端位于 vllm/v1/attention/backends/mla/common.py,核心类为 MLACommonBackend

4.2 _forward_prefill 的两阶段设计

当一个请求同时包含 prefix-cached tokens (context) 和 new tokens (extend) 时,vLLM 把 prefill 拆分为两个阶段:

vLLM _forward_prefill 两阶段流程 输入 Prompt: [prefix-cached tokens | new tokens] Stage 1: Context (Prefix Cache 命中部分) _compute_prefill_context() 1. 从 Paged KV Cache 分块读取 (c_t, k_rope) 2. 在线解压缩: K = c_t × kv_b_proj[:, :nope] V = c_t × kv_b_proj[:, nope:] 3. 拼接 K = [K_nope ; k_rope] → per-head K 4. 分 context_chunk 跑 FlashAttention 5. 得到 context_output + context_lse (attention output 和 log-sum-exp) 分块原因: context 可能很长 (数万 token) 一次解压太多会 OOM, 所以分 chunk 处理 Stage 2: Extend (新 token 部分) 标准 prefill 前向计算 1. 新 token 通过模型得到 q, kv_c, k_rope 2. 解压: K_new = kv_c × W_UK, V_new = kv_c × W_UV (新 token 本身也要解压, 但数量少) 3. 写入 Paged KV Cache (存 c_t + k_rope) 4. FlashAttention(Q_new, K_new, V_new) 5. 得到 extend_output + extend_lse (attention output 和 log-sum-exp)

4.3 merge_attn_states: 合并两阶段结果

两个阶段各自产出 (output, lse),最后通过 merge_attn_states 进行数值稳定的加权合并

# 伪代码 — merge_attn_states 的核心逻辑
def merge_attn_states(output_a, lse_a, output_b, lse_b):
    # lse = log-sum-exp of attention scores
    # 利用 log-sum-exp trick 做数值稳定的加权平均
    max_lse = torch.max(lse_a, lse_b)
    weight_a = torch.exp(lse_a - max_lse)
    weight_b = torch.exp(lse_b - max_lse)

    merged_output = (weight_a * output_a + weight_b * output_b) / (weight_a + weight_b)
    return merged_output
数学等价性:这个合并操作保证了最终结果与直接对全序列做 attention 完全等价。利用了 FlashAttention 计算过程中的 log-sum-exp 统计量,避免了需要重新计算全局 softmax。

4.4 _compute_prefill_context 的分块处理

对于 prefix cached 的 context 部分,vLLM 使用 context chunk 机制来控制内存:

# 伪代码 — context chunk 处理流程
for chunk_idx in range(num_context_chunks):
    # 1. 从 Paged KV Cache 读取一个 chunk 的 c_t 和 k_rope
    cache_kv_c, cache_k_pe = _gather_cache(
        kv_cache, block_table, chunk_start, chunk_len)

    # 2. 解压缩: 通过 kv_b_proj 将 c_t 投影回全量 K/V
    #    kv_b_proj = [W_UK; W_UV] 拼接, shape: [d_c, n_heads*(d_nope+d_v)]
    kv_full = cache_kv_c @ kv_b_proj   # [chunk_len, 512] × [512, n_heads*(128+128)]
    k_nope = kv_full[:, :n_heads*d_nope].reshape(chunk_len, n_heads, d_nope)
    v = kv_full[:, n_heads*d_nope:].reshape(chunk_len, n_heads, d_v)

    # 3. 拼接 K = [K_nope; K_rope] 
    k = torch.cat([k_nope, cache_k_pe.expand(..., n_heads, ...)], dim=-1)

    # 4. 跑 FlashAttention, 得到 chunk_output 和 chunk_lse
    chunk_out, chunk_lse = flash_attention(q_for_context, k, v)

    # 5. 将 chunk 结果累积合并
    ctx_output, ctx_lse = merge_attn_states(
        ctx_output, ctx_lse, chunk_out, chunk_lse)
为什么要分块?假设 prefix cache 命中了 32K token,一次性解压缩的显存开销为 32K × 128 heads × 320 dims × 2B ≈ 2.5 GB。分成 4K 的 chunk 则每次只需 ~320 MB,显存压力可控。

4.5 vLLM Paged KV Cache 对 MLA 的适配

特性标准 MHAMLA 模式
每 block 存储K_head [n_heads, d_k] + V_head [n_heads, d_v]c_t [d_c=512] + k_rope [d_rope=64]
block size16 token/block16 token/block
每 block 字节16 × 128 × 256 × 2 = 1 MB (BF16)16 × 656 = ~10 KB (FP8 mixed)
Hash 键token IDs + layer + positiontoken IDs + layer + position (相同)
Prefix 匹配直接使用匹配的 KV匹配后需在线解压缩才能给 prefill 用

5. SGLang 的实现方案 SGLang

5.1 RadixAttention 与 MLA KV Cache

SGLang 使用 RadixTree 数据结构来管理 prefix cache(称为 RadixAttention),比 vLLM 的 block hash 方案更高效地支持前缀共享:

SGLang RadixTree Prefix Cache (MLA) Root System Prompt c_t[512]+k_rope[64] × N tokens User A query c_t+k_rope User B query c_t+k_rope Few-shot Examples c_t[512]+k_rope[64] × M tokens User C query c_t+k_rope 每个节点存储: kv_buffer[layer] = Tensor[n_tokens, kv_lora_rank + qk_rope_head_dim]

5.2 MLA KV Buffer 格式

SGLang 中 MLA 的 KV 缓存以单一 buffer 存储:

# SGLang MLA KV Buffer 结构
kv_buffer: List[Tensor]   # [num_layers]
# 每层 shape: [max_total_tokens, kv_lora_rank + qk_rope_head_dim]
#           = [max_total_tokens, 512 + 64]
#           = [max_total_tokens, 576]
# 单一 buffer, interleaved: [c_t_0..c_t_511 | k_rope_0..k_rope_63]

5.3 Chunked Prefix Cache 优化 (FlashAttention3)

SGLang 的核心创新是 Chunked Prefix Cache 优化,目前仅在 FlashAttention3 后端可用:

SGLang Chunked Prefix Cache 数据流 Prefix (Cache 命中) — 12K tokens Extend (新 token) — 2K tokens 切分为 chunks (e.g., 4K/chunk) Chunk 1: [0, 4K) 读取 c_t + k_rope 解压 → K, V → MHA Attn Chunk 2: [4K, 8K) 读取 c_t + k_rope 解压 → K, V → MHA Attn Chunk 3: [8K, 12K) 读取 c_t + k_rope 解压 → K, V → MHA Attn (out_1, lse_1) (out_2, lse_2) (out_3, lse_3) merge_attn_states() 最终 Attention Output

5.4 SGLang 支持的 MLA Attention 后端

后端PrefillDecodeChunked Prefix Cache适用 GPU
FlashAttention3Non-AbsorbedAbsorbed支持Hopper+
FlashInferNon-AbsorbedAbsorbed不支持Ampere+
FlashMLA-Absorbed-Hopper
CutlassMLA-Absorbed-Hopper+
TRTLLM MLANon-AbsorbedAbsorbed不支持Blackwell
TritonNon-AbsorbedAbsorbed不支持通用
重要限制:Chunked Prefix Cache 优化目前仅在 FlashAttention3 后端可用。使用其他后端时,prefix cache 命中的 token 仍会走完整的重新计算路径而非缓存复用路径,或者采用更基础的处理方式。FA3 的 merge_states API 是实现此优化的关键。

6. 核心数据流全景图:Prefix Cache → Chunked Prefill → Attention Output

端到端数据流 (以 vLLM 为例) 请求到达: "System prompt + User query" (10K tokens) Step 1: Prefix Hash/Radix 匹配 命中 8K prefix token → 需要 extend 2K 新 token Step 2: Scheduler 分配 — 只对 2K 新 token 做模型前向 新 token 通过 Embedding → RMSNorm → ... → Attention Layer 输入 Step 3: Attention Layer 内部处理 3a. Q 计算 (新 2K token) q = h × W_Q → q_nope, q_rope → apply RoPE 3b. 新 token KV (2K token) kv_c = h × W_DKV, k_rope = ... → 存入 Cache 3c. Context 解压缩 (prefix-cached 8K tokens) — 分 chunk 处理 Chunk [0, 4K): c_t×W_UK→K, c_t×W_UV→V, FA Chunk [4K, 8K): c_t×W_UK→K, c_t×W_UV→V, FA 每 chunk 解压显存: 4K×128h×320d×2B ≈ 320 MB 3d. Extend FlashAttention (2K new tokens) Q_new × K_new, V_new → (extend_out, extend_lse) 3e. merge_attn_states() context_out + extend_out → final_out 3f. Attention Output → O × W_O Step 4: 新 token 写入 Cache 只存 c_t[512] + k_rope[64] = 576 dims Step 5: 后续 Decode 走 Absorbed 路径 (无需解压)
关键公式:解压缩的 GEMM 计算量 = prefix_len × d_c × n_heads × (d_nope + d_v) = 8K × 512 × 128 × 256 ≈ 135 GFLOP。虽然这是额外开销,但相比完全重新计算 8K token 的全部 Transformer 层(约 8K × 隐藏层宽度 × 参数量 的量级),prefix cache + 在线解压的方案仍然是巨大的节省。

7. 性能对比与工程取舍

7.1 Prefix Cache 收益分析

场景无 Prefix Cache有 Prefix Cache加速比
8K system prompt + 2K query全部 10K token 做 prefill仅 2K prefill + 8K 解压~3-5x
32K 长文档 + 512 query全部 32.5K token 做 prefill仅 512 prefill + 32K 解压~5-10x
重复 system prompt (多用户)每个用户都重算所有用户共享 cache线性增长

7.2 解压缩的额外开销

与标准 MHA 的 prefix cache "零开销命中" 不同,MLA 的 prefix cache 命中后仍有解压缩开销:

Prefix Cache 命中后的开销对比 MHA ~0 GQA ~0 MLA 解压 GEMM 无 Cache 全量计算 MHA/GQA: cache 命中 → 直接使用 K/V → 0 额外计算 MLA: cache 命中 → 解压 c_t → GEMM → K/V 额外开销 ≈ prefix_len × d_c × n_heads × d_head 无 Cache: 全部 token 过完整 Transformer

7.3 vLLM vs SGLang 实现对比

特性vLLMSGLang
Prefix Cache 结构Paged Block + Hash TableRadixTree (RadixAttention)
KV Cache 格式FP8 c_t [512] + BF16 k_rope [64] = 656B混合精度 buffer [576 dims]
Prefill 解压方式context chunk loop + kv_b_proj GEMMchunk 切分 + MHA attn + merge
Attention 合并merge_attn_states (lse-based)merge_attn_states (lse-based)
最佳后端FlashInfer / FA4 for MLAFlashAttention3 (chunked prefix cache)
Decode 路径Absorbed (FlashMLA / PagedAttn)Absorbed (FlashMLA / CutlassMLA / FlashInfer)
吞吐提升~3-5x (prefix cache 场景)~7x (综合优化后)

8. 总结

一句话总结:MLA 的 Prefix Cache 存的是压缩 latent (c_t + k_rope = 576 dims),但 Prefill 需要全量 KV (40,960 dims)。解决方案是在线解压缩 + 分块处理 + attention state 合并

关键技术要点

#要点说明
1Cache 只存压缩表示c_t [512] + k_rope [64] = 576 dims,所有 128 heads 共享
2Prefill 必须解压通过 kv_b_proj (= [W_UK; W_UV]) 将 c_t 投影回 per-head K/V
3分块处理控制显存长 prefix 分成 chunk(如 4K token/chunk),每次只解压一个 chunk
4LSE-based 合并各 chunk 结果通过 log-sum-exp 加权合并,保证数学等价性
5Decode 无此问题Absorbed 路径直接用 c_t,矩阵吸收消除了解压需求
6总体仍然值得解压 GEMM 的开销远小于完整 Transformer 重算的开销
工程启示:MLA 的 prefix cache 不是"免费"的——它以额外的解压 GEMM 为代价换取了 ~100x 的缓存空间节省和高效的前缀共享。这个 trade-off 在实际生产中(多用户共享 system prompt、长文档问答等场景)收益非常显著。
MLA Prefix Cache & Chunked Prefill 深度解析 | 基于 vLLM / SGLang 源码分析 | 2025