前向传播
文章的第二部分重点介绍了 LTCs 的前向传播算法。文章指出 ODEs 系统的状态可以由一个模拟系统从x(0)到x(t)的轨迹的数值 ODE 求解器来计算,ODE 求解器将连续的区间[0, T],分解成离散的时间序列[t_0, t_1, .t_n].
ODE 求解器的设计
文章指出 LTC 的 ODE 是一个刚性的方程组,首先想想什么是刚性方程组
刚性方程组
想象一个系统,其中包含着两个或这两个以上以截然不同的速度演化的部分,一部分变化非常快,另一部分变化非常慢,一个同时包含这两种极端行为的动态系统就被称为刚性系统,其数学描述就是刚性方程。
刚性方程意味着什么
刚性方程组是显式求解器的噩梦,比如标准的欧拉法或文章中提到的 RK 法。由于系统中一部分变化非常快,显示求解器为了准确捕捉到快速变化的部分,必须使用极小的时间步长\Delta t,然而系统要达到最终稳定状态,是由变化缓慢的部分决定的,如果以极小的时间步长\Delta t去计算,需要的计算步骤是巨大的,在计算上是不可行的。
为什么 LTC 的 ODE 是刚性方程组
我们观察 LTC 的 ODE 方程
其中包含了两项,一项是系统的“驱动力”f(x(t), I(t), t, \theta)A,它将系统状态x(t)拉向一个由A决定的状态,一项是“泄露项”-[1/τ + f(x(t), I(t), t, θ)]x(t),它将系统状态x(t)拉向 0,系统的最终状态是由这两项共同决定的。
由于f通常是一个有界的 S 形函数,例如\mathrm{tanh},因此f的值变化相对平滑和缓慢。但是当f的值较小时,时间常数\tau_{sys}会很大,导致x(t)以非常小的速度衰减,但是当f的值较大时,时间常数\tau_{sys}会很小,x(t)衰减速度变快。因此 LTC 的 ODE 本身构成了一个刚性系统,它的动态方程中同时包含了一个变化缓慢的“驱动力”以及一个可能变化迅速的“泄露项”,这完全取决于当前时刻的输入以及网络状态。
混合求解器
基于以上内容,文章得出结论,基于 RK 的 ODE 求解器不适合 LTCs,因此设计了全新的混合求解器,融合了显式和隐式欧拉方法,即混合求解器,它的更新规则是文中的公式 2
特别地,仅在线性项中使用x(t_{i+1})替换x(t_i),将这个规则代入公式 1,可以将公式 1 改写为
因此,x(t_{i+1})可以表示为
将f_{fused}代入可得
展开化简即可得到文中的公式 3
这样就完成了 LTC 网络中的一次状态更新。
为什么这样设计可以解决刚性问题
以下回答来自 AI
这是一个非常深刻的问题,触及了数值分析的核心。要理解为什么融合求解器能解决刚性方程问题,我们需要先理解标准方法为什么会失败,然后看融合求解器是如何巧妙地绕过这个陷阱的。
第一步:回顾问题——为什么标准显式方法会失败?
想象一个简单的衰减方程:dx/dt = -kx,其中 k 是一个很大的正数(代表快速衰减)。
- 显式欧拉法的更新规则是:
x(t+Δt) = x(t) + Δt * (-k * x(t)) = x(t) * (1 - kΔt) - 问题所在:为了保持稳定,我们必须保证
|1 - kΔt| < 1。这意味着kΔt < 2,即Δt < 2/k。 - 困境:因为
k非常大(刚性),所以允许的Δt必须非常非常小。否则,1 - kΔt会变成一个绝对值大于 1 的负数,导致x(t+Δt)的符号与x(t)相反,并且绝对值更大。这会在每一步都造成误差的指数级放大,最终导致计算结果剧烈振荡并发散。
核心问题: 显式方法只“看”当前状态 x(t) 来决定下一步。当系统内在的“拉力”很强时(k 很大),只看当前状态很容易“反应过度”,一步迈得太大,直接“冲过头”到稳定点的另一边,导致不稳定。
第二步:黄金标准——为什么隐式方法能成功?
- 隐式欧拉法的更新规则是:
x(t+Δt) = x(t) + Δt * (-k * x(t+Δt)) - 关键区别:它用未来状态
x(t+Δt)来计算变化率。 - 求解:
x(t+Δt) + kΔt * x(t+Δt) = x(t)=>x(t+Δt) * (1 + kΔt) = x(t)=>x(t+Δt) = x(t) / (1 + kΔt) - 为什么稳定:无论
kΔt有多大,分母(1 + kΔt)永远是正数。这意味着x(t+Δt)的符号永远与x(t)相同,并且其绝对值永远小于x(t)。它永远不会“冲过头”,而是平滑地、稳定地衰减到零。
核心优势: 隐式方法通过“展望未来”,提前知道了强拉力的存在,并据此调整了步长。它天生就具有 A-稳定性,即对于任何稳定的系统,无论步长 Δt 取多大,它都能保证数值解的稳定性。
第三步:融合求解器的巧妙之处——“对症下药”
现在我们来看 LTC 的融合求解器(公式 3):
x(t+Δt) = [x(t) + Δt * f(x(t), I(t), t, θ) * A] / [1 + Δt * (1/τ + f(x(t), I(t), t, θ))]
这个公式的结构是:
x(t+Δt) = (一个显式计算的部分) / (一个起稳定作用的分母)
它解决刚性问题的秘诀在于:它将 ODE 中导致不稳定的“元凶”单独拎出来,用隐式方法处理,而对其他部分则采用简单的显式方法。
- 识别“元凶”:
在 LTC 的 ODEdx/dt = -[1/τ + f(...)]x + f(...)A中,导致刚性和不稳定性的正是那个快速变化的衰减项-[1/τ + f(...)]x。当f(...)很大时,这一项会变得非常“陡峭”,就像我们例子中的-kx。 - “对症下药”——隐式处理:
融合求解器在推导过程中,专门针对这个线性项-[...]*x使用了隐式更新,即用x(t+Δt)替换了x(t)。这个操作最终体现在了公式的分母上:
分母 = 1 + Δt * (1/τ + f(...))
这个分母的结构和隐式欧拉法中的(1 + kΔt)完全一样!它起到了一个**“稳定器”或“阻尼器”**的作用。当f(...)变得很大时,分母也随之变大,从而强制整个更新结果保持稳定,防止“冲过头”。 - 保持效率——显式处理:
对于复杂的非线性项f(...)A,求解器保留了x(t),即采用显式计算。这使得整个公式最终可以被代数求解,而不需要像完全隐式方法那样进行昂贵的迭代计算。
一个生动的比喻
想象你在驾驶一辆速度极快的跑车(刚性系统)通过一个狭窄的 S 弯(快速变化的动态)。
- 显式方法:你只盯着后视镜(当前状态
x(t))来打方向盘。当你看到弯道时,已经晚了,你会以极快的速度撞上墙。 - 隐式方法:你每开一米就停下来,用 GPS 计算未来十米的最佳路线,然后再开。非常安全,但速度慢得像蜗牛。
- 融合方法:你用眼睛看着前方的道路曲率(显式处理
f(...)),但你的脚一直放在刹车上,并且有一个智能防抱死系统(ABS)在实时监测车轮转速(隐式处理-[...]*x)。当你入弯太快时,ABS 会立刻介入,自动帮你减速,确保你不会失控撞墙。你既能开得快,又保证了安全。
总结
融合求解器之所以能解决 LTC 的刚性问题,是因为它并非“一刀切”地使用显式或隐式方法,而是智慧地识别出导致不稳定的特定动态项(线性衰减项),并仅对该项应用了具有 A-稳定性的隐式更新。
这个“稳定器”被巧妙地设计成了公式的分母,从而在保持计算效率(避免了迭代)的同时,获得了处理刚性方程所必需的数值稳定性。这正是它设计的精髓所在。
前向传播算法
算法 1 介绍了 ODE 融合求解器的代码实现,其中核心的函数 FusedStep 就是公式 3. 通过循环调用 FusedStep 函数 L 次,提高数值精度。这里的L是将\Delta t又均分成了L份,而不是将整个时间序列分成L份,因此输入长度为T的时间序列时,算法的时间复杂度为O(T \times L).
算法的参数\theta是 LTC 网络中所有可学习的参数的集合。其中,\tau^{(N \times 1)}是每个神经元的基础时间常数向量,说是基础时间常数向量是因为真正的时间常数其实是\tau_{sys}。\gamma^{(M \times N)}是从输入I(t)到隐藏层的权重矩阵,\gamma_{r}^{(N \times N)}是隐藏层内部的循环权重矩阵,\mu^{(N \times 1)}是隐藏层的偏置向量,这些参数共同构成了非线性函数f,例如,f = \mathrm{tanh} (\gamma_{r} * x + I * \gamma + \mu). A^{(N \times 1)}是偏置向量,决定的是“驱动力”的最终状态。L是一个超参数,调节计算的精度。
算法 1 的工作流程可以概括为:接收当前状态x(t)、输入I(t)以及参数\theta. 调用 FusedStep 函数 L 次,返回\Delta t后的系统状态x(t + \Delta t).