MLA 矩阵吸收 — GEMM 矩阵形状图解

用矩阵乘法的「色块拼图」直观理解吸收与分离原理

图例说明

= Query 矩阵 (蓝) = 权重矩阵 W (绿) = KV Cache c_t (红) = 结果 (紫)

矩阵图中,宽度代表列数(dim)高度代表行数。两个矩阵相乘时,左矩阵的列数必须等于右矩阵的行数(即相邻边匹配)。

第一步:回顾标准 Attention 的 GEMM

标准MHA score = Q * K^T,其中 Q 和 K 的 head_dim 都是 192(128 nope + 64 rope):

标准 MHA: score = Q * K^T Q ← 192 列 → 1 行 x K^T 192 行 ← L 列 → = score [1, L] 问题在 K ! K 需要从 c_t 解压得到: K = c_t * W_UK [L, 512] x [512, 192] = [L, 192] 每步 decode 都要做!

第二步:Non-Absorbed 完整 GEMM 链

Non-Absorbed 先解压 K = c_t * W_UK,再 score = Q * K^T。两步 GEMM:

Non-Absorbed: 两步 GEMM(每次 decode 都要做!) GEMM-1: K = c_t * W_UK c_t 512 列 L 行 x W_UK [512, 192] 192 列 512行 = K [L, 192] 192 列 巨大! GEMM-2: score = Q * K^T Q [1, 192] x K^T L 列 192行 = score [1, L] Non-Absorbed 的开销 GEMM-1: c_t[L,512] x W_UK[512,192] -> K[L,192] — 每步都做! 读 Lx512 写 Lx192 GEMM-2: Q[1,192] x K^T[192,L] -> score[1,L] — 需要完整 K 总访存: L x 512 (读c_t) + L x 192 (写K) + L x 192 (读K) = L x 896

第三步:矩阵吸收的关键洞察 — 结合律换括号

核心 把上面两步 GEMM 写成一个表达式,然后利用矩阵乘法结合律重新分组:

原始计算顺序 vs 吸收后计算顺序 原始: score = Q x (c_t x W_UK)^T = Q x W_UK^T x c_t^T 先算右边两个 (解压 KV Cache) Q [1, 192] x 先算这一块 (每步 decode 都做!) W_UK^T [192, 512] x c_t^T [512, L] -> 结合律! 换个括号位置 吸收: score = (Q x W_UK^T) x c_t^T = Q' x c_t^T 先算左边两个 (吸收权重到 Q) 先算这一块 (仅 1 token!) Q [1, 192] x W_UK^T [192, 512] = Q' [1, 512] 然后: Q'[1,512] x c_t^T[512,L] = score[1,L]
关键区别一目了然:左图需要先做一个大 GEMM (c_t x W_UK = [L,512] x [512,192]) 生成完整 K,右图把 W_UK 挪到 Q 侧 (Q x W_UK^T = [1,192] x [192,512]),只是 1 行的小 GEMM,然后直接用压缩后的 c_t 做 attention。核心差异在于与历史长度 L 相关的那步 GEMM 被消除了

对比:Non-Absorbed vs Absorbed 实际 GEMM 形状

Non-Absorbed GEMM (每步 decode) GEMM-1: 解压 K c_t 512 L x W_UK 192 512 = K 192 GEMM-2: Q x K^T Q [1,192] x K^T [192, L] 访存: 读 Lx512 + 写读 Lx192 = L x 896 (巨大!) Absorbed GEMM (每步 decode) GEMM-a: Q x W_UK^T (仅1行!) Q [1,192] x W_UK^T [192, 512] = Q' [1, 512] GEMM-b: Q' x c_t^T Q' [1,512] x c_t^T [512, L] 访存: 读 Lx512 (仅c_t!) 无需物化 K! 省 ~43% 核心差异: 与历史长度 L 相关的大 GEMM 被消除了!

第四步:V 侧的矩阵吸收(同理)

V侧 原始 O = Attn x V = Attn x (c_t x W_UV),吸收后 O = (Attn x c_t) x W_UV

Non-Absorbed V 侧 O = attn_weights x (c_t x W_UV) attn [1, L] x c_t [L, 512] x W_UV [512,128] = V [L, 128] -> [1, 128] 先做 [L,512]x[512,128] 大GEMM,再 attn x V Absorbed V 侧 O = (attn_weights x c_t) x W_UV attn [1, L] x c_t [L, 512] = O' [1, 512] x W [512,128] = O [1, 128] attn x c_t 在 512 维潜空间完成,W_UV 只做一次 [1,512]x[512,128]

第五步:矩阵吸收分离 — 为什么 Key 要「拆」成两部分

核心问题 上面我们假设整个 K 都可以吸收,但实际上 K 由两部分组成,其中一部分不可吸收

5.1 Key 的组成结构

在 DeepSeek-V2/V3 的 MLA 中,每个 head 的 Key 向量 [192维] = K_nope [128维] + K_rope [64维],它们的来源完全不同:

Key [192维] = K_nope [128维] + K_rope [64维] K_nope [128 维] — 来自线性投影 公式: K_nope = c_t x W_UK 其中: c_t 是 KV Cache [L, 512] W_UK 是固定权重矩阵 [512, 128] 纯线性投影,与 token 位置无关 -> 可吸收! K_rope [64 维] — 经过 RoPE 旋转 公式: K_rope = RoPE(k_pe, pos) 其中: k_pe 是位置编码投影 [L, 64] RoPE(pos) = 与每个 token 位置相关的旋转 位置相关变换,不是固定权重 -> 不可吸收!

5.2 为什么 RoPE 部分不能吸收?— GEMM 图解

吸收的本质是利用结合律把 W 挪到 Q 侧。但 RoPE 不是固定矩阵,每个位置都不同:

K_nope: 可吸收 (W_UK 是固定权重) score_nope = Q_nope x (c_t x W_UK)^T = Q_nope x W_UK^T x c_t^T = (Q_nope x W_UK^T) x c_t^T = Q'_nope x c_t^T 吸收后的 GEMM: Q'_nope [1, 512] x c_t^T [512, L] = score_nope [1, L] W_UK 已吸收进 Q -> 只读 c_t [L, 512] K_rope: 不可吸收 (RoPE 随位置变化) score_rope = Q_rope x K_rope^T = RoPE(q_pe) x RoPE(k_pe)^T RoPE(pos) 每个位置不同,不是固定矩阵! 必须标准计算: Q_rope [1, 64] x k_rope^T [64, L] = score_rope [1, L] 必须单独缓存 k_rope [L, 64],直接标准计算
直觉理解:W_UK 就像一个固定的「翻译模板」,所有 token 用同一个模板,所以可以提前合并到 Q 侧。但 RoPE 是「随位置旋转的角度」,每个 token 位置角度不同,无法合并成一个固定模板——就像翻译每个人的话都要用不同的翻译标准,无法统一处理。

5.3 分离后的完整计算流程 — 两路合并

矩阵吸收分离: 两路 Score 计算后合并 路径-1 (吸收路径): score_nope Step-a: Q_nope x W_UK^T = Q'_nope [1,128] x W_UK^T [128,512] = Q'_nope [1,512] Step-b: Q'_nope x c_t^T = score_nope [1, 512] x c_t^T [512, L] = score_nope [1, L] 路径-2 (标准路径): score_rope 直接标准 attention 计算 Q_rope [1, 64] x k_rope^T [64, L] = score_rope [1, L] 必须缓存 k_rope [L, 64],无法避免 score = score_nope + score_rope [1, L] + [1, L] = [1, L] KV Cache per token: c_t (512) + k_rope (64) = 576 维
「矩阵吸收分离」总结:把 Key 的 192 维拆成 nope [128维] + rope [64维]。nope 部分是纯线性投影 (K_nope = c_t x W_UK),可以通过结合律将 W_UK 吸收到 Q 侧;rope 部分经过 RoPE 位置旋转 (K_rope = RoPE(k_pe, pos)),每个位置不同无法吸收,必须单独缓存 k_rope 做标准 attention。最后两路 score 相加。这就是「分离」的含义——把能吸收的分离出来吸收,不能吸收的单独保留计算

第六步:一图总结 — Absorbed Decode 完整 GEMM 路径

Absorbed Decode 完整 GEMM 路径(3 个核心矩阵乘) GEMM-a: 吸收 W_UK Q_nope [1, 128] x W_UK^T [128, 512] = Q'_nope [1, 512] 仅 1 token, 瞬间完成 GEMM-b: Score 计算 (两路) Q'_nope x c_t^T [512, L] Q_rope x k_rope^T [64, L] } + = score [1, L] 读 c_t[L,512] + k_rope[L,64] GEMM-c: 输出 (V 侧吸收) attn [1,L] x c_t [L, 512] = O'[1,512] x W_UV = O [1, 128] W_UV 吸收到输出侧 Decode 一步的总访存对比 Non-Absorbed 读c_t + 写读K + 写读V = L x 1152 Absorbed 读c_t + k_rope (可复用) ~ L x 576 * Absorbed 可复用 c_t 读取,实际约 L x 576;Non-Absorbed 必须物化 K/V 中间张量,访存 L x 1152节省约 50%!
总结 Absorbed Decode 的三步 GEMM:
GEMM-a [1,128] x [128,512] -> Q'_nope [1,512](吸收 W_UK,极小计算)
GEMM-b [1,512] x [512,L] + [1,64] x [64,L] -> score [1,L](在潜空间计算,读 c_t 和 k_rope)
GEMM-c [1,L] x [L,512] -> O'[1,512],再 [1,512] x [512,128] -> O[1,128](V 侧吸收,W_UV 最后做)

核心收益:所有与历史长度 L 相关的 GEMM 都在 512 维潜空间完成(而非展开的 192 维),KV Cache 只需存 576 维/token(vs 传统 MHA 的 32768 维)。