Search

Search articles

RMSNorm: Efficient Normalization for Modern LLMs

Michael BrenndoerferUpdated June 7, 202537 min read

Learn RMSNorm, the simpler alternative to LayerNorm used in LLaMA, Mistral, and modern LLMs. Understand how removing mean centering improves efficiency while maintaining model quality.

Track your reading progress

Sign in to mark chapters as read and track your learning journey

Sign in →
Reading Level

Choose your expertise level to adjust how many terms are explained. Beginners see more tooltips, experts see fewer to maintain reading flow. Hover over underlined terms for instant definitions.

RMSNorm

Layer normalization stabilizes training by centering activations around zero and scaling them to unit variance. But does it need both operations? RMSNorm, introduced by Zhang and Sennrich in 2019, answers with a surprising finding: mean centering is often unnecessary. By removing it, RMSNorm achieves comparable or better performance with reduced computational cost.

This simplification might seem minor, but it matters in practice. Modern large language models perform normalization at every layer, often multiple times per transformer block. When you're running billions of forward passes during training or serving millions of inference requests, even small efficiency gains compound significantly. LLaMA, Mistral, and most contemporary open-source LLMs have adopted RMSNorm as their standard normalization layer.

From LayerNorm to RMSNorm

To appreciate what RMSNorm removes and why that removal works, we first need to understand what LayerNorm does and the distinct roles of its two operations.

Decomposing LayerNorm: Two Operations, Two Purposes

When a vector of activations passes through a neural network layer, its values can drift to arbitrary scales. Some elements might be large and positive, others small and negative. This variability creates problems: gradients become uneven, optimization landscapes shift during training, and networks become sensitive to initialization. LayerNorm addresses this by forcing activations into a consistent statistical profile.

Given an input vector x\mathbf{x} with dd elements, LayerNorm performs two sequential transformations:

  1. Centering: Subtract the mean so values cluster around zero
  2. Scaling: Divide by the standard deviation so values have unit spread

After these operations, it applies learnable parameters to let the network recover any distribution it needs. The complete formula is:

LayerNorm(x)=γxμσ2+ϵ+β\text{LayerNorm}(\mathbf{x}) = \gamma \odot \frac{\mathbf{x} - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta

where:

  • xRd\mathbf{x} \in \mathbb{R}^d: the input vector to normalize
  • μ=1di=1dxi\mu = \frac{1}{d} \sum_{i=1}^{d} x_i: the mean of the input elements
  • σ2=1di=1d(xiμ)2\sigma^2 = \frac{1}{d} \sum_{i=1}^{d} (x_i - \mu)^2: the variance of the input elements
  • γRd\gamma \in \mathbb{R}^d: learnable scale parameters (initialized to 1)
  • βRd\beta \in \mathbb{R}^d: learnable shift parameters (initialized to 0)
  • ϵ\epsilon: a small constant for numerical stability (typically 10510^{-5} or 10610^{-6})
  • \odot: element-wise multiplication

The numerator (xμ)(\mathbf{x} - \mu) centers the data: every element is adjusted so the new mean is exactly zero. The denominator σ2+ϵ\sqrt{\sigma^2 + \epsilon} scales the data: it measures how spread out the values are and shrinks or expands them to unit variance. Together, these operations produce a standardized distribution, and the learnable parameters γ\gamma and β\beta then reshape it as the network sees fit.

But here's the key question Zhang and Sennrich asked: do we actually need both operations?

The Hypothesis: Is Mean Centering Necessary?

The intuition behind centering is compelling. Shifting activations to have zero mean creates a symmetric distribution around the origin. Optimization algorithms can work with positive and negative gradients more evenly. Activations don't accumulate bias through layers. It seems like a good idea.

But consider what happens after normalization. The learnable shift parameter β\beta can move the output to any mean value. If the network learns β=μ\beta = \mu (setting the shift equal to the original mean), it completely undoes the centering we just performed. The network has the freedom to recover the original distribution.

This observation reveals something subtle: centering is not a hard constraint. It's a soft regularization that the network can override if needed. The real question becomes whether the network benefits enough from the centering operation to justify its computational cost.

The cost is not trivial. Computing the mean requires summing all dd elements and dividing. Then we subtract this mean from every element. Only after that can we compute the variance (which requires another pass through the data). If we could skip the centering step entirely, we'd eliminate:

  • One reduction operation (computing μ\mu)
  • dd subtraction operations (computing xiμx_i - \mu for each element)
  • Dependency chains that limit parallelization

Zhang and Sennrich hypothesized that in transformer architectures with proper initialization, activations naturally stay roughly centered anyway. Residual connections add the original input back, preventing values from drifting too far from zero. If activations are already near-centered, explicitly centering them might be redundant.

Out[2]:
Visualization
Histogram showing the distribution of activation means across multiple simulated transformer layers, clustered tightly around zero.
Distribution of layer-wise mean values in a simulated transformer with residual connections. The means cluster tightly around zero, demonstrating that activations in well-initialized networks are naturally near-centered. This explains why removing explicit mean centering in RMSNorm has minimal impact on model quality.

The RMSNorm Formulation: Keeping Only What Matters

RMSNorm tests this hypothesis by removing mean centering entirely. Instead of measuring spread around the mean (standard deviation), it measures spread around zero (root mean square). This single change eliminates the need to compute or subtract the mean.

The root mean square (RMS) of a vector captures its typical magnitude:

RMS(x)=1di=1dxi2\text{RMS}(\mathbf{x}) = \sqrt{\frac{1}{d} \sum_{i=1}^{d} x_i^2}

where:

  • RMS(x)\text{RMS}(\mathbf{x}): the root mean square of the input vector
  • dd: the number of elements in the input vector
  • xix_i: the ii-th element of the input vector
  • 1di=1dxi2\frac{1}{d} \sum_{i=1}^{d} x_i^2: the mean of squared values (the "mean square")

Think of RMS as answering the question: "How big are these values, on average?" It squares each element (making everything positive), averages them, and takes the square root (returning to the original scale). Large values contribute more; small values contribute less. The result is a single number representing the typical magnitude.

Root Mean Square

The root mean square (RMS) of a set of values is the square root of the arithmetic mean of their squares. Unlike standard deviation, which measures spread around the mean, RMS measures the magnitude of values around zero. For a zero-mean distribution, RMS equals standard deviation.

Dividing by the RMS normalizes the vector to have unit RMS. Values that were large become order-1; values that were small stay small but in proportion. This is the core of RMSNorm:

RMSNorm(x)=γxRMS(x)+ϵ\text{RMSNorm}(\mathbf{x}) = \gamma \odot \frac{\mathbf{x}}{\text{RMS}(\mathbf{x}) + \epsilon}

Expanding the RMS definition:

RMSNorm(x)=γx1di=1dxi2+ϵ\text{RMSNorm}(\mathbf{x}) = \gamma \odot \frac{\mathbf{x}}{\sqrt{\frac{1}{d} \sum_{i=1}^{d} x_i^2} + \epsilon}

where:

  • γRd\gamma \in \mathbb{R}^d: learnable scale parameters (initialized to 1)
  • ϵ\epsilon: a small constant for numerical stability

Notice what's absent compared to LayerNorm:

  • No mean subtraction (xμ\mathbf{x} - \mu): We normalize around zero, not around the data's center
  • No β\beta shift parameter: Since we don't center, we don't need to un-center

RMSNorm is purely a scaling operation. Each input element is divided by the same scalar (the RMS), then multiplied by its corresponding learned scale factor. The operation preserves the relative relationships between elements while bringing everything to a consistent magnitude.

Mathematical Connection Between RMS and Standard Deviation

We've claimed that RMSNorm works because activations in neural networks tend to be near-centered. But how near is near enough? To answer this precisely, we need to understand the mathematical relationship between what RMSNorm computes (the RMS) and what LayerNorm computes (the standard deviation).

This section derives the exact connection between these two quantities. The derivation is worth following carefully because it reveals a beautiful geometric relationship and tells us exactly when the two normalizations diverge.

The Goal: Relating RMS to Standard Deviation

We want to express RMS(x)\text{RMS}(\mathbf{x}) in terms of σ\sigma (the standard deviation) and μ\mu (the mean). If we can do this, we'll know how much the two normalizations differ based on the mean alone.

Start with the definition of variance, which measures how spread out values are around their mean:

σ2=1di=1d(xiμ)2\sigma^2 = \frac{1}{d} \sum_{i=1}^{d} (x_i - \mu)^2

where:

  • σ2\sigma^2: the variance of the input vector
  • dd: the number of elements in the vector
  • xix_i: the ii-th element of the input vector
  • μ\mu: the mean of the input elements

This formula takes each element, measures its distance from the mean, squares that distance, and averages all the squared distances. The square root of variance gives us the standard deviation σ\sigma, which has the same units as the original data.

Step-by-Step Derivation

Our strategy is to expand the variance formula and recognize familiar terms. We'll use the algebraic identity (ab)2=a22ab+b2(a - b)^2 = a^2 - 2ab + b^2 to expand the squared term:

σ2=1di=1d(xi22xiμ+μ2)\sigma^2 = \frac{1}{d} \sum_{i=1}^{d} (x_i^2 - 2x_i\mu + \mu^2)

Now we can distribute the sum across the three terms. Each term gets its own summation:

σ2=1di=1dxi22μdi=1dxi+1di=1dμ2\sigma^2 = \frac{1}{d} \sum_{i=1}^{d} x_i^2 - \frac{2\mu}{d} \sum_{i=1}^{d} x_i + \frac{1}{d} \sum_{i=1}^{d} \mu^2

Let's simplify each term:

  • First term: 1di=1dxi2\frac{1}{d} \sum_{i=1}^{d} x_i^2 is the mean of squared values. This is exactly RMS(x)2\text{RMS}(\mathbf{x})^2.
  • Second term: The sum i=1dxi\sum_{i=1}^{d} x_i equals dμd \cdot \mu by the definition of mean. So this term becomes 2μd(dμ)=2μ2\frac{2\mu}{d} \cdot (d\mu) = 2\mu^2.
  • Third term: We're summing the constant μ2\mu^2 exactly dd times, so this equals dμ2d=μ2\frac{d \cdot \mu^2}{d} = \mu^2.

Substituting these simplifications:

σ2=RMS(x)22μ2+μ2=RMS(x)2μ2\sigma^2 = \text{RMS}(\mathbf{x})^2 - 2\mu^2 + \mu^2 = \text{RMS}(\mathbf{x})^2 - \mu^2

Rearranging to isolate the RMS:

RMS(x)2=σ2+μ2\text{RMS}(\mathbf{x})^2 = \sigma^2 + \mu^2

Taking square roots of both sides:

RMS(x)=σ2+μ2\text{RMS}(\mathbf{x}) = \sqrt{\sigma^2 + \mu^2}

where:

  • RMS(x)\text{RMS}(\mathbf{x}): the root mean square of the input vector
  • σ\sigma: the standard deviation of the input vector
  • μ\mu: the mean of the input vector

The Geometric Interpretation

This formula has a beautiful geometric meaning. Think of σ\sigma and μ\mu as the two legs of a right triangle. The RMS is the hypotenuse. The Pythagorean theorem tells us that hypotenuse2=leg12+leg22\text{hypotenuse}^2 = \text{leg}_1^2 + \text{leg}_2^2, which is exactly what we derived.

Out[3]:
Visualization
Geometric diagram showing right triangles where RMS is the hypotenuse and standard deviation and mean are the legs, demonstrating the Pythagorean relationship.
The Pythagorean relationship between RMS, standard deviation, and mean. RMS forms the hypotenuse of a right triangle with σ and μ as legs. When μ approaches zero (dashed triangle), RMS approaches σ, explaining why the two normalizations converge for centered data.

This geometric picture immediately reveals when RMSNorm and LayerNorm behave similarly:

  • When μ=0\mu = 0: The triangle collapses to a line. The hypotenuse equals the remaining leg: RMS=σ\text{RMS} = \sigma. The two normalizations are identical.
  • When μ\mu is small relative to σ\sigma: The triangle is nearly flat. The hypotenuse is only slightly longer than σ\sigma. The normalizations are nearly equivalent.
  • When μ\mu is comparable to or larger than σ\sigma: The triangle is more equilateral or tall. The hypotenuse differs significantly from σ\sigma. The normalizations diverge.

The crucial question for practical use becomes: in real neural networks, how large is μ\mu compared to σ\sigma?

In[4]:
Code
import numpy as np

np.random.seed(42)


def layer_norm(x, gamma, eps=1e-6):
    """Standard layer normalization with centering."""
    mu = np.mean(x, axis=-1, keepdims=True)
    var = np.var(x, axis=-1, keepdims=True)
    x_norm = (x - mu) / np.sqrt(var + eps)
    return gamma * x_norm


def rms_norm(x, gamma, eps=1e-6):
    """RMSNorm without centering."""
    rms = np.sqrt(np.mean(x**2, axis=-1, keepdims=True))
    x_norm = x / (rms + eps)
    return gamma * x_norm


# Generate random activations (typical neural network values)
d = 256  # Hidden dimension
batch_size = 32

# Simulate activations after a linear layer + activation
activations = np.random.randn(batch_size, d) * 0.5 + 0.1  # Small positive bias

gamma = np.ones(d)  # Identity scaling for comparison
Out[5]:
Console
Input statistics (per sample):
  Mean magnitude: 0.0976
  Std deviation:  0.5024
  RMS value:      0.5129

Comparison of LayerNorm vs RMSNorm outputs:
  Mean absolute difference: 0.190481
  Max absolute difference:  0.532921
  Correlation:              0.997639

The outputs are highly correlated but not identical. The differences arise from the mean subtraction in LayerNorm. Let's visualize how the two normalizations compare across different input distributions.

In[6]:
Code
# Test with different mean magnitudes
mean_offsets = [0.0, 0.5, 1.0, 2.0, 5.0]
results = []

for offset in mean_offsets:
    x = np.random.randn(1000, d) * 0.5 + offset

    ln_out = layer_norm(x, gamma)
    rms_out = rms_norm(x, gamma)

    # Compute RMS difference
    rms_diff = np.sqrt(np.mean((ln_out - rms_out) ** 2))

    # Compute statistics
    input_mean = np.mean(np.abs(np.mean(x, axis=-1)))
    input_std = np.mean(np.std(x, axis=-1))

    results.append(
        {
            "offset": offset,
            "input_mean": input_mean,
            "input_std": input_std,
            "rms_difference": rms_diff,
        }
    )
Out[7]:
Visualization
Line plot showing RMS difference between LayerNorm and RMSNorm increasing as input mean offset grows from 0 to 5.
Difference between LayerNorm and RMSNorm outputs as input mean increases. When inputs are centered near zero, both normalizations produce nearly identical results. As the mean shifts away from zero, the difference grows because RMSNorm doesn't subtract the mean.

The key insight emerges: when inputs are approximately centered (mean near zero), the two normalizations are nearly equivalent. In deep neural networks with proper initialization and residual connections, activations tend to stay roughly centered. This explains why RMSNorm works as well as LayerNorm in practice.

Implementation

With the mathematical foundation established, let's translate our understanding into working code. We'll build RMSNorm from first principles, compare it with LayerNorm, and observe how the two behave on real data.

Building RMSNorm Step by Step

The implementation follows directly from the formula. For each input vector, we need to:

  1. Square all elements
  2. Compute the mean of these squares
  3. Take the square root (adding epsilon for stability)
  4. Divide the original input by this RMS value
  5. Multiply by the learnable scale parameters

Let's implement both RMSNorm and LayerNorm as Python classes:

In[8]:
Code
class RMSNorm:
    """
    Root Mean Square Layer Normalization.

    Normalizes inputs by their RMS value without mean centering.
    """

    def __init__(self, dim, eps=1e-6):
        """
        Initialize RMSNorm layer.

        Args:
            dim: Dimension of the input features
            eps: Small constant for numerical stability
        """
        self.eps = eps
        self.weight = np.ones(dim)  # Learnable scale parameter (gamma)

    def __call__(self, x):
        """
        Apply RMSNorm to input.

        Args:
            x: Input tensor of shape (..., dim)

        Returns:
            Normalized tensor of the same shape
        """
        # Compute RMS along last dimension
        rms = np.sqrt(np.mean(x**2, axis=-1, keepdims=True) + self.eps)

        # Normalize and scale
        return self.weight * (x / rms)

    def _compute_rms(self, x):
        """Helper to compute RMS for analysis."""
        return np.sqrt(np.mean(x**2, axis=-1, keepdims=True) + self.eps)


class LayerNorm:
    """
    Standard Layer Normalization for comparison.
    """

    def __init__(self, dim, eps=1e-6):
        self.eps = eps
        self.weight = np.ones(dim)
        self.bias = np.zeros(dim)

    def __call__(self, x):
        mean = np.mean(x, axis=-1, keepdims=True)
        var = np.var(x, axis=-1, keepdims=True)
        x_norm = (x - mean) / np.sqrt(var + self.eps)
        return self.weight * x_norm + self.bias

Testing the Implementations

With both classes defined, let's apply them to realistic input data and examine the output statistics. We'll use a tensor shaped like typical transformer activations: batch size 16, sequence length 128, hidden dimension 512.

In[9]:
Code
# Test both implementations
dim = 512
rms_norm_layer = RMSNorm(dim)
layer_norm_layer = LayerNorm(dim)

# Create test input
x_test = np.random.randn(16, 128, dim)  # (batch, seq_len, dim)

# Apply both normalizations
rms_output = rms_norm_layer(x_test)
ln_output = layer_norm_layer(x_test)
Out[10]:
Visualization
Three overlapping histograms showing input distribution, RMSNorm output, and LayerNorm output, demonstrating how each normalization transforms the data.
Distribution of values before and after applying RMSNorm and LayerNorm. Both normalizations concentrate values around unit scale, but LayerNorm centers the distribution at zero while RMSNorm preserves the original mean structure.
Out[11]:
Console
Output statistics after normalization:

RMSNorm output:
  Mean:     0.000651
  Std:      0.999999
  RMS:      0.999999

LayerNorm output:
  Mean:     -0.000000
  Std:      0.999999
  RMS:      0.999999

The statistics reveal the key behavioral difference between the two normalizations. Look at the RMS values: RMSNorm produces output with RMS close to 1.0, which is exactly what it's designed to do. The mean, however, is not forced to zero.

LayerNorm tells a different story. Its output has near-zero mean (by design) and unit standard deviation. Because the mean is zero, the RMS and standard deviation are approximately equal, confirming our earlier mathematical derivation.

Despite these differences, both approaches accomplish the primary goal: they normalize the magnitude of activations to a consistent scale, which is what matters for stable training.

Computational Efficiency

The primary motivation for RMSNorm is computational efficiency. Let's count the operations required for each normalization.

LayerNorm operations per element:

  1. Compute mean: 1 addition (accumulated) + 1 division (shared across elements)
  2. Subtract mean: 1 subtraction
  3. Compute squared difference: 1 subtraction + 1 multiplication
  4. Compute variance: 1 addition (accumulated) + 1 division (shared)
  5. Add epsilon and take square root: 1 addition + 1 sqrt (shared)
  6. Divide by std: 1 division
  7. Scale by gamma and add beta: 1 multiplication + 1 addition

RMSNorm operations per element:

  1. Square each element: 1 multiplication
  2. Compute mean of squares: 1 addition (accumulated) + 1 division (shared)
  3. Add epsilon and take square root: 1 addition + 1 sqrt (shared)
  4. Divide by RMS: 1 division
  5. Scale by gamma: 1 multiplication

The reduction in operations comes from eliminating the mean computation and subtraction, plus removing the bias parameter. Let's measure the actual speedup:

In[12]:
Code
import time


def benchmark_normalization(norm_func, x, n_iterations=1000):
    """Benchmark a normalization function."""
    # Warmup
    for _ in range(100):
        _ = norm_func(x)

    # Timed runs
    start = time.perf_counter()
    for _ in range(n_iterations):
        _ = norm_func(x)
    end = time.perf_counter()

    return (end - start) / n_iterations * 1000  # Convert to milliseconds


# Benchmark with different sizes
sizes = [256, 512, 1024, 2048, 4096]
batch_seq = 64  # batch * sequence length combined

benchmark_results = []
for dim in sizes:
    x = np.random.randn(batch_seq, dim).astype(np.float32)
    rms_norm_layer = RMSNorm(dim)
    layer_norm_layer = LayerNorm(dim)

    rms_time = benchmark_normalization(rms_norm_layer, x)
    ln_time = benchmark_normalization(layer_norm_layer, x)

    benchmark_results.append(
        {
            "dim": dim,
            "rmsnorm_ms": rms_time,
            "layernorm_ms": ln_time,
            "speedup": ln_time / rms_time,
        }
    )
Out[13]:
Console
Benchmark results (pure NumPy, CPU):

 Dimension | RMSNorm (ms) | LayerNorm (ms) |  Speedup
-------------------------------------------------------
       256 |       0.0258 |         0.0530 |    2.06x
       512 |       0.0497 |         0.0785 |    1.58x
      1024 |       0.0663 |         0.1342 |    2.02x
      2048 |       0.1248 |         0.2457 |    1.97x
      4096 |       0.2375 |         0.4694 |    1.98x

The benchmark shows RMSNorm consistently outperforming LayerNorm across all dimensions. The speedup factor varies slightly with dimension, but RMSNorm is typically 10-30% faster in this CPU-based NumPy implementation. On GPUs with optimized CUDA kernels, the speedup is typically in the 5-15% range due to different bottlenecks.

Out[14]:
Visualization
Bar chart comparing RMSNorm and LayerNorm execution times across dimensions from 256 to 4096.
RMSNorm speedup over LayerNorm across different hidden dimensions. The efficiency gain is relatively consistent, with RMSNorm typically being 10-30% faster depending on the dimension and hardware.

The speedup varies depending on hardware and implementation details. On GPUs with optimized kernels, the speedup is typically 5-15% for RMSNorm. While this might seem modest, it adds up significantly in large models where normalization is applied at every layer.

Parameter Efficiency

Beyond computational cost, RMSNorm also reduces the number of parameters. LayerNorm has two learnable vectors per layer (γ\gamma and β\beta), while RMSNorm has only one (γ\gamma).

In[15]:
Code
def count_norm_parameters(hidden_dim, num_layers, norm_type="rmsnorm"):
    """Count normalization parameters in a transformer."""
    params_per_layer = {
        "layernorm": 2 * hidden_dim,  # gamma and beta
        "rmsnorm": hidden_dim,  # gamma only
    }
    return params_per_layer[norm_type] * num_layers


# Compare for different model sizes
model_configs = [
    ("GPT-2 Small", 768, 12),
    ("GPT-2 Medium", 1024, 24),
    ("GPT-2 Large", 1280, 36),
    ("LLaMA 7B", 4096, 32),
    ("LLaMA 13B", 5120, 40),
]
Out[16]:
Visualization
Horizontal bar chart comparing LayerNorm and RMSNorm parameter counts across model sizes from GPT-2 Small to LLaMA 13B.
Normalization parameter counts for different model sizes. RMSNorm reduces parameters by 50% compared to LayerNorm since it has only gamma (no beta). For LLaMA 7B, this saves over 500K parameters per model.

The parameter savings scale with model size. For LLaMA 7B, switching from LayerNorm to RMSNorm saves over 500,000 parameters. While this is less than 0.01% of the total model size, these parameters also consume memory bandwidth during inference and require gradient computation during training. Every reduction helps.

Gradient Flow Through RMSNorm

Understanding the backward pass helps us see why RMSNorm is computationally cheaper. The gradient computation for RMSNorm is simpler because it doesn't need to backpropagate through the mean computation.

Consider the forward pass where each output element yiy_i is computed as:

yi=γixiRMS(x)y_i = \gamma_i \cdot \frac{x_i}{\text{RMS}(\mathbf{x})}

where:

  • yiy_i: the ii-th element of the output
  • γi\gamma_i: the ii-th learnable scale parameter
  • xix_i: the ii-th element of the input
  • RMS(x)\text{RMS}(\mathbf{x}): the root mean square computed over all input elements

To compute gradients, we need to determine how each input element xjx_j affects each output element yiy_i. This requires the partial derivative:

yixj=γixj(xiRMS(x))\frac{\partial y_i}{\partial x_j} = \gamma_i \cdot \frac{\partial}{\partial x_j}\left(\frac{x_i}{\text{RMS}(\mathbf{x})}\right)

For notational convenience, let r=RMS(x)=1dkxk2r = \text{RMS}(\mathbf{x}) = \sqrt{\frac{1}{d}\sum_k x_k^2}. We first need the derivative of rr with respect to xjx_j. Using the chain rule on the square root and sum:

rxj=xj(1dkxk2)1/2=12(1dkxk2)1/22xjd=xjdr\frac{\partial r}{\partial x_j} = \frac{\partial}{\partial x_j}\left(\frac{1}{d}\sum_k x_k^2\right)^{1/2} = \frac{1}{2}\left(\frac{1}{d}\sum_k x_k^2\right)^{-1/2} \cdot \frac{2x_j}{d} = \frac{x_j}{d \cdot r}

where:

  • rxj\frac{\partial r}{\partial x_j}: how much the RMS changes when xjx_j changes
  • xjx_j: the jj-th input element (the one we're differentiating with respect to)
  • dd: the dimension of the input vector
  • rr: the RMS value (appears in the denominator because of the square root derivative)

Now we apply the quotient rule to xir\frac{x_i}{r}. The quotient rule states that ddx(uv)=uvuvv2\frac{d}{dx}\left(\frac{u}{v}\right) = \frac{u'v - uv'}{v^2}:

xj(xir)=δijrxixjdrr2=δijrxixjdr3\frac{\partial}{\partial x_j}\left(\frac{x_i}{r}\right) = \frac{\delta_{ij} \cdot r - x_i \cdot \frac{x_j}{d \cdot r}}{r^2} = \frac{\delta_{ij}}{r} - \frac{x_i x_j}{d \cdot r^3}

where:

  • δij\delta_{ij}: the Kronecker delta, equal to 1 if i=ji = j and 0 otherwise
  • The first term δijr\frac{\delta_{ij}}{r}: the direct effect when i=ji = j (changing xjx_j directly changes the numerator xix_i)
  • The second term xixjdr3\frac{x_i x_j}{d \cdot r^3}: the indirect effect through the RMS in the denominator (changing any xjx_j affects the RMS, which affects all outputs)

To compute the gradient of a loss LL with respect to input xjx_j, we apply the chain rule, summing over all output elements:

Lxj=iLyiγi(δijrxixjdr3)\frac{\partial L}{\partial x_j} = \sum_i \frac{\partial L}{\partial y_i} \cdot \gamma_i \cdot \left(\frac{\delta_{ij}}{r} - \frac{x_i x_j}{d \cdot r^3}\right)

The Kronecker delta δij\delta_{ij} selects only the i=ji = j term from the sum for the first part, giving us:

Lxj=γjrLyjxjdr3iγixiLyi\frac{\partial L}{\partial x_j} = \frac{\gamma_j}{r} \cdot \frac{\partial L}{\partial y_j} - \frac{x_j}{d \cdot r^3} \sum_i \gamma_i \cdot x_i \cdot \frac{\partial L}{\partial y_i}

where:

  • Lyj\frac{\partial L}{\partial y_j}: the upstream gradient for output element jj
  • The first term: the direct gradient path from xjx_j through yjy_j
  • The second term: the indirect gradient path where xjx_j affects all outputs via the shared RMS denominator
  • The sum iγixiLyi\sum_i \gamma_i \cdot x_i \cdot \frac{\partial L}{\partial y_i}: aggregates the indirect effects across all output dimensions
In[17]:
Code
def rmsnorm_backward(dout, x, gamma, eps=1e-6):
    """
    Backward pass for RMSNorm.

    Args:
        dout: Upstream gradient, shape (..., d)
        x: Original input, shape (..., d)
        gamma: Scale parameter, shape (d,)
        eps: Numerical stability constant

    Returns:
        dx: Gradient with respect to input
        dgamma: Gradient with respect to scale
    """
    # Compute RMS from forward pass
    mean_sq = np.mean(x**2, axis=-1, keepdims=True)
    rms = np.sqrt(mean_sq + eps)

    # Normalized input
    x_norm = x / rms

    # Gradient for gamma
    dgamma = np.sum(dout * x_norm, axis=tuple(range(dout.ndim - 1)))

    # Gradient for x
    d = x.shape[-1]

    # Term 1: direct gradient through division
    dx_norm = dout * gamma / rms

    # Term 2: gradient through RMS (chain rule)
    # d(1/rms)/dx_j = -x_j / (d * rms^3)
    sum_term = np.sum(dout * gamma * x, axis=-1, keepdims=True)
    dx_rms = -sum_term * x / (d * rms**3)

    dx = dx_norm + dx_rms

    return dx, dgamma
In[18]:
Code
# Verify gradients numerically
def numerical_gradient(f, x, h=1e-5):
    """Compute numerical gradient using central difference."""
    grad = np.zeros_like(x)
    it = np.nditer(x, flags=["multi_index"], op_flags=["readwrite"])
    while not it.finished:
        idx = it.multi_index
        old_val = x[idx]

        x[idx] = old_val + h
        fxph = f(x.copy())

        x[idx] = old_val - h
        fxmh = f(x.copy())

        grad[idx] = (fxph - fxmh) / (2 * h)
        x[idx] = old_val
        it.iternext()

    return grad


# Test setup
np.random.seed(42)
d_test = 8
x_test = np.random.randn(2, d_test)
gamma_test = np.random.randn(d_test)
dout_test = np.random.randn(2, d_test)


def forward_scalar(x):
    """Forward pass returning scalar loss."""
    rms = np.sqrt(np.mean(x**2, axis=-1, keepdims=True) + 1e-6)
    y = gamma_test * (x / rms)
    return np.sum(y * dout_test)


# Analytical gradient
dx_analytical, _ = rmsnorm_backward(dout_test, x_test.copy(), gamma_test)

# Numerical gradient
dx_numerical = numerical_gradient(forward_scalar, x_test.copy())
Out[19]:
Console
Gradient verification:
  Max relative error: 1.88e-08
  Gradient check passed

The gradient check confirms our analytical backward pass implementation is correct. A relative error below 10410^{-4} indicates that the analytical and numerical gradients agree to high precision, validating both the mathematical derivation and the code implementation.

Out[20]:
Visualization
Heatmap of the Jacobian matrix for RMSNorm showing diagonal dominance with subtle off-diagonal coupling through the shared RMS normalization.
Jacobian matrix showing how each output gradient depends on each input element. The diagonal shows the direct gradient path (term 1), while off-diagonal elements show the indirect coupling through the shared RMS denominator (term 2). The coupling is subtle but present, ensuring gradients flow to all inputs.

RMSNorm vs LayerNorm: Empirical Comparison

The theoretical efficiency of RMSNorm is clear, but does it maintain model quality? Let's compare the two normalizations on a simple task to see their behavior during training.

In[21]:
Code
import torch
import torch.nn as nn


class TorchRMSNorm(nn.Module):
    """RMSNorm implemented in PyTorch."""

    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)
        return self.weight * (x / rms)


class SimpleTransformerBlock(nn.Module):
    """Simplified transformer block for comparison."""

    def __init__(self, dim, norm_type="rmsnorm"):
        super().__init__()

        # Choose normalization
        if norm_type == "rmsnorm":
            self.norm1 = TorchRMSNorm(dim)
            self.norm2 = TorchRMSNorm(dim)
        else:
            self.norm1 = nn.LayerNorm(dim)
            self.norm2 = nn.LayerNorm(dim)

        # Simple self-attention substitute (linear projection)
        self.attn = nn.Linear(dim, dim)

        # Feed-forward network
        self.ffn = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Linear(dim * 4, dim),
        )

    def forward(self, x):
        # Pre-norm architecture
        x = x + self.attn(self.norm1(x))
        x = x + self.ffn(self.norm2(x))
        return x


class TinyTransformer(nn.Module):
    """Small transformer for testing normalizations."""

    def __init__(self, vocab_size, dim, n_layers, norm_type="rmsnorm"):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, dim)
        self.blocks = nn.ModuleList(
            [SimpleTransformerBlock(dim, norm_type) for _ in range(n_layers)]
        )

        if norm_type == "rmsnorm":
            self.final_norm = TorchRMSNorm(dim)
        else:
            self.final_norm = nn.LayerNorm(dim)

        self.output = nn.Linear(dim, vocab_size)

    def forward(self, x):
        x = self.embedding(x)
        for block in self.blocks:
            x = block(x)
        x = self.final_norm(x)
        return self.output(x)
In[22]:
Code
# Training comparison
def train_model(model, data, targets, epochs=100, lr=1e-3):
    """Train model and return loss history."""
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    losses = []
    for epoch in range(epochs):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output.view(-1, output.size(-1)), targets.view(-1))
        loss.backward()
        optimizer.step()
        losses.append(loss.item())

    return losses


# Create synthetic data
torch.manual_seed(42)
vocab_size = 100
seq_len = 32
batch_size = 16
dim = 64
n_layers = 4

data = torch.randint(0, vocab_size, (batch_size, seq_len))
targets = torch.randint(0, vocab_size, (batch_size, seq_len))

# Train both models
torch.manual_seed(42)
model_rms = TinyTransformer(vocab_size, dim, n_layers, "rmsnorm")
losses_rms = train_model(model_rms, data, targets)

torch.manual_seed(42)
model_ln = TinyTransformer(vocab_size, dim, n_layers, "layernorm")
losses_ln = train_model(model_ln, data, targets)
Out[23]:
Visualization
Line plot showing training loss curves for RMSNorm and LayerNorm converging to similar values over 100 epochs.
Training loss comparison between RMSNorm and LayerNorm on a small transformer model. Both normalizations achieve similar convergence, validating that RMSNorm's simpler formulation doesn't sacrifice model quality.
Out[24]:
Console
Final training loss comparison:
  RMSNorm:   1.8955
  LayerNorm: 1.8958
  Difference: 0.0003

Both normalizations converge to similar loss values, confirming that RMSNorm doesn't sacrifice model quality for efficiency. In larger-scale experiments on language modeling benchmarks, RMSNorm has been shown to match or slightly exceed LayerNorm performance.

RMSNorm in Modern Architectures

RMSNorm has become the normalization of choice for most modern large language models. Let's examine how it's typically used in practice.

LLaMA Architecture

The LLaMA family of models uses RMSNorm with a specific configuration:

In[25]:
Code
class LLaMAStyleRMSNorm(nn.Module):
    """
    RMSNorm as used in LLaMA models.

    Key differences from basic implementation:
    - Uses float32 for normalization even with mixed precision
    - Specific epsilon value
    """

    def __init__(self, dim, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        # Store original dtype for mixed precision training
        input_dtype = x.dtype

        # Compute in float32 for numerical stability
        x = x.float()

        # Compute RMS
        rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)

        # Normalize and scale
        x = x / rms

        # Apply weight and cast back to original dtype
        return (self.weight * x).to(input_dtype)

The key insight is computing normalization in float32 even when the model uses float16 or bfloat16 for other operations. This prevents numerical instability from small RMS values in reduced precision.

Pre-Norm Placement

Modern transformers use RMSNorm in a pre-normalization configuration, applying it before each sub-layer rather than after:

In[26]:
Code
class ModernTransformerBlock(nn.Module):
    """
    Transformer block with pre-RMSNorm architecture.

    This is the standard configuration for LLaMA, Mistral, etc.
    """

    def __init__(self, dim, n_heads, ffn_mult=4):
        super().__init__()

        # Pre-normalization for attention
        self.attention_norm = LLaMAStyleRMSNorm(dim)

        # Self-attention (simplified)
        self.attention = nn.MultiheadAttention(dim, n_heads, batch_first=True)

        # Pre-normalization for FFN
        self.ffn_norm = LLaMAStyleRMSNorm(dim)

        # Feed-forward network with SwiGLU (simplified to GELU)
        hidden_dim = int(dim * ffn_mult)
        self.ffn = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim),
        )

    def forward(self, x):
        # Attention with pre-norm
        normed = self.attention_norm(x)
        attn_out, _ = self.attention(normed, normed, normed)
        x = x + attn_out

        # FFN with pre-norm
        normed = self.ffn_norm(x)
        x = x + self.ffn(normed)

        return x

The pre-norm placement ensures that inputs to each sub-layer are normalized, stabilizing gradients throughout the network. This has become standard practice after research showed it improves training stability, especially for very deep models.

Limitations and Considerations

RMSNorm's simplicity comes with trade-offs that are worth understanding.

When mean centering matters. If the input distribution has a significant non-zero mean, RMSNorm and LayerNorm produce different outputs. In most transformer architectures with proper initialization and residual connections, activations remain approximately centered, making this difference negligible. However, if you're applying RMSNorm to data with known biases (like all-positive image features), the lack of centering could affect downstream layers.

Interaction with other components. RMSNorm's lack of a β\beta bias parameter means it can't shift the output distribution. In pre-norm architectures, this is fine because the subsequent linear layers can learn any necessary bias. But in post-norm configurations or when RMSNorm is the final layer before output, you might need an explicit bias term elsewhere.

Numerical precision. RMSNorm divides by a single scalar (the RMS), while LayerNorm divides by the standard deviation after subtracting the mean. In rare cases with extreme values, RMSNorm can be less stable because the RMS includes the squared mean contribution. Modern implementations address this by computing in float32 even for mixed-precision training.

Not a drop-in replacement. While RMSNorm can replace LayerNorm in most architectures, models pre-trained with LayerNorm should not have their normalization layers swapped without retraining. The learned γ\gamma parameters adapt to the specific normalization behavior.

PyTorch Production Implementation

For production use, here's a robust implementation that handles edge cases:

In[27]:
Code
class ProductionRMSNorm(nn.Module):
    """
    Production-ready RMSNorm with all optimizations.
    """

    def __init__(self, dim, eps=1e-6):
        """
        Initialize RMSNorm.

        Args:
            dim: Feature dimension to normalize over
            eps: Epsilon for numerical stability
        """
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        """Compute RMS normalization."""
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        """
        Apply RMSNorm.

        Handles mixed precision by computing norm in float32.
        """
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

    def extra_repr(self):
        return f"dim={self.weight.shape[0]}, eps={self.eps}"
Out[28]:
Console
Production RMSNorm test:
  Input shape:  (4, 32, 256)
  Output shape: (4, 32, 256)
  Output RMS:   1.0000
  Layer repr:   ProductionRMSNorm(dim=256, eps=1e-06)

The output RMS is close to 1.0, confirming the normalization is working correctly. The shape is preserved, and the layer's extra_repr shows the configured dimension and epsilon value.

The torch.rsqrt function computes 1/x1/\sqrt{x} in a single operation, which is faster than separate division and square root. The type_as call ensures the output matches the input dtype for mixed precision training.

Key Parameters

When implementing or using RMSNorm, several parameters control its behavior:

  • dim: The feature dimension to normalize over. This should match the hidden dimension of your model. For transformer models, this is typically the embedding dimension (e.g., 768 for BERT-base, 4096 for LLaMA-7B).

  • eps (default: 10610^{-6} to 10510^{-5}): A small constant added to the RMS before division for numerical stability. Prevents division by zero when all input values are near zero. LLaMA uses 10510^{-5}, while some implementations use 10610^{-6}. Smaller values provide more precision but risk numerical instability in low-precision training.

  • weight (γ\gamma): The learnable scale parameter, a vector of dimension dim. Initialized to ones, allowing the network to learn per-feature scaling. Unlike LayerNorm, RMSNorm has no bias (β\beta) parameter.

  • Computation dtype: For mixed-precision training (float16 or bfloat16), compute the normalization in float32 before casting back to the original dtype. This prevents numerical issues from small RMS values in reduced precision.

  • Placement: Modern architectures use pre-normalization, applying RMSNorm before each attention and feed-forward sub-layer rather than after. This improves gradient flow and training stability.

Summary

RMSNorm simplifies layer normalization by removing mean centering, keeping only the scaling operation based on the root mean square of the input. This reduction provides computational savings and parameter efficiency while maintaining model quality.

Key takeaways from this chapter:

  • RMS vs standard deviation: When inputs are centered around zero, RMS approximately equals the standard deviation. The relationship RMS2=σ2+μ2\text{RMS}^2 = \sigma^2 + \mu^2 shows they're equivalent when μ=0\mu = 0.

  • Computational efficiency: RMSNorm eliminates the mean computation and subtraction, plus removes the β\beta bias parameter. This typically provides 5-15% speedup on GPUs.

  • Parameter efficiency: With only γ\gamma instead of both γ\gamma and β\beta, RMSNorm halves the normalization parameters per layer.

  • Modern adoption: LLaMA, Mistral, and most contemporary LLMs use RMSNorm as their standard normalization layer, demonstrating its effectiveness at scale.

  • Implementation details: Production implementations compute normalization in float32 for numerical stability and use rsqrt for efficiency.

  • Pre-norm placement: RMSNorm is typically used in pre-normalization position, applied before each attention and feed-forward sub-layer.

The success of RMSNorm illustrates a broader principle in deep learning: simpler can be better. By questioning whether mean centering was truly necessary, researchers discovered that models could learn to work without it, gaining efficiency in the process.

Quiz

Ready to test your understanding? Take this quick quiz to reinforce what you've learned about RMSNorm and its relationship to LayerNorm.

Loading component...
Track your reading progress

Sign in to mark chapters as read and track your learning journey

Sign in →

Comments

Reference

BIBTEXAcademic
@misc{rmsnormefficientnormalizationformodernllms, author = {Michael Brenndoerfer}, title = {RMSNorm: Efficient Normalization for Modern LLMs}, year = {2025}, url = {https://mbrenndoerfer.com/writing/rmsnorm-efficient-normalization-modern-llms}, organization = {mbrenndoerfer.com}, note = {Accessed: 2025-12-19} }
APAAcademic
Michael Brenndoerfer (2025). RMSNorm: Efficient Normalization for Modern LLMs. Retrieved from https://mbrenndoerfer.com/writing/rmsnorm-efficient-normalization-modern-llms
MLAAcademic
Michael Brenndoerfer. "RMSNorm: Efficient Normalization for Modern LLMs." 2025. Web. 12/19/2025. <https://mbrenndoerfer.com/writing/rmsnorm-efficient-normalization-modern-llms>.
CHICAGOAcademic
Michael Brenndoerfer. "RMSNorm: Efficient Normalization for Modern LLMs." Accessed 12/19/2025. https://mbrenndoerfer.com/writing/rmsnorm-efficient-normalization-modern-llms.
HARVARDAcademic
Michael Brenndoerfer (2025) 'RMSNorm: Efficient Normalization for Modern LLMs'. Available at: https://mbrenndoerfer.com/writing/rmsnorm-efficient-normalization-modern-llms (Accessed: 12/19/2025).
SimpleBasic
Michael Brenndoerfer (2025). RMSNorm: Efficient Normalization for Modern LLMs. https://mbrenndoerfer.com/writing/rmsnorm-efficient-normalization-modern-llms
Michael Brenndoerfer

About the author: Michael Brenndoerfer

All opinions expressed here are my own and do not reflect the views of my employer.

Michael currently works as an Associate Director of Data Science at EQT Partners in Singapore, leading AI and data initiatives across private capital investments.

With over a decade of experience spanning private equity, management consulting, and software engineering, he specializes in building and scaling analytics capabilities from the ground up. He has published research in leading AI conferences and holds expertise in machine learning, natural language processing, and value creation through data.

Stay updated

Get notified when I publish new articles on data and AI, private equity, technology, and more.

No spam, unsubscribe anytime.

or

Create a free account to unlock exclusive features, track your progress, and join the conversation.

No popupsUnobstructed readingCommenting100% Free