⛰️ The Energy Landscape of Predictive Coding Networks
Published:
📖 TL;DR: Predictive coding makes the loss landscape of feedforward neural networks more benign and robust to vanishing gradients.
This post explains my recent NeurIPS 2024 paper Only Strict Saddles in the Energy Landscape of Predictive Coding Networks?. In it, we provide, in my very humble opinion, the best theory so far on the learning dynamics of predictive coding in terms of explanatory and predictive power. This work was very much inspired by our previous paper which I wrote about in another post.
🙏 I’d like to acknowledge my collaborators El Mehdi Achour, Ryan Singh and my supervisor Christopher L. Buckley.
Overview
- Predictive coding: A refresher
- Toy models (going deeper)
- A landscape theory
- Experiments
- Concluding thoughts
🧠 Predictive coding: A refresher
I gave a primer of predictive coding (PC) in a previous blog post, so here’s a refresher. PC is an energy-based learning algorithm that can be used to train deep neural networks as an alternative to backpropagation (BP). The key difference with BP is that, before updating weights, PC performs iterative inference over the network activities, as schematically shown by the gif below.
More formally, PC minimises an energy function \(\mathcal{F}\), first with respect to activities
\[\textbf{PC inference:} \quad \Delta z \propto - \nabla_{z} \mathcal{F}\]until convergence is reached \(\Delta z \approx 0\). (For simplicity, we will denote the converged energy at an inference equilibrium as \(\mathcal{F}^*\).) Then, we update the weights
\[\textbf{PC learning:} \quad \Delta \theta \propto - \nabla_{\theta} \mathcal{F}^*\]How can we gain insight into these learning dynamics? There have been some theories, but they have all tended to make unrealistic assumptions or approximations that end up not predicting well experimental data. In previous work [1], for example, we showed that first-order updates on neurons allow one to perform some kind of second-order update on the weights, making PC an implicit second-order method. But this was only to a second-order approximation and doesn’t provide as much explanatory power as we would like.
🪆 Toy models (going deeper)
It’s often a good idea to start from toy models. In our previous post, we considered the simplest possible deep neural network with a single hidden linear unit \(f(x) = w_2w_1x\). We then showed that PC inference has the effect of reshaping the loss landscape, and that stochastic gradient descent (SGD) on this reshaped landscape (the equilibrated energy) escapes the saddle point at the origin faster than on the loss \(\mathcal{L}\).
Figure 1
Now let’s try to go deeper and see if we can get some more intuition. Still considering the origin, what happens if we add just one layer or weight?
Figure 2
We see that SGD on the equilibrated energy escapes significantly faster than on the loss (given the same learning rate). It’s not as easy to see from the landscape visualisations, but if you look closely SGD on the loss spends a lot more time near the saddle (as indicated by the higher concentration of yellow dots 🟡). If this reminds you of “vanishing gradients”, it’s exactly that–just viewed from a landscape perspective.
What happens if we further increase the network depth and width?
Figure 3
For a more standard network (with 4 layers and non-unit width), PC now escapes orders of magnitude faster than BP (again initialising close to the origin and using SGD with the same learning rate). We can no longer visualise the landscape; however, we can project it onto the maximum and minimum curvature (Hessian) directions. Interestingly, we see that while the loss is flat around the origin, the equilibrated energy has negative curvature.
So it seems that, no matter the network depth and width, PC inference makes the origin saddle much easier to escape and thus learning more robust vanishing gradients. Can we say something more formal?
🏔 A landscape theory
In our paper, we use deep linear networks (DLNs) as our theoretical model, since they are the standard model for studies of the loss landscape and are relatively well understood. In contrast to previous theories of PC, this is the only major assumption we make, and we empirically verify that the theory holds for non-linear networks (more on this below).
The first surprising theoretical result is that for DLNs we can derive an exact solution for the energy at the inference equilibrium \(\mathcal{F}^*\). This is important because it’s the effective landscape on which PC learns.
\[\mathcal{F}^* = 1/2N \sum_i^N (\mathbf{y}_i - W_{L:1} \mathbf{x}_i)^T S^{-1} (\mathbf{y}_i - W_{L:1} \mathbf{x}_i)\]where \(\mathbf{x}_i\) and \(\mathbf{y}_i\) are the input and output, respectively, and \(W_{L:1}\) is just a shorthand for the network’s feedforward map. So, in the linear case, the equilibrated energy is simply a rescaled mean squared error (MSE) loss, where the rescaling depends on the network weights. This result formalises the intuition from our toy simulations that PC inference has the effect of reshaping the loss landscape.
But is this rescaling useful? How does the equilibrated energy differ from the loss?
Let’s return to our origin saddle. We know from previous work that this saddle becomes flatter and flatter as you increase the depth of the network. More precisely, the “order-flatness” of the saddle, if you like, is equal to the number of hidden layers. So, if you have 1 hidden layer, then the saddle is flat to order 1 (the gradient is zero), but there is negative curvature. And if you have 2 hidden layers, then there is no curvature around the saddle, and there is an escape direction in the third derivative.
First-order saddles are also known as “strict”, while higher-order saddles are labelled as “non-strict” [2]. You can loosely think of these as “good” and “bad” saddles, respectively, since non-strict ones can effectively trap first-order methods like gradient descent. Surprisingly, it turns out that the origin saddle of the equilibrated energy is always strict independent of network depth. More formally,
\[\lambda_{\text{min}}(H_{\mathcal{F}^*}(\boldsymbol{\theta} = \mathbf{0})) < 0, \quad \forall h \geq 1 \quad [\text{strict saddle}]\]where left side of the inequality is the minimum eigenvalue of the Hessian of the equilibrated energy at the origin, and \(h\) is the number of hidden layers.
This result explains our toy simulations. But what about other non-strict saddles? We know that there are plenty others in the loss landscape. Do they also become strict in the equilibrated energy, i.e. after PC inference? In the paper we then consider a general saddle type of which the origin is one (technically saddles of rank zero) and prove that indeed they all become strict in the equilibrated energy. We address other saddle types experimentally (see below).
Experiments
The above theory is for linear networks. Does it still hold for practical, non-linear ones? Fortunately, yes. We run a variety of experiments with different datasets, architectures and non-linearities and in all cases find that, when initialised close to any of the studied saddles, SGD on the equilibrated energy escapes much faster than on the loss (again for the same learning rate). The figure below shows results for the origin saddle, as an example.
Figure 4
To test saddles that we do not address theoretically, we trained networks on a matrix completion task where we know that starting near the origin GD will transition through saddles of successive rank. The figure below shows that SGD on the equilibrated energy (PC) quickly escapes all the saddles found on the loss (BP, including higher-order ones that we did not study theoretically) and does not suffer from vanishing gradients.
Figure 5
Based on these and other results in the paper, we conjecture that all the saddles of the equilibrated energy are strict. We don’t prove it– hence the question mark in the title of the paper–but the empirical evidence is quite compelling. Code to reproduce all results is available here.
💭 Concluding thoughts
So, we have shown, theoretically and empirically, that PC inference has the effect of reshaping the (MSE) loss landscape, making many (perhaps all) “bad” (non-strict) saddles of the loss “good” (strict) or easier to escape. These saddles include the origin, effectively making PC more robust to vanishing gradients.
The flip side of this story is that the convergence speed of PC inference scales very badly with the depth of the network. So, in a way, the problems in weight space are moved over to inference space. Stay tuned for progress this!
References
[1] F. Innocenti, R. Singh, and C. L. Buckley. Understanding Predictive Coding as a Second-Order Trust-Region Method. ICML Workshop on Localized Learning (LLW), 2023
[2] R. Ge, F. Huang, C. Jin, and Y. Yuan. Escaping from saddle points—online stochastic gradient for tensor decomposition. In Conference on learning theory, pages 797–842. PMLR, 2015
[3] M. Stern, A. J. Liu, V. Balasubramanian. Physical effects of learning. Physical Review E, 109(2):024311, 2024.