MLA (Multi-head Latent Attention) 简化数据流图

Non-Absorbed 路径 · 以 DeepSeek-V3 参数为例 (h=128, d_c=512, d_rope=64, d_head=128)
Tensor (shape)
Linear 投影
操作节点
KV Cache 存储
Attention 计算
hidden [bs, seq, 7168] Q 下投影 (Down Proj) W: [7168, 1536] compress_q [bs, seq, 1536] RMSNorm → Q 上投影 (Up Proj) W: [1536, 128×(128+64)] = [1536, 24576] View & Split q_nope [bs, 128, seq, 128] q_pe [bs, 128, seq, 64] + RoPE Concat (dim=-1) Q [bs, 128, seq, 192] KV 下投影 (Down Proj) W: [7168, 512+64] = [7168, 576] Split compress_kv (KV Cache) [bs, seq, 512] ← 仅缓存此! k_rope (KV Cache) [bs, 1, seq, 64] + RoPE RMSNorm → KV 上投影 (Up Proj) W: [512, 128×(128+128)] = [512, 32768] View & Split k_nope [bs, 128, seq, 128] v [bs, 128, seq, 128] broadcast 1→128 Concat (k_nope, k_pe) K [bs, 128, seq, 192] V [bs, 128, seq, 128] Q K V Attention (MHA) QK^T/√d → Softmax → ×V O (Attention Output) [bs, 128, seq, 128] O 投影 (Output Proj) W: [128×128, 7168] = [16384, 7168] hidden (output) [bs, seq, 7168] 关键参数 h_dim = 7168 n_heads = 128 d_head (qk) = 128 d_head (v) = 128 d_c (kv_lora) = 512 d_rope = 64 q_lora_rank = 1536 KV Cache = 512+64 = 576 per token KV Cache 仅存储: c_kv [512] k_rope [64]