为什么 attention 要除以 √dk
一句话速记
防止 Q·K 点积数值过大导致 softmax 饱和、梯度消失。 dk 是 key 向量的维度。
通俗解释(5 分钟版)
Attention 公式的核心是 softmax(QK^T / √dk) · V。为什么要除以 √dk,而不是直接 softmax(QK^T) · V?
关键问题在 softmax 的脾气:
- 输入值相差越大,softmax 输出越接近 one-hot(最大值那项几乎是 1,其他几乎是 0)
- 一旦接近 one-hot,除了最大那项以外的梯度都接近 0,模型”只学一个位置”,学不到分布
- 这叫 softmax 饱和
而 Q·K 的点积值有多大,取决于 dk:
- Q、K 每一维假设是 N(0,1) 的独立随机变量
- 点积是 dk 个乘积项求和,根据方差加法性,方差正比于 dk
- dk=512 时,点积的标准差约 22.6,softmax 差不多已经是 one-hot 了
除以 √dk 做什么:
- 把方差从 dk 压回 1(标准差从 √dk 压回 1)
- softmax 输入在一个合理范围内,输出是平滑的概率分布
- 梯度能正常回传,模型能学到”注意力应该分散在哪几个位置”
关键细节 / 数学直觉
方差推导:
- 设 Q_i、K_i ~ N(0,1),相互独立
- Q·K = Σ Q_i × K_i,一共 dk 项
- Var(Q_i × K_i) = E[(Q_i × K_i)²] - E[Q_i × K_i]² = 1 × 1 - 0 = 1
- Var(Q·K) = dk × 1 = dk
- 所以标准差是 √dk,除以 √dk 后方差归一到 1
直观阈值:
- dk=64(BERT-base 单头):不缩放的话点积标准差 8,已经有点危险
- dk=512(整个模型维度):不缩放就是灾难
延伸追问
-
Q:为什么是 √dk 不是 dk? 答:因为我们要归一化的是标准差(缩放应与数值量级同阶),方差是 dk,标准差是 √dk。除以 dk 会把方差压到 1/dk,反而过度平滑,attention 分布接近均匀就没区分度了。
-
Q:V 为什么不参与这个缩放? 答:V 只参与加权求和(最后那一步),不进 softmax。softmax 饱和问题和 V 无关。
-
Q:有没有不除 √dk 的方案? 答:有。比如在 QK 之前对 Q、K 做 LayerNorm 或 RMSNorm,本质也是控制数值范围。但工程上 √dk 最便宜——一个常数除法,没有额外参数。
-
Q:多头注意力里这个 dk 是整个模型的维度还是单头的? 答:单头的。多头会把 d_model 切成 h 份,每头的 dk = d_model / h。所以每个头独立做缩放,用各自的 √(d_model/h)。
我的记法
复述时这样讲:
“QK 点积的方差正比于 dk → dk 大的话 softmax 直接饱和 → 梯度回传不了 → 除以 √dk 把方差压回 1。”
能加分就补一句:“V 不进 softmax,所以不缩放 V。“
状态
- 已背速记
- 能讲通俗版
- 能答追问
- 在实际场景中用上过
参考资料
- Attention Is All You Need (Vaswani et al., 2017) §3.2.1
- Jay Alammar, The Illustrated Transformer