多头注意力为什么比单头好
一句话速记
一个头在一组 Q/K/V 子空间里只能学一种关系;多头把 d_model 切成 h 份,让 h 组注意力并行、各学不同子空间里的关系,再拼起来,表达力更强、更稳。每头里仍然要除以各自的 √(d_model/h)。
结构(和单头的差别)
- 设隐藏维是
d_model,头数为h。 - 每个头在维度
d_k = d_v = d_model / h的子空间里做注意力。 - 对同一层、同一位置,有 h 套 独立的
W_Q^i, W_K^i, W_V^i(或等价的投影实现)。
单步上:
- 第 i 个头的输出:
head_i = Attention(Q^i, K^i, V^i),形状与(n, d_k)相关。 - 多头输出:将
head_1, …, head_h在特征维上拼接 得到(n, d_model),再经W_O(输出投影) 变回与残差同维的表示。
单头 = 只保留 h=1 时那一套,子空间一个,能捕获的关系模式有限。
为什么多头通常更好(机制层面)
-
子空间解耦
高维里「一种关系模式」用单独一组Q/K/V很难同时拟合。拆成 h 个低维子空间,相当于 h 个并行专家,各自在更简单的空间里做注意力。 -
关系类型不单一
同一对词可涉及语法(主谓)、指代、语义、共指、局部 n-gram 等。多头允许不同头偏不同类模式(实践中会出现「有的头很局部、有的头很远程」的观测,视任务与层而定)。 -
与单头同算力时:多头 > 大一点的单头(经验上)
论文与后续复现实践里,多中等头 往往比 单大维头 更稳。直觉:多个低维自注意力在优化景观上更友好,避免一个超高维点积里所有关系混在一个 softmax 里抢权重。 -
与 √dk 配套
每头维度是d_k = d_model/h,点积的 scale 是1/√d_k(见 为什么 attention 要除以 √dk)。头越多,每头越窄,单头点积方差越可控;总通道数h * d_k = d_model不变(拼接后仍是d_model)。
不是越多越好
- h 太大 → 每头
d_k太小 → 每头表达能力弱,过碎。 - 工程上要平衡显存、算力与 h;
d_model与h的配比是架构搜索的一部分(如 4096/32 等)。
和「集成 / 多视角」的类比(可选一句)
- 可一句话:多头的并行注意力类似同一位置上的多路独立匹配,再融合(避免只说「更厉害」而说不清机制)。
延伸追问
-
Q:多头和「把 d_model 做大」有啥区别?
答:纯放大单头,仍在一个注意力分布里压所有关系;多头是多个 softmax 的注意力分布并行,再合成,表达的是「多种关系可同时高权」,不是互相挤一个分布。 -
Q:每头的维度为什么一般是 d_model / h?
答:为让拼接后长度仍是d_model,和残差、下一层输入维度对齐;同时每头有独立小空间做 softmax 竞争。 -
Q:一个头能合并成多头吗?
答:在固定算力下,多头结构带来多组不同投影,不是简单把大矩阵切成块不新增参数;参数与结构不同,不等价于一个巨大单头(细节实现另说,但答「结构归纳偏置不同」即可)。 -
Q:头有没有可解释性?
答:有论文做过可视化(如某些头偏句法、某些偏位置),不是每个头都干净对应人类标签;不要过度声称「第 3 头一定管指代」。
我的记法
TODO
状态
- 已背速记
- 能讲通俗版
- 能答追问
- 在实际场景中用上过
参考资料
- Vaswani et al., 2017 §3.2.2 Multi-Head Attention
- Attention 可视化(Tensor2Tensor 文章 / 后来 BERT 头的分析论文)