Search

Search articles

Batch Normalization: Stabilizing Deep Network Training

Michael BrenndoerferDecember 15, 202524 min read

Learn how batch normalization addresses internal covariate shift by normalizing layer inputs, enabling faster training with higher learning rates.

Reading Level

Choose your expertise level to adjust how many terms are explained. Beginners see more tooltips, experts see fewer to maintain reading flow. Hover over underlined terms for instant definitions.

Batch Normalization

Training deep neural networks is notoriously difficult. As networks grow deeper, the distribution of inputs to each layer shifts during training, forcing each layer to continuously adapt to new input statistics. This phenomenon, known as internal covariate shift, slows convergence and makes training unstable. Batch normalization, introduced by Ioffe and Szegedy in 2015, addresses this by normalizing layer inputs using batch statistics, fundamentally changing how we train deep networks.

Batch normalization became one of the most influential techniques in deep learning. It enables higher learning rates, reduces sensitivity to initialization, and acts as a regularizer. Almost every modern architecture, from ResNets to Transformers, incorporates some form of normalization. Understanding how batch normalization works, and its limitations, is essential for building and debugging deep networks.

The Problem: Internal Covariate Shift

When training a neural network, each layer receives inputs from the previous layer. As the weights of earlier layers update during training, the distribution of inputs to later layers changes. This constant shifting of input distributions is called internal covariate shift.

Internal Covariate Shift

The change in the distribution of network activations due to the update of parameters in preceding layers during training.

Consider a simple two-layer network. During a gradient update, the first layer's weights change, which alters the activations it produces. The second layer, which learned to expect inputs with a certain mean and variance, now receives inputs with different statistics. This forces the second layer to re-adapt, slowing learning.

The deeper the network, the worse this problem becomes. Each layer's output depends on all preceding layers, so small changes early in the network cascade into large distributional shifts later. Networks compensate by using small learning rates, but this dramatically slows training.

Batch normalization tackles this by explicitly normalizing the inputs to each layer, ensuring they maintain consistent statistics (zero mean, unit variance) throughout training. This stabilizes the learning process and allows for much more aggressive optimization.

Batch Statistics Computation

The core idea of batch normalization is elegantly simple: take a group of activations, figure out their typical value and spread, then rescale everything so that the group has zero mean and unit variance. If you've ever standardized data before fitting a machine learning model, you've done something similar. The twist here is that we apply this standardization inside the network, at every layer, during training.

But why does this help? Think about what a layer in a neural network is trying to learn. Each neuron combines its inputs, applies weights, and produces an activation. If those incoming activations have wildly varying scales, some with values around 1000 and others around 0.01, the neuron faces an awkward optimization landscape. The gradients for large-scale features dominate, while small-scale features get ignored. By normalizing activations to a consistent scale, we give every feature an equal footing.

Let's walk through the mathematics step by step. We'll work with a single feature dimension (one neuron's output) and consider a mini-batch of mm samples. The same process applies independently to every feature dimension in the layer.

Computing the Batch Mean

Given a mini-batch B={x1,x2,,xm}\mathcal{B} = \{x_1, x_2, \ldots, x_m\} of mm activations for a particular feature, we first compute the average activation across the batch:

μB=1mi=1mxi\mu_\mathcal{B} = \frac{1}{m} \sum_{i=1}^{m} x_i

where:

  • μB\mu_\mathcal{B}: the mean of activations computed over the current mini-batch
  • mm: the number of samples in the mini-batch (batch size)
  • xix_i: the activation value for sample ii

This mean tells us where the "center" of the activations lies for this batch. If activations tend to be large and positive, the mean will be large and positive. Our goal is to shift the distribution so this center moves to zero.

Computing the Batch Variance

Next, we measure how spread out the activations are around the mean:

σB2=1mi=1m(xiμB)2\sigma^2_\mathcal{B} = \frac{1}{m} \sum_{i=1}^{m} (x_i - \mu_\mathcal{B})^2

where:

  • σB2\sigma^2_\mathcal{B}: the variance of activations over the mini-batch
  • (xiμB)2(x_i - \mu_\mathcal{B})^2: the squared deviation of sample ii from the batch mean

The variance captures the scale of the distribution. If activations range from -100 to +100, the variance will be large. If they cluster tightly between -0.5 and +0.5, the variance will be small. We need this information to rescale the distribution to unit variance.

Normalizing the Activations

With both the mean and variance in hand, we can now transform each activation:

x^i=xiμBσB2+ϵ\hat{x}_i = \frac{x_i - \mu_\mathcal{B}}{\sqrt{\sigma^2_\mathcal{B} + \epsilon}}

where:

  • x^i\hat{x}_i: the normalized activation for sample ii
  • ϵ\epsilon: a small constant (typically 10510^{-5}) added for numerical stability

Let's unpack what this formula does. The numerator xiμBx_i - \mu_\mathcal{B} centers the activation around zero by subtracting the batch mean. The denominator σB2+ϵ\sqrt{\sigma^2_\mathcal{B} + \epsilon} is the standard deviation (with a tiny epsilon for safety), which rescales the centered value to have unit variance.

The epsilon term deserves special attention. In rare cases, all activations in a batch might be identical, giving a variance of exactly zero. Without epsilon, we would divide by zero and crash. Adding 10510^{-5} prevents this while being small enough not to affect normal computations.

After this transformation, the batch of activations has mean zero and variance approximately one. This happens independently for each feature dimension, so a layer with 256 neurons computes 256 separate means and variances, producing 256 independently normalized distributions.

Seeing Normalization in Action

Let's implement batch normalization from scratch to see exactly how these formulas work together:

In[2]:
Code
import numpy as np


def batch_norm_forward(x, gamma, beta, eps=1e-5):
    """
    Forward pass for batch normalization.

    Args:
        x: Input data of shape (batch_size, features)
        gamma: Scale parameter of shape (features,)
        beta: Shift parameter of shape (features,)
        eps: Small constant for numerical stability

    Returns:
        out: Normalized output
        cache: Values needed for backward pass
    """
    # Step 1: Compute batch mean
    mu = np.mean(x, axis=0)

    # Step 2: Center the data
    x_centered = x - mu

    # Step 3: Compute batch variance
    var = np.mean(x_centered**2, axis=0)

    # Step 4: Compute standard deviation with epsilon
    std = np.sqrt(var + eps)

    # Step 5: Normalize
    x_norm = x_centered / std

    # Step 6: Scale and shift with learnable parameters
    out = gamma * x_norm + beta

    # Cache values for backward pass
    cache = (x, x_norm, mu, var, std, gamma, beta, eps)

    return out, cache
In[3]:
Code
# Create sample data: batch of 4 samples, 3 features each
np.random.seed(42)
x = np.random.randn(4, 3) * 5 + 10  # Mean ~10, std ~5

# Initialize gamma=1 and beta=0 (identity transform initially)
gamma = np.ones(3)
beta = np.zeros(3)

# Apply batch normalization
out, cache = batch_norm_forward(x, gamma, beta)
Out[4]:
Console
Input statistics:
  Mean per feature: [15.18  9.91  9.35]
  Std per feature:  [2.58 2.34 2.3 ]

Output statistics:
  Mean per feature: [-0. -0. -0.]
  Std per feature:  [0.999999 0.999999 0.999999]

The input has varying means and standard deviations across features, but after batch normalization, each feature has mean approximately zero and standard deviation approximately one. The small deviations from exactly zero and one come from floating-point precision.

Let's visualize this transformation more clearly. We'll create a larger batch and plot the distribution of activations before and after normalization:

In[5]:
Code
# Create a larger batch to see distributions clearly
np.random.seed(42)
x_large = np.random.randn(1000, 3) * np.array([2, 5, 0.5]) + np.array(
    [10, -3, 7]
)

# Apply batch normalization
gamma_viz = np.ones(3)
beta_viz = np.zeros(3)
out_large, _ = batch_norm_forward(x_large, gamma_viz, beta_viz)
Out[6]:
Visualization
Six histograms arranged in two rows: top row shows varied distributions with different centers and spreads, bottom row shows standardized distributions all centered at zero.
Activation distributions before and after batch normalization. The top row shows raw activations with different means and variances across features. The bottom row shows normalized activations, all centered at zero with unit variance.

The visualization makes the effect immediately clear. Before normalization, each feature has a different center and spread, with Feature 1 centered around 10, Feature 2 around -3, and Feature 3 around 7. After normalization, all three features share the same standardized distribution, centered at zero with comparable spread. This consistency is what enables stable training.

Learnable Scale and Shift Parameters

We've just forced all activations to have zero mean and unit variance. But wait: what if the optimal representation for a particular layer actually needs activations centered around 3.7 with a spread of 0.5? By hardcoding zero mean and unit variance, we've potentially crippled the network's ability to learn the best representation.

This is where batch normalization gets clever. After normalizing, we apply a learnable affine transformation that can recover any mean and variance the network needs:

yi=γx^i+βy_i = \gamma \hat{x}_i + \beta

where:

  • yiy_i: the final output of batch normalization for sample ii
  • γ\gamma: a learnable scale parameter (initialized to 1)
  • x^i\hat{x}_i: the normalized activation with zero mean and unit variance
  • β\beta: a learnable shift parameter (initialized to 0)

Think of γ\gamma and β\beta as the network's way of saying "I understand you've normalized everything to standard form, but let me adjust it to what I actually need." The scale parameter γ\gamma stretches or compresses the distribution, while the shift parameter β\beta moves it left or right along the number line.

Here's the beautiful part: if the network learns γ=σB\gamma = \sigma_\mathcal{B} and β=μB\beta = \mu_\mathcal{B}, it completely undoes the normalization, recovering the original activations. This means batch normalization can never hurt representational capacity. In the worst case, the network learns to bypass it entirely. In practice, the network finds some intermediate setting that benefits from the normalized optimization landscape while still representing the patterns it needs.

The crucial insight is what gets learned versus what gets computed. The mean and variance (μB\mu_\mathcal{B}, σB2\sigma^2_\mathcal{B}) are computed from the batch data, not learned parameters. They stabilize the forward pass by keeping activations well-scaled. Meanwhile, γ\gamma and β\beta are learned through backpropagation, giving the network control over the final representation. This separation decouples the mechanics of stable training from the semantics of learned features.

In[7]:
Code
# Demonstrate how gamma and beta transform the output
gamma_custom = np.array([2.0, 0.5, 3.0])  # Different scales per feature
beta_custom = np.array([1.0, -1.0, 5.0])  # Different shifts per feature

out_custom, _ = batch_norm_forward(x, gamma_custom, beta_custom)
Out[8]:
Console
With custom gamma and beta:
  gamma = [2.  0.5 3. ]
  beta = [ 1. -1.  5.]

Output statistics:
  Mean per feature: [ 1. -1.  5.]
  Std per feature:  [2.  0.5 3. ]

With custom parameters, the output mean equals β\beta and the standard deviation equals γ|\gamma|, as expected from the transformation y=γx^+βy = \gamma \hat{x} + \beta where x^\hat{x} has zero mean and unit variance.

Let's visualize how different γ\gamma and β\beta values transform the normalized distribution:

In[9]:
Code
# Generate normalized data
np.random.seed(42)
x_demo = np.random.randn(2000, 1)  # Already standard normal

# Apply different gamma/beta transformations
configs = [
    (1.0, 0.0, "Identity (γ=1, β=0)"),
    (2.0, 0.0, "Scale only (γ=2, β=0)"),
    (1.0, 3.0, "Shift only (γ=1, β=3)"),
    (0.5, -2.0, "Scale and shift (γ=0.5, β=-2)"),
]

transformed = []
for gamma_val, beta_val, label in configs:
    y = gamma_val * x_demo + beta_val
    transformed.append((y.flatten(), label))
Out[10]:
Visualization
Four overlapping histograms showing how different gamma and beta values transform a standard normal distribution to have different centers and spreads.
Effect of learnable parameters γ (scale) and β (shift) on the normalized distribution. Starting from a standard normal distribution, different parameter combinations produce distributions with varying means and variances, demonstrating how batch normalization preserves representational flexibility.

The visualization shows how the network can recover any distribution it needs. The blue distribution (identity) shows the standard normalized output. Scaling by γ=2\gamma = 2 (red) doubles the spread. Shifting by β=3\beta = 3 (green) moves the center. Combining both (purple) demonstrates that the network can learn to place the distribution anywhere with any spread. This flexibility is crucial: batch normalization stabilizes training without constraining what the network can represent.

Training vs Inference Mode

Batch normalization behaves differently during training and inference, which is a critical detail that often causes bugs.

During training, batch normalization uses the current mini-batch statistics (μB\mu_\mathcal{B}, σB2\sigma^2_\mathcal{B}) for normalization. This introduces stochasticity since different batches have slightly different statistics, which acts as a regularizer.

During inference, using batch statistics is problematic. We might have a single sample (batch size 1), making batch statistics meaningless. We also want deterministic predictions. Instead, we use running averages of mean and variance accumulated during training. After each training batch, we update:

μrunningαμrunning+(1α)μB\mu_{\text{running}} \leftarrow \alpha \cdot \mu_{\text{running}} + (1 - \alpha) \cdot \mu_\mathcal{B} σrunning2ασrunning2+(1α)σB2\sigma^2_{\text{running}} \leftarrow \alpha \cdot \sigma^2_{\text{running}} + (1 - \alpha) \cdot \sigma^2_\mathcal{B}

where:

  • μrunning\mu_{\text{running}}: the exponential moving average of batch means, used at inference time
  • σrunning2\sigma^2_{\text{running}}: the exponential moving average of batch variances
  • α\alpha: the momentum coefficient (typically 0.9 or 0.99), controlling how much weight is given to the existing running average versus the new batch statistics
  • μB\mu_\mathcal{B}, σB2\sigma^2_\mathcal{B}: the mean and variance computed from the current mini-batch

These running statistics approximate the population statistics and are used for normalization at inference time.

In[11]:
Code
class BatchNorm:
    def __init__(self, num_features, momentum=0.9, eps=1e-5):
        self.gamma = np.ones(num_features)
        self.beta = np.zeros(num_features)
        self.eps = eps
        self.momentum = momentum

        # Running statistics for inference
        self.running_mean = np.zeros(num_features)
        self.running_var = np.ones(num_features)

        self.training = True

    def forward(self, x):
        if self.training:
            # Use batch statistics
            mu = np.mean(x, axis=0)
            var = np.var(x, axis=0)

            # Update running statistics
            self.running_mean = (
                self.momentum * self.running_mean + (1 - self.momentum) * mu
            )
            self.running_var = (
                self.momentum * self.running_var + (1 - self.momentum) * var
            )
        else:
            # Use running statistics for inference
            mu = self.running_mean
            var = self.running_var

        # Normalize
        x_norm = (x - mu) / np.sqrt(var + self.eps)

        # Scale and shift
        return self.gamma * x_norm + self.beta
In[12]:
Code
# Simulate training with multiple batches
bn = BatchNorm(num_features=3)
np.random.seed(42)

# Process several training batches
for i in range(100):
    batch = np.random.randn(32, 3) * 2 + 5  # Mean=5, std=2
    _ = bn.forward(batch)
Out[13]:
Console
After 100 training batches:
  Running mean: [4.969 5.058 5.017]
  Running var:  [4.407 3.521 3.494]

Expected (population statistics):
  Mean: ~5.0
  Var:  ~4.0

Let's visualize how the running statistics evolve during training to see the convergence process:

In[14]:
Code
# Track running statistics over training
bn_track = BatchNorm(num_features=1)
np.random.seed(42)

running_means = []
running_vars = []
batch_means = []
batch_vars = []

for i in range(200):
    batch = np.random.randn(32, 1) * 2 + 5  # True mean=5, true var=4
    batch_means.append(batch.mean())
    batch_vars.append(batch.var())
    _ = bn_track.forward(batch)
    running_means.append(bn_track.running_mean[0])
    running_vars.append(bn_track.running_var[0])
Out[15]:
Visualization
Scatter plot with line showing running mean converging to 5 over 200 training batches.
Running mean convergence during training. Individual batch means (gray dots) fluctuate around the true population mean, while the exponential moving average (blue line) smoothly converges toward 5.0.
Scatter plot with line showing running variance converging to 4 over 200 training batches.
Running variance convergence during training. Individual batch variances (gray dots) show higher variability, while the exponential moving average (blue line) converges toward the true variance of 4.0.

The visualization reveals an important property: individual batch statistics (gray dots) fluctuate considerably due to sampling noise, but the exponential moving average smooths out these fluctuations, converging steadily toward the true population values. This is why we use running statistics for inference rather than batch statistics. A single test sample would give meaningless batch statistics, but the running average provides stable, reliable normalization.

Gradient Flow Through Batch Normalization

Understanding how gradients flow through batch normalization is essential for both implementation and debugging. The backward pass is more complex than the forward pass, and the reason reveals something fundamental about how batch normalization operates.

In a typical neural network layer, each input affects only its own output. The activation x1x_1 flows through the layer and contributes to y1y_1, independent of what x2x_2 or x3x_3 are doing. Batch normalization breaks this independence. When we compute the batch mean, every input contributes. When we compute the variance, every input contributes again. This means that changing x1x_1 affects not just y1y_1, but every output in the batch, because x1x_1 shifts the mean and variance that normalize everyone.

This interconnection makes the gradient computation more intricate. We can't just compute yixi\frac{\partial y_i}{\partial x_i} and call it a day. We need to account for how xix_i affects the batch statistics, and how those statistics affect all outputs. Let's work through this step by step.

Given the loss LL and the upstream gradient Ly\frac{\partial L}{\partial y}, we need to compute gradients with respect to γ\gamma, β\beta, and the input xx.

Gradients for the learnable parameters. Since yi=γx^i+βy_i = \gamma \hat{x}_i + \beta, the gradients follow directly from the chain rule:

Lγ=i=1mLyix^i\frac{\partial L}{\partial \gamma} = \sum_{i=1}^{m} \frac{\partial L}{\partial y_i} \cdot \hat{x}_i Lβ=i=1mLyi\frac{\partial L}{\partial \beta} = \sum_{i=1}^{m} \frac{\partial L}{\partial y_i}

where:

  • Lyi\frac{\partial L}{\partial y_i}: the upstream gradient flowing back from the loss with respect to output yiy_i
  • x^i\hat{x}_i: the normalized activation, which acts as a scaling factor for the γ\gamma gradient

The β\beta gradient is simply the sum of upstream gradients because β\beta shifts all outputs equally.

Gradient with respect to the normalized input. This requires applying the chain rule through the scale parameter:

Lx^i=Lyiγ\frac{\partial L}{\partial \hat{x}_i} = \frac{\partial L}{\partial y_i} \cdot \gamma

Gradient with respect to variance. Here the derivation becomes more involved because the variance affects all normalized outputs. Recall that x^i=(xiμ)/σ2+ϵ\hat{x}_i = (x_i - \mu) / \sqrt{\sigma^2 + \epsilon}, so the variance appears in the denominator:

Lσ2=i=1mLx^i(xiμ)12(σ2+ϵ)3/2\frac{\partial L}{\partial \sigma^2} = \sum_{i=1}^{m} \frac{\partial L}{\partial \hat{x}_i} \cdot (x_i - \mu) \cdot \frac{-1}{2}(\sigma^2 + \epsilon)^{-3/2}

The term (σ2+ϵ)3/2(\sigma^2 + \epsilon)^{-3/2} comes from differentiating (σ2+ϵ)1/2(\sigma^2 + \epsilon)^{-1/2} with respect to σ2\sigma^2.

Gradient with respect to mean. The mean affects the normalized output both directly (in the numerator) and indirectly through the variance:

Lμ=i=1mLx^i1σ2+ϵ+Lσ22mi=1m(xiμ)\frac{\partial L}{\partial \mu} = \sum_{i=1}^{m} \frac{\partial L}{\partial \hat{x}_i} \cdot \frac{-1}{\sqrt{\sigma^2 + \epsilon}} + \frac{\partial L}{\partial \sigma^2} \cdot \frac{-2}{m}\sum_{i=1}^{m}(x_i - \mu)

The first term captures the direct effect (the μ-\mu in the numerator), while the second term captures how changing μ\mu affects the variance calculation.

Gradient with respect to input. Finally, we combine all pathways through which xix_i affects the loss:

Lxi=Lx^i1σ2+ϵ+Lσ22(xiμ)m+Lμ1m\frac{\partial L}{\partial x_i} = \frac{\partial L}{\partial \hat{x}_i} \cdot \frac{1}{\sqrt{\sigma^2 + \epsilon}} + \frac{\partial L}{\partial \sigma^2} \cdot \frac{2(x_i - \mu)}{m} + \frac{\partial L}{\partial \mu} \cdot \frac{1}{m}

where:

  • The first term: the direct effect of xix_i on its own normalized value x^i\hat{x}_i
  • The second term: the effect of xix_i on the batch variance (each input contributes to the variance)
  • The third term: the effect of xix_i on the batch mean (each input contributes equally, hence the 1m\frac{1}{m} factor)

The key insight is that each input affects all outputs through the shared batch statistics. This creates dependencies that must be properly accounted for during backpropagation.

In[16]:
Code
def batch_norm_backward(dout, cache):
    """
    Backward pass for batch normalization.

    Args:
        dout: Upstream gradients of shape (batch_size, features)
        cache: Values from forward pass

    Returns:
        dx: Gradient with respect to input
        dgamma: Gradient with respect to scale
        dbeta: Gradient with respect to shift
    """
    x, x_norm, mu, var, std, gamma, beta, eps = cache
    m = x.shape[0]

    # Gradients of learnable parameters
    dgamma = np.sum(dout * x_norm, axis=0)
    dbeta = np.sum(dout, axis=0)

    # Gradient with respect to normalized input
    dx_norm = dout * gamma

    # Gradient with respect to variance
    dvar = np.sum(dx_norm * (x - mu) * -0.5 * (var + eps) ** (-1.5), axis=0)

    # Gradient with respect to mean
    dmu = (
        np.sum(dx_norm * -1 / std, axis=0)
        + dvar * np.sum(-2 * (x - mu), axis=0) / m
    )

    # Gradient with respect to input
    dx = dx_norm / std + dvar * 2 * (x - mu) / m + dmu / m

    return dx, dgamma, dbeta
In[17]:
Code
# Numerical gradient check
np.random.seed(42)
x = np.random.randn(4, 3)
gamma = np.random.randn(3)
beta = np.random.randn(3)

# Forward pass
out, cache = batch_norm_forward(x, gamma, beta)

# Fake upstream gradient
dout = np.random.randn(*out.shape)

# Analytical gradients
dx, dgamma, dbeta = batch_norm_backward(dout, cache)

# Numerical gradient for dx[0, 0]
eps_num = 1e-5
x_plus = x.copy()
x_plus[0, 0] += eps_num
out_plus, _ = batch_norm_forward(x_plus, gamma, beta)

x_minus = x.copy()
x_minus[0, 0] -= eps_num
out_minus, _ = batch_norm_forward(x_minus, gamma, beta)

dx_numerical = np.sum((out_plus - out_minus) * dout) / (2 * eps_num)
Out[18]:
Console
Gradient check for dx[0, 0]:
  Analytical: -0.297074
  Numerical:  -0.297074
  Relative error: 2.23e-12

The analytical and numerical gradients match closely, confirming our backward pass implementation is correct. The small relative error is due to floating-point precision in the numerical approximation.

Batch Normalization Placement

Where to place batch normalization in the network architecture has been debated since its introduction. The original paper proposed placing it before the activation function, but subsequent research and practice have explored alternatives.

Before activation (original proposal):

y=activation(BN(Wx+b))y = \text{activation}(\text{BN}(Wx + b))

where WW is the weight matrix, xx is the input, bb is the bias, BN denotes batch normalization, and activation is the nonlinear function (e.g., ReLU). The reasoning: normalize the linear transformation's output before the nonlinearity, ensuring the activation receives well-conditioned inputs.

After activation (alternative):

y=BN(activation(Wx+b))y = \text{BN}(\text{activation}(Wx + b))

Some practitioners find this works better for certain architectures. The intuition is that normalizing the activation's output directly controls what the next layer sees.

Without bias:

When using batch normalization before or after a linear layer, the bias term becomes redundant. The batch norm's β\beta parameter can learn any shift, so we simplify to:

y=BN(Wx)y = \text{BN}(Wx)

Most frameworks use this optimization by default.

In[19]:
Code
import torch.nn as nn

# Common pattern: Linear without bias, then BatchNorm, then activation
model_before = nn.Sequential(
    nn.Linear(256, 128, bias=False),  # No bias needed
    nn.BatchNorm1d(128),
    nn.ReLU(),
    nn.Linear(128, 64, bias=False),
    nn.BatchNorm1d(64),
    nn.ReLU(),
    nn.Linear(64, 10),  # Final layer usually keeps bias
)

# Alternative: Linear, activation, then BatchNorm
model_after = nn.Sequential(
    nn.Linear(256, 128),
    nn.ReLU(),
    nn.BatchNorm1d(128),
    nn.Linear(128, 64),
    nn.ReLU(),
    nn.BatchNorm1d(64),
    nn.Linear(64, 10),
)
Out[20]:
Console
Parameter count comparison:
  BN before activation: 41,994
  BN after activation:  42,186

The placement choice often comes down to empirical performance on your specific task. Both approaches work well in practice, though the "before activation" pattern remains more common in modern architectures.

Batch Normalization in Practice with PyTorch

Let's see how batch normalization integrates into a complete training loop using PyTorch:

In[21]:
Code
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset


# Create a simple classification network with batch normalization
class MLPWithBatchNorm(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim, bias=False)
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.bn2 = nn.BatchNorm1d(hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.bn1(self.fc1(x)))
        x = self.relu(self.bn2(self.fc2(x)))
        x = self.fc3(x)
        return x


# Create synthetic dataset
torch.manual_seed(42)
n_samples = 1000
n_features = 20
n_classes = 5

X = torch.randn(n_samples, n_features)
y = torch.randint(0, n_classes, (n_samples,))

dataset = TensorDataset(X, y)
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)
In[22]:
Code
import torch.optim as optim

# Initialize model, loss, and optimizer
model = MLPWithBatchNorm(input_dim=20, hidden_dim=64, output_dim=5)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

# Training loop
losses = []
for epoch in range(50):
    model.train()  # Important: sets batch norm to training mode
    epoch_loss = 0
    for batch_x, batch_y in train_loader:
        optimizer.zero_grad()
        outputs = model(batch_x)
        loss = criterion(outputs, batch_y)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    losses.append(epoch_loss / len(train_loader))
Out[23]:
Visualization
Line plot showing training loss decreasing from about 1.6 to 1.4 over 50 epochs with smooth convergence.
Training loss over epochs for an MLP with batch normalization. The rapid initial decrease and smooth convergence demonstrate how batch normalization enables stable training with a relatively high learning rate.

The training converges smoothly even with a relatively high learning rate of 0.01. Without batch normalization, the same network might require a much smaller learning rate or fail to converge at all.

In[24]:
Code
# Inspect learned batch norm statistics
model.eval()  # Switch to evaluation mode
Out[25]:
Console
Batch Normalization Layer 1 statistics:
  Gamma (scale): mean=1.169, std=0.153
  Beta (shift):  mean=-0.329, std=0.206
  Running mean:  mean=0.023
  Running var:   mean=4.368

The learned γ\gamma values hover around 1 and β\beta values around 0, indicating the network hasn't needed to drastically rescale the normalized activations. The running statistics show the accumulated mean and variance from training.

Comparing Training With and Without Batch Normalization

To appreciate batch normalization's impact, let's compare identical networks with and without it:

In[26]:
Code
class MLPWithoutBatchNorm(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x


def train_model(model, train_loader, epochs=50, lr=0.01):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    losses = []

    for epoch in range(epochs):
        model.train()
        epoch_loss = 0
        for batch_x, batch_y in train_loader:
            optimizer.zero_grad()
            outputs = model(batch_x)
            loss = criterion(outputs, batch_y)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        losses.append(epoch_loss / len(train_loader))

    return losses


# Train both models
torch.manual_seed(42)
model_bn = MLPWithBatchNorm(input_dim=20, hidden_dim=64, output_dim=5)
losses_bn = train_model(model_bn, train_loader, epochs=50, lr=0.01)

torch.manual_seed(42)
model_no_bn = MLPWithoutBatchNorm(input_dim=20, hidden_dim=64, output_dim=5)
losses_no_bn = train_model(model_no_bn, train_loader, epochs=50, lr=0.01)
Out[27]:
Visualization
Line plot comparing two training curves: blue line for batch norm showing faster decrease, orange line without batch norm showing slower convergence.
Training loss comparison between networks with and without batch normalization using identical hyperparameters. Batch normalization enables faster initial convergence and reaches a lower final loss.

Both networks eventually converge, but the batch normalized version converges faster and more smoothly. The difference becomes more pronounced with deeper networks and higher learning rates.

Limitations and Practical Considerations

Batch normalization transformed deep learning, but it comes with significant limitations that have motivated alternative normalization techniques.

The most fundamental limitation is batch size dependency. Batch normalization requires sufficiently large batches to compute meaningful statistics. With very small batches (fewer than 8-16 samples), the batch mean and variance become noisy estimates of the population statistics, destabilizing training. This is particularly problematic in domains like object detection and 3D medical imaging, where memory constraints force small batch sizes, or in reinforcement learning, where samples within a batch may be highly correlated. In these settings, practitioners either use alternative normalizations (Layer Norm, Group Norm, Instance Norm) or accumulate gradient updates across multiple forward passes before applying them.

The training/inference discrepancy is another source of subtle bugs. During training, batch statistics introduce stochasticity that can differ significantly from the running statistics used at inference. If the data distribution at test time differs from training, the running statistics may be inappropriate. A common symptom is a model that performs well during training but poorly at inference, often traced back to batch normalization layers that haven't accumulated representative statistics. Always ensure you process enough training data for running statistics to stabilize, and call model.eval() during inference.

Batch normalization also introduces dependencies between samples in a batch. Each sample's normalized value depends on all other samples through the shared mean and variance. This breaks the independence assumption used in some theoretical analyses and can cause issues when the batch composition is non-random (for example, when all samples in a batch come from the same class). In sequence models, this dependency is particularly problematic, which is why Transformers use Layer Normalization instead.

Finally, batch normalization adds computational overhead. Computing statistics, normalizing, and applying learnable parameters adds operations during both forward and backward passes. The overhead is typically small compared to the benefits, but it's not negligible in latency-sensitive applications.

Despite these limitations, batch normalization remains extremely popular because it works well for most feedforward and convolutional networks with reasonable batch sizes. Understanding when it's appropriate, and when to use alternatives, is a key practical skill.

Key Parameters

When using batch normalization in PyTorch or implementing it from scratch, the following parameters control its behavior:

  • num_features: The number of features (channels) to normalize. Must match the size of the feature dimension in the input tensor. For a fully connected layer with 128 outputs, use BatchNorm1d(128).

  • eps (default: 10510^{-5}): A small constant added to the variance for numerical stability. Prevents division by zero when variance is very small. Rarely needs adjustment.

  • momentum (default: 0.1 in PyTorch): Controls the running statistics update rate. PyTorch uses the convention (1momentum)running_stat+momentumbatch_stat(1 - \text{momentum}) \cdot \text{running\_stat} + \text{momentum} \cdot \text{batch\_stat}, so higher values mean faster adaptation to recent batches. Values between 0.01 and 0.1 work well for most cases.

  • affine (default: True): Whether to include learnable γ\gamma and β\beta parameters. Setting to False removes the scale and shift, which is rarely useful but can reduce parameters in specific architectures.

  • track_running_stats (default: True): Whether to maintain running mean and variance for inference. Set to False only if you want batch statistics at inference time (unusual).

The most common configuration uses default values with bias=False on the preceding linear layer, since the batch norm's β\beta parameter subsumes the bias term.

Summary

Batch normalization addresses internal covariate shift by normalizing layer inputs using mini-batch statistics. For each feature, it computes the batch mean and variance, normalizes to zero mean and unit variance, then applies learnable scale (γ\gamma) and shift (β\beta) parameters. This decoupling of activation statistics from learned representations stabilizes training and enables higher learning rates.

The key concepts covered in this chapter:

  • Internal covariate shift: The shifting input distributions that make deep network training difficult
  • Batch statistics: Mean and variance computed per feature across the mini-batch
  • Learnable parameters: γ\gamma and β\beta that preserve representational capacity
  • Training vs inference: Batch statistics during training, running averages during inference
  • Gradient flow: The backward pass accounts for how each input affects batch statistics
  • Placement: Usually before activation, often without bias in preceding linear layers
  • Limitations: Batch size dependency, training/inference discrepancy, sample dependencies

Batch normalization was a breakthrough that enabled training of much deeper networks. While alternatives like Layer Normalization have become preferred for certain architectures (particularly Transformers), understanding batch normalization remains essential, as it forms the foundation for the entire family of normalization techniques used in modern deep learning.

Quiz

Ready to test your understanding? Take this quick quiz to reinforce what you've learned about batch normalization.

Loading component...

Comments

Reference

BIBTEXAcademic
@misc{batchnormalizationstabilizingdeepnetworktraining, author = {Michael Brenndoerfer}, title = {Batch Normalization: Stabilizing Deep Network Training}, year = {2025}, url = {https://mbrenndoerfer.com/writing/batch-normalization-deep-learning}, organization = {mbrenndoerfer.com}, note = {Accessed: 2025-12-15} }
APAAcademic
Michael Brenndoerfer (2025). Batch Normalization: Stabilizing Deep Network Training. Retrieved from https://mbrenndoerfer.com/writing/batch-normalization-deep-learning
MLAAcademic
Michael Brenndoerfer. "Batch Normalization: Stabilizing Deep Network Training." 2025. Web. 12/15/2025. <https://mbrenndoerfer.com/writing/batch-normalization-deep-learning>.
CHICAGOAcademic
Michael Brenndoerfer. "Batch Normalization: Stabilizing Deep Network Training." Accessed 12/15/2025. https://mbrenndoerfer.com/writing/batch-normalization-deep-learning.
HARVARDAcademic
Michael Brenndoerfer (2025) 'Batch Normalization: Stabilizing Deep Network Training'. Available at: https://mbrenndoerfer.com/writing/batch-normalization-deep-learning (Accessed: 12/15/2025).
SimpleBasic
Michael Brenndoerfer (2025). Batch Normalization: Stabilizing Deep Network Training. https://mbrenndoerfer.com/writing/batch-normalization-deep-learning
Michael Brenndoerfer

About the author: Michael Brenndoerfer

All opinions expressed here are my own and do not reflect the views of my employer.

Michael currently works as an Associate Director of Data Science at EQT Partners in Singapore, leading AI and data initiatives across private capital investments.

With over a decade of experience spanning private equity, management consulting, and software engineering, he specializes in building and scaling analytics capabilities from the ground up. He has published research in leading AI conferences and holds expertise in machine learning, natural language processing, and value creation through data.

Stay updated

Get notified when I publish new articles on data and AI, private equity, technology, and more.

No spam, unsubscribe anytime.

or

Create a free account to unlock exclusive features, track your progress, and join the conversation.

No popupsUnobstructed readingCommenting100% Free