梯度消失与梯度爆炸:信号衰减与失控
本节要解决什么问题
分布式系统的日志链路有一个经典问题:链路越长,追踪越难。每经过一个节点,日志信息可能衰减(关键字段丢失)、被放大(重复记录),甚至被噪声淹没(无关日志淹没关键信息)。
深层神经网络的反向传播也存在类似问题:梯度从输出层往输入层传,每经过一层都要乘以一个因子。层数一多,这些因子连乘的结果要么趋近于零(梯度消失),要么趋近于无穷(梯度爆炸)。这就是为什么训练一个 100 层的网络,曾经被认为是"不可能完成的任务"。
这个工具/机制是怎么工作的
梯度消失:信号越传越弱
在反向传播中,梯度是链式法则的连乘结果。如果每一层的梯度都小于 1,连乘的结果会指数级衰减:
假设每层梯度 ≈ 0.5:
第 5 层梯度: 0.5⁵ ≈ 0.03 ← 开始偏小
第 10 层梯度: 0.5¹⁰ ≈ 0.001 ← 接近于零
第 50 层梯度: 0.5⁵⁰ ≈ 10⁻¹⁵ ← 彻底消失
后果:靠近输入层(网络底部)的参数几乎收不到任何梯度更新信号,相当于一个分布式系统最底层的节点"断连"了——它的配置永远不会被调整。
梯度爆炸:信号失控放大
反过来,如果每层梯度都大于 1,连乘会让梯度指数级膨胀:
假设每层梯度 ≈ 1.5:
第 5 层梯度: 1.5⁵ ≈ 7.6 ← 开始放大
第 10 层梯度: 1.5¹⁰ ≈ 57.7 ← 严重放大
第 50 层梯度: 1.5⁵⁰ ≈ 10⁸ ← 数值爆炸
后果:参数更新步长巨大,损失函数变成 NaN 或 Inf,模型直接发散。
激活函数:影响梯度传播的关键因子
反向传播中,每层梯度都会乘以激活函数的导数。不同激活函数的导数差异巨大:
| 激活函数 | 正区间导数 | 负区间导数 | 对梯度的影响 |
|---|---|---|---|
| Sigmoid | 最大 0.25(x=0 处) | 接近 0 | 两端几乎阻断梯度 |
| Tanh | 最大 1 | 接近 0 | 优于 Sigmoid,但仍会饱和 |
| ReLU | 恒等于 1 | 恒等于 0 | 正区间梯度无损,但会"杀死"部分神经元 |
| GELU | 近似 1(正区间) | 平滑过渡 | 兼顾梯度稳定与函数连续性 |
残差连接:为梯度提供一条直达通道
残差连接(Residual Connection)是解决深层网络梯度消失的核心结构。其核心思想是:在每层输入和输出之间,额外加一条"直接通路",让梯度可以跳过中间层直接传回去。
没有残差:
y = F(x),梯度 = ∂F/∂x
如果 ∂F/∂x 很小,梯度就消失了。
有残差连接:
y = x + F(x),梯度 = 1 + ∂F/∂x
"1" 永远存在——无论 ∂F/∂x 多小,加上 1 之后至少保证梯度不会消失。这就像分布式追踪中,除了经过每个微服务的日志链路之外,还额外加了一条"直达 trace ID",即使某个节点日志丢失,也能从直达通道追溯。
传统深层网络:
x → Layer1 → Layer2 → ... → Layer100 → 输出
(梯度要穿过 100 个乘法因子)
残差网络:
x ──────────────────────────────────→ + ──→ 输出
→ Layer1 → Layer2 → ... → Layer100 ──┘
(梯度有"直达通道",直接乘 1 传回去)
Pre-LN 与 LayerNorm 的协同
残差连接解决了梯度传播问题,但单独使用残差的数值可能不断放大(相当于每次加 1,100 层后均值偏移 100)。残差连接配合 LayerNorm 使用,才能在保证梯度通路的同时维持数值稳定。
形式化
残差连接公式(Pre-LN 结构,GPT/LLaMA 等现代大模型采用):
y = x + Attention(LayerNorm(x)) # Self-Attention 子层
z = y + FFN(LayerNorm(y)) # 前馈网络子层
梯度稳定性:∂(x + F(x))/∂x = 1 + ∂F/∂x。即使 ∂F/∂x → 0,整体梯度仍然保持在 1 附近。
本节小结
梯度消失是信号衰减问题(连乘小于 1 的数),梯度爆炸是信号失控问题(连乘大于 1 的数),两者都是深层网络中梯度连乘的必然结果。残差连接通过引入恒等路径(梯度永远乘以 1)打破了"必须依赖中间层传递"的约束,是现代大模型能堆到几十甚至上百层的结构基础。