MLA (Multi-head Latent Attention) 简化数据流图
Non-Absorbed 路径 · 以 DeepSeek-V3 参数为例 (h=128, d_c=512, d_rope=64, d_head=128)
Tensor (shape)
Linear 投影
操作节点
KV Cache 存储
Attention 计算
hidden
[bs, seq, 7168]
Q 下投影 (Down Proj)
W: [7168, 1536]
compress_q
[bs, seq, 1536]
RMSNorm → Q 上投影 (Up Proj)
W: [1536, 128×(128+64)] = [1536, 24576]
View & Split
q_nope
[bs, 128, seq, 128]
q_pe
[bs, 128, seq, 64]
+ RoPE
Concat (dim=-1)
Q
[bs, 128, seq, 192]
KV 下投影 (Down Proj)
W: [7168, 512+64] = [7168, 576]
Split
compress_kv (KV Cache)
[bs, seq, 512] ← 仅缓存此!
k_rope (KV Cache)
[bs, 1, seq, 64]
+ RoPE
RMSNorm → KV 上投影 (Up Proj)
W: [512, 128×(128+128)] = [512, 32768]
View & Split
k_nope
[bs, 128, seq, 128]
v
[bs, 128, seq, 128]
broadcast 1→128
Concat (k_nope, k_pe)
K
[bs, 128, seq, 192]
V
[bs, 128, seq, 128]
Q
K
V
Attention (MHA)
QK^T/√d → Softmax → ×V
O (Attention Output)
[bs, 128, seq, 128]
O 投影 (Output Proj)
W: [128×128, 7168] = [16384, 7168]
hidden (output)
[bs, seq, 7168]
关键参数
h_dim = 7168
n_heads = 128
d_head (qk) = 128
d_head (v) = 128
d_c (kv_lora) = 512
d_rope = 64
q_lora_rank = 1536
KV Cache = 512+64
= 576 per token
KV Cache
仅存储:
c_kv [512]
k_rope [64]