为什么点积太大会导致 Softmax 梯度问题,以及这个问题是如何被解决的

以下内容由AI辅助生成

一、问题背景:我们到底在担心什么?

在 Transformer 的注意力机制中,有一个非常经典的公式:

很多人在第一次看到这个公式时都会产生疑问:

  • 为什么注意力机制使用点积来计算相关性?
  • 为什么点积之后要接一个 softmax?
  • 最关键的是:为什么一定要除以

如果只是把它当成论文中的经验公式,那么对注意力机制的理解仍然停留在表面。事实上,这个结构并不是为了“效果更好看”,而是为了解决一个在训练过程中非常具体、而且非常致命的问题


二、softmax 的作用,以及它对输入尺度的敏感性

softmax 的定义如下:

它的作用是把一组实数打分映射为一个概率分布。但 softmax 有一个非常重要的性质:

softmax 对输入的整体尺度极其敏感

当输入数值较为温和时,softmax 会产生一个相对平滑的分布;而当输入整体变大时,softmax 的输出会迅速变得极端。

一个简单的数值例子

情况一:输入尺度适中

1
2
输入: z = [1, 0]
输出: softmax(z) ≈ [0.73, 0.27]

两个位置都有明显概率,模型仍然保留不确定性。


情况二:输入尺度变大

1
2
输入: z = [10, 0]
输出: softmax(z) ≈ [0.9999, 0.0001]

此时 softmax 的输出已经几乎等同于 one-hot 分布 [1, 0]

这是因为指数函数会将线性差距放大为指数级差距,只要输入的整体尺度变大,softmax 的输出就会迅速向极端塌缩。


三、注意力分数为什么会天然变大:点积的统计性质

在注意力机制中,softmax 的输入来自查询向量和键向量的点积:

对于单个 query-key 对,这个分数可以写成:

在常见训练设置中(例如经过 LayerNorm 之后),可以合理假设:

  • 各维度的 相互独立
  • 均值为 0
  • 方差为 1

点积方差的数学推导

在上述假设下,对于每一项

  • 均值:
  • 方差:

项独立相加时:

因此标准差为:

关键结论:点积的典型数值规模与 成正比

这并非实现细节或偶然现象,而是高维点积在统计意义上的必然结果。

具体数值示例

维度 点积标准差 典型点积范围
64 8 [-16, 16]
128 11.3 [-23, 23]
512 22.6 [-45, 45]

时,点积值可能达到 ±45,这会让 softmax 严重饱和。


四、softmax 的梯度结构:什么是 Jacobian?

softmax 是一个向量到向量的函数:

  • 输入:
  • 输出:

当输入和输出都是向量时,需要描述这样一件事:

某一个输入分量发生微小变化,会如何影响所有输出分量

所有这些“偏导关系”组成的一整张表,称为 Jacobian 矩阵

softmax 的 Jacobian 具体形式

对于 softmax,其 Jacobian 有明确的数学形式:

关键观察:

  • (饱和状态)时,
  • 时,

梯度大小直接由输出概率本身控制


五、softmax 饱和:输出分布发生了什么变化

对比:适中尺度 vs 大尺度

输入尺度适中时(例如 ):

1
2
3
4
5
6
7
8
9
位置:    1      2      3      4
输入: 2 1 0 -1
概率: 0.52 0.28 0.14 0.06

可视化:
位置 1: ██████████████████ (52%)
位置 2: ████████████ (28%)
位置 3: ██████ (14%)
位置 4: ██ (6%)

在这种状态下,多个位置都具有非零概率,输出对输入变化是敏感的。


输入尺度变大时(例如 ):

1
2
3
4
5
6
7
8
9
位置:    1      2       3        4
输入: 20 10 0 -10
概率: 1.00 0.00 0.00 0.00

可视化:
位置 1: ████████████████████████████████ (≈100%)
位置 2: (≈0%)
位置 3: (≈0%)
位置 4: (≈0%)

此时,几乎所有概率质量都集中在单一位置,其余位置的概率被压缩到接近零。

这种从“平滑分布”到“极端分布”的转变,称为 softmax 饱和(saturation)


六、梯度是如何在 softmax 处消失的

在反向传播中,梯度的传递遵循链式法则:

未饱和状态:梯度正常传播

1
2
3
4
5
6
7
8
9
10
11
12
反向传播路径:

Loss
│ ∂L/∂s (来自上游)

softmax
│ ∂s/∂z ≈ [0.2, 0.3, 0.1, ...] (梯度有效)

logits z
│ ∂z/∂Q, ∂z/∂K (梯度继续传播)

Q, K (参数可以更新)

饱和状态:梯度被截断

1
2
3
4
5
6
7
8
9
10
11
12
反向传播路径:

Loss
│ ∂L/∂s (来自上游)

softmax
│ ∂s/∂z ≈ [0.00001, 0, 0, ...] (梯度几乎为0!)

logits z
│ ∂z/∂Q ≈ 0, ∂z/∂K ≈ 0 (梯度消失)

Q, K (参数无法更新!)

梯度不是在网络深处逐层衰减的,而是在 softmax 这一层被直接截断的

这是一种发生位置非常明确的梯度消失问题。


七、从点积到梯度消失的完整因果链

将前面的所有环节串联起来,可以得到一条完整因果链:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
高维点积统计性质

点积方差 ∝ d_k

点积值的典型尺度 ∝ √d_k

softmax 输入整体变大

指数函数放大差距

softmax 输出趋向 one-hot 分布

softmax Jacobian 中梯度项趋近于 0

∂s/∂z ≈ 0 导致梯度截断

Q、K 无法更新,注意力权重过早固定

问题的本质是:点积让 softmax 过早进入了饱和区间,从而导致梯度消失。


八、为什么这不是梯度爆炸问题

这一现象有时会被误认为是梯度爆炸,但两者在机制上完全不同。

梯度爆炸 vs 梯度消失对比

特征 梯度爆炸 本文讨论的问题
梯度大小 趋向无穷大 趋向零
发生原因 导数累乘 > 1 softmax 饱和导致 ∂s/∂z ≈ 0
数值稳定性 数值溢出 数值下溢
训练表现 参数震荡、NaN 参数停止更新

关键区别:

softmax 的梯度由其输出概率控制,具体形式为 。由于概率值 ,这些导数在数值上是有上界的,不可能随着输入增大而放大梯度。

softmax 只会压缩梯度,而不会放大梯度。

因此,这里出现的问题不是梯度失控增大,而是梯度被系统性压缩并最终消失。


九、根本解决方案:Scaled Dot-Product Attention

既然问题的根源在于点积的方差与维度成正比(),那么最直接的解决方式就是对点积进行缩放:

缩放后的统计性质

缩放后,点积的方差变为:

效果:

  • 控制点积的典型尺度稳定在常数级别(与维度无关)
  • 防止 softmax 过早进入饱和区
  • 保持输出对输入变化的敏感性
  • 让梯度能够持续传回 Q 和 K

这一步的本质是方差归一化(variance normalization)


十、为什么是

三种缩放方式的对比

缩放方式 点积方差 softmax 行为 问题
不做缩放 很快饱和 梯度消失
除以 过于均匀 区分能力不足,所有位置概率接近
除以 稳定且可学习 ✓ 最优

直观解释

  • 不做缩放:点积方差随 线性增长,softmax 很快饱和
  • **除以 **:矫枉过正,点积方差变成 ,当 很大时会让所有注意力权重过于平均,失去了“注意”的意义
  • **除以 **:恰好让方差归一化到 1,既不会饱和也不会过于平滑

这一选择与 Xavier 初始化、LayerNorm 等方法背后的统计思想是一致的:保持信号的方差在网络中稳定传播


十一、总结

点积过大不会导致梯度爆炸,而是会使 softmax 过早进入饱和状态。一旦 softmax 的输出分布变得极端,其 Jacobian 中的梯度项会趋近于零,导致梯度在这一层被有效截断,学习信号无法继续传播到 Q 和 K。

的引入,正是为了把点积的方差从 归一化到 1,从而把 softmax 的输入尺度拉回到一个模型仍然能够持续学习的区间。

这是一个精心设计的数学解决方案,背后的统计原理清晰明确。