interactive essay · draft

Understanding the learning dynamics of recurrent neural networks

Saddle-to-saddle dynamics underlie many ANN learning processes, and often the saddles are quite interpretable as 'modes' learned one by one. But what are these modes for recurrent systems?

§ 1

Modes are how networks learn

It is often intuited that deep networks learn "simple things first". But what a "simple" hypothesis?

The learning dynamics of neural networks offer a wonderful lens to understand what functions neural networks are actually learning. Gradient descent follows a trajectory through parameter space, slowly moving from a random initialization towards a solution. At some finite time we truncate this trajectory, and end up with a network. What we learn earlier on this trajectory can be thought of as a "bias" affecting what sorts of functions networks "prefer" to learn.

For example, if you train a deep linear network with a small initialization on a structured target, the loss does not glide smoothly to zero. It drops in steps.1 Each cliff is the 'aha' moment associated with the learning of a new 'mode' of the target. In deep linear networks, we can say exactly what these modes are: the singular directions of the input-output correlation matrix.2 If your task is, say, to encode natural images, the first mode will be the dominant principal component of the image dataset, and the learning dynamics of gradient descent learn low spatial frequencies first.3

Folks at UCL (Yedi Zhang, Andrew Saxe, and Peter Latham) recently distilled this intuitive picture even further into a general principle: saddle-to-saddle dynamics2. It's one of those ideas that's immediately intuitive. Picture the dynamics of learning as a staircase, where each step is a saddle point in the loss landscape. It's quite clear why gradient descent would get stuck at saddles, and how it would escape them one by one. Our 'modes' correspond to the thing that we pick up when we escape a saddle. For linear feedforward networks, the modes are computable via the singular value decomposition of the training data.

§ 2

What are the dynamics of learning in recurrent networks?

A neuroscientist's everyday network model has memory: a recurrent population $h_t \in \mathbb{R}^n$ driven by a scalar input $x_t$ and read out as a scalar $y_t$,

$$h_{t+1} = W h_t + b\, x_t, \qquad y_t = c^\top h_t.$$

What are the 'modes' we pick up during learning? The saddle-to-saddle picture of learning gives us a recipe. To understand the learning dynamics of this system, we must find the saddles and the modes that emerge when a saddle is passed.

Let's get some incorrect answers out of the way first. Unlike in feedforward linear networks, the answer is not the singular values. If you're familiar with linear dynamical systems, your next guess might be the eigenvalues of $W$. This is close, but eigenvalues are not still not quite the object we're looking for. You might also guess the moments when we learn new modes correspond to bifurcations in the dynamics of $h_t$. But none of these are quite right.

What we will find below is that the saddles of the loss landscape are actually interpretable as cancellations between poles and zeros of the I/O transfer function, and learning new modes correspond to new pole-zero pairs emerging from the same location in the complex plane. The rest of this post is an introduction to what that sentence means, and an exploration of how we can use it to understand the learning dynamics of recurrent networks.

§ 3

Introducing the transfer function

Understanding learning dynamics requires describing the function the network implements, which is best done with the transfer function, a classic tool from control theory.

When we ask "what has the network learned?" we really mean: what map does it implement from the input stream $\{x_t\}$ to the output stream $\{y_t\}$? What describes the input/output behavior?

Unroll the recurrence with $h_0 = 0$ and the map appears directly: every output is a weighted sum of past inputs,

$$y_t \;=\; \sum_{s} g_s\, x_{t-s}, \qquad g_s \;=\; c^\top W^s b,$$

where the sequence $\{g_s\}_{s \ge 0}$ is the impulse response — the output the network produces from a unit kick at $t=0$. This convolution is the I/O function: hand it any input, the formula returns the output. Two parameter triples that produce the same $\{g_s\}$ are the same network in different clothes.

The impulse response has hidden structure. Diagonalising $W = V P V^{-1}$ with $P = \mathrm{diag}(p_1, \dots, p_n)$ being the set of eigenvalues $\{p_k\}_{k=1}^n$ of $W$ that we got to play around with above, and $V$ the matrix of right eigenvectors. You might recall that eigenvectors are those vectors $v_k$ that are simply scaled by $W$, as in the equation $W v_k = p_k v_k$. The eigenvalues $p_k$ are the scaling factors. Because of this definition they play very nicely when we apply $W$ repeatedly $s$ times: $W^s v_k = p_k^s v_k$. This means that when we write the impulse response in the eigenbasis of $W$, it breaks up into a sum of pure exponentials:

$$g_s \;=\; c^\top V P^s V^{-1} b \;=\; \sum_{k=1}^{n} r_k\, p_k^{\,s},$$

one term per mode of $W$. Each mode rings as a pure exponential weighted by the residue $r_k = (c^\top v_k)(u_k^\top b)$, the product of how strongly the input excites mode $k$ and how strongly the readout listens to it.

What we'd really love is a clean description of the output stream that doesn't depend on an infinite series of terms going back to the beginning of time. We'd like a single finite expression that is the I/O map, with no time index left in it. It turns out there's a trick. We can sum up the contribution of mode $k$ from all timepoints in the past via a dummy variable $z^{-1}$ for each time lag: $\sum_s r_k p_k^s z^{-s}$. Because this is a geometric series, it sums in closed form — just the way $1 + r + r^2 + \cdots = 1/(1-r)$ — collapsing to a single fraction. Adding $n$ modes gives the transfer function,5

$$H(z) \;=\; \sum_{k=1}^{n} \frac{r_k}{z - p_k} \;=\; c^\top (zI - W)^{-1} b.$$

Look at the left-hand form: $H(z)$ blows up whenever $z - p_k = 0$. These singularities are called poles, and they sit exactly at the eigenvalues of $W$. Meanwhile, at certain other values of $z$ the $n$ terms cancel and $H(z) = 0$; these are the zeros (there are at most $n-1$ of them). Poles and zeros together fix $H(z)$ up to overall scale — a finite list of marked points in the complex plane $\mathbb{C}$ that completely describes the I/O map.

This is the benefit of the transfer function. A finite list of marked points in $\mathbb{C}$ — at most $n$ poles and $n-1$ zeros — is a complete, basis-free summary of what the network does to any input, ever. This is the answer to "what has the network learned": a configuration of points in the complex plane.

Drag a pole or a zero in the playground below; the frequency response and impulse response update live. Poles come in complex-conjugate pairs because $W$ is real, and the widget enforces this by mirroring across the real axis.

§ 4

The residue is the readout volume knob

A pole (i.e. eigenvalue of $W$) alone is not enough to produce a resonance: it has to be both excited and read out. The residue at each pole determines how loudly that mode gets talked to by the input and how much it speaks to the output. When a zero sits directly on top of a pole, the residue is zero and the mode is silent — even though the eigenvalue of $W$ is still there. Pulling the zero away activates the mode; the residue grows in proportion to the pole–zero separation (linearly, for small separations).

Slide the zero away from the pole in the widget below and watch the resonance peak appear out of nothing.

§ 5

Cancellation is a saddle

Let's recall the big question: what are the learning dynamics of recurrent networks? We can suspect that the dynamics will still be saddle-to-saddle. But what are the saddles?

The answer is: saddles are the points in (W, b, c) parameter space where a pole and a zero cancel each other out, making the residue zero.

The reason comes from the product structure of residues. Recall that the residue factorises as $r_k = \alpha_k \beta_k$, where $\alpha_k$ is how strongly the readout $c$ observes mode $k$, and $\beta_k$ is how strongly the input $b$ excites it. The output depends on mode $k$ only through this product:

$$\hat{y}_t \;=\; \alpha_k\,\beta_k\;\phi_k(t) \;+\;\underbrace{\sum_{j \neq k} r_j\,\phi_j(t)}_{\text{other modes}},$$

where $\phi_k(t) = \sum_{s=0}^{t-1} p_k^{\,t-1-s}\, x_s$ is mode $k$'s response to the input — the input convolved with the pure exponential $p_k^{\,s}$ that mode $k$ rings at. (This is just the $r_k\,p_k^{\,s}$ term of the impulse response from the transfer-function section, applied to the input stream.)

For a moment let us restrict our attention to this one mode $k$, and consider what we should do holding the other modes fixed. Looking only at the $(\alpha_k, \beta_k)$ subspace, the squared-error loss for those parameters becomes

$$\mathcal{L}(\alpha_k,\,\beta_k) \;=\; \tfrac{1}{2}\bigl\|r_k^{\ast} - \alpha_k\,\beta_k\bigr\|^2 \;+\;\text{const},$$

where $r_k^{\ast}$ stands for the target residue — a single number summarising what the data wants mode $k$ to be. This loss is the recurrent analogue of what we see for feedforward networks, where the loss for mode $k$ is $\mathcal{L} = \frac{1}{2}\|\sigma^{\ast} - u\,v\|^2$ for a factorised singular value.

Now check the gradients at the cancellation point $\alpha_k = \beta_k = 0$:

$$\frac{\partial\mathcal{L}}{\partial\alpha_k} = -\bigl(r_k^{\ast} - \alpha_k\beta_k\bigr)\,\beta_k \;\xrightarrow{\;\alpha_k=\beta_k=0\;}\; 0,$$ $$\frac{\partial\mathcal{L}}{\partial\beta_k} = -\bigl(r_k^{\ast} - \alpha_k\beta_k\bigr)\,\alpha_k \;\xrightarrow{\;\alpha_k=\beta_k=0\;}\; 0.$$

Both gradients vanish — it's a critical point. But it's not a minimum. The Hessian at the origin is

$$H \;=\; \begin{pmatrix} 0 & -r_k^{\ast} \\[4pt] -r_k^{\ast} & 0 \end{pmatrix}$$

with eigenvalues $\pm\, r_k^{\ast}$ — one positive, one negative. The loss surface is a hyperbolic paraboloid in $(\alpha_k, \beta_k)$, with the unstable direction along $\alpha_k = +\beta_k$. Learning escapes the saddle along this direction, turning on the mode.

Note that this point is a saddle of the full loss only as long as the other modes are also at minima of their restricted losses. We are only looking at the loss along this one mode.

Notice what happens at the origin $b = c = 0$: every residue vanishes at once, so all poles are simultaneously cancelled by zeros. This is the highest saddle — every mode is an unstable direction — and small random initialisation places the network nearby. The modes then race to escape — and the loss descends in steps, one for each mode that breaks free.

§ 6

The staircase, decoded

We're now prepared to actually look at the dynamics of training our recurrent linear network. Play around for this playground for a while. Notice especially how the step-like learning behavior manifests. Each step is one pole–zero pair separating in the complex plane — one mode escaping its saddle.

§ 7

Plateaus resemble best-fit degree-k models

What does the function we learn at each plateau look like?

The best $k$-pole rational approximation

Set the network aside for a moment. Suppose we are simply handed the target impulse response $g_t^{\ast}$ and asked: what is the best linear recurrent system with exactly $k$ poles? This seems like a reasonable hypothesis to check against, and this optimum is something we can write down. Concretely, choose both $k$ poles $\{p_1, \ldots, p_k\}$ and $k$ residues $\{r_1, \ldots, r_k\}$ to minimise $\mathcal{L}(p, r) = \tfrac{1}{2}\sum_{t \ge 0}\bigl( \sum_j r_j\,p_j^{\,t} - g_t^{\ast}\bigr)^2.$ The problem is non-convex in the poles, but the residues are easy: at any fixed pole configuration $p$ the loss is quadratic in $r$, so the optimal residues are the unique solution of the normal equations $G\,r = d$ with $G_{jl} = \tfrac{1}{1 - p_j\,p_l}$ and $d_l = \sum_j \tfrac{r_j^{\ast}}{1 - p_j^{\ast}\,p_l}.$ Plugging those optimal residues back in profiles them out and leaves a closed-form objective in the poles alone — the Gram identity $\mathcal{L}^{\ast}(p) \equiv \min_r \mathcal{L}(p, r) = \tfrac{1}{2}\bigl(\|g^{\ast}\|^2 - d^\top G^{-1} d\bigr).$ The best degree-$k$ rational approximation to $g^{\ast}$ is then the global minimum of $\mathcal{L}^{\ast}_k(p)$ over all real-pole configurations (or complex with conjugate symmetry).

Does gradient descent learn the best $k$-pole rational approximation?

Playing around with the widget above, what seems to be happening is that at the rank-$k$ plateau of gradient descent, the student's $k$ active poles and their residues together resemble a good the $k$-pole rational approximation to the target.

Note those optimal pole locations are generically not at any subset of the target's poles. Each plateau represents a compromise pole configuration optimised for the rank-$k$ budget. This means that old poles readjust when a new one activates.

We can test this hypothesis in simulation. Below, an $n = 3$ student trains on a three-real-pole target at small initialisation $\sigma_0 \approx 0.01$. We've plotted the losses of the best 1-pole and 2-pole approximations to the target, and you can see that the plateaus of gradient descent line up with these best-approximation losses very closesly.

Because $(b, c)$ are dimensionally fast and the eigenvalues of $W$ are slow (§ 8, bullet 1), the residues continuously equilibrate to $r^{\ast}(p) = G^{-1}d$ for whatever the current eigenvalues are, and the slow drift of the eigenvalues is gradient descent on the closed-form $\mathcal{L}^{\ast}(p)$ landscape — descending its successive minima as more directions in $W$ activate.

§ 8

Some observations on dynamics

Interestingly, there's a lot else we can observe about the dynamics besides the fact that pole/zero separations are saddles.

  1. $b$ and $c$ learn faster than $W$.

    It turns out that the input and output streams evolve much faster than the recurrent weights when our initialization is small. In fact, $(b, c)$ evolve on a timescale $\sim 1/\sigma_0$ and $W$ on $\sim 1/\sigma_0^{\,2}$, which can be a really big difference for tiny $\sigma_0$.

    To see this, we can write out the BPTT gradients for the recurrent linear network. The BPTT gradients of $\mathcal{L} = \tfrac{1}{2}\sum_t e_t^{\,2}$ with $e_t = \hat y_t - y_t^{\ast}$ are

    $$\frac{\partial \mathcal{L}}{\partial c} = \sum_t e_t\, h_{t+1},\qquad \frac{\partial \mathcal{L}}{\partial b} = \sum_t x_t\, \lambda_{t+1},\qquad \frac{\partial \mathcal{L}}{\partial W} = \sum_t \lambda_{t+1}\, h_t^{\top},$$

    with adjoint $\lambda_t = e_t\, c + W^{\top} \lambda_{t+1}$. At initialisation with all parameters at scale $\sigma_0$, count powers of $\sigma_0$ on each side: the hidden state $h_t = W h_{t-1} + b\, x_t$ is $O(\sigma_0)$, the student prediction $\hat y_t = c^{\top} h_{t+1} = O(\sigma_0^{\,2})$ is negligible against the target so $e_t = O(1)$, and the adjoint $\lambda_t = e_t\,c + W^{\top}\lambda_{t+1}$ is $O(\sigma_0)$. Substituting,

    $$\frac{\partial \mathcal{L}}{\partial c} = O(\sigma_0), \qquad \frac{\partial \mathcal{L}}{\partial b} = O(\sigma_0), \qquad \frac{\partial \mathcal{L}}{\partial W} = O(\sigma_0^{\,2}).$$

    This makes the input/ouput streams $(b, c)$ the fast variables, and $W$ the slow one.

  2. Other saddles exist — such as the $b^\top c$ sign flip.

    Before any single pole/zero separation happens, the very first saddle we find actually corresponds to learning the non-recurrent pathway that bypasses the recurrent weights. This is the "zero-lag" path. For our system with 1D inputs and 1D outputs, this is structurally a rank-1 residue.

    This 0-lag path learns first before any poles separate from zero. Perhaps you've seen this in the learning trajectories. Half the time the random initialisation gives it the wrong sign, and learning must flip it. In pole-space this looks bizarre: the single real eigenvalue of $W$ runs off the real axis through the origin, out to $\pm \infty$, and returns from the other side. The plateau structure is the same, but the trajectory is a reminder that "saddle-to-saddle" lives in $(W, b, c)$ space, not in pole space, and the map between them has its own topology.

  3. The input spectrum decides which modes can be learned.

    The §7 widget uses a unit impulse, which excites every frequency equally — so every mode of the target eventually activates. With other inputs this stops being automatic. At each saddle, the gradient that drives the $(\alpha_k, \beta_k)$ escape for mode $k$ is the cross-energy of the input filtered through mode-$k$'s kernel against the residual error,

    $$J_k \;=\; \sum_t \overline{\phi_k(t)}\, e_t, \qquad \phi_k(t) \;=\; \sum_{s} p_k^{\,s}\, x_{t-s}.$$

    If the input has no energy at mode $k$'s frequency — say a bandlimited signal whose passband doesn't include $\arg p_k$ — then $\phi_k(t) \equiv 0$, so $J_k = 0$, and the gradient in the $(\alpha_k, \beta_k)$ directions is exactly zero. The mode isn't slow to learn; it's invisible. A lowpass training signal cannot teach the network a high-frequency oscillation, period.

  4. The first eigenvalue is always real.

    Even when the target wants a complex pair, the rank-1 plateau places both eigenvalues of $W$ on the real axis. The reason is structural: near $W = 0$ the gradient is approximately $\partial \mathcal{L}/\partial W \approx \big(\sum_t e_t\, x_t\big)\, c\, b^\top$ — a rank-1 real matrix. A rank-1 real update can only move eigenvalues along $\mathbb{R}$; complex eigenvalues require two real eigenvalues to first collide on the real axis and bifurcate, which can happen only once $W$ has built up rank-$2$ structure. Real dynamical content (exponential decay) precedes oscillatory content (damped sinusoids) as a geometric consequence of how gradient descent builds up the weight matrix from zero.

  5. Oscillations switch on with a square-root ramp.

    Once two real eigenvalues collide and split into a complex pair, the oscillation frequency doesn't appear at full strength — it grows at a speed proportional to the square root of time. That is, $|\mathrm{Im}\,\lambda| \propto \sqrt{\tau - \tau_{\rm coll}}$, where $\tau_{\rm coll}$ is the moment of collision. The reason is that the imaginary part is set by $\tfrac{1}{2}\sqrt{-D}$, where $D = (\operatorname{tr} W)^2 - 4\det W$ is the discriminant of $W$'s characteristic polynomial. Since $W$ evolves smoothly under gradient descent, so does $D$ — and a smooth function that crosses zero (rather than just touching it and bouncing back) does so linearly to leading order: $D(\tau) \approx D'(\tau_{\rm coll})(\tau - \tau_{\rm coll})$. Taking the square root of something linear gives the square-root onset. I verified this in some simulations, and it holds up pretty strongly.

  6. Plateau dwell time is logarithmic in $\sigma_0$.

    Near each saddle, the unstable mode grows exponentially with a target-specific rate $|J|$, starting at amplitude of order $\sigma_0$, so the escape time is approximately $\tau_{\rm esc} \approx \frac{1}{2|J|} \log\!\bigl(1/\sigma_0^2\bigr) + \text{const}$. Smaller initialisation stretches the plateaux but leaves the sequence of visited configurations — which best-fit $k$-pole model the network sits at on each plateau — entirely unchanged. I'm not sure what the prefactor $|J|$ is but it depends on the target geometry.

§ 9

Conclusion

The learning dynamics of recurrent networks are saddle-to-saddle, just like feedforward networks. This gives linear networks a powerful inductive bias: they learn simple functions first, and then gradually build up to more complex ones. The definition of simple is roughly the degree of the transfer function, which is the number of poles in the I/O map. We do indeed pick up eigenvalues of $W$ one by one, in the sense that we have a sequential series of pole-zero pair separations.

It's possible that this framework could help us understand the dynamics of learning tasks in neuroscience, especially in motor control where you have a timeseries output. However, I think it would be crucial to extend the story in a few ways to make it more realistic for neuroscience.

What this leaves out

The picture above is for a single-input, single-output linear recurrent network with a squared-error loss on the impulse response. Changes for 'real' networks may break our story.

There's lots of directions for future work that would make this story more realistic for both neuroscience and machine learning. For example, learning in neuroscience never starts with near-zero initializations. Yet, as usual, the theory is cleanest for tiny initializations when you start right at a global saddle point. Some other key points of departure from the idealised story above are:

SISO → MIMO

Multi-input, multi-output systems replace a scalar pole with a pole accompanied by an input/output direction in $\mathbb{R}^{n_y}\times\mathbb{R}^{n_x}$. The transfer function becomes matrix-valued.

Linear → nonlinear

Real recurrent networks have an activation: $h_{t+1} = \phi(W h_t + b\, x_t)$. The hidden-state distribution then has finite support, and the linearisation of the network around any operating point has its own poles and residues.

references

References

  1. A. M. Saxe, J. L. McClelland & S. Ganguli. Exact solutions to the nonlinear dynamics of learning in deep linear neural networks. International Conference on Learning Representations (2014). arXiv:1312.6120.
  2. A. M. Saxe, J. L. McClelland & S. Ganguli. A mathematical theory of semantic development in deep neural networks. Proc. Natl. Acad. Sci. USA 116(23), 11537–11546 (2019). doi:10.1073/pnas.1820226116.
  3. A. S. Benjamin, L.-Q. Zhang, C. Qiu, A. A. Stocker & K. P. Kording. Efficient neural codes naturally emerge through gradient descent learning. Nature Communications 13, 7972 (2022). doi:10.1038/s41467-022-35659-7.
  4. A. Jacot, F. Ged, B. Şimşek, C. Hongler & F. Gabriel. Saddle-to-saddle dynamics in deep linear networks: Small initialization training, symmetry, and sparsity. arXiv:2106.15933 (2021).
  5. T. Kailath. Linear Systems. Prentice-Hall (1980).
  6. Y. Zhang, A. Saxe & P. E. Latham. Saddle-to-saddle dynamics explains a simplicity bias across neural network architectures. International Conference on Learning Representations (2026). arXiv:2512.20607.