反向传播:故障溯源
本节要解决什么问题
在分布式系统中,当最终请求失败时,工程师最关心的问题不是"失败了",而是"谁的锅"。每个中间环节都可能是嫌疑对象:网关超时?数据库慢查询?缓存击穿?下游服务故障?你需要从最终结果往回追,查每个环节分别负多大责任。
反向传播(Backpropagation)做的事情本质上完全一样:从最终损失出发,往回追查每个参数分别要负多少责任。责任的大小,用梯度来衡量。
这个工具/机制是怎么工作的
链式法则:追责的工具
反向传播的数学核心是链式法则。假设:
L = f(y), y = g(x)
那么 x 的变化对 L 的影响,等于两步影响的乘积:
dL/dx = dL/dy × dy/dx 用故障溯源的比喻:
- dL/dy:你(y 的负责人)要对最终损失负多少责任
- dy/dx:你的上游(x 的负责人)对 y 的偏差负多少责任
- 两者相乘:上游通过你传导,最终对损失负了多少责任
完整的计算图示例
以一个带分叉的网络为例:
h2 = w2 · h1
↗
x ──→ h1 ──→ y = h2 + h3 ──→ L
↘
h3 = w3 · h1
数学定义:
h₁ = w₁ · x h₂ = w₂ · h₁ h₃ = w₃ · h₁ y = h₂ + h₃ L = (1/2) × (y − t)² 设定数值:x=2, t=12, w1=1, w2=2, w3=3
正向传播(计算损失):
h1 = 1 × 2 = 2
h2 = 2 × 2 = 4
h3 = 3 × 2 = 6
y = 4 + 6 = 10
L = (1/2) × (10 - 12)² = 2 ← 损失值
反向传播(追责):
第1步:∂L/∂y = y - t = 10 - 12 = -2 ← 终点开始
第2步:分叉——梯度同时传向两个分支
∂L/∂h2 = ∂L/∂y × ∂y/∂h2 = -2 × 1 = -2
∂L/∂h3 = ∂L/∂y × ∂y/∂h3 = -2 × 1 = -2
第3步:追到 w2 和 w3
∂L/∂w2 = ∂L/∂h2 × ∂h2/∂w2 = -2 × 2 = -4
∂L/∂w3 = ∂L/∂h3 × ∂h3/∂w3 = -2 × 2 = -4
第4步:梯度汇总(关键!)
h1 同时影响了 h2 和 h3,两个分支的"责任"要相加
∂L/∂h1 = ∂L/∂h2 × ∂h2/∂h1 + ∂L/∂h3 × ∂h3/∂h1
= (-2) × 2 + (-2) × 3 = -10
第5步:追到 w1
∂L/∂w1 = ∂L/∂h1 × ∂h1/∂w1 = -10 × 2 = -20
最终每个参数的梯度:
| 参数 | 梯度 | 含义 |
|---|---|---|
| w1 | -20 | 责任最重,影响路径最长 |
| w2 | -4 | 直接影响输出 |
| w3 | -4 | 直接影响输出 |
两个关键概念
梯度分发:在分叉处,损失对某个节点的梯度,会同时传给所有下游分支。
梯度汇总:当多个分支汇聚到同一个节点时,各自传来的梯度在该节点相加(因为这个节点的输出同时被多个上游影响)。
梯度下降:落实"处罚"
反向传播算出梯度后,梯度下降负责执行参数更新:
w ← w − η × dL/dw
- 梯度为负(w₁ 的情况):当前参数值太小,损失偏高 → 增加 w₁
- 梯度为正:当前参数值太大,损失偏高 → 减少 w
链式法则是追责工具,反向传播是追责的执行策略,梯度下降是根据追责结果执行处罚。
形式化
链式法则(多步):∂L/∂w₁ = (∂L/∂h₁) · (∂h₁/∂w₁)
梯度下降:
for _ in range(num_epochs):
# 正向传播:算损失
loss = forward_pass(model, x, y)
# 反向传播:算梯度
grads = backward_pass(loss, model)
# 参数更新
for param, grad in zip(model.parameters(), grads):
param -= learning_rate * grad
本节小结
反向传播不是新的数学,它只是链式法则和偏导数加法规则在复杂计算图上的系统应用:从最终损失出发,把每个节点和参数的"责任"(梯度)一层层追回来。梯度下降再根据追责结果更新参数。深层网络中梯度消失和梯度爆炸,就是这个"追责链条太长"时产生的副作用。