以下内容由AI辅助生成
一、问题背景:我们到底在担心什么?
在 Transformer 的注意力机制中,有一个非常经典的公式:
很多人在第一次看到这个公式时都会产生疑问:
- 为什么注意力机制使用点积来计算相关性?
- 为什么点积之后要接一个 softmax?
- 最关键的是:为什么一定要除以
?
如果只是把它当成论文中的经验公式,那么对注意力机制的理解仍然停留在表面。事实上,这个结构并不是为了“效果更好看”,而是为了解决一个在训练过程中非常具体、而且非常致命的问题。
二、softmax 的作用,以及它对输入尺度的敏感性
softmax 的定义如下:
它的作用是把一组实数打分映射为一个概率分布。但 softmax 有一个非常重要的性质:
softmax 对输入的整体尺度极其敏感
当输入数值较为温和时,softmax 会产生一个相对平滑的分布;而当输入整体变大时,softmax 的输出会迅速变得极端。
一个简单的数值例子
情况一:输入尺度适中
1 | 输入: z = [1, 0] |
两个位置都有明显概率,模型仍然保留不确定性。
情况二:输入尺度变大
1 | 输入: z = [10, 0] |
此时 softmax 的输出已经几乎等同于 one-hot 分布 [1, 0]。
这是因为指数函数会将线性差距放大为指数级差距,只要输入的整体尺度变大,softmax 的输出就会迅速向极端塌缩。
三、注意力分数为什么会天然变大:点积的统计性质
在注意力机制中,softmax 的输入来自查询向量和键向量的点积:
对于单个 query-key 对,这个分数可以写成:
在常见训练设置中(例如经过 LayerNorm 之后),可以合理假设:
- 各维度的
相互独立 - 均值为 0
- 方差为 1
点积方差的数学推导
在上述假设下,对于每一项
- 均值:
- 方差:
当
因此标准差为:
关键结论:点积的典型数值规模与
成正比
这并非实现细节或偶然现象,而是高维点积在统计意义上的必然结果。
具体数值示例
| 维度 |
点积标准差 |
典型点积范围 |
|---|---|---|
| 64 | 8 | [-16, 16] |
| 128 | 11.3 | [-23, 23] |
| 512 | 22.6 | [-45, 45] |
当
四、softmax 的梯度结构:什么是 Jacobian?
softmax 是一个向量到向量的函数:
- 输入:
- 输出:
当输入和输出都是向量时,需要描述这样一件事:
某一个输入分量发生微小变化,会如何影响所有输出分量
所有这些“偏导关系”组成的一整张表,称为 Jacobian 矩阵:
softmax 的 Jacobian 具体形式
对于 softmax,其 Jacobian 有明确的数学形式:
关键观察:
- 当
(饱和状态)时, - 当
时,
梯度大小直接由输出概率本身控制
五、softmax 饱和:输出分布发生了什么变化
对比:适中尺度 vs 大尺度
输入尺度适中时(例如
1 | 位置: 1 2 3 4 |
在这种状态下,多个位置都具有非零概率,输出对输入变化是敏感的。
输入尺度变大时(例如
1 | 位置: 1 2 3 4 |
此时,几乎所有概率质量都集中在单一位置,其余位置的概率被压缩到接近零。
这种从“平滑分布”到“极端分布”的转变,称为 softmax 饱和(saturation)。
六、梯度是如何在 softmax 处消失的
在反向传播中,梯度的传递遵循链式法则:
未饱和状态:梯度正常传播
1 | 反向传播路径: |
饱和状态:梯度被截断
1 | 反向传播路径: |
梯度不是在网络深处逐层衰减的,而是在 softmax 这一层被直接截断的
这是一种发生位置非常明确的梯度消失问题。
七、从点积到梯度消失的完整因果链
将前面的所有环节串联起来,可以得到一条完整因果链:
1 | 高维点积统计性质 |
问题的本质是:点积让 softmax 过早进入了饱和区间,从而导致梯度消失。
八、为什么这不是梯度爆炸问题
这一现象有时会被误认为是梯度爆炸,但两者在机制上完全不同。
梯度爆炸 vs 梯度消失对比
| 特征 | 梯度爆炸 | 本文讨论的问题 |
|---|---|---|
| 梯度大小 | 趋向无穷大 | 趋向零 |
| 发生原因 | 导数累乘 > 1 | softmax 饱和导致 ∂s/∂z ≈ 0 |
| 数值稳定性 | 数值溢出 | 数值下溢 |
| 训练表现 | 参数震荡、NaN | 参数停止更新 |
关键区别:
softmax 的梯度由其输出概率控制,具体形式为
softmax 只会压缩梯度,而不会放大梯度。
因此,这里出现的问题不是梯度失控增大,而是梯度被系统性压缩并最终消失。
九、根本解决方案:Scaled Dot-Product Attention
既然问题的根源在于点积的方差与维度成正比(
缩放后的统计性质
缩放后,点积的方差变为:
效果:
- 控制点积的典型尺度稳定在常数级别(与维度无关)
- 防止 softmax 过早进入饱和区
- 保持输出对输入变化的敏感性
- 让梯度能够持续传回 Q 和 K
这一步的本质是方差归一化(variance normalization)。
十、为什么是
三种缩放方式的对比
| 缩放方式 | 点积方差 | softmax 行为 | 问题 |
|---|---|---|---|
| 不做缩放 | 很快饱和 | 梯度消失 | |
| 除以 |
过于均匀 | 区分能力不足,所有位置概率接近 | |
| 除以 |
稳定且可学习 | ✓ 最优 |
直观解释
- 不做缩放:点积方差随
线性增长,softmax 很快饱和 - **除以
**:矫枉过正,点积方差变成 ,当 很大时会让所有注意力权重过于平均,失去了“注意”的意义 - **除以
**:恰好让方差归一化到 1,既不会饱和也不会过于平滑
这一选择与 Xavier 初始化、LayerNorm 等方法背后的统计思想是一致的:保持信号的方差在网络中稳定传播。
十一、总结
点积过大不会导致梯度爆炸,而是会使 softmax 过早进入饱和状态。一旦 softmax 的输出分布变得极端,其 Jacobian 中的梯度项会趋近于零,导致梯度在这一层被有效截断,学习信号无法继续传播到 Q 和 K。
这是一个精心设计的数学解决方案,背后的统计原理清晰明确。