Search

Search articles

Layer Normalization: Stabilizing Transformer Training

Michael BrenndoerferUpdated June 11, 202530 min read

Learn how layer normalization enables stable transformer training by normalizing across features rather than batches, with implementations and gradient analysis.

Track your reading progress

Sign in to mark chapters as read and track your learning journey

Sign in →
Reading Level

Choose your expertise level to adjust how many terms are explained. Beginners see more tooltips, experts see fewer to maintain reading flow. Hover over underlined terms for instant definitions.

Layer Normalization

Batch normalization transformed how we train deep feedforward networks, but it stumbles when applied to transformers. The batch dimension becomes problematic: batch sizes vary, sequences have different lengths, and the statistics computed across a batch of diverse sentences lack semantic coherence. Layer normalization, introduced by Ba, Kiros, and Hinton in 2016, sidesteps these issues entirely by normalizing across features rather than across the batch. This seemingly simple change made layer normalization the default normalization technique for transformers, from the original "Attention is All You Need" architecture to modern large language models.

In this chapter, we'll explore why layer normalization works so well for transformers, how it differs from batch normalization in both computation and behavior, and the subtle implementation details that affect training stability. We'll also examine how the placement of layer normalization within transformer blocks affects learning dynamics, a design choice that has evolved significantly since the original transformer architecture.

Why Batch Normalization Fails for Transformers

Before diving into layer normalization, it's worth understanding exactly why batch normalization doesn't work well for sequence models. The core issue is that batch normalization computes statistics across the batch dimension, assuming that each position in a layer sees similar data across samples.

In a transformer processing sentences of varying lengths, each position in the sequence represents something different. Position 0 might be "The" in one sentence and "Scientists" in another. Position 50 might be a verb in one sentence, a noun in another, and padding in a third. Computing a mean and variance across these semantically unrelated positions produces statistics that don't reflect any meaningful property of the data.

In[2]:
Code
import torch

# Simulate a batch of sequences with varying content
torch.manual_seed(42)
batch_size = 4
seq_len = 8
hidden_dim = 16

# Different sequences have very different activation patterns
activations = torch.randn(batch_size, seq_len, hidden_dim)
# Exaggerate differences between sequences
activations[0] *= 0.5  # First sequence has small activations
activations[1] *= 3.0  # Second sequence has large activations
activations[2] += 5.0  # Third sequence has positive shift
activations[3] -= 5.0  # Fourth sequence has negative shift
Out[3]:
Console
Activation statistics per sequence (averaged across positions and features):
  Sequence 0: mean=  0.041, std=0.481
  Sequence 1: mean=  0.135, std=3.056
  Sequence 2: mean=  5.027, std=0.976
  Sequence 3: mean= -5.042, std=0.935

Batch statistics at position 0:
  Mean range: [-1.47, 1.68]
  Var range:  [8.38, 28.55]

The batch statistics are dominated by the extreme sequences, and these statistics change dramatically between positions. Layer normalization avoids this problem entirely by computing statistics within each sample independently, treating each token's representation as a self-contained unit to normalize.

The Layer Normalization Formula

To understand layer normalization, let's start with a fundamental question: what does it mean for a neural network layer to have "unstable" activations, and how can we fix it?

The Problem: Activation Scale Drift

Imagine a token's hidden representation as a vector of 768 numbers (a typical transformer dimension). During training, these numbers can drift: some become very large, others very small, and their collective distribution shifts unpredictably. This creates two problems. First, downstream layers must constantly adapt to changing input statistics, making learning inefficient. Second, when values grow too large or too small, gradients either explode or vanish, destabilizing training entirely.

The solution is elegant: before each token's representation moves to the next layer, we transform it to have a predictable, standardized distribution. Specifically, we want the 768 features to have zero mean and unit variance. This "resets" the scale at every layer, preventing drift from accumulating.

Step 1: Finding the Center

The first step is computing where the current distribution is centered. Given a hidden state vector x=[x1,x2,,xd]\mathbf{x} = [x_1, x_2, \ldots, x_d] representing one token, we calculate its mean:

μ=1di=1dxi\mu = \frac{1}{d} \sum_{i=1}^{d} x_i

where:

  • μ\mu: the arithmetic mean of all dd features in this token's representation
  • dd: the hidden dimension (e.g., 768 for BERT-base, 4096 for LLaMA-7B)
  • xix_i: the value of the ii-th feature

This tells us the "center of mass" of the representation. If μ=2.5\mu = 2.5, the features are shifted toward positive values; if μ=1.3\mu = -1.3, they lean negative. The goal is to shift this center to zero.

Step 2: Measuring the Spread

Next, we need to know how spread out the values are. A representation where all values cluster tightly around the mean is very different from one where values are scattered widely. We capture this with variance:

σ2=1di=1d(xiμ)2\sigma^2 = \frac{1}{d} \sum_{i=1}^{d} (x_i - \mu)^2

where:

  • σ2\sigma^2: the variance, measuring how much the features deviate from their mean
  • (xiμ)2(x_i - \mu)^2: the squared deviation of each feature from the mean

Squaring ensures that positive and negative deviations don't cancel out. If σ2\sigma^2 is large, the features are spread out; if small, they're tightly clustered. We'll use the standard deviation σ=σ2\sigma = \sqrt{\sigma^2} to rescale the values to unit variance.

Step 3: The Normalization Transform

With mean and variance in hand, we can now standardize each feature:

x^i=xiμσ2+ϵ\hat{x}_i = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}}

where:

  • x^i\hat{x}_i: the normalized value of the ii-th feature
  • ϵ\epsilon: a tiny constant (typically 10510^{-5}) added to prevent division by zero if variance is extremely small

This two-part transformation is exactly what you'd do to standardize any dataset: subtract the mean (centering at zero), then divide by the standard deviation (scaling to unit variance). The result x^i\hat{x}_i has zero mean and approximately unit variance across the dd features.

Step 4: Restoring Flexibility with Learnable Parameters

Here's where layer normalization becomes clever rather than restrictive. Forcing every representation to have exactly zero mean and unit variance might seem limiting: what if the optimal representation for some layer actually needs a different distribution?

The solution is to add learnable parameters that can undo the normalization if needed:

yi=γix^i+βiy_i = \gamma_i \cdot \hat{x}_i + \beta_i

where:

  • yiy_i: the final output for the ii-th feature
  • γi\gamma_i: a learned scale parameter for feature ii (initialized to 1)
  • βi\beta_i: a learned shift parameter for feature ii (initialized to 0)

These parameters are learned during training, just like weights and biases. If the network discovers that feature ii should have mean 3.5 and standard deviation 2.0, it can learn γi=2.0\gamma_i = 2.0 and βi=3.5\beta_i = 3.5 to recover that distribution. This means layer normalization never reduces the network's representational power: it starts from a stable baseline but can learn any distribution it needs.

The Complete Formula

Putting all the pieces together, layer normalization transforms a hidden state vector x\mathbf{x} of dimension dd into:

LayerNorm(x)=γxμσ2+ϵ+β\text{LayerNorm}(\mathbf{x}) = \gamma \odot \frac{\mathbf{x} - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta

where:

  • x=[x1,x2,,xd]\mathbf{x} = [x_1, x_2, \ldots, x_d]: the input vector representing one token's hidden state
  • μ=1di=1dxi\mu = \frac{1}{d} \sum_{i=1}^{d} x_i: the mean across all features
  • σ2=1di=1d(xiμ)2\sigma^2 = \frac{1}{d} \sum_{i=1}^{d} (x_i - \mu)^2: the variance across all features
  • ϵ\epsilon: a small stability constant (typically 10510^{-5} or 10610^{-6})
  • γ=[γ1,γ2,,γd]\gamma = [\gamma_1, \gamma_2, \ldots, \gamma_d]: learned scale parameters, initialized to ones
  • β=[β1,β2,,βd]\beta = [\beta_1, \beta_2, \ldots, \beta_d]: learned shift parameters, initialized to zeros
  • \odot: element-wise multiplication

The formula reads naturally: subtract the mean, divide by the standard deviation (with a safety epsilon), then apply a learned scale and shift.

Why Features, Not Samples?

The key insight that distinguishes layer normalization from batch normalization is the dimension over which we compute statistics. Batch normalization asks: "What's the typical value of feature ii across all samples in this batch?" Layer normalization asks: "What's the typical value across all features for this particular token?"

For transformers, the layer normalization approach is far more natural. Each token is processed independently, and we want stable statistics regardless of what other tokens or samples happen to be in the batch. This independence also means layer normalization works identically during training and inference, with no need for running statistics or batch size considerations.

Out[4]:
Visualization
Diagram showing batch normalization highlighting a column across batch samples.
Batch normalization computes statistics across samples (blue column). For each feature, we compute mean and variance over all batch samples.
Diagram showing layer normalization highlighting a row across features.
Layer normalization computes statistics across features (orange row). For each sample, we compute mean and variance over all features.

Implementing Layer Normalization from Scratch

Now that we understand the formula conceptually, let's translate it into code. Building layer normalization from scratch will solidify our understanding and reveal the implementation details that matter in practice.

The Forward Pass

Our implementation follows the mathematical steps exactly: compute mean, compute variance, normalize, then apply the learnable transformation.

In[5]:
Code
def layer_norm_forward(x, gamma, beta, eps=1e-5):
    """
    Layer normalization forward pass.

    Args:
        x: Input tensor of shape (batch, seq_len, hidden_dim)
        gamma: Scale parameters of shape (hidden_dim,)
        beta: Shift parameters of shape (hidden_dim,)
        eps: Small constant for numerical stability

    Returns:
        Normalized output with same shape as input
    """
    # Step 1: Compute mean across the feature dimension (last axis)
    mu = x.mean(dim=-1, keepdim=True)

    # Step 2: Compute variance across the feature dimension
    var = x.var(dim=-1, keepdim=True, unbiased=False)

    # Step 3: Normalize to zero mean and unit variance
    x_norm = (x - mu) / torch.sqrt(var + eps)

    # Step 4: Apply learnable affine transformation
    out = gamma * x_norm + beta

    return out, (x, x_norm, mu, var, gamma, eps)

The dim=-1 argument tells PyTorch to compute statistics across the last dimension (features), which is exactly what layer normalization requires. The keepdim=True preserves the dimension for broadcasting during subtraction and division.

A Worked Example

Let's trace through layer normalization with concrete numbers to see exactly what happens at each step.

In[6]:
Code
# Create a simple example: 2 samples, 4 tokens each, 8 features per token
torch.manual_seed(42)

batch_size = 2
seq_len = 4
hidden_dim = 8

# Generate input with non-standard distribution (mean ~2, varied std)
x = torch.randn(batch_size, seq_len, hidden_dim) * 3 + 2

# Initialize gamma=1 and beta=0 (identity affine transform)
gamma = torch.ones(hidden_dim)
beta = torch.zeros(hidden_dim)

# Apply layer normalization
output, cache = layer_norm_forward(x, gamma, beta)
Out[7]:
Console
Input statistics (per token):
  Shape: torch.Size([2, 4, 8])
  First token mean: 2.002, std: 4.497
  Second token mean: 1.178, std: 2.962

Output statistics (per token):
  First token mean: -0.000000, std: 1.069045
  Second token mean: -0.000000, std: 1.069044

The input tokens have varying means (around 2-3) and standard deviations (around 2-4), reflecting the non-standard distribution we created. After layer normalization, each token has mean essentially zero and standard deviation essentially one. The tiny deviations from exactly 0 and 1 are floating-point precision artifacts, not algorithmic issues.

Let's visualize this transformation to see exactly how layer normalization reshapes the feature distribution.

Out[8]:
Visualization
Histogram showing feature values with positive mean around 2-3.
Feature distribution before layer normalization. The 8 features of a single token show varied values with a non-zero mean (dashed line) and non-unit variance. Each feature contributes to the overall distribution.
Histogram showing feature values centered at zero with unit spread.
Feature distribution after layer normalization. The same token's features are now centered at zero with unit variance. The transformation standardizes each token independently.

The histograms make the transformation crystal clear. Before normalization, the feature values are scattered around a positive mean with varied spread. After normalization, they're centered at zero with approximately unit variance. This happens independently for every token in the sequence.

Notice that each token is normalized independently: the first token's statistics don't affect the second token's normalization. This independence is precisely what makes layer normalization suitable for transformers, where tokens must be processed in parallel and sequences have variable lengths.

The Role of Learnable Parameters

After normalizing activations to zero mean and unit variance, we've effectively forced all features into a standardized distribution. But what if the network actually needs some features to have a larger spread, or to be centered around a non-zero value? The learnable parameters γ\gamma (scale) and β\beta (shift) solve this problem.

For each feature dimension ii, the final output is:

yi=γix^i+βiy_i = \gamma_i \cdot \hat{x}_i + \beta_i

where:

  • yiy_i: the final output for the ii-th feature
  • x^i\hat{x}_i: the normalized value (zero mean, unit variance)
  • γi\gamma_i: the learned scale for feature ii, which controls the spread of values
  • βi\beta_i: the learned shift for feature ii, which controls the center of the distribution

Here's the key insight: if the network learns γi=σoriginal\gamma_i = \sigma_{\text{original}} and βi=μoriginal\beta_i = \mu_{\text{original}}, it can completely undo the normalization and recover the original distribution. This means layer normalization can never hurt the network's representational capacity; in the worst case, it learns to bypass itself entirely. In practice, the network finds an intermediate setting that benefits from stable optimization while still representing the patterns it needs.

In[9]:
Code
# Demonstrate how gamma and beta affect the output
gamma_custom = torch.tensor([2.0, 0.5, 1.5, 0.8, 1.0, 3.0, 0.3, 2.5])
beta_custom = torch.tensor([1.0, -1.0, 0.0, 2.0, -0.5, 0.5, 0.0, -2.0])

output_custom, _ = layer_norm_forward(x, gamma_custom, beta_custom)
Out[10]:
Console
Feature-wise comparison (first token):
Feature  gamma    beta     Output mean 
----------------------------------------
0        2.0      1.0      0.592       
1        0.5      -1.0     -0.809      
2        1.5      0.0      -0.175      
3        0.8      2.0      2.176       
4        1.0      -0.5     -0.953      
5        3.0      0.5      1.979       
6        0.3      0.0      -0.146      
7        2.5      -2.0     -1.582      

The output distribution for each feature is controlled by its corresponding γ\gamma and β\beta values. Features with larger γ\gamma values have wider distributions, while β\beta shifts the center. This per-feature control allows different dimensions of the representation to operate at different scales, which is crucial for transformers where different attention heads and feature dimensions may need different dynamic ranges.

Out[11]:
Visualization
Bar chart comparing feature output means with identity parameters versus custom gamma and beta values.
Effect of learned gamma and beta parameters on output distributions. Each bar shows the mean output value for one feature. With gamma=1, beta=0 (identity), all features have mean near zero. With custom parameters, each feature can have its own center (controlled by beta) and spread (controlled by gamma).

With identity parameters (gamma=1, beta=0), all feature means are near zero as expected. With custom parameters, each feature shifts to its corresponding beta value, demonstrating how the learnable parameters give each dimension independent control over its output distribution.

PyTorch's LayerNorm

PyTorch provides a built-in nn.LayerNorm that handles all these details efficiently. Let's verify our implementation matches PyTorch's behavior.

In[12]:
Code
import torch.nn as nn

# Create PyTorch LayerNorm
pytorch_ln = nn.LayerNorm(hidden_dim, eps=1e-5)

# Initialize with same parameters
with torch.no_grad():
    pytorch_ln.weight.fill_(1.0)  # gamma
    pytorch_ln.bias.fill_(0.0)  # beta

# Compare outputs
pytorch_output = pytorch_ln(x)
our_output, _ = layer_norm_forward(x, gamma, beta)
Out[13]:
Console
Maximum difference between implementations: 2.38e-07
Outputs match: True

The outputs match within floating-point precision, confirming our implementation is correct.

Layer Normalization in Transformers

In transformer architectures, layer normalization appears in two key locations: after the attention mechanism and after the feed-forward network. The original transformer used "post-norm" placement, where normalization comes after the residual connection:

In[14]:
Code
class PostNormTransformerBlock(nn.Module):
    """Transformer block with post-normalization (original architecture)."""

    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.attention = nn.MultiheadAttention(
            d_model, n_heads, batch_first=True
        )
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff), nn.GELU(), nn.Linear(d_ff, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # Attention with residual, then normalize
        attn_out, _ = self.attention(x, x, x)
        x = self.norm1(x + self.dropout(attn_out))

        # FFN with residual, then normalize
        ff_out = self.ff(x)
        x = self.norm2(x + self.dropout(ff_out))

        return x

Modern architectures like GPT-2, GPT-3, and LLaMA use "pre-norm" placement, where normalization comes before the sublayer:

In[15]:
Code
class PreNormTransformerBlock(nn.Module):
    """Transformer block with pre-normalization (modern architecture)."""

    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.attention = nn.MultiheadAttention(
            d_model, n_heads, batch_first=True
        )
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff), nn.GELU(), nn.Linear(d_ff, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # Normalize, then attention with residual
        normed = self.norm1(x)
        attn_out, _ = self.attention(normed, normed, normed)
        x = x + self.dropout(attn_out)

        # Normalize, then FFN with residual
        ff_out = self.ff(self.norm2(x))
        x = x + self.dropout(ff_out)

        return x

The difference in placement has significant implications for gradient flow, which we'll explore in detail in the pre-norm vs post-norm chapter.

Gradient Flow Through Layer Normalization

Understanding how gradients flow through layer normalization is essential for debugging training issues and understanding why normalization stabilizes training.

The backward pass through layer normalization is more complex than a simple element-wise operation because each output depends on all inputs through the mean and variance computation. When we change a single input xix_i, it affects not only its own normalized output but also the mean μ\mu and variance σ2\sigma^2, which in turn affect every output element. This coupling makes the gradient computation more intricate.

Given the loss LL and the upstream gradient Ly\frac{\partial L}{\partial y} (the gradient flowing back from later layers), we need to compute three gradients: Lx\frac{\partial L}{\partial x} for backpropagation, and Lγ\frac{\partial L}{\partial \gamma} and Lβ\frac{\partial L}{\partial \beta} for updating the learnable parameters.

Gradients for Learnable Parameters

The output of layer normalization is yi=γix^i+βiy_i = \gamma_i \cdot \hat{x}_i + \beta_i, where x^i\hat{x}_i is the normalized input. Since this is a simple affine transformation, the gradients follow directly from the chain rule.

For the scale parameter γi\gamma_i:

Lγi=nLyn,ix^n,i\frac{\partial L}{\partial \gamma_i} = \sum_{n} \frac{\partial L}{\partial y_{n,i}} \cdot \hat{x}_{n,i}

where:

  • Lγi\frac{\partial L}{\partial \gamma_i}: the gradient of the loss with respect to the ii-th scale parameter
  • nn: an index over all tokens (across batch and sequence dimensions)
  • Lyn,i\frac{\partial L}{\partial y_{n,i}}: the upstream gradient for the ii-th feature of the nn-th token
  • x^n,i\hat{x}_{n,i}: the normalized value of the ii-th feature for the nn-th token

Intuitively, this sums up how much each token's normalized value contributed to the loss through this scale parameter.

For the shift parameter βi\beta_i:

Lβi=nLyn,i\frac{\partial L}{\partial \beta_i} = \sum_{n} \frac{\partial L}{\partial y_{n,i}}

where:

  • Lβi\frac{\partial L}{\partial \beta_i}: the gradient of the loss with respect to the ii-th shift parameter

This is simply the sum of upstream gradients, since βi\beta_i adds directly to the output.

Gradient for Input

The gradient with respect to input is more involved because x^i\hat{x}_i depends on xix_i in three ways: directly, through the mean μ\mu, and through the variance σ2\sigma^2. Applying the chain rule carefully yields:

Lxi=γiσ(Lyi1dj=1dLyjx^idj=1dLyjx^j)\frac{\partial L}{\partial x_i} = \frac{\gamma_i}{\sigma} \left( \frac{\partial L}{\partial y_i} - \frac{1}{d}\sum_{j=1}^{d}\frac{\partial L}{\partial y_j} - \frac{\hat{x}_i}{d}\sum_{j=1}^{d}\frac{\partial L}{\partial y_j}\hat{x}_j \right)

where:

  • Lxi\frac{\partial L}{\partial x_i}: the gradient of the loss with respect to the ii-th input element
  • γi\gamma_i: the learned scale parameter for the ii-th feature
  • σ=σ2+ϵ\sigma = \sqrt{\sigma^2 + \epsilon}: the standard deviation (with epsilon for stability)
  • dd: the feature dimension (number of elements in the input vector)
  • x^i\hat{x}_i: the normalized input, equal to (xiμ)/σ(x_i - \mu) / \sigma

Let's break down the three terms inside the parentheses:

  1. Direct contribution Lyi\frac{\partial L}{\partial y_i}: The gradient that would flow if normalization were a simple scaling operation.

  2. Mean correction 1dj=1dLyj-\frac{1}{d}\sum_{j=1}^{d}\frac{\partial L}{\partial y_j}: Accounts for how changing xix_i affects μ\mu, which affects all outputs. This term subtracts the average gradient, centering the gradient distribution.

  3. Variance correction x^idj=1dLyjx^j-\frac{\hat{x}_i}{d}\sum_{j=1}^{d}\frac{\partial L}{\partial y_j}\hat{x}_j: Accounts for how changing xix_i affects σ2\sigma^2, which scales all outputs. This term is proportional to the normalized value x^i\hat{x}_i, meaning inputs far from the mean get larger corrections.

This formula reveals something important: the gradient for each input element depends on the gradients of all other elements through the mean and variance terms. This coupling helps distribute gradient information across features, which can improve training stability.

In[16]:
Code
def layer_norm_backward(dout, cache):
    """
    Layer normalization backward pass.

    Args:
        dout: Upstream gradient of shape (batch, seq_len, hidden_dim)
        cache: Values from forward pass

    Returns:
        dx: Gradient with respect to input
        dgamma: Gradient with respect to scale
        dbeta: Gradient with respect to shift
    """
    x, x_norm, mu, var, gamma, eps = cache
    d = x.shape[-1]

    # Gradients for learnable parameters (sum over batch and sequence)
    dgamma = (dout * x_norm).sum(dim=(0, 1))
    dbeta = dout.sum(dim=(0, 1))

    # Gradient for normalized input
    dx_norm = dout * gamma

    # Gradient for input (the complex part)
    std = torch.sqrt(var + eps)

    # Three terms in the gradient
    term1 = dx_norm / std
    term2 = dx_norm.mean(dim=-1, keepdim=True) / std
    term3 = (dx_norm * x_norm).mean(dim=-1, keepdim=True) * x_norm / std

    dx = term1 - term2 - term3

    return dx, dgamma, dbeta
In[17]:
Code
# Verify against PyTorch autograd
torch.manual_seed(42)
x_test = torch.randn(2, 4, 8, requires_grad=True)
gamma_test = torch.ones(8, requires_grad=True)
beta_test = torch.zeros(8, requires_grad=True)

# Forward pass with our implementation
output, cache = layer_norm_forward(x_test, gamma_test, beta_test)

# Create fake upstream gradient
dout = torch.randn_like(output)

# Our backward pass
dx_ours, dgamma_ours, dbeta_ours = layer_norm_backward(dout, cache)

# PyTorch autograd backward
output.backward(dout)
Out[18]:
Console
Gradient comparison with PyTorch autograd:
  dx max difference: 2.38e-07
  dgamma max difference: 0.00e+00
  dbeta max difference: 0.00e+00

The differences are on the order of 10710^{-7} or smaller, well within floating-point precision. This confirms our manual backward pass implementation correctly computes the gradients that PyTorch's autograd produces automatically.

Visualizing Layer Normalization's Effect

Let's visualize how layer normalization transforms the activation distribution during a forward pass through multiple transformer blocks.

In[19]:
Code
class StackedTransformerBlocks(nn.Module):
    """Stack of transformer blocks for visualization."""

    def __init__(self, d_model, n_heads, d_ff, n_layers, use_layernorm=True):
        super().__init__()
        self.use_layernorm = use_layernorm
        self.layers = nn.ModuleList(
            [
                PreNormTransformerBlock(d_model, n_heads, d_ff)
                for _ in range(n_layers)
            ]
        )
        if not use_layernorm:
            # Replace LayerNorm with identity
            for layer in self.layers:
                layer.norm1 = nn.Identity()
                layer.norm2 = nn.Identity()

    def forward_with_activations(self, x):
        """Return activations after each layer."""
        activations = [x.detach().clone()]
        for layer in self.layers:
            x = layer(x)
            activations.append(x.detach().clone())
        return activations


# Create models with and without layer normalization
torch.manual_seed(42)
d_model, n_heads, d_ff, n_layers = 64, 4, 256, 6

model_with_ln = StackedTransformerBlocks(
    d_model, n_heads, d_ff, n_layers, use_layernorm=True
)
model_without_ln = StackedTransformerBlocks(
    d_model, n_heads, d_ff, n_layers, use_layernorm=False
)

# Forward pass
x = torch.randn(4, 16, d_model)
acts_with_ln = model_with_ln.forward_with_activations(x)
acts_without_ln = model_without_ln.forward_with_activations(x)
Out[20]:
Visualization
Line plot showing stable activation mean near zero and std near one across 6 layers.
Activation statistics with layer normalization. The mean stays centered near zero and standard deviation remains stable around 1 across all layers, enabling stable training even in deep networks.
Line plot showing activation mean and std that may drift away from normalized values across layers.
Activation statistics without layer normalization. Without normalization, activations can drift in mean and variance, potentially causing training instability in deeper networks.

With layer normalization, activations maintain stable statistics throughout the network. Without it, activations can drift, though the effect depends heavily on initialization. In practice, this stability becomes crucial during training when weight updates can cause activation statistics to shift dramatically without normalization to anchor them.

Epsilon: A Small but Critical Detail

The epsilon parameter (ϵ\epsilon) appears in the denominator of the normalization formula:

x^i=xiμσ2+ϵ\hat{x}_i = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}}

where:

  • x^i\hat{x}_i: the normalized value for the ii-th feature
  • xix_i: the original input value
  • μ\mu: the mean across all features
  • σ2\sigma^2: the variance across all features
  • ϵ\epsilon: a small constant added to prevent division by zero

The purpose of ϵ\epsilon is to ensure numerical stability. If all input values are identical (or nearly so), the variance σ2\sigma^2 approaches zero. Without ϵ\epsilon, we would divide by zero, producing infinity or NaN. Adding a small positive constant like 10510^{-5} ensures the denominator is always positive.

The choice of epsilon can affect numerical stability, especially with mixed-precision (FP16) training where very small values may underflow.

In[21]:
Code
# Demonstrate epsilon's role with near-constant input
near_constant = torch.ones(1, 4, 8) * 5.0
near_constant[0, 0, 0] = 5.001  # Tiny variation


def test_epsilon(x, eps):
    """Test layer normalization with different epsilon values."""
    mu = x.mean(dim=-1, keepdim=True)
    var = x.var(dim=-1, keepdim=True, unbiased=False)
    try:
        x_norm = (x - mu) / torch.sqrt(var + eps)
        return x_norm.std().item(), "OK"
    except Exception as e:
        return float("nan"), str(e)


epsilons = [0, 1e-12, 1e-8, 1e-5, 1e-3]
Out[22]:
Console
Effect of epsilon on near-constant input:
  Input variance: 1.25e-07

Epsilon      Output std      Status
----------------------------------------
0e+00        NaN/Inf         Division issue
1e-12        0.507998        OK
1e-08        0.486255        OK
1e-05        0.052836        OK
1e-03        0.005312        OK

The input variance is extremely small (around 10810^{-8}), which means we're dividing by a very small number. With epsilon = 0, the output standard deviation explodes because we're essentially dividing by nearly zero. As epsilon increases, the output becomes more stable. The standard choice of 10510^{-5} strikes a balance: it's large enough to prevent numerical issues but small enough not to distort the normalization when variance is reasonably sized. A reasonable epsilon value (typically 10510^{-5} to 10610^{-6}) provides a safety net without affecting normal computations.

Layer Normalization with Different Normalized Shapes

PyTorch's nn.LayerNorm accepts a normalized_shape parameter that controls which dimensions are normalized. For transformers, we typically normalize over the feature dimension only:

In[23]:
Code
# Different normalized shapes
x = torch.randn(2, 4, 8)  # (batch, seq_len, features)

# Normalize over features only (most common for transformers)
ln_features = nn.LayerNorm(8)

# Normalize over sequence and features
ln_seq_features = nn.LayerNorm([4, 8])

# Normalize over entire sample (batch, sequence, features)
ln_all = nn.LayerNorm([4, 8])

out_features = ln_features(x)
out_seq_features = ln_seq_features(x)
Out[24]:
Console
LayerNorm with normalized_shape=(8,) - normalize over features:
  Output shape: torch.Size([2, 4, 8])
  Each token normalized independently
  Token 0,0 mean: 0.000000
  Token 0,1 mean: 0.000000

LayerNorm with normalized_shape=(4, 8) - normalize over seq+features:
  Output shape: torch.Size([2, 4, 8])
  Each sample normalized as a whole
  Sample 0 mean: -0.000000

Notice the difference: with normalized_shape=(8,), each individual token has zero mean (Token 0,0 and Token 0,1 both have mean approximately 0). With normalized_shape=(4, 8), the entire sample is normalized together, so individual tokens may have non-zero means but the sample as a whole has zero mean.

Normalizing over features only is the standard choice for transformers because it treats each token independently, matching the autoregressive nature of language models and allowing the model to process variable-length sequences.

Limitations and Impact

Layer normalization has become ubiquitous in transformer architectures, but it's not without drawbacks.

The primary computational overhead comes from computing statistics for every token at every layer. For a model with hidden dimension dd, each layer normalization requires computing a mean (sum of dd elements) and variance (sum of dd squared differences), then normalizing all dd elements. While these operations are memory-bandwidth bound rather than compute-bound on modern GPUs, they still add up in models with hundreds of layers.

The learned γ\gamma and β\beta parameters add 2d2d parameters per layer normalization, which is negligible compared to attention and FFN parameters but contributes to model complexity. More importantly, these parameters can be a source of numerical issues when they grow very large or approach zero, requiring careful initialization and sometimes explicit constraints.

Layer normalization also introduces a subtle form of coupling between features that can affect interpretability. Because each feature is normalized relative to the others, the absolute activation value of any single feature becomes less meaningful. This makes it harder to interpret individual neurons or feature dimensions in isolation.

Despite these limitations, layer normalization's impact on transformer training stability cannot be overstated. Before normalization techniques were widely adopted, training deep networks required careful learning rate tuning, extensive warmup periods, and often failed entirely for very deep models. Layer normalization enables stable training with higher learning rates, reduces sensitivity to initialization, and allows models to scale to unprecedented depths. The original transformer used layer normalization, and every major language model since has relied on some form of normalization to train successfully.

The success of layer normalization has also spurred research into alternatives. RMSNorm, which we'll cover in the next chapter, removes the mean-centering step to improve computational efficiency while maintaining most of the stability benefits.

Key Parameters

When using nn.LayerNorm in PyTorch, understanding the key parameters helps you configure it correctly for your architecture:

  • normalized_shape: The shape of the input over which to normalize. For transformers, this is typically the hidden dimension d_model (e.g., 768, 1024). You can also pass a list like [seq_len, d_model] to normalize over multiple dimensions, though normalizing over features only is the standard choice.

  • eps: The epsilon value added to the denominator for numerical stability. Default is 1e-5, which works well for most cases. For mixed-precision (FP16) training, you may need a larger value like 1e-4 to avoid underflow issues when variance is very small.

  • elementwise_affine: Whether to include learnable γ\gamma and β\beta parameters. Default is True. Setting to False removes the learnable parameters, reducing model size slightly but limiting the network's ability to learn optimal feature scales.

Summary

Layer normalization is a fundamental component of transformer architectures that enables stable training of deep models. Unlike batch normalization, which computes statistics across the batch dimension, layer normalization operates on each sample independently, making it well-suited for variable-length sequences and small batch sizes.

The core operation normalizes each token's representation to zero mean and unit variance, then applies learned scale (γ\gamma) and shift (β\beta) parameters to recover representational flexibility. This simple transformation stabilizes activations throughout the network, prevents gradient issues during training, and reduces sensitivity to initialization.

Key takeaways:

  • Feature-wise normalization: Layer normalization computes mean and variance across the feature dimension, treating each token independently
  • Learnable parameters: γ\gamma and β\beta allow the network to undo normalization when beneficial, preserving representational capacity
  • Placement matters: Pre-norm (normalize before sublayer) has become the modern standard, improving gradient flow in deep networks
  • Epsilon for stability: A small constant prevents division by zero with near-constant inputs
  • No batch dependency: Works with any batch size, including single samples during inference

Quiz

Ready to test your understanding? Take this quick quiz to reinforce what you've learned about layer normalization in transformers.

Loading component...
Track your reading progress

Sign in to mark chapters as read and track your learning journey

Sign in →

Comments

Reference

BIBTEXAcademic
@misc{layernormalizationstabilizingtransformertraining, author = {Michael Brenndoerfer}, title = {Layer Normalization: Stabilizing Transformer Training}, year = {2025}, url = {https://mbrenndoerfer.com/writing/layer-normalization-transformers-implementation}, organization = {mbrenndoerfer.com}, note = {Accessed: 2025-12-19} }
APAAcademic
Michael Brenndoerfer (2025). Layer Normalization: Stabilizing Transformer Training. Retrieved from https://mbrenndoerfer.com/writing/layer-normalization-transformers-implementation
MLAAcademic
Michael Brenndoerfer. "Layer Normalization: Stabilizing Transformer Training." 2025. Web. 12/19/2025. <https://mbrenndoerfer.com/writing/layer-normalization-transformers-implementation>.
CHICAGOAcademic
Michael Brenndoerfer. "Layer Normalization: Stabilizing Transformer Training." Accessed 12/19/2025. https://mbrenndoerfer.com/writing/layer-normalization-transformers-implementation.
HARVARDAcademic
Michael Brenndoerfer (2025) 'Layer Normalization: Stabilizing Transformer Training'. Available at: https://mbrenndoerfer.com/writing/layer-normalization-transformers-implementation (Accessed: 12/19/2025).
SimpleBasic
Michael Brenndoerfer (2025). Layer Normalization: Stabilizing Transformer Training. https://mbrenndoerfer.com/writing/layer-normalization-transformers-implementation
Michael Brenndoerfer

About the author: Michael Brenndoerfer

All opinions expressed here are my own and do not reflect the views of my employer.

Michael currently works as an Associate Director of Data Science at EQT Partners in Singapore, leading AI and data initiatives across private capital investments.

With over a decade of experience spanning private equity, management consulting, and software engineering, he specializes in building and scaling analytics capabilities from the ground up. He has published research in leading AI conferences and holds expertise in machine learning, natural language processing, and value creation through data.

Stay updated

Get notified when I publish new articles on data and AI, private equity, technology, and more.

No spam, unsubscribe anytime.

or

Create a free account to unlock exclusive features, track your progress, and join the conversation.

No popupsUnobstructed readingCommenting100% Free