梯度消失与梯度爆炸:信号衰减与失控

本节要解决什么问题

分布式系统的日志链路有一个经典问题:链路越长,追踪越难。每经过一个节点,日志信息可能衰减(关键字段丢失)、被放大(重复记录),甚至被噪声淹没(无关日志淹没关键信息)。

深层神经网络的反向传播也存在类似问题:梯度从输出层往输入层传,每经过一层都要乘以一个因子。层数一多,这些因子连乘的结果要么趋近于零(梯度消失),要么趋近于无穷(梯度爆炸)。这就是为什么训练一个 100 层的网络,曾经被认为是"不可能完成的任务"。

这个工具/机制是怎么工作的

梯度消失:信号越传越弱

在反向传播中,梯度是链式法则的连乘结果。如果每一层的梯度都小于 1,连乘的结果会指数级衰减:

假设每层梯度 ≈ 0.5:

第 5 层梯度:  0.5⁵  ≈ 0.03    ← 开始偏小
第 10 层梯度: 0.5¹⁰ ≈ 0.001   ← 接近于零
第 50 层梯度: 0.5⁵⁰ ≈ 10⁻¹⁵  ← 彻底消失

后果:靠近输入层(网络底部)的参数几乎收不到任何梯度更新信号,相当于一个分布式系统最底层的节点"断连"了——它的配置永远不会被调整。

梯度爆炸:信号失控放大

反过来,如果每层梯度都大于 1,连乘会让梯度指数级膨胀:

假设每层梯度 ≈ 1.5:

第 5 层梯度:  1.5⁵  ≈ 7.6    ← 开始放大
第 10 层梯度: 1.5¹⁰ ≈ 57.7   ← 严重放大
第 50 层梯度: 1.5⁵⁰ ≈ 10⁸    ← 数值爆炸

后果:参数更新步长巨大,损失函数变成 NaN 或 Inf,模型直接发散。

激活函数:影响梯度传播的关键因子

反向传播中,每层梯度都会乘以激活函数的导数。不同激活函数的导数差异巨大:

激活函数 正区间导数 负区间导数 对梯度的影响
Sigmoid 最大 0.25(x=0 处) 接近 0 两端几乎阻断梯度
Tanh 最大 1 接近 0 优于 Sigmoid,但仍会饱和
ReLU 恒等于 1 恒等于 0 正区间梯度无损,但会"杀死"部分神经元
GELU 近似 1(正区间) 平滑过渡 兼顾梯度稳定与函数连续性

残差连接:为梯度提供一条直达通道

残差连接(Residual Connection)是解决深层网络梯度消失的核心结构。其核心思想是:在每层输入和输出之间,额外加一条"直接通路",让梯度可以跳过中间层直接传回去。

没有残差

y = F(x),梯度 = ∂F/∂x

如果 ∂F/∂x 很小,梯度就消失了。

有残差连接

y = x + F(x),梯度 = 1 + ∂F/∂x

"1" 永远存在——无论 ∂F/∂x 多小,加上 1 之后至少保证梯度不会消失。这就像分布式追踪中,除了经过每个微服务的日志链路之外,还额外加了一条"直达 trace ID",即使某个节点日志丢失,也能从直达通道追溯。

传统深层网络:
x → Layer1 → Layer2 → ... → Layer100 → 输出
(梯度要穿过 100 个乘法因子)

残差网络:
x ──────────────────────────────────→ + ──→ 输出
    → Layer1 → Layer2 → ... → Layer100 ──┘
(梯度有"直达通道",直接乘 1 传回去)

Pre-LN 与 LayerNorm 的协同

残差连接解决了梯度传播问题,但单独使用残差的数值可能不断放大(相当于每次加 1,100 层后均值偏移 100)。残差连接配合 LayerNorm 使用,才能在保证梯度通路的同时维持数值稳定。

形式化

残差连接公式(Pre-LN 结构,GPT/LLaMA 等现代大模型采用):

y = x + Attention(LayerNorm(x))      # Self-Attention 子层
z = y + FFN(LayerNorm(y))            # 前馈网络子层

梯度稳定性:∂(x + F(x))/∂x = 1 + ∂F/∂x。即使 ∂F/∂x → 0,整体梯度仍然保持在 1 附近。

本节小结

梯度消失是信号衰减问题(连乘小于 1 的数),梯度爆炸是信号失控问题(连乘大于 1 的数),两者都是深层网络中梯度连乘的必然结果。残差连接通过引入恒等路径(梯度永远乘以 1)打破了"必须依赖中间层传递"的约束,是现代大模型能堆到几十甚至上百层的结构基础。

延伸阅读

  • LayerNorm — LayerNorm 负责控制残差叠加后的数值尺度,与残差连接一起构成 Transformer 的数值稳定双保险
  • 反向传播 — 梯度消失与爆炸的根因在于反向传播的链式法则,了解它才能理解残差连接"为什么有效"

results matching ""

    No results matching ""