
反向传播
如何训练一个 ODE
传统的神经网络(如 CNN)由离散层组成,训练它们的反向传播是直截了当的,但是 LTCs 是由一个连续时间的常微分方程(ODE)定义的,与传统的神经网络训练区别很大。文章中提到了两种主流的训练方法,并选择了第二种方法进行训练。
方法一:伴随方法
由 Chen 等在 2018 年提出的方法,它的核心思想是,在计算前向传播时不存储中间状态,反向传播时将中间态重新计算出来,是一种以计算换存储的方法。
伴随方法的数学原理
假设我们的 ODE 是dx/dt = f(x(t), t, \theta),最终损失是L(x(T)),目标是求梯度dL/d\theta,根据链式法则,梯度可以写成
dL/dθ = ∫ (dL/dx) \times (dx/dθ) dt.
这个式子可以理解为\theta对系统状态从x(0)->x(T)的每一步变化的影响的累加,最终形成了损失值L。但是直接计算这个积分非常困难,因此引入了一个新的变量——伴随状态a(t).
a(t)的定义是损失函数L对x(t)的梯度:
a(t) = dL / dx(t).
a(t)的物理意义是,在t时刻,系统状态x(t)的微小变化会对最终的损失值L产生多大的影响,我们的目标是找到一个关于a(t)的微分方程,从而求得所有时刻的a(t),进一步求出dL/d\theta.
首先考虑一个极小的时间间隔dt,在t时刻的状态x(t)首先影响了下一个时刻的状态x(t+dt),然后x(t+dt)又影响了下一个时刻的状态x(t+2dt),以此类推,最终影响损失值L。根据链式法则可以写出
dL / dx(t) = dL / dx(t + dt) \times dx(t + dt) / dx(t).
代入a(t)可得
a(t) = a(t + dt) \times dx(t + dt) / dx(t).
问题变成了计算状态转移得雅可比矩阵dx(t+dt) / dx(t). 利用 ODE 的欧拉法近似(推到过程见附录):
x(t+dt) ≈ x(t) + f(x(t), t, θ) \times dt.
式子两边对x(t)求导,得
dx(t + dt) / dx \approx dx(t) / dx(t) + d(f \times dt) / dx(t).
即
dx(t + dt) / dx \approx I + \frac{\partial f}{\partial x} \times dt.
将dx(t + dt) / dx代入可得
a(t) \approx a(t + dt) \times ( I + \frac{\partial f}{\partial x} \times dt).
展开得
a(t) \approx a(t + dt) + a(t + dt) \times \frac{\partial f}{\partial x} \times dt.
移项并除以dt得
(a(t) - a(t + dt)) / dt \approx a(t + dt) \times \frac{\partial f}{\partial x}.
将(a(t) - a(t + dt)) / dt近似为da / dt,a(t + dt)近似为a(t),即可求得伴随方程
da / dt \approx -a(t) \times \frac{\partial f}{\partial x}.
附录 1
推导公式
x(t+dt) ≈ x(t) + f(x(t), t, θ) \times dt.
根据极限与导数的定义:
dx/dt = \lim\limits_{\Delta t \to 0} \frac{x(t + \Delta t) - x(t)}{\Delta t}
假设一个很小的步长dt,则有
dx/dt \approx \frac{x(t + dt) - x(t)}{dt}
在 ODE 中,dx/dt = f(x(t), t, \theta),代入可得
f(x(t), t, \theta) \approx \frac{x(t + dt) - x(t)}{dt}
两边同时乘dt,得
x(x + dt) - x(t) \approx f(x(t), t, \theta).
移项得
x(t+dt) ≈ x(t) + f(x(t), t, θ) \times dt.