反向传播:故障溯源

本节要解决什么问题

在分布式系统中,当最终请求失败时,工程师最关心的问题不是"失败了",而是"谁的锅"。每个中间环节都可能是嫌疑对象:网关超时?数据库慢查询?缓存击穿?下游服务故障?你需要从最终结果往回追,查每个环节分别负多大责任。

反向传播(Backpropagation)做的事情本质上完全一样:从最终损失出发,往回追查每个参数分别要负多少责任。责任的大小,用梯度来衡量。

这个工具/机制是怎么工作的

链式法则:追责的工具

反向传播的数学核心是链式法则。假设:

L = f(y), y = g(x)

那么 x 的变化对 L 的影响,等于两步影响的乘积:

dL/dx = dL/dy × dy/dx 用故障溯源的比喻:

  • dL/dy:你(y 的负责人)要对最终损失负多少责任
  • dy/dx:你的上游(x 的负责人)对 y 的偏差负多少责任
  • 两者相乘:上游通过你传导,最终对损失负了多少责任

完整的计算图示例

以一个带分叉的网络为例:

            h2 = w2 · h1
          ↗
x ──→ h1                ──→ y = h2 + h3 ──→ L
          ↘
            h3 = w3 · h1

数学定义:

h₁ = w₁ · x h₂ = w₂ · h₁ h₃ = w₃ · h₁ y = h₂ + h₃ L = (1/2) × (y − t)² 设定数值:x=2, t=12, w1=1, w2=2, w3=3

正向传播(计算损失)

h1 = 1 × 2 = 2
h2 = 2 × 2 = 4
h3 = 3 × 2 = 6
y  = 4 + 6 = 10
L  = (1/2) × (10 - 12)² = 2     ← 损失值

反向传播(追责)

第1步:∂L/∂y = y - t = 10 - 12 = -2    ← 终点开始

第2步:分叉——梯度同时传向两个分支
  ∂L/∂h2 = ∂L/∂y × ∂y/∂h2 = -2 × 1 = -2
  ∂L/∂h3 = ∂L/∂y × ∂y/∂h3 = -2 × 1 = -2

第3步:追到 w2 和 w3
  ∂L/∂w2 = ∂L/∂h2 × ∂h2/∂w2 = -2 × 2 = -4
  ∂L/∂w3 = ∂L/∂h3 × ∂h3/∂w3 = -2 × 2 = -4

第4步:梯度汇总(关键!)
  h1 同时影响了 h2 和 h3,两个分支的"责任"要相加
  ∂L/∂h1 = ∂L/∂h2 × ∂h2/∂h1 + ∂L/∂h3 × ∂h3/∂h1
         = (-2) × 2 + (-2) × 3 = -10

第5步:追到 w1
  ∂L/∂w1 = ∂L/∂h1 × ∂h1/∂w1 = -10 × 2 = -20

最终每个参数的梯度:

参数 梯度 含义
w1 -20 责任最重,影响路径最长
w2 -4 直接影响输出
w3 -4 直接影响输出

两个关键概念

梯度分发:在分叉处,损失对某个节点的梯度,会同时传给所有下游分支。

梯度汇总:当多个分支汇聚到同一个节点时,各自传来的梯度在该节点相加(因为这个节点的输出同时被多个上游影响)。

梯度下降:落实"处罚"

反向传播算出梯度后,梯度下降负责执行参数更新:

w ← w − η × dL/dw

  • 梯度为负(w₁ 的情况):当前参数值太小,损失偏高 → 增加 w₁
  • 梯度为正:当前参数值太大,损失偏高 → 减少 w

链式法则是追责工具,反向传播是追责的执行策略,梯度下降是根据追责结果执行处罚。

形式化

链式法则(多步):∂L/∂w₁ = (∂L/∂h₁) · (∂h₁/∂w₁)

梯度下降:

for _ in range(num_epochs):
    # 正向传播:算损失
    loss = forward_pass(model, x, y)

    # 反向传播:算梯度
    grads = backward_pass(loss, model)

    # 参数更新
    for param, grad in zip(model.parameters(), grads):
        param -= learning_rate * grad

本节小结

反向传播不是新的数学,它只是链式法则和偏导数加法规则在复杂计算图上的系统应用:从最终损失出发,把每个节点和参数的"责任"(梯度)一层层追回来。梯度下降再根据追责结果更新参数。深层网络中梯度消失和梯度爆炸,就是这个"追责链条太长"时产生的副作用。

延伸阅读

  • 梯度消失与梯度爆炸 — 当追责链条(网络层数)太长时,梯度连乘会出现什么问题,以及残差连接如何给追责提供一条"直达通道"
  • LayerNorm — 归一化如何让梯度在反向传播中更稳定,不至于追责到一半数值就崩溃
  • 交叉熵 — 交叉熵 + Softmax 的梯度恰好是 p_i - y_i,是反向传播中少有的"简洁追责"

results matching ""

    No results matching ""