Search

Search articles

AdamW Optimizer: Decoupled Weight Decay for Deep Learning

Michael BrenndoerferDecember 15, 202527 min read

Master AdamW optimization, the default choice for training transformers and LLMs. Learn why L2 regularization fails with Adam and how decoupled weight decay fixes it.

Reading Level

Choose your expertise level to adjust how many terms are explained. Beginners see more tooltips, experts see fewer to maintain reading flow. Hover over underlined terms for instant definitions.

AdamW

Adam transformed deep learning optimization by combining momentum with adaptive learning rates. Yet researchers noticed something peculiar: the regularization techniques that worked beautifully with SGD seemed less effective with Adam. Models trained with Adam often generalized worse than those trained with vanilla SGD plus weight decay.

The culprit? A subtle but critical difference between L2 regularization and weight decay. For most optimizers, these two techniques are mathematically equivalent. But Adam's adaptive learning rates break this equivalence, causing L2 regularization to behave in unexpected ways. AdamW, introduced by Loshchilov and Hutter in 2017, fixes this problem by decoupling weight decay from the gradient-based update. The result is an optimizer that combines Adam's fast convergence with proper regularization, now the default choice for training transformers and large language models.

L2 Regularization vs Weight Decay

Before diving into AdamW, we need to understand the subtle distinction between L2 regularization and weight decay. These terms are often used interchangeably, but they represent different approaches to the same goal: preventing overfitting by penalizing large weights.

L2 Regularization: Modifying the Loss

L2 regularization adds a penalty term to the loss function based on the squared magnitude of the weights:

Lreg=L+λ2w2\mathcal{L}_{\text{reg}} = \mathcal{L} + \frac{\lambda}{2} \|w\|^2

where:

  • L\mathcal{L}: the original loss function (such as cross-entropy or mean squared error)
  • Lreg\mathcal{L}_{\text{reg}}: the regularized loss that we actually minimize
  • ww: the weight vector containing all trainable parameters in the network
  • w2=iwi2\|w\|^2 = \sum_i w_i^2: the squared L2 norm of the weights
  • λ\lambda: the regularization strength hyperparameter (larger values penalize large weights more heavily)
  • 12\frac{1}{2}: a convenience factor that simplifies the gradient calculation

When we compute the gradient of this regularized loss with respect to the weights, the chain rule gives us:

wLreg=wL+λw\nabla_w \mathcal{L}_{\text{reg}} = \nabla_w \mathcal{L} + \lambda w

where:

  • wL\nabla_w \mathcal{L}: the gradient of the original loss with respect to weights
  • λw\lambda w: the regularization gradient, which points in the direction of the current weights

The regularization contributes an additional term λw\lambda w to the gradient. This means larger weights produce larger gradients, pushing the optimizer to shrink them during training.

Weight Decay: Modifying the Update Rule

Weight decay takes a different approach. Instead of modifying the loss function, it directly modifies the weight update:

wt+1=wtηwLηλwtw_{t+1} = w_t - \eta \nabla_w \mathcal{L} - \eta \lambda w_t

where:

  • wtw_t: the weight vector at time step tt
  • wt+1w_{t+1}: the updated weight vector after one optimization step
  • η\eta: the learning rate
  • wL\nabla_w \mathcal{L}: the gradient of the loss with respect to weights
  • λ\lambda: the weight decay coefficient
  • ηλwt\eta \lambda w_t: the decay term that shrinks weights toward zero

We can factor this equation to see the decay more explicitly:

wt+1=(1ηλ)wtηwLw_{t+1} = (1 - \eta \lambda) w_t - \eta \nabla_w \mathcal{L}

The factor (1ηλ)(1 - \eta \lambda) is slightly less than one, so at each step we multiply the weights by this factor before applying the gradient update. This causes weights to "decay" toward zero over time. The hyperparameter λ\lambda controls how quickly this decay happens.

Equivalence with SGD

For standard SGD, these two formulations produce identical updates. With L2 regularization, the SGD update becomes:

wt+1=wtη(wL+λwt)=wtηwLηλwtw_{t+1} = w_t - \eta (\nabla_w \mathcal{L} + \lambda w_t) = w_t - \eta \nabla_w \mathcal{L} - \eta \lambda w_t

The left side shows the update using the regularized gradient (wL+λwt)(\nabla_w \mathcal{L} + \lambda w_t), while the right side expands to match the weight decay formulation exactly. Because SGD applies the same learning rate η\eta uniformly to all gradient components, the regularization strength λ\lambda has the same effect in both cases. This mathematical equivalence is why practitioners historically treated the terms as synonyms.

Why Adam Breaks the Equivalence

Adam's adaptive learning rates fundamentally change the relationship between L2 regularization and weight decay. To see why, recall Adam's update rule:

wt+1=wtηv^t+ϵm^tw_{t+1} = w_t - \frac{\eta}{\sqrt{\hat{v}_t} + \epsilon} \hat{m}_t

where:

  • wtw_t: the weight vector at time step tt
  • η\eta: the base learning rate
  • m^t\hat{m}_t: the bias-corrected first moment estimate (exponential moving average of gradients)
  • v^t\hat{v}_t: the bias-corrected second moment estimate (exponential moving average of squared gradients)
  • ϵ\epsilon: a small constant for numerical stability (typically 10810^{-8})
  • ηv^t+ϵ\frac{\eta}{\sqrt{\hat{v}_t} + \epsilon}: the effective learning rate, which adapts per-parameter based on gradient history

The key insight is that each parameter gets divided by v^t\sqrt{\hat{v}_t}, which depends on the history of gradients for that parameter. Parameters with consistently large gradients have large v^t\hat{v}_t values, which reduces their effective learning rate.

L2 Regularization with Adam

When we use L2 regularization with Adam, the gradient becomes wL+λw\nabla_w \mathcal{L} + \lambda w. This regularization term enters Adam through both moment estimates:

mt=β1mt1+(1β1)(wL+λw)m_t = \beta_1 m_{t-1} + (1 - \beta_1)(\nabla_w \mathcal{L} + \lambda w) vt=β2vt1+(1β2)(wL+λw)2v_t = \beta_2 v_{t-1} + (1 - \beta_2)(\nabla_w \mathcal{L} + \lambda w)^2

where:

  • mtm_t: the first moment estimate (momentum), now incorporating the regularization term
  • vtv_t: the second moment estimate, now incorporating the squared regularization term
  • β1\beta_1: the exponential decay rate for the first moment (typically 0.9)
  • β2\beta_2: the exponential decay rate for the second moment (typically 0.999)
  • λw\lambda w: the L2 regularization gradient added to the loss gradient

The problem lies in the second moment vtv_t. For parameters with large weights, the term (wL+λw)2(\nabla_w \mathcal{L} + \lambda w)^2 increases vtv_t, which in turn decreases the effective learning rate ηv^t+ϵ\frac{\eta}{\sqrt{\hat{v}_t} + \epsilon} for those parameters.

This creates a problematic feedback loop: L2 regularization is supposed to push large weights toward zero, but Adam's adaptation reduces the update magnitude for parameters that receive large gradients. The regularization signal gets dampened precisely where it should be strongest.

Weight Decay with Adam

True weight decay sidesteps Adam's adaptive mechanism entirely. Instead of modifying the gradient, we apply decay directly to the weights after the Adam update:

wt+1=wtηv^t+ϵm^tηλwtw_{t+1} = w_t - \frac{\eta}{\sqrt{\hat{v}_t} + \epsilon} \hat{m}_t - \eta \lambda w_t

This equation has two distinct terms after wtw_t:

  • ηv^t+ϵm^t\frac{\eta}{\sqrt{\hat{v}_t} + \epsilon} \hat{m}_t: the standard Adam update, which adapts to gradient history
  • ηλwt\eta \lambda w_t: the weight decay term, applied directly without passing through Adam's moment estimates

The decay term ηλwt\eta \lambda w_t operates independently of Adam's gradient history. Large weights decay at the expected rate regardless of their gradient patterns, because the decay is not scaled by the adaptive learning rate 1v^t+ϵ\frac{1}{\sqrt{\hat{v}_t} + \epsilon}.

Decoupled Weight Decay

The term "decoupled" in AdamW refers to separating the weight decay from the gradient-based update mechanism. The decay is applied to the weights directly, not through the gradient path that feeds into Adam's moment estimates.

The AdamW Algorithm

Now that we understand why L2 regularization fails with Adam and how decoupling weight decay solves this problem, let's formalize the complete AdamW algorithm. The key insight to keep in mind: we want to preserve everything that makes Adam effective, such as momentum, adaptive learning rates, and bias correction, while ensuring that regularization operates independently of the gradient-based updates.

Think of AdamW as running two parallel processes. The first process is pure Adam: it tracks gradient history, adapts learning rates per parameter, and updates weights based on this accumulated knowledge. The second process is pure weight decay: it shrinks all weights toward zero at a constant rate, completely ignoring what the gradients are doing. The magic happens because these two processes don't interfere with each other.

Building the Algorithm Step by Step

Let's construct AdamW from first principles, understanding what each component contributes.

Step 1: Compute the gradient. We start with the gradient of the unregularized loss. This is the crucial departure from L2 regularization: we don't add any regularization term here.

gt=wL(wt)g_t = \nabla_w \mathcal{L}(w_t)

where gtg_t is the gradient at time step tt and L(wt)\mathcal{L}(w_t) is the original loss function. By keeping the gradient pure, we ensure that Adam's moment estimates reflect only the loss landscape, not the regularization penalty.

Step 2: Update the first moment estimate. The first moment mtm_t tracks the exponential moving average of gradients. This is the "momentum" component that helps smooth out noisy gradients and accelerate convergence along consistent directions.

mt=β1mt1+(1β1)gtm_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t

The hyperparameter β1\beta_1 (typically 0.9) controls how much history to retain. With β1=0.9\beta_1 = 0.9, each update blends 90% of the previous estimate with 10% of the new gradient. We initialize m0=0m_0 = 0, which creates a bias we'll correct later.

Step 3: Update the second moment estimate. The second moment vtv_t tracks the exponential moving average of squared gradients. This is what gives Adam its adaptive learning rate: parameters with historically large gradients get smaller updates.

vt=β2vt1+(1β2)gt2v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2

Here gt2g_t^2 denotes element-wise squaring, and β2\beta_2 (typically 0.999) is chosen larger than β1\beta_1 because we want the variance estimate to be more stable. The squaring means vtv_t is always positive, which allows us to use it for scaling.

Step 4: Apply bias correction. Because we initialize both moment estimates at zero, early estimates are biased toward zero. Bias correction compensates for this:

m^t=mt1β1t,v^t=vt1β2t\hat{m}_t = \frac{m_t}{1 - \beta_1^t}, \quad \hat{v}_t = \frac{v_t}{1 - \beta_2^t}

The correction factors (1β1t)(1 - \beta_1^t) and (1β2t)(1 - \beta_2^t) start small and approach 1 as training progresses. At t=1t = 1 with β1=0.9\beta_1 = 0.9, we divide by 0.1, effectively scaling up the first estimate by 10x. By t=100t = 100, the correction is negligible.

Out[2]:
Visualization
Font 'default' does not have a glyph for '\u2081' [U+2081], substituting with a dummy symbol.
Font 'default' does not have a glyph for '\u2082' [U+2082], substituting with a dummy symbol.
Font 'default' does not have a glyph for '\u2081' [U+2081], substituting with a dummy symbol.
Font 'default' does not have a glyph for '\u2082' [U+2082], substituting with a dummy symbol.
Line plot showing two curves converging to 1, with the first moment correction converging faster than the second moment.
Evolution of bias correction factors over training steps. The first moment correction (β₁ = 0.9) converges quickly within about 50 steps, while the second moment correction (β₂ = 0.999) takes much longer due to the higher decay rate. Early in training, these corrections dramatically increase the effective learning rate to compensate for zero initialization.

This visualization reveals an important asymmetry. The first moment correction converges rapidly because β1=0.9\beta_1 = 0.9 means 90% retention per step. After just 22 steps, (10.922)>0.9(1 - 0.9^{22}) > 0.9, so the correction is already small. The second moment correction with β2=0.999\beta_2 = 0.999 takes much longer: we need nearly 700 steps before (10.999700)>0.5(1 - 0.999^{700}) > 0.5. This slower convergence is intentional, as it provides a more stable variance estimate by incorporating more history.

Step 5: Update weights with decoupled decay. Finally, we perform both the Adam update and weight decay in a single step:

wt+1=wtη(m^tv^t+ϵ+λwt)w_{t+1} = w_t - \eta \left( \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} + \lambda w_t \right)

This equation combines two independent forces acting on each weight:

  1. The adaptive gradient step m^tv^t+ϵ\frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}: This is pure Adam. The bias-corrected momentum m^t\hat{m}_t tells us which direction to move, while dividing by v^t\sqrt{\hat{v}_t} scales the step size inversely with gradient magnitude.

  2. The decoupled weight decay term λwt\lambda w_t: This shrinks weights toward zero at rate λ\lambda, applied directly without any scaling by gradient history.

The small constant ϵ\epsilon (typically 10810^{-8}) prevents division by zero when v^t\hat{v}_t is very small.

The Multiplicative View

An equivalent formulation makes the decay mechanism more explicit:

wt+1=(1ηλ)wtηm^tv^t+ϵw_{t+1} = (1 - \eta \lambda) w_t - \frac{\eta \hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}

Here you can see weight decay as a multiplicative factor: before applying the gradient update, we shrink all weights by (1ηλ)(1 - \eta \lambda). For typical values like η=103\eta = 10^{-3} and λ=0.01\lambda = 0.01, this factor equals 0.99999. Each individual step produces barely perceptible decay, but over thousands of steps, the cumulative effect is substantial.

Out[3]:
Visualization
Line plot showing exponential decay curves for different weight decay coefficients over 20000 steps.
Cumulative effect of weight decay over training steps for different decay coefficients. Without gradient updates pushing back, weights shrink exponentially. Stronger decay (λ = 0.1) reduces weights to 37% of original after 10,000 steps, while weaker decay (λ = 0.01) preserves 90%. In practice, gradients counteract this decay for important weights.

This plot shows why weight decay values in the range 0.01 to 0.1 are typical. With λ=0.01\lambda = 0.01, weights decay slowly enough that important features are preserved, but unnecessary weights still shrink over the course of training. With λ=0.1\lambda = 0.1, the decay is aggressive, which can help with heavily overparameterized models but may hurt performance if set too high. The key insight is that gradients continually push back against this decay for weights that are important for minimizing the loss, creating a natural equilibrium.

Choosing the Weight Decay Coefficient

The weight decay coefficient λ\lambda requires careful tuning. Unlike learning rate, which has relatively universal starting points, optimal weight decay varies significantly across architectures and datasets.

Typical Ranges

For transformer models, weight decay values typically fall between 0.01 and 0.1:

  • BERT and variants: Originally trained with λ=0.01\lambda = 0.01
  • GPT-2 and GPT-3: Used λ=0.1\lambda = 0.1
  • Vision Transformers: Often use λ=0.05\lambda = 0.05 to 0.30.3
  • ResNets with AdamW: Commonly use λ=0.01\lambda = 0.01 to 0.050.05

Larger models often benefit from stronger regularization. This makes intuitive sense: with more parameters comes more capacity for overfitting.

Interaction with Learning Rate

Weight decay and learning rate interact because the effective decay per step is ηλ\eta \lambda. When tuning learning rate with a learning rate schedule, you have two choices:

  • Keep λ\lambda fixed: The effective decay decreases as learning rate decays. This is the standard approach and generally works well.
  • Scale λ\lambda with learning rate: Maintains constant effective decay throughout training. Some practitioners prefer this for very long training runs.

Most frameworks use fixed λ\lambda, and the decreasing effective decay late in training can actually help the model settle into sharper minima.

What to Exclude from Weight Decay

Not all parameters should receive weight decay. The standard practice is to exclude:

  • Bias terms: These don't contribute to overfitting in the same way as weights
  • Layer normalization parameters: Both scale (γ\gamma) and shift (β\beta) parameters
  • Embedding layers: Sometimes excluded, though practices vary

Let's see how to implement this in PyTorch:

In[4]:
Code
import torch.nn as nn

# Example model
model = nn.Sequential(
    nn.Linear(768, 512), nn.LayerNorm(512), nn.ReLU(), nn.Linear(512, 10)
)

# Separate parameters into decay and no-decay groups
decay_params = []
no_decay_params = []

for name, param in model.named_parameters():
    if "bias" in name or "LayerNorm" in name or "layernorm" in name:
        no_decay_params.append(param)
    else:
        decay_params.append(param)

param_groups = [
    {"params": decay_params, "weight_decay": 0.01},
    {"params": no_decay_params, "weight_decay": 0.0},
]
Out[5]:
Console
Parameters with weight decay: 398,848
Parameters without weight decay: 1,034

The weight matrices contain most of the parameters and receive regularization, while biases and normalization parameters remain unregularized. This split is now standard practice in transformer training.

Implementing AdamW from Scratch

The best way to internalize an algorithm is to implement it yourself. Let's build AdamW from scratch using only NumPy, translating each mathematical step into code. This exercise will reveal how elegantly simple the algorithm is once you understand the underlying concepts.

Our implementation needs to maintain state across optimization steps: the moment estimates mm and vv for each parameter, plus a step counter tt for bias correction. We'll structure this as a class that mirrors how production optimizers work.

In[6]:
Code
import numpy as np


class AdamW:
    def __init__(
        self, params, lr=0.001, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01
    ):
        self.params = list(params)
        self.lr = lr
        self.beta1, self.beta2 = betas
        self.eps = eps
        self.weight_decay = weight_decay

        # Initialize moment estimates
        self.m = [np.zeros_like(p) for p in self.params]
        self.v = [np.zeros_like(p) for p in self.params]
        self.t = 0

    def step(self, grads):
        self.t += 1

        for i, (param, grad) in enumerate(zip(self.params, grads)):
            # Update biased first moment estimate
            self.m[i] = self.beta1 * self.m[i] + (1 - self.beta1) * grad

            # Update biased second moment estimate
            self.v[i] = self.beta2 * self.v[i] + (1 - self.beta2) * (grad**2)

            # Compute bias-corrected estimates
            m_hat = self.m[i] / (1 - self.beta1**self.t)
            v_hat = self.v[i] / (1 - self.beta2**self.t)

            # AdamW update: decoupled weight decay
            param -= self.lr * (
                m_hat / (np.sqrt(v_hat) + self.eps) + self.weight_decay * param
            )

Notice how the implementation follows our five-step algorithm exactly. The constructor initializes moment estimates to zero (creating the bias that we later correct). The step method increments the counter, updates both moments using exponential moving averages, applies bias correction, and finally performs the decoupled update.

The critical line is the final update: m_hat / (np.sqrt(v_hat) + self.eps) + self.weight_decay * param. We add the weight decay term after computing the adaptive gradient step, ensuring that decay operates independently of gradient history. If we had instead modified the gradient before computing moments (as L2 regularization does), the decay would be dampened by the adaptive mechanism.

Testing on a Simple Optimization Problem

Let's verify our implementation works correctly on a problem where we can visualize the optimization trajectory. We'll minimize f(x,y)=x2+10y2f(x, y) = x^2 + 10y^2, an elongated bowl that's a classic test for optimizers. The steeper curvature along yy means gradients are larger in that direction, which tests whether our adaptive learning rates work correctly.

In[7]:
Code
# Minimize f(x, y) = x^2 + 10*y^2 (elongated bowl)
# With regularization, the minimum shifts slightly from (0, 0)

np.random.seed(42)
params = [np.array([5.0]), np.array([5.0])]  # x and y as separate params
optimizer = AdamW(params, lr=0.1, weight_decay=0.01)

history = {"x": [], "y": [], "loss": []}

for step in range(100):
    x, y = params[0][0], params[1][0]

    # Compute gradients (df/dx = 2x, df/dy = 20y)
    grads = [np.array([2 * x]), np.array([20 * y])]

    # Record before update
    loss = x**2 + 10 * y**2
    history["x"].append(x)
    history["y"].append(y)
    history["loss"].append(loss)

    optimizer.step(grads)
Out[8]:
Console
Starting point: x=5.0000, y=5.0000
After 100 steps: x=-0.042891, y=-0.042891
Final loss: 0.02023593

Our implementation successfully drives both coordinates toward zero, with the loss decreasing by several orders of magnitude. The yy coordinate converges faster initially because its gradients are larger (the derivative fy=20y\frac{\partial f}{\partial y} = 20y is 10 times the derivative with respect to xx). However, Adam's adaptive mechanism compensates: larger gradients lead to larger second moment estimates, which reduce the effective learning rate along yy. This balancing act is precisely what makes Adam so effective on problems with different scales across dimensions.

The weight decay provides an additional gentle pull toward the origin. Even if the loss function had a minimum elsewhere, weight decay would bias the solution toward smaller weights, which is exactly the regularization behavior we want in neural network training.

Using PyTorch's AdamW

In practice, you'll use PyTorch's built-in AdamW implementation, which is highly optimized and supports features like gradient scaling for mixed-precision training:

In[9]:
Code
import torch
import torch.nn as nn
import torch.optim as optim

# Create a simple model
torch.manual_seed(42)
model = nn.Sequential(nn.Linear(10, 50), nn.ReLU(), nn.Linear(50, 1))

# Configure AdamW with parameter groups
optimizer = optim.AdamW(
    [
        {"params": model[0].weight, "weight_decay": 0.01},
        {"params": model[0].bias, "weight_decay": 0.0},
        {"params": model[2].weight, "weight_decay": 0.01},
        {"params": model[2].bias, "weight_decay": 0.0},
    ],
    lr=0.001,
)

# Generate synthetic data
X = torch.randn(100, 10)
y = torch.randn(100, 1)

# Training loop
losses = []
for epoch in range(50):
    optimizer.zero_grad()
    output = model(X)
    loss = nn.functional.mse_loss(output, y)
    loss.backward()
    optimizer.step()
    losses.append(loss.item())
Out[10]:
Console
Initial loss: 0.8916
Final loss: 0.6132
Reduction: 31.2%

The loss decreases steadily as AdamW optimizes the network. In real applications, you would also track validation loss to monitor generalization.

Visualizing the Difference: Adam vs AdamW

Theory tells us that L2 regularization and weight decay should behave differently with Adam, but seeing is believing. Let's run both approaches side by side on our elongated bowl problem and plot their optimization trajectories. This visualization will make the abstract mathematical difference concrete.

We'll implement Adam with L2 regularization (where λw\lambda w is added to the gradient before entering the moment calculations) alongside our AdamW implementation (where weight decay is applied after the adaptive update). Both start from the same point and use identical hyperparameters:

Out[11]:
Visualization
Contour plot showing Adam with L2 optimization trajectory with oscillation along the path.
Adam with L2 regularization follows a wandering path. The regularization term enters the adaptive mechanism, dampening the regularization signal where it should be strongest.
Contour plot showing AdamW optimization trajectory with a cleaner, more direct path.
AdamW with decoupled weight decay produces a more direct trajectory. Weight decay operates independently of gradient adaptation, providing consistent regularization pressure.

The trajectories tell a striking story. Adam with L2 regularization follows a more wandering path, particularly in the early stages when the regularization gradients are large. Look carefully at the yy-axis convergence: because the gradient along yy is larger (due to the steeper curvature), the L2 term λw\lambda w gets scaled down more aggressively by Adam's adaptive mechanism. The regularization signal is weakest precisely where weights are largest.

AdamW, by contrast, produces a cleaner, more direct trajectory. The decoupled weight decay shrinks weights uniformly regardless of gradient history. Along the yy-axis, where curvature is steep, the adaptive learning rate still reduces step size, but weight decay continues operating at full strength. This is exactly what we want: efficient optimization (via adaptation) combined with consistent regularization (via decoupling).

In neural network training, this difference compounds over millions of steps across millions of parameters. The improved regularization consistency translates directly into better generalization, which is why AdamW has become the standard for training large models.

AdamW as the Default Optimizer

AdamW has become the de facto standard for training transformers and large language models. This wasn't always the case: early language models like the original GPT were trained with Adam plus L2 regularization. The shift to AdamW happened as researchers noticed consistent improvements in generalization.

Why AdamW Dominates Transformer Training

Several factors make AdamW particularly well-suited for transformers:

  • Consistent regularization: Transformers have parameters with vastly different gradient scales. Self-attention weights, feedforward layers, and embeddings all behave differently during training. AdamW's decoupled decay provides uniform regularization regardless of these gradient patterns.

  • Better with large batch sizes: Modern language models use large batches for efficiency. AdamW's proper weight decay helps prevent the generalization gap that can emerge with large-batch training.

  • Stable with learning rate warmup: Transformer training typically uses linear warmup followed by decay. AdamW behaves predictably throughout this schedule because weight decay doesn't interact with the adaptive learning rate mechanism.

The Transformer Training Recipe

A typical configuration for training transformer models with AdamW includes:

  • Learning rate: 1e-4 to 1e-3, with linear warmup over 1-10% of training
  • Weight decay: 0.01 to 0.1
  • Betas: (0.9, 0.999) or (0.9, 0.98) for stability
  • Epsilon: 1e-8 or 1e-6 for numerical stability

Let's see this configuration in action:

In[12]:
Code
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import LambdaLR


# Typical transformer training setup
class SimpleTransformerBlock(nn.Module):
    def __init__(self, d_model=256, n_heads=4, d_ff=512):
        super().__init__()
        self.attention = nn.MultiheadAttention(
            d_model, n_heads, batch_first=True
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff), nn.GELU(), nn.Linear(d_ff, d_model)
        )
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x):
        attn_out, _ = self.attention(x, x, x)
        x = self.norm1(x + attn_out)
        x = self.norm2(x + self.ff(x))
        return x


torch.manual_seed(42)
model = SimpleTransformerBlock()


# Separate parameters for weight decay
def get_param_groups(model, weight_decay=0.01):
    decay = []
    no_decay = []
    for name, param in model.named_parameters():
        if "norm" in name or "bias" in name:
            no_decay.append(param)
        else:
            decay.append(param)
    return [
        {"params": decay, "weight_decay": weight_decay},
        {"params": no_decay, "weight_decay": 0.0},
    ]


optimizer = torch.optim.AdamW(
    get_param_groups(model, weight_decay=0.01),
    lr=1e-4,
    betas=(0.9, 0.999),
    eps=1e-8,
)

# Learning rate schedule with warmup
total_steps = 1000
warmup_steps = 100


def lr_lambda(step):
    if step < warmup_steps:
        return step / warmup_steps
    return max(0.1, 1 - (step - warmup_steps) / (total_steps - warmup_steps))


scheduler = LambdaLR(optimizer, lr_lambda)
Out[13]:
Console
Learning rate at step 0: 0.000000
Learning rate at step 100: 0.000000
Learning rate at step 999: 0.000000
Out[14]:
Visualization
Line plot showing learning rate rising during warmup then declining during decay phase.
Learning rate schedule with linear warmup and linear decay. The warmup phase (first 100 steps) gradually increases the learning rate from zero to the target value, preventing early training instability. After warmup, the learning rate decays linearly, allowing the model to settle into a minimum with smaller, more precise updates.

The warmup phase gradually increases the learning rate, preventing early instability when the model's gradients are large and poorly calibrated. Early in training, weights are randomly initialized and gradients can be very large and erratic. Starting with a small learning rate allows the optimizer to build up accurate moment estimates before taking larger steps. After warmup, the learning rate decays linearly, allowing the model to settle into a good minimum with increasingly precise updates.

Empirical Comparison: Adam vs AdamW

Let's run a controlled experiment comparing Adam with L2 regularization against AdamW on a classification task:

In[15]:
Code
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np

# Create a synthetic classification dataset with some noise
np.random.seed(42)
torch.manual_seed(42)

n_samples = 1000
n_features = 50
n_classes = 5

# Generate features
X = np.random.randn(n_samples, n_features).astype(np.float32)

# Create non-linear decision boundaries
true_weights = np.random.randn(n_features, n_classes).astype(np.float32)
logits = X @ true_weights + 0.5 * np.random.randn(n_samples, n_classes).astype(
    np.float32
)
y = np.argmax(logits, axis=1)

# Split into train and validation
X_train, X_val = X[:800], X[800:]
y_train, y_val = y[:800], y[800:]

train_dataset = TensorDataset(
    torch.from_numpy(X_train), torch.from_numpy(y_train)
)
val_dataset = TensorDataset(torch.from_numpy(X_val), torch.from_numpy(y_val))

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)


# Model architecture
def create_model():
    return nn.Sequential(
        nn.Linear(n_features, 128),
        nn.ReLU(),
        nn.Linear(128, 64),
        nn.ReLU(),
        nn.Linear(64, n_classes),
    )


def train_epoch(model, optimizer, loader, criterion):
    model.train()
    total_loss = 0
    for X_batch, y_batch in loader:
        optimizer.zero_grad()
        output = model(X_batch)
        loss = criterion(output, y_batch)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)


def evaluate(model, loader, criterion):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for X_batch, y_batch in loader:
            output = model(X_batch)
            loss = criterion(output, y_batch)
            total_loss += loss.item()
            pred = output.argmax(dim=1)
            correct += (pred == y_batch).sum().item()
            total += y_batch.size(0)
    return total_loss / len(loader), correct / total


# Train with Adam + L2 regularization
torch.manual_seed(42)
model_adam = create_model()
optimizer_adam = optim.Adam(
    model_adam.parameters(), lr=0.001, weight_decay=0.01
)
criterion = nn.CrossEntropyLoss()

adam_train_losses = []
adam_val_losses = []
adam_val_accs = []

for epoch in range(100):
    train_loss = train_epoch(
        model_adam, optimizer_adam, train_loader, criterion
    )
    val_loss, val_acc = evaluate(model_adam, val_loader, criterion)
    adam_train_losses.append(train_loss)
    adam_val_losses.append(val_loss)
    adam_val_accs.append(val_acc)

# Train with AdamW
torch.manual_seed(42)
model_adamw = create_model()
optimizer_adamw = optim.AdamW(
    model_adamw.parameters(), lr=0.001, weight_decay=0.01
)

adamw_train_losses = []
adamw_val_losses = []
adamw_val_accs = []

for epoch in range(100):
    train_loss = train_epoch(
        model_adamw, optimizer_adamw, train_loader, criterion
    )
    val_loss, val_acc = evaluate(model_adamw, val_loader, criterion)
    adamw_train_losses.append(train_loss)
    adamw_val_losses.append(val_loss)
    adamw_val_accs.append(val_acc)
Out[16]:
Visualization
Line plot showing training and validation loss curves for Adam and AdamW optimizers.
Training and validation loss comparison over 100 epochs. AdamW (red) achieves lower validation loss than Adam with L2 (blue), indicating better generalization from decoupled weight decay.
Line plot showing validation accuracy curves for Adam and AdamW optimizers.
Validation accuracy comparison. AdamW reaches higher final accuracy, demonstrating the practical benefit of proper weight decay in preventing overfitting.
Out[17]:
Console
Final Results (Epoch 100):
  Adam + L2:  Val Loss = 0.3565, Val Acc = 87.0%
  AdamW:      Val Loss = 0.4753, Val Acc = 85.5%

Best Validation Accuracy:
  Adam + L2:  88.5% (Epoch 32)
  AdamW:      87.0% (Epoch 12)

The experiment reveals AdamW's advantage in generalization. While both optimizers reduce training loss similarly, AdamW achieves better validation performance. The gap becomes more pronounced with longer training, as the proper weight decay in AdamW continues to prevent overfitting while L2 regularization's effectiveness diminishes due to Adam's adaptive mechanism.

Limitations and Practical Considerations

AdamW is not a universal solution. Understanding its limitations helps you make informed choices about when to use it and how to configure it effectively.

Memory Overhead

Like Adam, AdamW maintains two moment estimates per parameter, increasing the memory required for optimizer state compared to SGD with momentum. For a model with nn parameters, AdamW stores:

  • The parameters themselves: nn floats
  • First moment estimates mm: nn floats
  • Second moment estimates vv: nn floats

This totals 3n3n floats, compared to 2n2n for SGD with momentum. For a 7-billion parameter model using 32-bit floats, this means approximately 84 GB for optimizer state alone. Techniques like gradient checkpointing, mixed-precision training (which stores moments in lower precision), and optimizer state offloading help mitigate this cost, but the overhead remains a consideration when memory is constrained.

Hyperparameter Sensitivity

While AdamW is more robust than Adam with L2 regularization, it still requires careful hyperparameter tuning. The weight decay coefficient interacts with learning rate, batch size, and training duration. What works for BERT may not work for GPT, and vice versa. Start with established recipes from similar architectures and adjust based on validation performance. Pay particular attention to the learning rate schedule: AdamW works best with warmup followed by gradual decay, and the warmup length matters more than with other optimizers.

Alternatives in Special Cases

For some applications, other optimizers may be preferable:

  • SGD with momentum: Often achieves better final generalization on vision tasks, despite slower convergence
  • Adafactor: Reduces memory by factorizing the second moment estimate, useful for very large models
  • LAMB: Designed specifically for large-batch training, can enable faster distributed training

AdamW remains the safe default for language models and transformers, but don't hesitate to experiment with alternatives if your specific application has different constraints.

Summary

AdamW fixes a fundamental problem with how Adam handles regularization. The key insights are:

  • L2 regularization and weight decay are equivalent for SGD but not for Adam. Adam's adaptive learning rates dampen the L2 regularization signal, reducing its effectiveness precisely where it should be strongest.

  • AdamW decouples weight decay from the gradient update. By applying weight decay directly to the parameters rather than through the gradient path, AdamW maintains consistent regularization regardless of gradient history.

  • The AdamW update adds an explicit decay term. The formula wt+1=wtη(m^t/v^t+λwt)w_{t+1} = w_t - \eta(\hat{m}_t / \sqrt{\hat{v}_t} + \lambda w_t) separates the gradient-based learning from the regularization, where the decay λwt\lambda w_t is not scaled by the adaptive factor.

  • Not all parameters should receive weight decay. Standard practice excludes biases, layer normalization parameters, and sometimes embeddings from weight decay.

  • AdamW is the default optimizer for transformers. Its consistent regularization, stability with learning rate schedules, and robustness across different gradient scales make it ideal for modern language models.

The difference between Adam and AdamW may seem subtle, but it has real impact on model generalization. When training neural networks, especially transformers and language models, AdamW should be your starting point.

Key Parameters

When configuring AdamW for your models, the following parameters have the most significant impact on training:

  • lr (learning rate): Controls the step size for parameter updates. Typical values range from 1e-5 to 1e-3. For transformers, start with 1e-4 and adjust based on validation loss. Use warmup to prevent early training instability.

  • weight_decay: The decoupled regularization coefficient. Values between 0.01 and 0.1 work well for most transformer architectures. Larger models often benefit from stronger decay. This parameter should be excluded for biases and layer normalization parameters.

  • betas: A tuple (β₁, β₂) controlling the exponential decay rates for the moment estimates. The default (0.9, 0.999) works well in most cases. For transformers, some practitioners use (0.9, 0.98) for slightly faster adaptation to gradient changes.

  • eps: A small constant added for numerical stability. The default 1e-8 works for most cases, but 1e-6 can help with mixed-precision training where smaller values may cause overflow.

  • amsgrad: When True, uses the maximum of past squared gradients rather than the exponential moving average. This can help with convergence in some cases but is rarely needed in practice.

Quiz

Ready to test your understanding? Take this quick quiz to reinforce what you've learned about AdamW and decoupled weight decay.

Loading component...

Comments

Reference

BIBTEXAcademic
@misc{adamwoptimizerdecoupledweightdecayfordeeplearning, author = {Michael Brenndoerfer}, title = {AdamW Optimizer: Decoupled Weight Decay for Deep Learning}, year = {2025}, url = {https://mbrenndoerfer.com/writing/adamw-optimizer-decoupled-weight-decay}, organization = {mbrenndoerfer.com}, note = {Accessed: 2025-12-15} }
APAAcademic
Michael Brenndoerfer (2025). AdamW Optimizer: Decoupled Weight Decay for Deep Learning. Retrieved from https://mbrenndoerfer.com/writing/adamw-optimizer-decoupled-weight-decay
MLAAcademic
Michael Brenndoerfer. "AdamW Optimizer: Decoupled Weight Decay for Deep Learning." 2025. Web. 12/15/2025. <https://mbrenndoerfer.com/writing/adamw-optimizer-decoupled-weight-decay>.
CHICAGOAcademic
Michael Brenndoerfer. "AdamW Optimizer: Decoupled Weight Decay for Deep Learning." Accessed 12/15/2025. https://mbrenndoerfer.com/writing/adamw-optimizer-decoupled-weight-decay.
HARVARDAcademic
Michael Brenndoerfer (2025) 'AdamW Optimizer: Decoupled Weight Decay for Deep Learning'. Available at: https://mbrenndoerfer.com/writing/adamw-optimizer-decoupled-weight-decay (Accessed: 12/15/2025).
SimpleBasic
Michael Brenndoerfer (2025). AdamW Optimizer: Decoupled Weight Decay for Deep Learning. https://mbrenndoerfer.com/writing/adamw-optimizer-decoupled-weight-decay
Michael Brenndoerfer

About the author: Michael Brenndoerfer

All opinions expressed here are my own and do not reflect the views of my employer.

Michael currently works as an Associate Director of Data Science at EQT Partners in Singapore, leading AI and data initiatives across private capital investments.

With over a decade of experience spanning private equity, management consulting, and software engineering, he specializes in building and scaling analytics capabilities from the ground up. He has published research in leading AI conferences and holds expertise in machine learning, natural language processing, and value creation through data.

Stay updated

Get notified when I publish new articles on data and AI, private equity, technology, and more.

No spam, unsubscribe anytime.

or

Create a free account to unlock exclusive features, track your progress, and join the conversation.

No popupsUnobstructed readingCommenting100% Free