MLA 矩阵吸收 & 吸收分离 原理详解

Weight Absorption · Decode 阶段如何避免解压 KV Cache · 以 DeepSeek-V3 为例

0. 一句话直觉

💡
矩阵吸收的本质:利用矩阵乘法的结合律 (A × B) × C = A × (B × C), 把原本需要「先解压 KV Cache,再算注意力」的两步操作,重新组合成「先把解压矩阵吸收进 Q,再直接用压缩后的 Cache 算注意力」。 这样 Decode 时就不需要从显存读取庞大的 K/V,只读取极小的潜向量 c_t,显存带宽开销降到 1/50+。

1. 前置知识回顾:MLA 的 KV 压缩

在 MLA 中,输入 hidden state 经过下投影压缩成低维潜向量 c_t,推理时缓存 c_t(而非完整 K/V)。需要计算注意力时,再通过上投影还原:

下投影 c_t = h_t × WDKV    // h_t:[d_model] → c_t:[d_c],d_c=512 ≪ d_model=7168
上投影 K = c_t × WUK    // c_t:[d_c] → K:[n_h × d_h],512 → 128×128
上投影 V = c_t × WUV    // c_t:[d_c] → V:[n_h × d_h],512 → 128×128

问题来了:在 Decode 阶段(每步只生成 1 个 token),每一步都要对所有历史 token 做上投影还原 K、V。这需要从显存读取每个历史 token 的 c_t,然后做大矩阵乘解压——非常浪费带宽和算力。

Non-Absorbed 的 Decode 问题:假设历史长度 L=4096,则每步需要将 4096 个 c_t 都通过 W_UK/W_UV 解压成完整 K/V,产生 4096 × 128 × 128 × 2 = 128M 参数的中间张量,再做 attention。Decode 本该是一步极轻的操作,却被解压拖慢。

2. 矩阵吸收:逐步推导

2.1 Attention Score 的吸收(K 侧)

标准注意力分数计算:score = Q × KT。将 K 的定义代入:

原始 score = Q × KT

展开 K score = Q × (c_t × WUK)T

转置展开 score = Q × WUKT × c_tT

结合律 score = (Q × WUKT) × c_tT

定义 Q' = Q × WUKT    ← WUK 被「吸收」进了 Q

最终 score = Q' × c_tT    ← 直接用压缩后的 c_t 计算!
BEFORE:Non-Absorbed Q [1, 128, 192] × c_t [L, 512] W_UK K (巨大!) [L, 128, 128] → KT = score 每步都要 解压 c_t→K AFTER:Absorbed (矩阵吸收) Q [1, 128, 192] × W_UKT [192, 512] = Q' [1, 128, 512] Q'  ×  c_tT c_t (小!) [L, 512] = score W_UK 吸收进 Q 侧 无需解压! 直接读 c_t V 侧同理 (Output 吸收) 原始: O = Attn × V = Attn × (c_t × WUV) 结合律: O = (Attn × c_t) × WUV 先在潜空间算 O' = Attn × c_t,再一次性 W_UV 投影 O' shape: [1, 128, 512] → × W_UV → [1, 128, 128] V 侧吸收效果 ✅ Attn 计算在 d_c=512 维度完成 ✅ 不需要先解压出完整 V [128×128] ✅ W_UV 只做一次,与历史长度 L 无关 相当于 MQA:所有 head 共享同一个 d_c 维 KV cache

2.2 完整对比

❌ Non-Absorbed(Prefill 阶段用)

  • K = c_t × W_UK → 先解压
  • V = c_t × W_UV → 先解压
  • score = Q × KT → 标准 MHA
  • O = softmax(score) × V
  • 需要物化完整 K/V:[L, 128, 128]
  • 优点:大 batch 时 GEMM 效率高

✅ Absorbed(Decode 阶段用)

  • Q' = Q × W_UKT → 吸收到 Q 侧
  • score = Q' × c_tT → 直接用潜向量
  • O' = softmax(score) × c_t → 潜空间
  • O = O' × W_UV → 最后一次投影
  • Cache 只读 [L, 512],缩小 ~32 倍
  • 优点:极省带宽,等效 MQA

3. 形象类比:「翻译官」的故事

场景 A:Non-Absorbed(每次都翻译) 书架 (c_t 压缩存储) 中文速记笔记 512 维 翻译官 (W_UK) 中文→英文 完整英文文档 (K: 巨大!) 128 heads×128 dim 每次 decode 都重翻! 读者 (Q) 只能读英文 痛点: 每次提问都要把整个 书架的中文笔记重新 翻译成完整英文文档! 场景 B:Absorbed(教读者学中文) 读者先学翻译 Q' = Q × W_UKT 一次性学习! 书架 (c_t 原样) 中文速记 512 维 score 不需要翻译官! W_UK 已被吸收 🚀 好处: 不管书架有多少笔记(L),读者都直接读中文速记,无需翻译

4. 矩阵吸收「分离」:RoPE 带来的特殊处理

矩阵吸收看似完美,但有一个无法吸收的部分——RoPE(旋转位置编码)

⚠️
为什么 RoPE 不能被吸收? RoPE 是位置相关的变换,每个 token 位置的旋转矩阵 R(pos) 不同。而矩阵吸收要求 W 是固定权重(所有位置共享)。RoPE 作用在 K 上之后,Q×KT 中的位置信息就无法简单用结合律挪到 Q 侧了。

4.1「分离」的核心思路

MLA 的解决方案是:把 Key 分成两部分——可以吸收的和不能吸收的——分别处理后合并。

Key 分拆
  K = [Knope , Krope]    // 拼接成完整 Key

Knope = c_t × WUK    // ✅ 可吸收! 位置无关的静态投影
Krope = RoPE(k_pe)    // ❌ 不可吸收! 位置相关的动态变换

对应 Q
  Q = [Qnope , Qrope]    // 也分成两部分

Score 计算(分拆后)
  score = Qnope × KnopeT  +  Qrope × KropeT

吸收第一项
  score = (Qnope × WUKT) × c_tT  +  Qrope × KropeT

  = Q'nope × c_tT  +  Qrope × RoPE(k_pe)T

4.2 图解:分离吸收的数据流

Absorbed Decode 完整数据流(含 RoPE 分离) Q 侧 Q_nope [1, 128, 128] Q_rope [1, 128, 64] + RoPE × W_UKT (吸收!) [128, 512] Q'_nope [1, 128, 512] Q_rope [1, 128, 64] 不变 KV Cache (仅存这两个!) c_t (compress_kv) [L, 512] ← KV Cache 主体 k_rope_cache [L, 1, 64] ← 位置编码缓存 score_nope = Q'_nope × c_tT [1,128,512] × [512,L] → [1,128,L] score_rope = Q_rope × k_peT [1,128,64] × [64,L] → [1,128,L] + score (完整) [1, 128, L] softmax → × c_t → × W_UV → Output 「分离」的含义 nope 部分 → 可吸收 rope 部分 → 不可吸收 Cache 总计 per token 512 + 64 = 576 维
🔎
「分离」总结:
① 将 K 拆成 K_nope(无位置编码)和 K_rope(有 RoPE)两部分
K_nope 来自 c_t × W_UK,可以用结合律把 W_UK 吸收进 Q 侧 → 不需要解压
K_rope 经过 RoPE 位置编码,无法吸收 → 必须单独缓存(仅 64 维,很小)
④ 最终 score = 吸收部分得分 + RoPE部分得分,两者相加后做 softmax

5. 量化对比:吸收带来的收益

指标 Non-Absorbed (Prefill) Absorbed (Decode) 收益
每 token KV 读取量 128×128×2(K+V) = 32,768 维 512 + 64 = 576 维 ~57x 减少
KV 等效行为 标准 MHA(多头独立 KV) 等效 MQA(共享单潜头) 带宽友好
Arithmetic Intensity 低(Decode 单 token 利用率差) ~2x 提升(数据复用率高) 更接近计算瓶颈
额外计算开销 Q × W_UKT(一次性小矩阵乘) 可忽略
模型效果 无损 无损(数学等价变换) 完全等价
推荐场景 Prefill(大 batch GEMM 效率高) Decode(单 token 省带宽) 动态切换
🚀
核心要点:矩阵吸收是纯数学等价变换(结合律),不引入任何近似或精度损失。它只是重新安排了矩阵乘法的计算顺序——在 Decode 阶段,先把固定权重 W_UK 乘进 Q(计算量小,只做一次),从而避免对每个历史 token 都做耗带宽的 KV 解压。这就是 MLA 在 Decode 时能以 MQA 级别的 KV cache 成本达到 MHA 级别效果的秘密。

6. 实际工程:Prefill vs Decode 动态切换

在实际推理系统(SGLang、vLLM)中,MLA 会根据当前阶段动态选择路径:

新 token 数量 x = ? x ≥ 阈值 (如 ≥ 64) Non-Absorbed (MHA) 先解压 c_t → K, V 用 FlashAttention 做标准 MHA 大 batch GEMM 效率高 Tensor Core 利用率高 适合 Prefill / Chunked Prefill x < 阈值 (如 = 1) Absorbed (MQA) Q' = Q × W_UKT (吸收) 直接 Q' × c_tT KV cache 读取量降 57x 带宽压力极小 适合 Decode / 投机采样
📌
切换阈值:当新 token 数 x 较大时(Prefill),解压 KV 的开销可被大 GEMM 的高效率摊薄,Non-Absorbed 更快。当 x=1(Decode 逐 token 生成),解压 KV 变成纯粹的带宽浪费,Absorbed 路径完胜。SGLang/vLLM 中典型阈值为 x=64 左右,chunked prefill 配合 FlashMLA 的 page_size=64 对齐。
MLA 矩阵吸收原理详解 · DeepSeek-V2/V3 技术