多头注意力为什么比单头好

一句话速记

一个头在一组 Q/K/V 子空间里只能学一种关系;多头把 d_model 切成 h 份,让 h 组注意力并行、各学不同子空间里的关系,再拼起来,表达力更强、更稳。每头里仍然要除以各自的 √(d_model/h)。

结构(和单头的差别)

  • 设隐藏维是 d_model,头数为 h
  • 每个头在维度 d_k = d_v = d_model / h 的子空间里做注意力。
  • 对同一层、同一位置,有 h 套 独立的 W_Q^i, W_K^i, W_V^i(或等价的投影实现)。

单步上:

  • 第 i 个头的输出:
    head_i = Attention(Q^i, K^i, V^i),形状与 (n, d_k) 相关。
  • 多头输出:将 head_1, …, head_h 在特征维上拼接 得到 (n, d_model),再经 W_O(输出投影) 变回与残差同维的表示。

单头 = 只保留 h=1 时那一套,子空间一个,能捕获的关系模式有限。

为什么多头通常更好(机制层面)

  1. 子空间解耦
    高维里「一种关系模式」用单独一组 Q/K/V 很难同时拟合。拆成 h 个低维子空间,相当于 h 个并行专家,各自在更简单的空间里做注意力。

  2. 关系类型不单一
    同一对词可涉及语法(主谓)、指代、语义、共指、局部 n-gram 等。多头允许不同头偏不同类模式(实践中会出现「有的头很局部、有的头很远程」的观测,视任务与层而定)。

  3. 与单头同算力时:多头 > 大一点的单头(经验上)
    论文与后续复现实践里,多中等头 往往比 单大维头 更稳。直觉:多个低维自注意力在优化景观上更友好,避免一个超高维点积里所有关系混在一个 softmax 里抢权重。

  4. 与 √dk 配套
    每头维度是 d_k = d_model/h,点积的 scale 是 1/√d_k(见 为什么 attention 要除以 √dk)。头越多,每头越窄,单头点积方差越可控;总通道数 h * d_k = d_model 不变(拼接后仍是 d_model)。

不是越多越好

  • h 太大 → 每头 d_k 太小 → 每头表达能力弱,过碎。
  • 工程上要平衡显存、算力与 h;d_modelh 的配比是架构搜索的一部分(如 4096/32 等)。

和「集成 / 多视角」的类比(可选一句)

  • 可一句话:多头的并行注意力类似同一位置上的多路独立匹配,再融合(避免只说「更厉害」而说不清机制)。

延伸追问

  • Q:多头和「把 d_model 做大」有啥区别?
    答:纯放大单头,仍在一个注意力分布里压所有关系;多头是多个 softmax 的注意力分布并行,再合成,表达的是「多种关系可同时高权」,不是互相挤一个分布。

  • Q:每头的维度为什么一般是 d_model / h?
    答:为让拼接后长度仍是 d_model,和残差、下一层输入维度对齐;同时每头有独立小空间做 softmax 竞争。

  • Q:一个头能合并成多头吗?
    答:在固定算力下,多头结构带来多组不同投影,不是简单把大矩阵切成块不新增参数;参数与结构不同,不等价于一个巨大单头(细节实现另说,但答「结构归纳偏置不同」即可)。

  • Q:头有没有可解释性?
    答:有论文做过可视化(如某些头偏句法、某些偏位置),不是每个头都干净对应人类标签;不要过度声称「第 3 头一定管指代」。

我的记法

TODO

状态

  • 已背速记
  • 能讲通俗版
  • 能答追问
  • 在实际场景中用上过

参考资料

  • Vaswani et al., 2017 §3.2.2 Multi-Head Attention
  • Attention 可视化(Tensor2Tensor 文章 / 后来 BERT 头的分析论文)