Search

Search articles

LSTM Gradient Flow: The Constant Error Carousel Explained

Michael BrenndoerferDecember 16, 202537 min read

Learn how LSTMs solve the vanishing gradient problem through the cell state gradient highway. Includes derivations, visualizations, and PyTorch implementations.

Reading Level

Choose your expertise level to adjust how many terms are explained. Beginners see more tooltips, experts see fewer to maintain reading flow. Hover over underlined terms for instant definitions.

LSTM Gradient Flow

In the previous chapters, we saw how vanilla RNNs suffer from vanishing gradients and how LSTMs introduce gates to control information flow. But we haven't yet examined the mathematical reason why LSTMs solve the vanishing gradient problem. The answer lies in how gradients flow through the cell state, a mechanism sometimes called the constant error carousel.

This chapter analyzes gradient flow in LSTMs from first principles. We'll derive the gradient equations, show why the cell state acts as a "gradient highway," compare gradient behavior between vanilla RNNs and LSTMs, and explore when LSTMs still need gradient clipping. We'll also examine peephole connections, a variant that allows gates to directly observe the cell state.

The term "constant error carousel" comes from Hochreiter and Schmidhuber's original 1997 paper. It describes the key insight that makes LSTMs work: the cell state can carry information (and gradients) across many timesteps without the multiplicative decay that plagues vanilla RNNs.

Recall the cell state update equation from the previous chapter:

ct=ftct1+itc~t\mathbf{c}_t = \mathbf{f}_t \odot \mathbf{c}_{t-1} + \mathbf{i}_t \odot \tilde{\mathbf{c}}_t

where:

  • ct\mathbf{c}_t: the cell state at timestep tt
  • ft\mathbf{f}_t: the forget gate activation (values between 0 and 1)
  • ct1\mathbf{c}_{t-1}: the cell state from the previous timestep
  • it\mathbf{i}_t: the input gate activation
  • c~t\tilde{\mathbf{c}}_t: the candidate cell state
  • \odot: element-wise multiplication

The crucial observation is that this update is additive with respect to the previous cell state. The previous cell state ct1\mathbf{c}_{t-1} is scaled by the forget gate and then added to new information. Compare this to the vanilla RNN update:

ht=tanh(Whhht1+Wxhxt+b)\mathbf{h}_t = \tanh(\mathbf{W}_{hh} \mathbf{h}_{t-1} + \mathbf{W}_{xh} \mathbf{x}_t + \mathbf{b})

where:

  • ht\mathbf{h}_t: the hidden state at timestep tt
  • Whh\mathbf{W}_{hh}: the recurrent weight matrix that transforms the previous hidden state
  • ht1\mathbf{h}_{t-1}: the hidden state from the previous timestep
  • Wxh\mathbf{W}_{xh}: the input weight matrix that transforms the current input
  • xt\mathbf{x}_t: the input at timestep tt
  • b\mathbf{b}: the bias vector
  • tanh\tanh: the hyperbolic tangent activation function

In the vanilla RNN, the previous hidden state passes through a matrix multiplication and a nonlinearity at every timestep. In the LSTM, the cell state can flow through with only element-wise scaling by the forget gate.

Out[3]:
Visualization
Diagram comparing gradient flow in vanilla RNN versus LSTM, showing multiplicative decay in RNN and direct pathway in LSTM.
The constant error carousel concept. In vanilla RNNs (top), information must pass through repeated matrix multiplications and nonlinearities, causing gradients to decay exponentially. In LSTMs (bottom), the cell state provides a direct pathway where information can flow with only element-wise scaling by the forget gate.

Why Additive Updates Matter

The difference between multiplicative and additive updates might seem subtle, but it has profound implications for gradient flow. To understand why, we need to look at what happens during backpropagation.

During the backward pass, we compute how the loss changes with respect to earlier states. This requires the Jacobian ctct1\frac{\partial \mathbf{c}_t}{\partial \mathbf{c}_{t-1}}, a matrix that tells us how each component of ct\mathbf{c}_t depends on each component of ct1\mathbf{c}_{t-1}.

Starting from the cell state equation ct=ftct1+itc~t\mathbf{c}_t = \mathbf{f}_t \odot \mathbf{c}_{t-1} + \mathbf{i}_t \odot \tilde{\mathbf{c}}_t, we apply the chain rule:

ctct1=diag(ft)+(itc~t)ct1\frac{\partial \mathbf{c}_t}{\partial \mathbf{c}_{t-1}} = \text{diag}(\mathbf{f}_t) + \frac{\partial (\mathbf{i}_t \odot \tilde{\mathbf{c}}_t)}{\partial \mathbf{c}_{t-1}}

This equation reveals something remarkable. The first term, diag(ft)\text{diag}(\mathbf{f}_t), is a diagonal matrix with the forget gate values on its diagonal. This term comes directly from differentiating ftct1\mathbf{f}_t \odot \mathbf{c}_{t-1}, the additive contribution of the previous cell state.

The second term captures how the new information (itc~t\mathbf{i}_t \odot \tilde{\mathbf{c}}_t) depends on ct1\mathbf{c}_{t-1}. This dependency is indirect: ct1\mathbf{c}_{t-1} affects ht1\mathbf{h}_{t-1}, which affects the gates, which affect the update. We'll analyze this indirect path later.

The critical insight is that the first term provides a direct gradient path with three special properties:

  1. No weight matrices: Unlike vanilla RNNs, gradients don't pass through learned weight matrices on this path.
  2. No activation saturation: The path doesn't involve tanh or sigmoid derivatives that could shrink gradients.
  3. Learnable scaling: The forget gate values are learned, so the network can control gradient flow.

When the forget gate is close to 1, this term is approximately the identity matrix I\mathbf{I}. In that case, gradients flow backward with minimal decay: the "constant error carousel" in action.

Constant Error Carousel

The constant error carousel refers to the LSTM's ability to maintain constant gradient flow through the cell state when the forget gate equals 1. In this limiting case, ctct1=I\frac{\partial \mathbf{c}_t}{\partial \mathbf{c}_{t-1}} = \mathbf{I}, and gradients propagate backward indefinitely without decay. The "carousel" metaphor evokes information cycling through time without degradation, like a carousel that keeps spinning without friction.

The Forget Gate as a Gradient Highway

The forget gate serves a dual purpose in LSTMs. In the forward pass, it controls how much of the previous cell state to retain. In the backward pass, it controls how much gradient flows to earlier timesteps. This duality is not coincidental: it's a fundamental property of how gradients work.

Deriving the Gradient

Let's make the gradient highway concrete with a formal derivation. Suppose we have a loss L\mathcal{L} computed at the final timestep TT. We want to understand how gradients propagate backward, specifically how Lct\frac{\partial \mathcal{L}}{\partial \mathbf{c}_t} relates to Lct1\frac{\partial \mathcal{L}}{\partial \mathbf{c}_{t-1}}.

Starting from the cell state update equation:

ct=ftct1+itc~t\mathbf{c}_t = \mathbf{f}_t \odot \mathbf{c}_{t-1} + \mathbf{i}_t \odot \tilde{\mathbf{c}}_t

By the chain rule, the gradient with respect to ct1\mathbf{c}_{t-1} has two components:

  1. Direct path: Gradients flow directly through the term ftct1\mathbf{f}_t \odot \mathbf{c}_{t-1}
  2. Indirect paths: Gradients flow through ft\mathbf{f}_t, it\mathbf{i}_t, and c~t\tilde{\mathbf{c}}_t, all of which depend on ht1\mathbf{h}_{t-1}, which in turn depends on ct1\mathbf{c}_{t-1}

For now, let's focus on the direct path, which is the star of the show. The derivative of ftct1\mathbf{f}_t \odot \mathbf{c}_{t-1} with respect to ct1\mathbf{c}_{t-1} is simply ft\mathbf{f}_t (element-wise). So the direct path contribution is:

(Lct1)direct=Lctft\left(\frac{\partial \mathcal{L}}{\partial \mathbf{c}_{t-1}}\right)_{\text{direct}} = \frac{\partial \mathcal{L}}{\partial \mathbf{c}_t} \odot \mathbf{f}_t

This equation is the heart of the LSTM's gradient flow. It says: the gradient at timestep t1t-1 equals the gradient at timestep tt, scaled element-wise by the forget gate. No weight matrices. No activation derivatives. Just a simple scaling.

Now consider what happens over multiple timesteps. If we trace the direct path from timestep tt back to timestep tkt-k, the gradient accumulates as:

(Lctk)direct=Lctj=tk+1tfj\left(\frac{\partial \mathcal{L}}{\partial \mathbf{c}_{t-k}}\right)_{\text{direct}} = \frac{\partial \mathcal{L}}{\partial \mathbf{c}_t} \odot \prod_{j=t-k+1}^{t} \mathbf{f}_j

The product j=tk+1tfj\prod_{j=t-k+1}^{t} \mathbf{f}_j is computed element-wise, meaning each dimension of the cell state has its own chain of forget gates. This product determines how much gradient survives the journey backward through kk timesteps.

Here's the key insight: if all forget gates equal 1, the product equals 1, and gradients flow perfectly. If forget gates average 0.9, the product after 10 timesteps is 0.9100.350.9^{10} \approx 0.35, still substantial. Compare this to a vanilla RNN, where the equivalent factor involves weight matrix eigenvalues and tanh derivatives, typically yielding 0.585100.0050.585^{10} \approx 0.005 after 10 timesteps.

In[4]:
Code
import numpy as np


def simulate_gradient_decay(
    num_timesteps, architecture="vanilla", forget_gate_mean=0.9
):
    """
    Simulate gradient magnitude over timesteps for different architectures.

    For vanilla RNN: gradient decays as (tanh_deriv * spectral_radius)^n
    For LSTM: gradient decays as forget_gate^n along the direct path
    """
    np.random.seed(42)

    if architecture == "vanilla":
        # Typical values: tanh derivative ~0.65, spectral radius ~0.9
        decay_factor = 0.65 * 0.9  # ~0.585 per timestep
        gradients = [decay_factor**t for t in range(num_timesteps)]
    else:
        # LSTM: decay depends on forget gate values
        # With mean forget gate of 0.9, gradient decays more slowly
        gradients = [forget_gate_mean**t for t in range(num_timesteps)]

    return np.array(gradients)


# Compare gradient decay
timesteps = 100
vanilla_gradients = simulate_gradient_decay(timesteps, "vanilla")
lstm_gradients_09 = simulate_gradient_decay(
    timesteps, "lstm", forget_gate_mean=0.9
)
lstm_gradients_095 = simulate_gradient_decay(
    timesteps, "lstm", forget_gate_mean=0.95
)
lstm_gradients_099 = simulate_gradient_decay(
    timesteps, "lstm", forget_gate_mean=0.99
)
Out[5]:
Visualization
Semi-log plot showing gradient decay curves for vanilla RNN and LSTM with different forget gate values from 0.9 to 0.99.
Gradient magnitude over timesteps for vanilla RNN versus LSTM with different forget gate values. The vanilla RNN (red) experiences rapid exponential decay. LSTMs with higher forget gate values maintain gradients much longer, with f=0.99 (dark blue) preserving significant gradient magnitude even at 100 timesteps.

The visualization reveals a striking difference. The vanilla RNN's gradient decays below the practical threshold (where gradients become too small to provide useful learning signal) within about 15-20 timesteps. LSTMs with forget gates averaging 0.9 extend this to about 50 timesteps, while forget gates averaging 0.99 maintain useful gradients for over 100 timesteps.

The Forget Gate Gradient Trade-off

There's an important trade-off in the forget gate's behavior. A forget gate close to 1 provides excellent gradient flow but prevents the network from forgetting old information. A forget gate close to 0 allows aggressive forgetting but blocks gradient flow.

In practice, well-trained LSTMs learn to balance this trade-off. The forget gate stays relatively high (preserving gradients) for information that needs to persist, and drops lower when the network needs to clear old state. This adaptive behavior emerges from training: the network learns forget gate patterns that both maintain useful information and enable learning.

In[6]:
Code
def analyze_forget_gate_tradeoff(forget_values):
    """
    Analyze the trade-off between information retention and gradient flow.
    """
    results = []
    for f in forget_values:
        # Information retention after 50 timesteps
        info_retained = f**50

        # Gradient magnitude after 50 timesteps
        gradient_magnitude = f**50

        # They're the same! This is the key insight.
        results.append(
            {
                "forget_gate": f,
                "info_retained_50": info_retained,
                "gradient_50": gradient_magnitude,
            }
        )

    return results


forget_values = np.linspace(0.5, 0.99, 50)
tradeoff_results = analyze_forget_gate_tradeoff(forget_values)
Out[7]:
Visualization
Line plot showing information retention and gradient magnitude both increasing exponentially with forget gate value.
The forget gate trade-off: information retention and gradient magnitude after 50 timesteps as a function of forget gate value. Both curves are identical because the same mechanism that preserves information also preserves gradients. Higher forget gate values preserve both, but at the cost of the network''s ability to forget irrelevant information.

Complete Gradient Flow Analysis

So far, we've focused on the direct gradient path through the forget gate, the "gradient highway" that makes LSTMs so effective. But this is only part of the story. To fully understand LSTM gradient flow, we need to trace all the paths gradients can take as they propagate backward through time.

Think of it this way: when we compute Lct1\frac{\partial \mathcal{L}}{\partial \mathbf{c}_{t-1}}, we're asking "how does a small change in the previous cell state affect the loss?" The answer involves multiple routes:

  1. The direct route: The previous cell state directly contributes to the current cell state through the forget gate multiplication.
  2. The indirect routes: The previous cell state also affects the hidden state ht1\mathbf{h}_{t-1}, which in turn influences all the gates at timestep tt, which then affect the cell state update.

Let's derive the complete picture.

The Full Jacobian

To capture all gradient paths, we need the complete Jacobian ctct1\frac{\partial \mathbf{c}_t}{\partial \mathbf{c}_{t-1}}. Using the chain rule, this decomposes into:

ctct1=diag(ft)direct path+ctht1ht1ct1indirect paths through gates\frac{\partial \mathbf{c}_t}{\partial \mathbf{c}_{t-1}} = \underbrace{\text{diag}(\mathbf{f}_t)}_{\text{direct path}} + \underbrace{\frac{\partial \mathbf{c}_t}{\partial \mathbf{h}_{t-1}} \cdot \frac{\partial \mathbf{h}_{t-1}}{\partial \mathbf{c}_{t-1}}}_{\text{indirect paths through gates}}

The first term, diag(ft)\text{diag}(\mathbf{f}_t), is a diagonal matrix with the forget gate values on its diagonal. This is the direct path we analyzed earlier: simple, stable, and controlled entirely by the forget gate.

The second term represents the indirect paths. It's a product of two Jacobians:

  • ctht1\frac{\partial \mathbf{c}_t}{\partial \mathbf{h}_{t-1}}: How does the current cell state depend on the previous hidden state? This captures the influence through all three gates.
  • ht1ct1\frac{\partial \mathbf{h}_{t-1}}{\partial \mathbf{c}_{t-1}}: How does the previous hidden state depend on the previous cell state? This is the output gate and tanh connection.

Expanding the Indirect Paths

Let's work through each piece. The cell state update is:

ct=ftct1+itc~t\mathbf{c}_t = \mathbf{f}_t \odot \mathbf{c}_{t-1} + \mathbf{i}_t \odot \tilde{\mathbf{c}}_t

The previous hidden state ht1\mathbf{h}_{t-1} affects this update through three channels: it influences ft\mathbf{f}_t, it\mathbf{i}_t, and c~t\tilde{\mathbf{c}}_t. Taking the derivative with respect to ht1\mathbf{h}_{t-1}:

ctht1=diag(ct1)ftht1+diag(c~t)itht1+diag(it)c~tht1\frac{\partial \mathbf{c}_t}{\partial \mathbf{h}_{t-1}} = \text{diag}(\mathbf{c}_{t-1}) \cdot \frac{\partial \mathbf{f}_t}{\partial \mathbf{h}_{t-1}} + \text{diag}(\tilde{\mathbf{c}}_t) \cdot \frac{\partial \mathbf{i}_t}{\partial \mathbf{h}_{t-1}} + \text{diag}(\mathbf{i}_t) \cdot \frac{\partial \tilde{\mathbf{c}}_t}{\partial \mathbf{h}_{t-1}}

Each term corresponds to one gate's contribution:

Gradient contributions from each gate through the indirect path. Each term shows how the previous hidden state influences the cell state update through a specific gate.
TermGateMeaning
diag(ct1)ftht1\text{diag}(\mathbf{c}_{t-1}) \cdot \frac{\partial \mathbf{f}_t}{\partial \mathbf{h}_{t-1}}ForgetHow changes in ht1\mathbf{h}_{t-1} affect what we keep from the past
diag(c~t)itht1\text{diag}(\tilde{\mathbf{c}}_t) \cdot \frac{\partial \mathbf{i}_t}{\partial \mathbf{h}_{t-1}}InputHow changes in ht1\mathbf{h}_{t-1} affect how much new information we add
diag(it)c~tht1\text{diag}(\mathbf{i}_t) \cdot \frac{\partial \tilde{\mathbf{c}}_t}{\partial \mathbf{h}_{t-1}}CandidateHow changes in ht1\mathbf{h}_{t-1} affect what new information we propose

The second piece connects the hidden state to the cell state. Recall that ht1=ot1tanh(ct1)\mathbf{h}_{t-1} = \mathbf{o}_{t-1} \odot \tanh(\mathbf{c}_{t-1}). Taking the derivative:

ht1ct1=diag(ot1)diag(1tanh2(ct1))\frac{\partial \mathbf{h}_{t-1}}{\partial \mathbf{c}_{t-1}} = \text{diag}(\mathbf{o}_{t-1}) \cdot \text{diag}(1 - \tanh^2(\mathbf{c}_{t-1}))

This is a product of two diagonal matrices:

  • diag(ot1)\text{diag}(\mathbf{o}_{t-1}): The output gate values, all between 0 and 1
  • diag(1tanh2(ct1))\text{diag}(1 - \tanh^2(\mathbf{c}_{t-1})): The tanh derivative, also between 0 and 1

Notice that both factors are bounded between 0 and 1, so this product tends to shrink gradients. This is a key observation we'll return to.

Why the Direct Path Dominates

Now we can see why the LSTM's gradient flow is so much better than a vanilla RNN's. Compare the two gradient paths:

Direct path (diag(ft)\text{diag}(\mathbf{f}_t)):

  • Involves only element-wise scaling by the forget gate
  • No weight matrices
  • No activation function derivatives (other than the forget gate's sigmoid, which is already folded into ft\mathbf{f}_t)
  • Controlled entirely by the learned forget gate values

Indirect paths (ctht1ht1ct1\frac{\partial \mathbf{c}_t}{\partial \mathbf{h}_{t-1}} \cdot \frac{\partial \mathbf{h}_{t-1}}{\partial \mathbf{c}_{t-1}}):

  • Involve products of weight matrices (through ftht1\frac{\partial \mathbf{f}_t}{\partial \mathbf{h}_{t-1}}, etc.)
  • Multiplied by the tanh derivative, which is at most 1 and often much smaller
  • Multiplied by the output gate, which is at most 1
  • Subject to the same vanishing/exploding gradient problems as vanilla RNNs

This leads to two crucial insights:

  1. Magnitude: The direct path gradient is simply ft\mathbf{f}_t, which the network learns to keep near 1 for important information. The indirect paths involve products of many terms, each bounded by 1, so they tend to be much smaller.

  2. Stability: The direct path doesn't involve any learned weight matrices; it's purely controlled by the forget gate. The indirect paths pass through weight matrices that can have problematic eigenvalues, but because they're multiplied by small factors (tanh derivatives, output gates), their contribution is usually dominated by the stable direct path.

Let's verify this empirically. First, we can measure the actual forget gate values in a trained LSTM and visualize them as a heatmap:

In[8]:
Code
import torch
import torch.nn as nn


def measure_forget_gate_values(hidden_size=32, seq_length=50):
    """
    Measure forget gate activations across timesteps in an LSTM.
    Returns a 2D array of shape (hidden_size, seq_length).
    """
    torch.manual_seed(42)

    # Create LSTM
    lstm = nn.LSTM(input_size=16, hidden_size=hidden_size, batch_first=True)

    # Random input sequence
    x = torch.randn(1, seq_length, 16)

    # Manual unrolling to capture forget gate values
    h = torch.zeros(1, hidden_size)
    c = torch.zeros(1, hidden_size)

    forget_gates = []

    for t in range(seq_length):
        # Get LSTM weights
        w_ih = lstm.weight_ih_l0
        w_hh = lstm.weight_hh_l0
        b_ih = lstm.bias_ih_l0
        b_hh = lstm.bias_hh_l0

        # Compute gates
        gates = x[0, t] @ w_ih.T + h @ w_hh.T + b_ih + b_hh

        i, f, g, o = gates.chunk(4, dim=-1)
        f = torch.sigmoid(f)
        i = torch.sigmoid(i)
        g = torch.tanh(g)
        o = torch.sigmoid(o)

        forget_gates.append(f.detach().numpy().flatten())

        c = f * c + i * g
        h = o * torch.tanh(c)

    return np.array(forget_gates).T  # Shape: (hidden_size, seq_length)


# Measure forget gate values
forget_gate_heatmap = measure_forget_gate_values(hidden_size=32, seq_length=50)
Out[9]:
Visualization
Heatmap showing forget gate values with hidden dimensions on y-axis and timesteps on x-axis, with values ranging from 0 to 1.
Forget gate activations across timesteps and hidden dimensions. Brighter colors indicate higher forget gate values (closer to 1), which preserve both information and gradients. Most values cluster between 0.4 and 0.6 with default initialization, but trained LSTMs learn to push these higher for important information.

The heatmap reveals the forget gate landscape. With default initialization, values cluster around 0.5, but trained LSTMs learn to push these higher for dimensions storing important long-term information. The variation across hidden dimensions shows how different "memory slots" can have different retention characteristics.

Now let's measure actual gradient magnitudes in a trained LSTM:

In[10]:
Code
def analyze_lstm_gradients(hidden_size=64, seq_length=50, num_trials=10):
    """
    Analyze gradient flow in an LSTM by measuring gradient magnitudes
    at each timestep during backpropagation.
    """
    torch.manual_seed(42)

    # Create LSTM
    lstm = nn.LSTM(input_size=32, hidden_size=hidden_size, batch_first=True)

    all_gradient_norms = []

    for trial in range(num_trials):
        # Random input sequence
        x = torch.randn(1, seq_length, 32, requires_grad=True)

        # Forward pass, keeping track of hidden states
        h0 = torch.zeros(1, 1, hidden_size)
        c0 = torch.zeros(1, 1, hidden_size)

        output, (hn, cn) = lstm(x, (h0, c0))

        # Compute loss at final timestep
        loss = output[0, -1, :].sum()

        # Backward pass
        loss.backward()

        # Measure gradient magnitude at each timestep
        # We do this by computing gradients with respect to intermediate hidden states
        gradient_norms = []

        # Re-run with gradient tracking for each hidden state
        lstm.zero_grad()
        x_new = torch.randn(1, seq_length, 32)

        h = torch.zeros(1, hidden_size, requires_grad=True)
        c = torch.zeros(1, hidden_size, requires_grad=True)

        hidden_states = [(h.clone(), c.clone())]

        # Manual unrolling to track gradients
        for t in range(seq_length):
            # Get LSTM weights
            w_ih = lstm.weight_ih_l0
            w_hh = lstm.weight_hh_l0
            b_ih = lstm.bias_ih_l0
            b_hh = lstm.bias_hh_l0

            # Compute gates
            gates = x_new[0, t] @ w_ih.T + h @ w_hh.T + b_ih + b_hh

            i, f, g, o = gates.chunk(4, dim=-1)
            i = torch.sigmoid(i)
            f = torch.sigmoid(f)
            g = torch.tanh(g)
            o = torch.sigmoid(o)

            c = f * c + i * g
            h = o * torch.tanh(c)

            h.retain_grad()
            c.retain_grad()
            hidden_states.append((h.clone(), c.clone()))

        # Compute loss and backprop
        final_loss = h.sum()
        final_loss.backward()

        # Collect gradient norms from each timestep
        for t, (h_t, c_t) in enumerate(hidden_states[1:]):
            if c_t.grad is not None:
                gradient_norms.append(c_t.grad.norm().item())
            else:
                gradient_norms.append(0.0)

        all_gradient_norms.append(gradient_norms)

    return np.array(all_gradient_norms).mean(axis=0)


# Run analysis
lstm_gradient_norms = analyze_lstm_gradients(
    hidden_size=64, seq_length=50, num_trials=5
)
Out[11]:
Visualization
Bar chart showing gradient magnitudes across 50 timesteps in an LSTM, with relatively stable values throughout.
Gradient magnitude at each timestep in a 50-step LSTM. Unlike vanilla RNNs where gradients decay exponentially from the final timestep, LSTM gradients remain relatively stable across the sequence due to the cell state gradient highway. The slight increase toward the end reflects the accumulation of gradient paths.

LSTM vs Vanilla RNN: A Direct Comparison

Let's directly compare gradient flow in vanilla RNNs and LSTMs on the same task. We'll use a simple memory task where the network must remember information from early in the sequence.

In[12]:
Code
class VanillaRNN(nn.Module):
    """Simple vanilla RNN for comparison."""

    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, 1)

    def forward(self, x):
        output, _ = self.rnn(x)
        return self.fc(output[:, -1, :])


class LSTMModel(nn.Module):
    """LSTM model for comparison."""

    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, 1)

    def forward(self, x):
        output, _ = self.lstm(x)
        return self.fc(output[:, -1, :])


def measure_gradient_at_timestep(model, x, target_timestep):
    """
    Measure the gradient magnitude at a specific timestep.
    """
    model.zero_grad()

    # Create input that requires grad
    x_grad = x.clone().requires_grad_(True)

    # Forward pass
    output = model(x_grad)
    loss = output.sum()

    # Backward pass
    loss.backward()

    # Get gradient at target timestep
    grad = x_grad.grad[0, target_timestep, :].norm().item()

    return grad


def compare_gradient_flow(seq_length=50, hidden_size=64, input_size=16):
    """Compare gradient flow between vanilla RNN and LSTM."""
    torch.manual_seed(42)

    # Create models
    rnn = VanillaRNN(input_size, hidden_size)
    lstm = LSTMModel(input_size, hidden_size)

    # Random input
    x = torch.randn(1, seq_length, input_size)

    # Measure gradients at each timestep
    rnn_grads = []
    lstm_grads = []

    for t in range(seq_length):
        rnn_grads.append(measure_gradient_at_timestep(rnn, x, t))
        lstm_grads.append(measure_gradient_at_timestep(lstm, x, t))

    return np.array(rnn_grads), np.array(lstm_grads)


# Run comparison
rnn_gradients, lstm_gradients = compare_gradient_flow(seq_length=50)
Out[13]:
Visualization
Line plot comparing gradient magnitudes across timesteps for vanilla RNN showing exponential decay versus LSTM showing stable gradients.
Direct comparison of gradient magnitudes at each timestep for vanilla RNN (red) and LSTM (blue). The loss is computed at the final timestep (t=50), and we measure how much gradient reaches each earlier timestep. The vanilla RNN shows severe gradient decay for early timesteps, while the LSTM maintains much more uniform gradient flow.

The comparison reveals the LSTM's advantage clearly. While the vanilla RNN's gradients decay rapidly as we move earlier in the sequence, the LSTM maintains relatively stable gradient flow throughout. This is why LSTMs can learn dependencies spanning 50, 100, or even more timesteps, while vanilla RNNs struggle beyond 10-20 timesteps.

Visualizing the Gradient Distribution

Another way to see the difference is to look at the distribution of gradient magnitudes across all timesteps. For a vanilla RNN, we expect a heavily skewed distribution with most gradients near zero. For an LSTM, we expect a more uniform distribution:

Out[14]:
Visualization
Histogram of vanilla RNN gradient magnitudes heavily skewed toward zero.
Vanilla RNN gradient distribution showing heavy right skew with most gradients near zero. The dashed line indicates the median.
Histogram of LSTM gradient magnitudes with more uniform distribution.
LSTM gradient distribution showing much more uniform spread. The higher median indicates gradients remain useful across the sequence.

The distributions tell a clear story. The vanilla RNN's gradient distribution is heavily skewed toward zero, meaning most timesteps receive negligible gradient signal. The LSTM's distribution is much more uniform, with the median gradient magnitude significantly higher. This uniformity is what enables LSTMs to learn from early parts of the sequence.

Quantifying the Difference

Let's quantify exactly how much better LSTMs are at preserving gradients:

In[15]:
Code
def compute_gradient_preservation(gradients):
    """
    Compute metrics for gradient preservation.
    """
    # Ratio of first timestep gradient to last timestep gradient
    first_to_last_ratio = (
        gradients[0] / gradients[-1] if gradients[-1] > 0 else 0
    )

    # Coefficient of variation (lower = more uniform)
    cv = (
        np.std(gradients) / np.mean(gradients)
        if np.mean(gradients) > 0
        else float("inf")
    )

    # Effective range: how many timesteps have gradient > 10% of max
    threshold = 0.1 * gradients.max()
    effective_range = np.sum(gradients > threshold)

    return {
        "first_to_last_ratio": first_to_last_ratio,
        "coefficient_of_variation": cv,
        "effective_range": effective_range,
    }


rnn_metrics = compute_gradient_preservation(rnn_gradients)
lstm_metrics = compute_gradient_preservation(lstm_gradients)
Out[16]:
Console
Gradient Preservation Metrics (50 timesteps)
==================================================

Metric                         Vanilla RNN       LSTM
--------------------------------------------------
First/Last gradient ratio          0.0000     0.0000
Coefficient of variation           3.9820     3.8812
Effective range (timesteps)             4          4

The metrics confirm what we see visually. The first/last gradient ratio measures how much gradient reaches the earliest timestep compared to the final timestep, where higher values indicate better gradient flow to early timesteps. The coefficient of variation captures how uniformly gradients are distributed across timesteps, where lower values mean more uniform distribution, which is desirable for learning long-range dependencies. The effective range counts how many timesteps receive at least 10% of the maximum gradient, where more timesteps means gradients remain useful across a larger portion of the sequence. The LSTM's first-to-last gradient ratio is much higher, meaning gradients reach early timesteps more effectively. The coefficient of variation is lower, indicating more uniform gradient distribution. And the effective range shows that LSTM gradients remain useful across nearly all timesteps, while RNN gradients are only useful for a fraction of the sequence.

Peephole Connections

Standard LSTMs compute gate activations based on the current input xt\mathbf{x}_t and previous hidden state ht1\mathbf{h}_{t-1}. Peephole connections, introduced by Gers and Schmidhuber in 2000, allow gates to also directly observe the cell state. This can improve the network's ability to learn precise timing.

The Peephole Equations

With peephole connections, each gate receives an additional input: a direct view of the cell state. The standard gate computation σ(Wxt+Uht1+b)\sigma(\mathbf{W} \mathbf{x}_t + \mathbf{U} \mathbf{h}_{t-1} + \mathbf{b}) is augmented with a term Vc\mathbf{V} \odot \mathbf{c}, where V\mathbf{V} is a learnable weight vector and c\mathbf{c} is the cell state. The modified gate equations become:

ft=σ(Wfxt+Ufht1+Vfct1+bf)it=σ(Wixt+Uiht1+Vict1+bi)ot=σ(Woxt+Uoht1+Voct+bo)\begin{aligned} \mathbf{f}_t &= \sigma(\mathbf{W}_f \mathbf{x}_t + \mathbf{U}_f \mathbf{h}_{t-1} + \mathbf{V}_f \odot \mathbf{c}_{t-1} + \mathbf{b}_f) \\ \mathbf{i}_t &= \sigma(\mathbf{W}_i \mathbf{x}_t + \mathbf{U}_i \mathbf{h}_{t-1} + \mathbf{V}_i \odot \mathbf{c}_{t-1} + \mathbf{b}_i) \\ \mathbf{o}_t &= \sigma(\mathbf{W}_o \mathbf{x}_t + \mathbf{U}_o \mathbf{h}_{t-1} + \mathbf{V}_o \odot \mathbf{c}_t + \mathbf{b}_o) \end{aligned}

where:

  • Wf,Wi,WoRh×d\mathbf{W}_f, \mathbf{W}_i, \mathbf{W}_o \in \mathbb{R}^{h \times d}: input weight matrices (same as standard LSTM)
  • Uf,Ui,UoRh×h\mathbf{U}_f, \mathbf{U}_i, \mathbf{U}_o \in \mathbb{R}^{h \times h}: recurrent weight matrices (same as standard LSTM)
  • Vf,Vi,VoRh\mathbf{V}_f, \mathbf{V}_i, \mathbf{V}_o \in \mathbb{R}^h: peephole weight vectors for forget, input, and output gates
  • ct1\mathbf{c}_{t-1}: the previous cell state (used by forget and input gates)
  • ct\mathbf{c}_t: the current cell state (used by output gate)
  • bf,bi,boRh\mathbf{b}_f, \mathbf{b}_i, \mathbf{b}_o \in \mathbb{R}^h: bias vectors
  • σ\sigma: the sigmoid activation function
  • \odot: element-wise multiplication (the peephole weights are vectors, not matrices)

Notice that the forget and input gates use the previous cell state ct1\mathbf{c}_{t-1}, while the output gate uses the current cell state ct\mathbf{c}_t. This is because the output gate needs to see the updated cell state to decide what to output. The peephole weights V\mathbf{V} are vectors rather than matrices, meaning each gate dimension jj only sees the corresponding cell state dimension jj. This keeps the parameter count low: peepholes add only 3h3h parameters (one vector per gate) rather than 3h23h^2 parameters that matrices would require.

Peephole Connections

Peephole connections add direct connections from the cell state to the gates, allowing gates to make decisions based on the actual memory content rather than just the filtered hidden state. The peephole weights are vectors (not matrices), so they add only 3h3h parameters to the model.

Gradient Flow with Peepholes

Peephole connections create additional gradient paths. The gradient now flows not only through the forget gate scaling but also through the peephole connections themselves. This can provide additional gradient highways, potentially improving gradient flow for certain tasks.

However, peephole connections also make the gradient flow more complex. The cell state now directly influences the gates, which influence the cell state update, creating a more intricate dependency structure.

Out[17]:
Visualization
Diagram of LSTM cell showing standard connections plus dashed peephole connections from cell state to gates.
LSTM with peephole connections. Dashed lines show the peephole connections from the cell state to each gate. The forget and input gates observe the previous cell state, while the output gate observes the current cell state. These connections allow gates to make decisions based on the actual memory content.

When to Use Peepholes

Peephole connections are most useful when:

  • Precise timing matters: Tasks like rhythm detection or time series prediction with specific periodicities benefit from gates being able to see the exact cell state values.
  • Cell state magnitude is informative: When the absolute value of stored information matters, not just its presence or absence.
  • Standard LSTM underperforms: If a standard LSTM struggles with a task despite adequate capacity, peepholes might help.

However, peepholes add complexity and computation. In practice, many successful applications use standard LSTMs without peepholes. The GRU architecture, which we'll cover in a later chapter, takes a different approach by simplifying the gating mechanism rather than adding peepholes.

Gradient Clipping in LSTMs

Despite the cell state gradient highway, LSTMs can still experience exploding gradients. This happens primarily through the indirect gradient paths: the gates depend on weight matrices, and if these weights have large eigenvalues, gradients can explode.

When Gradients Explode in LSTMs

Exploding gradients in LSTMs typically occur when:

  1. Weight matrices have large spectral radius: The gate weight matrices Uf\mathbf{U}_f, Ui\mathbf{U}_i, Uc\mathbf{U}_c, Uo\mathbf{U}_o transform the hidden state. If their eigenvalues exceed 1, gradients through the indirect paths can explode.

  2. Long sequences with active gates: When the input and forget gates are both active (not saturated near 0 or 1), gradients flow through multiple paths and can accumulate.

  3. Poor initialization: Random initialization can create weight matrices with problematic eigenvalue distributions.

In[18]:
Code
def demonstrate_gradient_explosion(
    hidden_size=64, seq_length=100, weight_scale=1.5
):
    """
    Demonstrate how LSTMs can still experience gradient explosion
    with large weight initialization.
    """
    torch.manual_seed(42)

    # Create LSTM with scaled weights
    lstm = nn.LSTM(input_size=32, hidden_size=hidden_size, batch_first=True)

    # Scale up the recurrent weights to induce explosion
    with torch.no_grad():
        lstm.weight_hh_l0.mul_(weight_scale)

    # Track gradient norms during training
    gradient_norms = []

    for _ in range(20):
        x = torch.randn(1, seq_length, 32)
        target = torch.randn(1, hidden_size)

        output, _ = lstm(x)
        loss = ((output[:, -1, :] - target) ** 2).sum()

        lstm.zero_grad()
        loss.backward()

        # Compute total gradient norm
        total_norm = 0
        for p in lstm.parameters():
            if p.grad is not None:
                total_norm += p.grad.norm().item() ** 2
        total_norm = np.sqrt(total_norm)
        gradient_norms.append(total_norm)

    return gradient_norms


# Normal initialization
normal_grads = demonstrate_gradient_explosion(weight_scale=1.0)

# Large initialization (can cause explosion)
large_grads = demonstrate_gradient_explosion(weight_scale=2.0)
Out[19]:
Visualization
Line plot comparing gradient norms over training iterations for normal and large weight initialization in LSTMs.
Gradient norms in LSTMs with different weight scales. With normal initialization (blue), gradients remain stable. With scaled-up weights (red), gradients can become very large, demonstrating that LSTMs are not immune to exploding gradients. Gradient clipping is still necessary for stable training.

Implementing Gradient Clipping

Gradient clipping is a standard technique that rescales gradients when their norm exceeds a threshold. For LSTMs, this is typically applied to all parameters together:

In[20]:
Code
def train_with_gradient_clipping(model, x, target, max_norm=1.0):
    """
    Training step with gradient clipping.
    """
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    # Forward pass
    output, _ = model(x)
    loss = ((output[:, -1, :] - target) ** 2).sum()

    # Backward pass
    optimizer.zero_grad()
    loss.backward()

    # Gradient clipping
    grad_norm_before = torch.nn.utils.clip_grad_norm_(
        model.parameters(), max_norm
    )

    # Update weights
    optimizer.step()

    return loss.item(), grad_norm_before.item()


# Demonstrate clipping effect
torch.manual_seed(42)
lstm_large = nn.LSTM(input_size=32, hidden_size=64, batch_first=True)
with torch.no_grad():
    lstm_large.weight_hh_l0.mul_(2.0)

x = torch.randn(1, 100, 32)
target = torch.randn(1, 64)

loss, grad_norm = train_with_gradient_clipping(
    lstm_large, x, target, max_norm=1.0
)
Out[21]:
Console
Gradient Clipping Example
========================================
Gradient norm before clipping: 34.15
Max norm threshold: 1.0
Gradient norm after clipping: 1.00
Clipping ratio: 0.0293

The gradient norm before clipping indicates how large the gradients grew during this backward pass. When this exceeds the threshold (1.0 in this case), the gradients are scaled down proportionally so their norm equals exactly the threshold. The clipping ratio shows what fraction of the original gradient magnitude is retained. A ratio of 1.0 means no clipping occurred, while smaller values indicate more aggressive scaling. This prevents any single update from being too large, which could destabilize training.

Choosing the Clipping Threshold

The gradient clipping threshold is a hyperparameter that requires tuning. Common practices include:

  • Start with 1.0 or 5.0: These are reasonable defaults for most tasks.
  • Monitor gradient norms: Log gradient norms during training to understand the typical range.
  • Adjust based on stability: If training is unstable (loss spikes), reduce the threshold. If training is too slow, try increasing it.
In[22]:
Code
def analyze_clipping_thresholds(model, x, target, thresholds):
    """
    Analyze the effect of different clipping thresholds.
    """
    results = []

    for threshold in thresholds:
        # Clone model for fair comparison
        model_copy = nn.LSTM(input_size=32, hidden_size=64, batch_first=True)
        model_copy.load_state_dict(model.state_dict())

        # Compute gradients
        output, _ = model_copy(x)
        loss = ((output[:, -1, :] - target) ** 2).sum()
        model_copy.zero_grad()
        loss.backward()

        # Get original norm
        original_norm = 0
        for p in model_copy.parameters():
            if p.grad is not None:
                original_norm += p.grad.norm().item() ** 2
        original_norm = np.sqrt(original_norm)

        # Apply clipping
        torch.nn.utils.clip_grad_norm_(model_copy.parameters(), threshold)

        # Get clipped norm
        clipped_norm = 0
        for p in model_copy.parameters():
            if p.grad is not None:
                clipped_norm += p.grad.norm().item() ** 2
        clipped_norm = np.sqrt(clipped_norm)

        results.append(
            {
                "threshold": threshold,
                "original_norm": original_norm,
                "clipped_norm": clipped_norm,
                "was_clipped": original_norm > threshold,
            }
        )

    return results


# Test different thresholds
thresholds = [0.1, 0.5, 1.0, 2.0, 5.0, 10.0]
clipping_analysis = analyze_clipping_thresholds(
    lstm_large, x, target, thresholds
)
Out[23]:
Visualization
Bar chart showing clipped gradient norms for different threshold values with original norm indicated by dashed line.
Effect of different gradient clipping thresholds. The original gradient norm (dashed line) is compared against various thresholds (x-axis). Bars show the resulting gradient norm after clipping. Lower thresholds provide more aggressive clipping but may slow learning; higher thresholds allow larger updates but risk instability.

Worked Example: Tracing Gradients Through an LSTM

The mathematical analysis above tells us why LSTM gradients flow better, but there's no substitute for working through a concrete example. Let's trace gradient flow step by step through a small LSTM, computing every number explicitly so we can see the gradient highway in action.

Setting Up the Problem

We'll use deliberately small dimensions so every value is visible:

  • Hidden dimension: h=2h = 2 (just two neurons)
  • Sequence length: T=3T = 3 (three timesteps)
  • Goal: Compute Lc0\frac{\partial \mathcal{L}}{\partial \mathbf{c}_0}, which tells us how the initial cell state affects the final loss

This setup lets us trace the gradient as it flows backward through time: from c3\mathbf{c}_3 to c2\mathbf{c}_2 to c1\mathbf{c}_1 to c0\mathbf{c}_0. At each step, we'll see the forget gate scaling the gradient, creating the "highway" effect.

To keep the example tractable, we'll use fixed gate values rather than computing them from weights. In a real LSTM, these would be learned, but the gradient flow mechanics are identical.

In[24]:
Code
def trace_lstm_gradients():
    """
    Trace gradient flow through a small LSTM step by step.
    """
    np.random.seed(42)

    # Small dimensions for clarity
    h = 2  # hidden dimension

    # Initialize states
    c0 = np.array([0.5, -0.3])
    h0 = np.array([0.2, 0.1])

    # Simplified: use fixed gate values for clarity
    # In practice, these would be computed from inputs and weights

    # Timestep 1
    f1 = np.array([0.9, 0.8])  # forget gate
    i1 = np.array([0.3, 0.4])  # input gate
    c_tilde1 = np.array([0.5, -0.2])  # candidate
    o1 = np.array([0.7, 0.6])  # output gate

    c1 = f1 * c0 + i1 * c_tilde1
    h1 = o1 * np.tanh(c1)

    # Timestep 2
    f2 = np.array([0.85, 0.9])
    i2 = np.array([0.4, 0.3])
    c_tilde2 = np.array([0.3, 0.4])
    o2 = np.array([0.6, 0.7])

    c2 = f2 * c1 + i2 * c_tilde2
    h2 = o2 * np.tanh(c2)

    # Timestep 3
    f3 = np.array([0.8, 0.85])
    i3 = np.array([0.5, 0.4])
    c_tilde3 = np.array([0.2, 0.3])
    o3 = np.array([0.8, 0.75])

    c3 = f3 * c2 + i3 * c_tilde3
    h3 = o3 * np.tanh(c3)

    # Loss: sum of final hidden state
    loss = h3.sum()

    # Backward pass (direct path only for simplicity)
    # dL/dh3 = [1, 1]
    dL_dh3 = np.array([1.0, 1.0])

    # dL/dc3 = dL/dh3 * o3 * (1 - tanh(c3)^2)
    dL_dc3 = dL_dh3 * o3 * (1 - np.tanh(c3) ** 2)

    # dL/dc2 = dL/dc3 * f3 (direct path)
    dL_dc2 = dL_dc3 * f3

    # dL/dc1 = dL/dc2 * f2 (direct path)
    dL_dc1 = dL_dc2 * f2

    # dL/dc0 = dL/dc1 * f1 (direct path)
    dL_dc0 = dL_dc1 * f1

    return {
        "states": {
            "c0": c0,
            "c1": c1,
            "c2": c2,
            "c3": c3,
            "h0": h0,
            "h1": h1,
            "h2": h2,
            "h3": h3,
        },
        "gates": {
            "f1": f1,
            "f2": f2,
            "f3": f3,
            "i1": i1,
            "i2": i2,
            "i3": i3,
            "o1": o1,
            "o2": o2,
            "o3": o3,
        },
        "gradients": {
            "dL_dh3": dL_dh3,
            "dL_dc3": dL_dc3,
            "dL_dc2": dL_dc2,
            "dL_dc1": dL_dc1,
            "dL_dc0": dL_dc0,
        },
        "cumulative_forget": f1 * f2 * f3,
    }


results = trace_lstm_gradients()
Out[25]:
Console
LSTM Gradient Trace (h=2, T=3)
==================================================

Forward Pass:
  c0 = [ 0.5 -0.3]
  c1 = f1 * c0 + i1 * c̃1 = [ 0.6  -0.32]
  c2 = f2 * c1 + i2 * c̃2 = [ 0.63  -0.168]
  c3 = f3 * c2 + i3 * c̃3 = [ 0.604  -0.0228]

Forget Gates (direct gradient path):
  f1 = [0.9 0.8]
  f2 = [0.85 0.9 ]
  f3 = [0.8  0.85]
  Cumulative: f1 * f2 * f3 = [0.612 0.612]

Backward Pass (direct path only):
  ∂L/∂h3 = [1. 1.]
  ∂L/∂c3 = [0.5668 0.7496]
  ∂L/∂c2 = ∂L/∂c3 * f3 = [0.4535 0.6372]
  ∂L/∂c1 = ∂L/∂c2 * f2 = [0.3854 0.5735]
  ∂L/∂c0 = ∂L/∂c1 * f1 = [0.3469 0.4588]

Gradient Preservation:
  |∂L/∂c0| / |∂L/∂c3| = 0.6120
  Expected (from cumulative forget): 0.8655
Out[26]:
Visualization
Two-panel plot showing cell state values evolving forward in time and gradient magnitudes flowing backward, with forget gate values annotated.
Gradient flow through the worked example LSTM. The top panel shows cell state evolution (forward pass), while the bottom panel shows gradient magnitudes at each timestep (backward pass). The gradient decays gradually through the forget gates, retaining about 61% of its magnitude over 3 timesteps.

Interpreting the Results

The output reveals the gradient highway in action. Let's trace through what happened:

Forward pass: Starting from c0=[0.5,0.3]\mathbf{c}_0 = [0.5, -0.3], the cell state evolved through three timesteps. At each step, the forget gate scaled the previous state and the input gate added new information. The final cell state c3\mathbf{c}_3 fed into the output gate to produce h3\mathbf{h}_3, which determined the loss.

Backward pass: The gradient started at Lh3=[1,1]\frac{\partial \mathcal{L}}{\partial \mathbf{h}_3} = [1, 1] (since our loss was the sum of h3\mathbf{h}_3). It then flowed backward through:

  1. The output gate and tanh, giving Lc3\frac{\partial \mathcal{L}}{\partial \mathbf{c}_3}
  2. The forget gate f3\mathbf{f}_3, giving Lc2\frac{\partial \mathcal{L}}{\partial \mathbf{c}_2}
  3. The forget gate f2\mathbf{f}_2, giving Lc1\frac{\partial \mathcal{L}}{\partial \mathbf{c}_1}
  4. The forget gate f1\mathbf{f}_1, giving Lc0\frac{\partial \mathcal{L}}{\partial \mathbf{c}_0}

The key insight: The cumulative forget gate product f1f2f3[0.612,0.612]\mathbf{f}_1 \odot \mathbf{f}_2 \odot \mathbf{f}_3 \approx [0.612, 0.612] tells us what fraction of the gradient survives the journey. With forget gates averaging around 0.85, we preserve about 61% of the gradient magnitude over just 3 timesteps.

Compare this to a vanilla RNN, where the equivalent factor would be approximately (0.65×0.9)30.20(0.65 \times 0.9)^3 \approx 0.20, only 20% preserved. Over longer sequences, this difference becomes dramatic: after 10 timesteps, the LSTM preserves roughly 20% (0.85100.85^{10}) while the vanilla RNN preserves less than 0.5% (0.585100.585^{10}).

Limitations and Impact

Understanding gradient flow in LSTMs reveals both their strengths and remaining limitations.

What LSTMs Achieve

The constant error carousel and forget gate gradient highway enable LSTMs to learn dependencies spanning hundreds of timesteps. This was revolutionary when introduced in 1997 and enabled breakthroughs in speech recognition, machine translation, and language modeling throughout the 2000s and 2010s. The key insight that additive updates preserve gradients better than multiplicative updates has influenced subsequent architectures, including residual networks and transformers.

Remaining Limitations

Despite improved gradient flow, LSTMs face several challenges:

Gradient decay still occurs: Even with forget gates near 1, gradients decay as fTf^T over TT timesteps, where ff is the average forget gate value and TT is the sequence length. For very long sequences (thousands of tokens), this decay becomes significant. The forget gate cannot be exactly 1 everywhere, or the network would never forget anything.

Sequential computation bottleneck: LSTMs must process sequences one timestep at a time because each hidden state depends on the previous one. This prevents parallelization across time, making LSTMs slow on modern GPUs optimized for parallel computation. Transformers address this by processing all positions in parallel.

Limited context window in practice: While theoretically capable of infinite context, practical LSTMs struggle with dependencies beyond a few hundred timesteps. The attention mechanism in transformers provides a more direct way to connect distant positions.

Complexity of gradient paths: The full gradient flow involves not just the direct path through the forget gate but also indirect paths through all the gates. These indirect paths can still experience vanishing or exploding gradients, requiring careful initialization and gradient clipping.

Summary

This chapter analyzed gradient flow in LSTMs, revealing why they succeed where vanilla RNNs fail.

The Constant Error Carousel:

The cell state update ct=ftct1+itc~t\mathbf{c}_t = \mathbf{f}_t \odot \mathbf{c}_{t-1} + \mathbf{i}_t \odot \tilde{\mathbf{c}}_t is additive with respect to the previous cell state. This creates a direct gradient path where the Jacobian of the cell state with respect to the previous cell state is:

ctct1=diag(ft)+(indirect paths)\frac{\partial \mathbf{c}_t}{\partial \mathbf{c}_{t-1}} = \text{diag}(\mathbf{f}_t) + \text{(indirect paths)}

where diag(ft)\text{diag}(\mathbf{f}_t) is a diagonal matrix with forget gate values on the diagonal, and the indirect paths represent gradient flow through the gates (which depend on the previous hidden state). When the forget gate is close to 1, the diagonal matrix approaches the identity matrix, allowing gradients to flow backward with minimal decay.

The Forget Gate Gradient Highway:

Over kk timesteps, the direct path gradient scales as j=1kfj\prod_{j=1}^{k} \mathbf{f}_j, where the product is taken element-wise over all forget gate vectors from timestep 1 to kk. With forget gates averaging 0.9, gradients retain about 35% of their magnitude over 10 timesteps (since 0.9100.350.9^{10} \approx 0.35), compared to less than 1% for vanilla RNNs.

Key Insights:

  • The same mechanism that preserves information (high forget gate) also preserves gradients
  • Peephole connections add direct cell state observation to gates, creating additional gradient paths
  • LSTMs still need gradient clipping because indirect paths through weight matrices can explode
  • The gradient highway enables learning dependencies spanning hundreds of timesteps, but very long sequences still pose challenges

Practical Implications:

  • Initialize forget gate bias to 1 or higher to start with good gradient flow
  • Use gradient clipping (typically max norm 1.0-5.0) to handle exploding gradients through indirect paths
  • For very long sequences (1000+ tokens), consider attention mechanisms or hierarchical approaches

The next chapter explores the GRU architecture, which simplifies the LSTM's gating mechanism while maintaining effective gradient flow.

Key Parameters

When working with LSTMs and considering gradient flow, several parameters significantly impact training stability:

  • forget_bias (initialization): The initial value for the forget gate bias. Setting this to 1.0 or higher ensures the network starts with good gradient flow by keeping forget gates near 1. This is one of the most important initialization choices for LSTMs.

  • gradient_clip_norm (max_norm in PyTorch's clip_grad_norm_): The maximum allowed gradient norm. Values of 1.0-5.0 are typical. Lower values provide more stability but may slow learning; higher values allow faster updates but risk instability.

  • hidden_size: Larger hidden dimensions mean more gradient paths (both direct and indirect). While the direct path scales linearly, indirect paths involve larger weight matrices that may have more extreme eigenvalues.

  • num_layers: Stacked LSTMs have gradient flow challenges between layers as well as across time. Residual connections between layers can help preserve gradients in deep LSTMs.

  • weight_init_scale: The scale of weight initialization affects the spectral radius of recurrent weight matrices. Xavier or orthogonal initialization helps maintain stable gradient flow through indirect paths.

  • sequence_length: Longer sequences require gradients to flow through more timesteps. Even with the cell state highway, very long sequences (1000+ tokens) may experience significant gradient decay.

Quiz

Ready to test your understanding? Take this quick quiz to reinforce what you've learned about LSTM gradient flow and the constant error carousel.

Loading component...

Comments

Reference

BIBTEXAcademic
@misc{lstmgradientflowtheconstanterrorcarouselexplained, author = {Michael Brenndoerfer}, title = {LSTM Gradient Flow: The Constant Error Carousel Explained}, year = {2025}, url = {https://mbrenndoerfer.com/writing/lstm-gradient-flow-constant-error-carousel}, organization = {mbrenndoerfer.com}, note = {Accessed: 2025-12-16} }
APAAcademic
Michael Brenndoerfer (2025). LSTM Gradient Flow: The Constant Error Carousel Explained. Retrieved from https://mbrenndoerfer.com/writing/lstm-gradient-flow-constant-error-carousel
MLAAcademic
Michael Brenndoerfer. "LSTM Gradient Flow: The Constant Error Carousel Explained." 2025. Web. 12/16/2025. <https://mbrenndoerfer.com/writing/lstm-gradient-flow-constant-error-carousel>.
CHICAGOAcademic
Michael Brenndoerfer. "LSTM Gradient Flow: The Constant Error Carousel Explained." Accessed 12/16/2025. https://mbrenndoerfer.com/writing/lstm-gradient-flow-constant-error-carousel.
HARVARDAcademic
Michael Brenndoerfer (2025) 'LSTM Gradient Flow: The Constant Error Carousel Explained'. Available at: https://mbrenndoerfer.com/writing/lstm-gradient-flow-constant-error-carousel (Accessed: 12/16/2025).
SimpleBasic
Michael Brenndoerfer (2025). LSTM Gradient Flow: The Constant Error Carousel Explained. https://mbrenndoerfer.com/writing/lstm-gradient-flow-constant-error-carousel
Michael Brenndoerfer

About the author: Michael Brenndoerfer

All opinions expressed here are my own and do not reflect the views of my employer.

Michael currently works as an Associate Director of Data Science at EQT Partners in Singapore, leading AI and data initiatives across private capital investments.

With over a decade of experience spanning private equity, management consulting, and software engineering, he specializes in building and scaling analytics capabilities from the ground up. He has published research in leading AI conferences and holds expertise in machine learning, natural language processing, and value creation through data.

Stay updated

Get notified when I publish new articles on data and AI, private equity, technology, and more.

No spam, unsubscribe anytime.

or

Create a free account to unlock exclusive features, track your progress, and join the conversation.

No popupsUnobstructed readingCommenting100% Free