1. 前置知识回顾:MLA 的 KV 压缩
在 MLA 中,输入 hidden state 经过下投影压缩成低维潜向量 c_t,推理时缓存 c_t(而非完整 K/V)。需要计算注意力时,再通过上投影还原:
下投影 c_t = h_t × WDKV // h_t:[d_model] → c_t:[d_c],d_c=512 ≪ d_model=7168
上投影 K = c_t × WUK // c_t:[d_c] → K:[n_h × d_h],512 → 128×128
上投影 V = c_t × WUV // c_t:[d_c] → V:[n_h × d_h],512 → 128×128
问题来了:在 Decode 阶段(每步只生成 1 个 token),每一步都要对所有历史 token 做上投影还原 K、V。这需要从显存读取每个历史 token 的 c_t,然后做大矩阵乘解压——非常浪费带宽和算力。
❌
Non-Absorbed 的 Decode 问题:假设历史长度 L=4096,则每步需要将 4096 个 c_t 都通过 W_UK/W_UV 解压成完整 K/V,产生 4096 × 128 × 128 × 2 = 128M 参数的中间张量,再做 attention。Decode 本该是一步极轻的操作,却被解压拖慢。
2. 矩阵吸收:逐步推导
2.1 Attention Score 的吸收(K 侧)
标准注意力分数计算:score = Q × KT。将 K 的定义代入:
原始 score = Q × KT
展开 K score = Q × (c_t × WUK)T
转置展开 score = Q × WUKT × c_tT
结合律 score = (Q × WUKT) × c_tT
定义 Q' = Q × WUKT ← WUK 被「吸收」进了 Q
最终 score = Q' × c_tT ← 直接用压缩后的 c_t 计算!
2.2 完整对比
❌ Non-Absorbed(Prefill 阶段用)
K = c_t × W_UK → 先解压
V = c_t × W_UV → 先解压
score = Q × KT → 标准 MHA
O = softmax(score) × V
- 需要物化完整 K/V:[L, 128, 128]
- 优点:大 batch 时 GEMM 效率高
✅ Absorbed(Decode 阶段用)
Q' = Q × W_UKT → 吸收到 Q 侧
score = Q' × c_tT → 直接用潜向量
O' = softmax(score) × c_t → 潜空间
O = O' × W_UV → 最后一次投影
- Cache 只读 [L, 512],缩小 ~32 倍
- 优点:极省带宽,等效 MQA
3. 形象类比:「翻译官」的故事
4. 矩阵吸收「分离」:RoPE 带来的特殊处理
矩阵吸收看似完美,但有一个无法吸收的部分——RoPE(旋转位置编码)。
⚠️
为什么 RoPE 不能被吸收? RoPE 是位置相关的变换,每个 token 位置的旋转矩阵 R(pos) 不同。而矩阵吸收要求 W 是固定权重(所有位置共享)。RoPE 作用在 K 上之后,Q×KT 中的位置信息就无法简单用结合律挪到 Q 侧了。
4.1「分离」的核心思路
MLA 的解决方案是:把 Key 分成两部分——可以吸收的和不能吸收的——分别处理后合并。
Key 分拆
K = [Knope , Krope] // 拼接成完整 Key
Knope = c_t × WUK // ✅ 可吸收! 位置无关的静态投影
Krope = RoPE(k_pe) // ❌ 不可吸收! 位置相关的动态变换
对应 Q
Q = [Qnope , Qrope] // 也分成两部分
Score 计算(分拆后)
score = Qnope × KnopeT + Qrope × KropeT
吸收第一项
score = (Qnope × WUKT) × c_tT + Qrope × KropeT
= Q'nope × c_tT + Qrope × RoPE(k_pe)T
4.2 图解:分离吸收的数据流
🔎
「分离」总结:
① 将 K 拆成 K_nope(无位置编码)和 K_rope(有 RoPE)两部分
② K_nope 来自 c_t × W_UK,可以用结合律把 W_UK 吸收进 Q 侧 → 不需要解压
③ K_rope 经过 RoPE 位置编码,无法吸收 → 必须单独缓存(仅 64 维,很小)
④ 最终 score = 吸收部分得分 + RoPE部分得分,两者相加后做 softmax
6. 实际工程:Prefill vs Decode 动态切换
在实际推理系统(SGLang、vLLM)中,MLA 会根据当前阶段动态选择路径:
📌
切换阈值:当新 token 数 x 较大时(Prefill),解压 KV 的开销可被大 GEMM 的高效率摊薄,Non-Absorbed 更快。当 x=1(Decode 逐 token 生成),解压 KV 变成纯粹的带宽浪费,Absorbed 路径完胜。SGLang/vLLM 中典型阈值为 x=64 左右,chunked prefill 配合 FlashMLA 的 page_size=64 对齐。