多头注意力
为什么把注意力切成多头、每头各学到什么功能、现代 LLM 的头数与维度如何权衡
核心要点:
- hidden_size = n_head × d_head,是切分不是复制
- 多头解掉单头"只能学一种关注模式" 的瓶颈
- 不同 head 学不同模式:positional / syntactic / rare word / induction
- Vaswani Table 3: n_head=8 最优 (BLEU 25.8),太多反而下降
- 现代 LLM d_head 工业收敛 128, n_head 随模型大小涨
- GQA / MQA / MLA: KV cache 压缩,不展开外链长上下文章
名词定义
本篇共享名词在 4.1 总览 已定义 (Head / Multi-head attention / Grouped-query attention)。本篇新引入:
| 名词 | 定义 |
|---|---|
| Head dimension ($d_{\text{head}}$) | 单个 head 的 Q/K/V 子空间维度,$d_{\text{head}} = h / n_{\text{head}}$ |
| Output projection ($W_O$) | 多头拼接后用一个 $W_O \in \mathbb{R}^{h \times h}$ 投影回 hidden 维度 |
| Specialized head | Voita 2019 在 head pruning 中发现的功能性专门 head: positional / syntactic / rare word 三类 |
| Previous Token Head | Anthropic 发现的 layer 1 head 类型,写入 "前一个 token 是什么" 信息到 residual stream |
| MQA (Multi-Query Attention) | Shazeer 2019:所有 Q head 共享一组 K/V head, KV cache 缩 $n_{\text{head}}$ 倍,但表达力损失明显 |
| GQA (Grouped-Query Attention) | Ainslie 2023: Q head 分 $g$ 组,每组共享一组 K/V head, $1 \leq g \leq n_{\text{head}}$;是 MQA 与 MHA 之间的平衡 |
@tbl-mha-glossary 本篇新引入名词
单头为什么不够?
核心问题:03 篇引入 Q/K/V 投影后,attention 已经能学反义 / 句法关系 / 角色解耦了,为什么还要分多头?单头用大维度 $d_{\text{head}} = h$ 不就够了?
单头只能让所有 token 学一种 attention 模式,多头让模型同时维护多种独立的关注模式。
单头表达力的具体瓶颈
单头 self-attention 给定 Q/K/V 投影,每对 token 的 attention 权重 $\alpha_{ij}$ 由一个值决定 (单个 $\mathbf{q}_i \cdot \mathbf{k}_j$ 内积). 这意味着:
- 位置 $i$ 注意位置 $j$ 的强度只有一个数,无法同时表达"语义上关注 + 句法上不关注" 这种多维度关注
- $W_Q W_K^\top$ 是单一矩阵,训练时只能收敛到一种模式 (平均权衡)
- 复杂的 attention 模式 (induction head 之类) 单头根本表达不出,必须靠多头组合
多头的核心思路:把 hidden 维度切成 $n_{\text{head}}$ 个独立子空间,每个子空间各自算一份 attention,互不干扰。模型可以让 head 0 学 positional, head 1 学 syntactic, head 2 学 induction,互不冲突。
切分不是复制 (常见误解)
多头切分 $d_{\text{head}} = h / n_{\text{head}}$,总参数量不变:
$$\begin{equation} h = n_{\text{head}} \cdot d_{\text{head}} \label{eq:mha-split} \end{equation}$$- $h = 4096,\ n_{\text{head}} = 32 \to d_{\text{head}} = 128$ (Llama 3 8B)
- $h = 8192,\ n_{\text{head}} = 64 \to d_{\text{head}} = 128$ (Llama 3 70B)
反例 (常见误解):多头不是把同一个 $W_Q$ 复制 $n_{\text{head}}$ 份并行跑——那样总参数翻倍,训练成本翻倍,也学不到多样性 (复制的 weight 训练时差不多走到同一处)。
多头怎么实现?
核心问题:多头切分听起来像很多并行 head,实际怎么用一次矩阵乘搞定?
工程上仍是一次大矩阵乘 + reshape,不真的跑 $n_{\text{head}}$ 个独立 kernel。
完整公式
Vaswani 2017 §3.2.2 给出 multi-head attention 完整定义:
$$\begin{equation} \mathrm{MultiHead}(\mathbf{X}) = \mathrm{Concat}(\mathrm{head}_1, \ldots, \mathrm{head}_{n_{\text{head}}}) \cdot W_O \label{eq:mha-multihead} \end{equation}$$每个 head:
$$\begin{equation} \mathrm{head}_i = \mathrm{Attention}(\mathbf{X} W_Q^{(i)}, \mathbf{X} W_K^{(i)}, \mathbf{X} W_V^{(i)}) \label{eq:mha-head} \end{equation}$$其中:
- $W_Q^{(i)}, W_K^{(i)}, W_V^{(i)} \in \mathbb{R}^{h \times d_{\text{head}}}$,每个 head 独立 Q/K/V 投影
- $W_O \in \mathbb{R}^{h \times h}$,拼接后输出投影 (Anthropic Circuits 框架里的 OV circuit 一部分)
工程实现:单一大矩阵 + reshape
虽然概念上每个 head 有独立 $W_Q^{(i)}$,工程上所有 head 的 $W_Q^{(i)}$ 拼成一个大矩阵 $W_Q \in \mathbb{R}^{h \times h}$,一次矩阵乘后 reshape:
# nanoGPT 风格 (Karpathy)
self.c_attn = nn.Linear(h, 3 * h, bias=False) # Q, K, V 一起
self.c_proj = nn.Linear(h, h, bias=False) # W_O
def forward(x):
B, T, C = x.shape # batch, seqlen, hidden
q, k, v = self.c_attn(x).split(h, dim=2)
# reshape to [B, n_head, T, d_head]
q = q.view(B, T, n_head, d_head).transpose(1, 2)
k = k.view(B, T, n_head, d_head).transpose(1, 2)
v = v.view(B, T, n_head, d_head).transpose(1, 2)
# scaled dot-product on each head (batched matmul)
att = (q @ k.transpose(-2, -1)) / math.sqrt(d_head)
att = att.masked_fill(causal_mask, float('-inf')) # 04 篇展开
att = att.softmax(dim=-1)
y = att @ v # [B, n_head, T, d_head]
# 拼接 + 输出投影
y = y.transpose(1, 2).contiguous().view(B, T, h)
y = self.c_proj(y)
return y
关键观察:
- 一次大矩阵乘做完所有 head 的 Q/K/V 投影:
c_attn是 $W_{QKV} \in \mathbb{R}^{h \times 3h}$ view + transpose把 hidden 维拆成 (n_head, d_head):这是"切分而非复制" 的具体落地- Scaled dot-product 在 head 维度上是 batched matmul:所有 head 同时算,互相独立
- 输出 $W_O$ 是 hidden × hidden 全连接,让各 head 输出可以相互混合
不同 head 实际学到什么?
核心问题:概念上多头让不同 head 学不同模式,实际呢?一个训练好的模型,32 个 head 真的学到 32 种不同模式吗?
实测发现少数 specialized head 做绝大部分功能,其他 head 可被剪掉;specialized head 内部又分几类清晰功能。
Voita 2019: head pruning 实验
Voita et al. ACL 2019[1] 在 Transformer encoder (6 层 × 8 head = 48 个 head) 上用 L0 门控剪枝:
- 保留 10 / 48 head (-79%) BLEU 仅降 0.15,几乎不影响性能
- 最后被剪掉的 10 个 head 表现出三类清晰的功能性专门化:
| 类型 | 行为 | 实测特征 |
|---|---|---|
| Positional | 关注相邻 token | 最大 attention 权重 > 0.8 集中在 $j = i \pm 1$ 或 $i \pm 2$ |
| Syntactic | 追踪依存关系 | 高对应度:nsubj (主语-动词) / dobj (动词-宾语) / amod (修饰-名词) |
| Rare word | 指向罕见 token | > 50% 的情况下指向 IDF 最低 (最罕见) 的 token |
@tbl-mha-specialized Voita 2019 发现的三类 specialized head
结论:大部分 head 是冗余的或学到平庸模式,少数 specialized head 承担实际功能——这是 multi-head 设计的"过度参数化 + 自然专门化" 工程哲学。
Anthropic Induction Head:跨层电路
Anthropic mechanistic interpretability 工作[2] 发现了一个跨层 head 电路实现 in-context learning:
- Layer 1 Previous Token Head:把 token A 的信息写入 token B 的 residual stream (B 是 A 后面那个 token),通过 K-composition (跨层 key 复用)
- 后层 Induction Head:看到 pattern
[A][B] ... [A]时,利用 Layer 1 写入的信息预测[B]
这是 LLM 在上下文中"重复学习" 能力的核心电路:给模型几个 Q: ... A: ... 范例,它能在新的 Q: ... 后预测 A: 格式——induction head 是这个能力的电路基础。
适用范围:Olsson 2022 在 2~130B 参数的所有规模模型上都观察到 induction head,是 LLM 的普遍现象。
Vaswani Table 3: n_head=8 最优,太多反而下降
Vaswani 2017 自己做了 n_head 数量的消融[3] (固定总计算量):
| $n_{\text{head}}$ | $d_{\text{head}}$ | BLEU |
|---|---|---|
| 1 | 512 | 24.9 |
| 4 | 128 | 25.5 |
| 8 | 64 | 25.8 (最优) |
| 16 | 32 | 25.4 |
| 32 | 16 | 25.4 |
@tbl-mha-vaswani-ablation Vaswani 2017 Table 3 Row B: $n_{\text{head}}$ 与 BLEU (WMT En-De)
$n_{\text{head}}$ 过多反而下降: $d_{\text{head}}$ 太小 (32 / 16),每个 head 的子空间维度不足以表达足够丰富的 Q/K,表达力下降抵消了多样性收益。实测 $d_{\text{head}}$ 不应小于 32-64。
现代 LLM 的配置:$d_{\text{head}} = 128$ 工业收敛
| 模型 | $h$ | $n_{\text{head}}$ | $n_{\text{kv\_head}}$ (GQA) | $d_{\text{head}}$ |
|---|---|---|---|---|
| GPT-2 124M | 768 | 12 | 12 (MHA) | 64 |
| Llama 3 8B | 4096 | 32 | 8 (GQA) | 128 |
| Llama 3 70B | 8192 | 64 | 8 (GQA) | 128 |
| Llama 3 405B | 16384 | 128 | 8 (GQA) | 128 |
| Qwen2.5 7B | 3584 | 28 | 4 (GQA) | 128 |
| DeepSeek-V3 | (MLA 架构,见外链) |
@tbl-mha-modern-config 现代主流 LLM 的多头配置
两个规律:
- $d_{\text{head}} = 128$ 工业收敛: Llama 1/2/3 / Qwen2.5 全系列,不再像 Vaswani 时代用 64
- $n_{\text{head}}$ 随模型规模涨: $h$ 翻倍 $n_{\text{head}}$ 也翻倍,保持 $d_{\text{head}} = 128$ 不变
GQA / MQA / MLA: KV 架构压缩
核心问题:Llama 3 8B 表里 $n_{\text{kv\_head}} = 8$ 远小于 $n_{\text{head}} = 32$,这就是 GQA。它解决什么问题?怎么演化的?
KV cache 在长上下文 / 大 batch 推理下占主要显存,GQA / MQA / MLA 是 KV cache 压缩方向的三种方案。本篇仅点到为止,详见 knowledge/03-长上下文/05-kv-cache架构压缩。
三档简介
| 方案 | 含义 | KV cache 缩减 | 表达力代价 |
|---|---|---|---|
| MHA (标准多头) | $n_{\text{kv\_head}} = n_{\text{head}}$,每个 Q head 配自己的 K/V head | 无压缩 | 无 |
| MQA (Shazeer 2019)[4] | 所有 Q head 共享 1 组 K/V head, $n_{\text{kv\_head}} = 1$ | 缩 $n_{\text{head}}$ 倍 (32 → 1, 32 倍) | 明显,表达力损失 |
| GQA (Ainslie 2023)[5] | Q head 分 $g$ 组,每组共享 1 组 K/V head | 缩 $n_{\text{head}} / n_{\text{kv\_head}}$ 倍 (常用 4-8 倍) | 实测可忽略 |
| MLA (DeepSeek 2024) | K/V 从 low-rank latent (~512 dim) 重建 | 缩 ~70× (DeepSeek-V3) | 实测优于 MHA |
@tbl-mha-kv-compression KV 架构压缩方案对比 (详见长上下文章)
Llama 3 GQA 实证
Llama 3 全系列用 GQA:8B (32/8 = 4 倍压缩) / 70B (64/8 = 8 倍压缩) / 405B (128/8 = 16 倍压缩)[6]. KV cache 缩 4-16 倍,长上下文推理显存友好,性能下降可忽略。
这是 Llama 3 能稳定支持 128K 上下文的工程前提——MHA 下 KV cache 会撑爆显存。
Takeaway
| 知识点 | 核心结论 |
|---|---|
| 单头瓶颈 | 一对 token 只有一个 attention 权重,无法同时表达多种关注模式 |
| 切分不是复制 | $h = n_{\text{head}} \cdot d_{\text{head}}$,切分子空间,总参数不变 |
| 工程实现 | 单一大矩阵 + reshape,不跑 $n_{\text{head}}$ 个独立 kernel |
| 输出投影 $W_O$ | 多头拼接后 $\mathbb{R}^{h \times h}$ 投影,让 head 间可混合 |
| Specialized head (Voita 2019) | 48 head 保留 10 个 BLEU 仅降 0.15;三类专门化:positional / syntactic / rare word |
| Induction head (Anthropic) | Layer 1 Previous Token + 后层 Induction 的跨层电路,in-context learning 的电路基础 |
| Vaswani 最优 | $n_{\text{head}} = 8, d_{\text{head}} = 64$ (BLEU 25.8);太多 head $d_{\text{head}}$ 不足反而下降 |
| 现代收敛 | $d_{\text{head}} = 128$ 工业标准,$n_{\text{head}}$ 随模型规模涨 |
| GQA / MQA / MLA | KV cache 压缩方向,GQA 主流;详见长上下文章 |
| Llama 3 GQA | 8B 4×, 70B 8×, 405B 16× KV cache 压缩,长上下文推理前提 |
开放问题
- specialized head 的涌现机制:Voita 2019 实测了 specialized head 的存在,但训练动态上这些 head 怎么自发涌现仍开放
- head 数与模型规模的最优关系:现代 LLM $n_{\text{head}}$ 与 $h$ 同步涨保持 $d_{\text{head}} = 128$,是否存在更优关系?还无系统研究
- Induction head 是否大模型仍依赖同样机制:Anthropic 在中小模型观察到,大模型 (175B+) 是否还是这个电路还是涌现了更复杂的,未充分研究
- MQA / GQA / MLA 的最终收敛:业界目前 GQA 主流但 MLA (DeepSeek) 显示 KV 压缩还有空间,是否会出现下一代方案
- 多头是否在 reasoning 上有特殊作用:o1 / R1 类推理模型是否需要特殊 head 配置,还是用同样 multi-head 架构
本章结束:4 步递进走完
至此本章 4 步递进完整:简化版 → Q/K/V → 因果掩码 → 多头。读者拿到的是 GPT block 里 attention 子层的完整心智图。
后续章节继续:05-组装GPT 把 attention 与 FFN / LayerNorm / 残差组装成完整 block 并堆叠 $L$ 层。
延伸阅读
- 上一步:因果掩码 → 4.4 因果掩码
- 下一步:组装 GPT → 05-组装GPT/01-总览
- GQA / MQA / MLA 详细 → knowledge/03-长上下文/05-kv-cache架构压缩
- Anthropic Transformer Circuits 系列 → https://transformer-circuits.pub/
- Voita 2019 head pruning 完整实验 → https://arxiv.org/abs/1905.09418
参考资料
- Voita et al. Analyzing Multi-Head Self-Attention: Specialized Heads Do the Heavy Lifting, the Rest Can Be Pruned. ACL 2019. https://arxiv.org/abs/1905.09418
- Olsson et al. In-context Learning and Induction Heads. Anthropic, 2022. https://arxiv.org/abs/2209.11895
- Vaswani et al. Attention Is All You Need. NeurIPS 2017. https://arxiv.org/abs/1706.03762
- Shazeer. Fast Transformer Decoding: One Write-Head is All You Need (MQA). 2019. https://arxiv.org/abs/1911.02150
- Ainslie et al. GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints. EMNLP 2023. https://arxiv.org/abs/2305.13245
- Meta AI. The Llama 3 Herd of Models. 2024. https://arxiv.org/abs/2407.21783