Switch Transformer: Top-1 Routing & Trillion-Parameter Scaling

Michael BrenndoerferUpdated January 5, 202641 min read

Learn how Switch Transformer simplifies MoE with top-1 routing, capacity factors, and training stability for trillion-parameter language models.

Reading Level

Choose your expertise level to adjust how many terms are explained. Beginners see more tooltips, experts see fewer to maintain reading flow. Hover over underlined terms for instant definitions.

Switch Transformer

The previous chapters in this part established the fundamentals of Mixture of Experts: how expert networks specialize, how gating mechanisms route tokens, and how auxiliary losses maintain load balance. Yet despite these elegant solutions, MoE models remained notoriously difficult to train at scale. The Switch Transformer, introduced by Fedus, Zoph, and Shazeer in 2022, changed this by making a counterintuitive choice: route each token to just one expert instead of two or more. This radical simplification, combined with careful engineering around capacity limits and training stability, enabled scaling to over a trillion parameters while maintaining computational efficiency.

The Switch Transformer demonstrated that complexity isn't always the path to better performance. By removing the interpolation between multiple experts that previous MoE designs required, Switch reduced communication overhead, simplified gradient computation, and paradoxically improved model quality. This chapter examines the design decisions behind Switch Transformer, the capacity factor mechanism that handles routing imbalances, and the scaling results that established MoE as a viable path toward ever-larger language models.

The Switch Layer

The core architectural innovation of Switch Transformer is the Switch layer, which replaces the standard feed-forward network (FFN) in each transformer block with a sparse, routed alternative. To understand why this substitution matters, we need to first recall what happens in a standard transformer and then see how the Switch layer transforms this computation into something far more powerful.

Recall from our discussion of feed-forward networks in Part XII that a standard transformer block processes each token through a two-layer MLP. This feed-forward network applies the same transformation to every token in the sequence, using shared parameters regardless of what the token represents or what context it appears in:

FFN(x)=W2activation(W1x+b1)+b2\text{FFN}(x) = W_2 \cdot \text{activation}(W_1 \cdot x + b_1) + b_2

where:

  • xx: the input vector to the layer
  • W1,W2W_1, W_2: learnable weight matrices for the first and second linear transformations
  • b1,b2b_1, b_2: learnable bias vectors
  • activation\text{activation}: the non-linear activation function (typically ReLU or GeLU)

This standard FFN architecture has served transformers well, but it embodies a fundamental limitation: every token receives identical processing. A token representing a mathematical concept receives the same transformation as a token representing a cooking ingredient. The network must somehow compress all of its knowledge into a single set of parameters.

In a Switch layer, we break free from this constraint. Instead of one FFN, we have EE expert networks, each with the same architecture as the original FFN but with independent parameters. Each expert can specialize in different types of tokens or different aspects of language understanding. A router network decides which single expert should process each token, enabling the model to deploy specialized computation where it matters most.

This architectural choice creates a natural division of labor. One expert might become skilled at processing mathematical notation, another at handling named entities, and yet another at processing syntactic function words. The router learns to recognize these patterns and direct each token to the expert best suited to process it.

In[2]:
Code
import torch
import torch.nn as nn
import torch.nn.functional as F


class SwitchExpert(nn.Module):
    """A single expert network with standard FFN architecture."""

    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.w1 = nn.Linear(d_model, d_ff)
        self.w2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Standard FFN: up-project, activate, down-project
        return self.w2(self.dropout(F.gelu(self.w1(x))))

The router serves as the decision-making component of the Switch layer. It examines each incoming token representation and produces a probability distribution indicating how suitable each expert is for processing that particular token. Despite the sophistication of the routing decisions it makes, the router itself is remarkably simple: just a linear layer that projects the token representation to a vector of expert scores, followed by a softmax to convert these scores into probabilities.

In[3]:
Code
class SwitchRouter(nn.Module):
    """Routes tokens to experts using learned linear projection."""

    def __init__(self, d_model: int, num_experts: int):
        super().__init__()
        self.num_experts = num_experts
        self.router = nn.Linear(d_model, num_experts, bias=False)

    def forward(
        self, x: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        # x shape: (batch_size, seq_len, d_model)
        router_logits = self.router(x)  # (batch, seq, num_experts)
        router_probs = F.softmax(router_logits, dim=-1)

        # Select the single best expert for each token
        expert_indices = torch.argmax(router_probs, dim=-1)  # (batch, seq)
        expert_weights = router_probs.gather(
            -1, expert_indices.unsqueeze(-1)
        ).squeeze(-1)

        return expert_indices, expert_weights, router_logits

The key difference from previous MoE architectures like those in GShard is the hard assignment to a single expert. Each token goes to exactly one expert, and the router's softmax probability for that expert becomes a multiplicative weight on the expert's output. This design choice, which might initially seem like a limitation, turns out to be the crucial insight that makes Switch Transformer so effective at scale.

Out[4]:
Visualization
Dense feed-forward network architecture. All tokens are processed through a single shared FFN, applying identical transformations regardless of token content.
Dense feed-forward network architecture. All tokens are processed through a single shared FFN, applying identical transformations regardless of token content.
Switch layer architecture with top-1 routing. Each token is routed to exactly one specialized expert based on content, enabling parameter scaling without increasing computational cost per token.
Switch layer architecture with top-1 routing. Each token is routed to exactly one specialized expert based on content, enabling parameter scaling without increasing computational cost per token.

Top-1 Routing: The Simplification Insight

Prior MoE models, including the influential GShard architecture, used top-2 routing: each token was sent to two experts, with the outputs combined as a weighted average. The reasoning was intuitive: combining perspectives from multiple specialists should produce richer representations. After all, if one expert knows about syntax and another knows about semantics, shouldn't we benefit from consulting both?

Switch Transformer challenged this assumption with empirical evidence that contradicted conventional wisdom. The authors found that routing to a single expert achieves comparable or better performance while providing several practical advantages that compound at scale.

Why Top-1 Works

The effectiveness of top-1 routing can be understood through several lenses, each revealing a different facet of why simpler is better in this context:

Reduced communication cost. In distributed training, each expert typically lives on a different device. This means that routing a token to an expert requires sending that token's representation across the network to wherever the expert resides. Top-2 routing means every token must be sent to two devices and results gathered back, doubling the network traffic. Top-1 cuts this communication in half, which becomes increasingly important as models scale across hundreds or thousands of devices.

Simpler gradient flow. With top-2 routing, the gradient for a token flows through two expert networks weighted by their respective probabilities. This creates interdependencies that complicate optimization: the gradient with respect to one expert's parameters depends on what the other expert produced. Top-1 provides cleaner gradients, where each expert receives full responsibility for the tokens it processes. This clarity in the optimization landscape helps the model converge more reliably.

Higher effective capacity. If each token visits two experts, and experts have a fixed capacity (more on this shortly), then top-1 routing means twice as many tokens can be processed per expert. This translates to larger effective batch sizes per expert, which improves the quality of gradient estimates and allows each expert to see more diverse examples during training.

Empirical validation. The Switch Transformer paper demonstrated that top-1 routing with the same total compute achieves better perplexity than top-2 routing across multiple model scales. This wasn't a marginal improvement; the gains were consistent and significant, validating the theoretical arguments with concrete evidence.

The mathematical formulation for top-1 routing is straightforward, which is part of its elegance. Given a token representation xx, the router computes a probability distribution over all available experts:

p(x)=softmax(Wrx)p(x) = \text{softmax}(W_r \cdot x)

where:

  • p(x)p(x): the probability distribution over experts for the input token
  • WrW_r: the learnable router weight matrix projecting the input to expert logits
  • xx: the input token representation

This probability distribution represents the router's assessment of how well-suited each expert is for processing this particular token. The selected expert is simply the one with the highest probability: e=argmaxipi(x)e^* = \arg\max_i p_i(x). Once the expert is selected, the output of the Switch layer combines the expert's processing with a scaling factor:

y=pe(x)Ee(x)y = p_{e^*}(x) \cdot E_{e^*}(x)

where:

  • yy: the output of the Switch layer for the token
  • ee^*: the index of the selected expert (the one with the highest probability)
  • pe(x)p_{e^*}(x): the probability assigned to the selected expert, acting as a gating factor
  • Ee(x)E_{e^*}(x): the output of the selected expert network processed through its specific parameters

The multiplicative weighting by pe(x)p_{e^*}(x) serves an important purpose beyond just combining the expert output. It provides a differentiable signal that allows gradients to flow back through the router, enabling the router to learn which experts work best for which tokens.

Out[5]:
Visualization
Comparison of routing strategy costs. Top-1 routing reduces communication overhead and gradient complexity compared to top-2 approaches, while doubling the effective batch capacity per expert.
Comparison of routing strategy costs. Top-1 routing reduces communication overhead and gradient complexity compared to top-2 approaches, while doubling the effective batch capacity per expert.
Out[6]:
Visualization
Probability weighting mechanism for differentiable routing. The router's probability assignment scales the expert output, ensuring gradients flow back to the router. Signals near the decision boundary (0.5) are dampened, smoothing training dynamics.
Probability weighting mechanism for differentiable routing. The router's probability assignment scales the expert output, ensuring gradients flow back to the router. Signals near the decision boundary (0.5) are dampened, smoothing training dynamics.

The No-Token-Left-Behind Principle

A potential concern with top-1 routing is that hard assignment creates discontinuities. Small changes in a token's representation could flip it to a different expert, potentially causing unstable training dynamics. If a token is hovering on the decision boundary between two experts, tiny perturbations might cause it to oscillate between them during training.

However, the multiplicative weighting smooths this effect in an elegant way. If a token is marginally assigned to expert 3 with probability 0.51 (vs. 0.49 for expert 2), the output is scaled by 0.51. This scaling provides a soft transition that prevents gradient spikes. When the token is near a decision boundary, the output magnitude is reduced, naturally dampening the impact of near-ties. As the router becomes more confident about a token's assignment, the probability approaches 1.0 and the full expert output is used.

This design ensures that tokens near decision boundaries contribute less strongly to the gradient signal, allowing the model to focus its learning on tokens where the routing decision is clear and meaningful.

Capacity Factor: Managing Overflow

Even with perfect load balancing losses pushing the router toward uniform distribution, routing can never be perfectly uniform in practice. The discrete nature of expert selection means that random fluctuations will always cause some experts to receive more tokens than others. This creates a practical problem that cannot be ignored: how many tokens should each expert be prepared to handle?

The capacity factor (CC) is Switch Transformer's solution to this challenge. It defines how much buffer capacity each expert has beyond the perfectly balanced allocation, providing a mechanism to handle the inevitable routing imbalances without catastrophic failure.

Capacity Calculation

If a batch contains nn total tokens and we have EE experts, perfect balance would assign n/En/E tokens to each expert. In an ideal world, the router would achieve this exact distribution. But in reality, some experts will be more popular than others for any given batch. The capacity factor scales the allocation to accommodate this variance:

capacity=nCE\text{capacity} = \left\lfloor \frac{n \cdot C}{E} \right\rfloor

where:

  • nn: the total number of tokens in the batch
  • CC: the capacity factor (scalar multiplier, typically >1.0>1.0)
  • EE: the number of experts available
  • \lfloor \cdot \rfloor: the floor function, ensuring an integer number of slots

Understanding this formula requires thinking about what happens at different values of CC. A capacity factor of C=1.0C = 1.0 means experts have exactly enough slots for perfect balance. In practice, routing is never perfect, so some experts overflow. The Switch solution is simple and pragmatic: tokens that exceed an expert's capacity skip the expert entirely and pass through via the residual connection.

This design creates a spectrum of tradeoffs that you can navigate based on your specific constraints:

  • C=1.0C = 1.0: Minimal memory usage, but many tokens may be dropped (skipped)
  • C=1.25C = 1.25: 25% buffer handles typical load variance well
  • C=2.0C = 2.0: Generous buffer, fewer drops, but double the memory per expert
In[7]:
Code
def compute_capacity(
    batch_size: int, seq_len: int, num_experts: int, capacity_factor: float
) -> int:
    """Calculate expert capacity given batch parameters."""
    total_tokens = batch_size * seq_len
    # Perfect balance allocation with capacity factor buffer
    capacity = int((total_tokens * capacity_factor) / num_experts)
    return max(capacity, 1)  # At least 1 token per expert
In[8]:
Code
# Example setup: 32 batch size, 512 sequence length, 16 experts
batch_size = 32
seq_len = 512
num_experts = 16
total_tokens = batch_size * seq_len

# Calculate capacities for different factors
capacity_results = []
for cf in [1.0, 1.25, 1.5, 2.0]:
    cap = compute_capacity(batch_size, seq_len, num_experts, cf)
    perfect_balance = total_tokens // num_experts
    buffer = cap - perfect_balance
    capacity_results.append((cf, cap, perfect_balance, buffer))
Out[9]:
Console
C=1.00: capacity=1024 tokens/expert (perfect balance=1024, buffer=0)
C=1.25: capacity=1280 tokens/expert (perfect balance=1024, buffer=256)
C=1.50: capacity=1536 tokens/expert (perfect balance=1024, buffer=512)
C=2.00: capacity=2048 tokens/expert (perfect balance=1024, buffer=1024)

With a capacity factor of 1.25, each expert can handle 25% more tokens than perfect balance would require. This buffer accommodates the natural variance in routing decisions without requiring excessive memory. In practice, the authors found C=1.25C = 1.25 provides a good tradeoff between memory efficiency and token coverage, though this value may need adjustment for different batch sizes and expert counts.

Out[10]:
Visualization
Expert capacity scaling with capacity factor. Capacity increases linearly with the capacity factor, providing a buffer against routing imbalances. The recommended value of C=1.25 balances memory efficiency with token coverage.
Expert capacity scaling with capacity factor. Capacity increases linearly with the capacity factor, providing a buffer against routing imbalances. The recommended value of C=1.25 balances memory efficiency with token coverage.
Token dropping rates for different capacity factors. Higher capacity factors reduce token dropping rates, with diminishing returns observed beyond C=1.5, making C=1.25 a common practical choice.
Token dropping rates for different capacity factors. Higher capacity factors reduce token dropping rates, with diminishing returns observed beyond C=1.5, making C=1.25 a common practical choice.

Token Dropping Behavior

When an expert reaches capacity, additional tokens routed to it are "dropped": they bypass the expert entirely and only the residual connection preserves their information. This might seem problematic at first glance. After all, if the router determined that a token should go to a particular expert, doesn't skipping that expert harm the model's performance?

Several factors mitigate the impact of token dropping, making this design more robust than it initially appears:

  • Load balancing reduces drops. The auxiliary balancing loss (covered in Chapter 6 of this part) incentivizes the router to distribute tokens evenly across experts. A well-trained router rarely overloads any single expert severely, minimizing the frequency of capacity overflow.

  • Residual connections preserve information. Dropped tokens still retain their original representation through the residual connection; they simply don't receive the expert's specialized processing for that layer. The token's information isn't lost, just unchanged by that particular expert.

  • Drops are distributed across tokens. With proper load balancing, drops rarely concentrate on specific tokens across multiple layers. A token might skip one expert in layer 5 but be processed normally in layers 1 through 4 and 6 through 12. The impact of a single dropped expert processing is diluted across the many layers of the network.

  • Training learns robustness. The model learns to function with some level of token dropping, developing representations that don't critically depend on every expert processing. This emergent robustness means that occasional capacity overflow doesn't catastrophically harm performance.

Complete Switch Layer Implementation

Let's assemble the full Switch layer, incorporating routing, capacity management, and the load balancing loss we discussed in Chapter 6. This implementation ties together all the concepts we've explored: the router that assigns tokens to experts, the capacity mechanism that handles overflow, and the auxiliary loss that encourages balanced utilization.

In[11]:
Code
class SwitchLayer(nn.Module):
    """
    Switch Transformer layer: routes each token to one expert.
    """

    def __init__(
        self,
        d_model: int,
        d_ff: int,
        num_experts: int,
        capacity_factor: float = 1.25,
        dropout: float = 0.1,
        balance_loss_weight: float = 0.01,
    ):
        super().__init__()
        self.num_experts = num_experts
        self.capacity_factor = capacity_factor
        self.balance_loss_weight = balance_loss_weight

        # Router
        self.router = SwitchRouter(d_model, num_experts)

        # Expert networks
        self.experts = nn.ModuleList(
            [SwitchExpert(d_model, d_ff, dropout) for _ in range(num_experts)]
        )

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        batch_size, seq_len, d_model = x.shape
        total_tokens = batch_size * seq_len

        # Get routing decisions
        expert_indices, expert_weights, router_logits = self.router(x)

        # Flatten for easier processing
        x_flat = x.view(total_tokens, d_model)
        expert_indices_flat = expert_indices.view(total_tokens)
        expert_weights_flat = expert_weights.view(total_tokens)

        # Calculate capacity
        capacity = compute_capacity(
            batch_size, seq_len, self.num_experts, self.capacity_factor
        )

        # Initialize output
        output = torch.zeros_like(x_flat)

        # Track which tokens are processed for load balancing
        tokens_per_expert = torch.zeros(self.num_experts, device=x.device)

        # Process each expert
        for expert_idx in range(self.num_experts):
            # Find tokens assigned to this expert
            mask = expert_indices_flat == expert_idx
            token_indices = mask.nonzero(as_tuple=True)[0]

            # Apply capacity limit
            num_tokens = min(len(token_indices), capacity)
            token_indices = token_indices[:num_tokens]
            tokens_per_expert[expert_idx] = num_tokens

            if num_tokens > 0:
                # Get token representations
                expert_input = x_flat[token_indices]

                # Process through expert
                expert_output = self.experts[expert_idx](expert_input)

                # Weight by router probability
                weights = expert_weights_flat[token_indices].unsqueeze(-1)
                output[token_indices] = weights * expert_output

        # Reshape output
        output = output.view(batch_size, seq_len, d_model)

        # Add residual connection (critical for dropped tokens!)
        output = output + x

        # Compute load balancing loss
        balance_loss = self._compute_balance_loss(
            router_logits.view(total_tokens, -1), expert_indices_flat
        )

        return output, balance_loss * self.balance_loss_weight

    def _compute_balance_loss(
        self, router_logits: torch.Tensor, expert_indices: torch.Tensor
    ) -> torch.Tensor:
        """Auxiliary loss to encourage balanced routing."""
        # Fraction of tokens routed to each expert
        num_tokens = router_logits.shape[0]

        # f_i: fraction of tokens to expert i
        expert_counts = torch.zeros(
            self.num_experts, device=router_logits.device
        )
        for i in range(self.num_experts):
            expert_counts[i] = (expert_indices == i).float().sum()
        f = expert_counts / num_tokens

        # P_i: mean router probability for expert i
        router_probs = F.softmax(router_logits, dim=-1)
        P = router_probs.mean(dim=0)

        # Balance loss: sum of f_i * P_i, scaled by num_experts
        balance_loss = self.num_experts * torch.sum(f * P)

        return balance_loss

Let's verify the layer works correctly:

In[12]:
Code
## Setup parameters
B, S, D = 4, 64, 256
num_experts = 8

## Create a Switch layer
switch_layer = SwitchLayer(
    d_model=D, d_ff=1024, num_experts=num_experts, capacity_factor=1.25
)

## Test input
test_input = torch.randn(B, S, D)
output, aux_loss = switch_layer(test_input)

## Analyze routing distribution
with torch.no_grad():
    expert_indices, _, _ = switch_layer.router(test_input)
    expert_counts = torch.bincount(
        expert_indices.flatten(), minlength=num_experts
    )
Out[13]:
Console
Input shape:  torch.Size([4, 64, 256])
Output shape: torch.Size([4, 64, 256])
Auxiliary loss: 0.0100

Tokens per expert: [42, 33, 27, 38, 28, 30, 27, 31]
Expected per expert (uniform): 32

The output shows that tokens are distributed across experts, with the auxiliary loss encouraging this balance. As training progresses, the balance loss will push the router toward more uniform distribution while still allowing meaningful specialization.

Key Parameters

The key parameters for the Switch Layer are:

  • num_experts: Number of expert networks. More experts increase capacity without increasing compute.
  • capacity_factor: Multiplier defining the expert buffer size (CC). Controls the tradeoff between memory and token dropping.
  • balance_loss_weight: Coefficient for the auxiliary load balancing loss to prevent routing collapse.
  • d_ff: Hidden dimension of the expert feed-forward networks.

Training Stability Strategies

Switch Transformers, like other MoE models, can be unstable during training. The sparse routing mechanism introduces discontinuities that dense models don't face, and the discrete expert selection can create optimization challenges. The authors identified several strategies critical for stable training at scale, each addressing a specific source of instability.

Selective Precision

One counterintuitive finding was that using lower precision (bfloat16 or float16) in most of the model while keeping the router in float32 significantly improved stability. This selective precision strategy reflects a deep understanding of where numerical precision matters most.

The router's softmax operation is particularly sensitive to numerical precision because small differences in logits get amplified exponentially. When computing ezie^{z_i} for logits ziz_i, even small rounding errors can cause some experts to receive disproportionately high or low probabilities. In lower precision, these errors accumulate and can cause the router to make inconsistent or degenerate routing decisions.

By keeping the router computation in float32 while allowing the rest of the model to use memory-efficient bfloat16, Switch Transformer achieves both numerical stability and computational efficiency. The router represents a tiny fraction of total computation, so the precision overhead is minimal.

In[14]:
Code
class StableSwitchRouter(nn.Module):
    """Router with selective precision for stability."""

    def __init__(self, d_model: int, num_experts: int):
        super().__init__()
        self.num_experts = num_experts
        # Router weights stay in float32 even if model is bfloat16
        self.router = nn.Linear(d_model, num_experts, bias=False)

    def forward(
        self, x: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        # Cast to float32 for routing computation
        router_input = x.float()
        router_logits = self.router(router_input)
        router_probs = F.softmax(router_logits, dim=-1)

        expert_indices = torch.argmax(router_probs, dim=-1)
        expert_weights = router_probs.gather(
            -1, expert_indices.unsqueeze(-1)
        ).squeeze(-1)

        # Cast weights back to input dtype
        expert_weights = expert_weights.to(x.dtype)

        return expert_indices, expert_weights, router_logits

Router Z-Loss

As we discussed in Chapter 7, the router z-loss penalizes large router logits, preventing the softmax from becoming too peaked. This is particularly important for Switch Transformer because top-1 routing already creates hard decisions. Extremely peaked softmax distributions worsen the discontinuity, making tiny input changes flip routing decisions dramatically.

The z-loss provides a gentle pressure against this behavior:

Lz=1ni=1n(logj=1Eezj(i))2L_z = \frac{1}{n} \sum_{i=1}^{n} \left( \log \sum_{j=1}^{E} e^{z_j^{(i)}} \right)^2

where:

  • LzL_z: the auxiliary router z-loss
  • nn: the total number of tokens in the batch
  • EE: the total number of experts
  • zj(i)z_j^{(i)}: the logit (pre-softmax score) for expert jj given token ii
  • j=1Eezj(i)\sum_{j=1}^{E} e^{z_j^{(i)}}: the sum of exponentials (the denominator of the softmax function)

This formula penalizes large logits to maintain numerical stability. Each component serves a specific purpose:

  1. Sum of Exponentials: The inner sum ezj\sum e^{z_j} captures the aggregate magnitude of the logits. When logits are large, this sum grows exponentially.
  2. Logarithm: The log\log scales this value back to the linear domain, essentially computing something close to maxjzj\max_j z_j plus a small correction.
  3. Square: Squaring the result penalizes large positive values strongly, creating a strong incentive to keep logits bounded.

Without this loss, the model could increase logits indefinitely. Due to softmax's translation invariance, logits of [5, 3, 2] and [500, 498, 497] produce identical probability distributions. But the second case risks floating-point overflow and creates extremely peaked distributions that harm optimization. The z-loss prevents this degenerate behavior.

Out[15]:
Visualization
Softmax translation invariance. Adding a constant shift to all logits yields identical probability distributions. Without regularization, this property allows logits to drift to arbitrarily high values, causing numerical instability.
Softmax translation invariance. Adding a constant shift to all logits yields identical probability distributions. Without regularization, this property allows logits to drift to arbitrarily high values, causing numerical instability.
Out[16]:
Visualization
Router z-loss penalty curve. The z-loss applies a quadratic penalty to the log-sum-exp of logits. This constraint keeps logits within a safe numerical range (green) and prevents them from growing into the overflow-risk region (red).
Router z-loss penalty curve. The z-loss applies a quadratic penalty to the log-sum-exp of logits. This constraint keeps logits within a safe numerical range (green) and prevents them from growing into the overflow-risk region (red).

Initialization and Dropout

The Switch paper recommended several initialization tweaks that collectively improve training stability:

  • Smaller router initialization: Initialize router weights with smaller standard deviation to prevent early routing collapse
  • Increased dropout: Higher dropout rates (0.1-0.2) in experts help prevent overfitting to specific routing patterns
  • Expert dropout: Occasionally dropping entire experts during training encourages redundancy

These techniques work together to prevent premature specialization, where the model commits too strongly to certain routing patterns before it has seen enough data to make informed decisions.

Scaling Results

The most compelling evidence for Switch Transformer's effectiveness came from scaling experiments. The authors compared Switch Transformer against dense T5 models with equivalent compute budgets, demonstrating that the sparse architecture provides consistent benefits across a wide range of scales.

Speed vs. Quality Tradeoffs

The key insight is that Switch models achieve better quality at the same training compute. A Switch-Base model with 7B total parameters (but only activating approximately 100M per token) outperformed T5-Base while using the same FLOPs per training step. This wasn't achieved by throwing more compute at the problem; it emerged from the more efficient use of parameters that sparsity enables.

The improvements were substantial across multiple dimensions:

  • Pre-training speed: Switch-Base achieved T5-Base quality in 1/7th the training time
  • Same-compute quality: Given equal compute, Switch models consistently achieved lower perplexity
  • Scaling efficiency: The gap widened at larger scales, with Switch-XXL showing dramatic improvements over T5-XXL

Scaling to Trillion Parameters

The Switch Transformer demonstrated that MoE enables scaling to unprecedented parameter counts. The largest model, Switch-C, contained 1.6 trillion parameters distributed across 2048 experts. This scale was simply not achievable with dense models given the computational resources available.

Despite this massive parameter count, the computational cost per token remained manageable because only one expert activates per token. A 1.6 trillion parameter model might have the knowledge capacity of its full parameter count, but the inference cost of a model roughly 2000 times smaller.

In[17]:
Code
def compare_model_costs(
    d_model: int,
    d_ff: int,
    num_layers: int,
    num_experts: int,
    seq_len: int = 512,
) -> dict:
    """Compare parameter counts and FLOPs for dense vs Switch."""
    # Dense model: each layer has one FFN
    dense_ffn_params = d_model * d_ff + d_ff * d_model  # Up + down projection
    dense_total_ffn = dense_ffn_params * num_layers

    # Switch model: each layer has num_experts FFNs
    switch_total_ffn = dense_ffn_params * num_layers * num_experts

    # Router parameters (small)
    router_params = d_model * num_experts * num_layers

    # FLOPs per token (both activate same amount)
    dense_flops = 2 * dense_ffn_params  # Forward pass FLOPs
    switch_flops = 2 * dense_ffn_params  # Same! Only one expert activates

    return {
        "dense_params": dense_total_ffn,
        "switch_params": switch_total_ffn + router_params,
        "param_ratio": (switch_total_ffn + router_params) / dense_total_ffn,
        "flops_ratio": switch_flops / dense_flops,
    }
In[18]:
Code
## Compare at different scales: (Name, d_model, d_ff, layers, experts)
configs = [
    ("Base", 768, 3072, 12, 128),
    ("Large", 1024, 4096, 24, 128),
    ("XL", 2048, 8192, 24, 256),
]

comparison_results = []
for name, d_model, d_ff, layers, experts in configs:
    res = compare_model_costs(d_model, d_ff, layers, experts)
    comparison_results.append((name, d_model, experts, res))
Out[19]:
Console
Comparison: Dense vs Switch Transformer
=================================================================

Base (d_model=768, 128 experts):
  Dense FFN params:    56,623,104
  Switch total params: 7,248,936,960
  Parameter ratio: 128.0x more parameters
  FLOPs ratio: 1.0x (same compute!)

Large (d_model=1024, 128 experts):
  Dense FFN params:   201,326,592
  Switch total params: 25,772,949,504
  Parameter ratio: 128.0x more parameters
  FLOPs ratio: 1.0x (same compute!)

XL (d_model=2048, 256 experts):
  Dense FFN params:   805,306,368
  Switch total params: 206,171,013,120
  Parameter ratio: 256.0x more parameters
  FLOPs ratio: 1.0x (same compute!)
Out[20]:
Visualization
Parameter counts for dense versus Switch models at different scales. Switch models scale to significantly higher parameter counts through expert addition, enabling massive capacity growth.
Parameter counts for dense versus Switch models at different scales. Switch models scale to significantly higher parameter counts through expert addition, enabling massive capacity growth.
Computational costs per token for dense versus Switch models. Despite having many more parameters, Switch models maintain identical computational cost per token, decoupling capacity from compute.
Computational costs per token for dense versus Switch models. Despite having many more parameters, Switch models maintain identical computational cost per token, decoupling capacity from compute.

This comparison illustrates the key advantage: Switch models can have 128x or more parameters while using identical compute per forward pass. The additional parameters provide more capacity for learning without proportionally increasing training cost. This decoupling of parameters from compute was the central insight that enabled scaling to trillion-parameter models.

Sample Efficiency

Beyond raw speed, Switch Transformers showed improved sample efficiency. The models achieved better quality with fewer training tokens, suggesting that the sparse expert structure enables more effective use of training data. This aligns with the intuition that specialized experts can learn more from each example in their domain.

When a token about mathematics routes to a mathematics-specialized expert, that expert receives a concentrated signal for updating its parameters. In contrast, a dense model must update all parameters regardless of the token's content, diluting the learning signal across parameters that may not be relevant.

Out[21]:
Visualization
Line plot showing training curves with Switch reaching target perplexity faster than dense baseline.
Training efficiency of Switch Transformer compared to a dense T5 baseline. The Switch model achieves the same perplexity level significantly faster than the dense model, demonstrating superior sample efficiency. The shaded region highlights the reduction in training steps required to reach equivalent model quality.

Distillation and Fine-tuning

One challenge with large MoE models is deployment: serving a 1.6T parameter model requires distributing experts across many devices, increasing latency and infrastructure complexity. The Switch paper explored distillation as a solution, showing that much of the knowledge learned by large MoE models can be transferred to smaller, more deployable architectures.

Distilling to Dense Models

The authors found that Switch Transformers could be distilled into dense models that retain much of the quality gain. A Switch-Base model distilled to a T5-Base architecture achieved 30% of the original quality improvement while eliminating the MoE complexity entirely. This provides a practical path for organizations that want MoE's training benefits without its deployment challenges.

The distillation process is straightforward: train the dense student model to match the Switch teacher's outputs, using a combination of hard labels (for classification) or soft targets (for language modeling). The student learns to approximate the ensemble behavior of the sparse experts using its single, dense feed-forward network.

Fine-tuning Challenges

Fine-tuning MoE models presents unique challenges that you must address:

  • Expert collapse: During fine-tuning on narrow domains, some experts may receive very few tokens, causing their parameters to drift or become useless. The fine-tuning dataset may not cover the full range of content that the pre-trained router learned to handle. Increasing the load balancing loss weight during fine-tuning helps maintain expert diversity.

  • Capacity tuning: Fine-tuning datasets are often smaller than pre-training corpora, changing the optimal capacity factor. With smaller batches, the variance in routing decisions increases, potentially requiring higher capacity factors to avoid excessive token dropping. The authors recommended re-tuning capacity factor for each downstream task.

  • Transfer gaps: While Switch models excelled at pre-training, the gains sometimes diminished after fine-tuning. Dense models occasionally caught up on specific tasks, suggesting that MoE's advantages are most pronounced for general language modeling rather than task-specific optimization. This finding motivates the distillation approach: use MoE for pre-training, then distill to dense for fine-tuning and deployment.

Worked Example: Routing Visualization

Let's visualize how tokens get routed to experts in a trained Switch layer. This visualization helps build intuition about what the router learns and how it distributes tokens across the available experts:

In[22]:
Code
# Create a Switch layer and process some text
torch.manual_seed(42)

num_viz_experts = 8
switch_layer = SwitchLayer(
    d_model=128, d_ff=512, num_experts=num_viz_experts, capacity_factor=1.5
)

# Simulate token embeddings for a sentence
# Different semantic categories should route differently
batch_size, seq_len, d_model = 1, 20, 128
tokens = torch.randn(batch_size, seq_len, d_model)

# Get routing decisions
with torch.no_grad():
    expert_indices, expert_weights, router_logits = switch_layer.router(tokens)
    router_probs = F.softmax(router_logits, dim=-1)
Out[23]:
Visualization
Router probability distribution heatmap. The router assigns high probabilities (dark blue) to specific experts for each token, demonstrating confident, sparse routing decisions.
Router probability distribution heatmap. The router assigns high probabilities (dark blue) to specific experts for each token, demonstrating confident, sparse routing decisions.
Out[24]:
Visualization
Expert load distribution. Aggregating routing decisions reveals the number of tokens assigned to each expert. The red line indicates perfect balance; deviations are addressed by the auxiliary load balancing loss during training.
Expert load distribution. Aggregating routing decisions reveals the number of tokens assigned to each expert. The red line indicates perfect balance; deviations are addressed by the auxiliary load balancing loss during training.

The heatmap shows each token's probability distribution over experts. Notice how most tokens have one dominant expert (high probability in one column), demonstrating the router's confident routing decisions. The bar chart shows the actual token distribution, with the red line indicating perfect balance. Even with random initialization, we see some variance in expert popularity, which the load balancing loss would address during training.

Limitations and Practical Considerations

Despite its successes, Switch Transformer has important limitations that influenced subsequent work and that you should understand before deploying these models.

Training Instability

MoE models, including Switch Transformer, exhibit more training instability than dense counterparts. The routing decisions create discontinuities in the loss landscape, and router collapse (where all tokens route to few experts) can occur if load balancing fails. The selective precision and z-loss strategies help but don't eliminate these issues entirely. Training large Switch models requires careful monitoring and sometimes manual intervention when instability appears.

Fine-tuning Gap

While Switch models excel at pre-training, the advantages often shrink after fine-tuning. On some tasks, dense models with equivalent compute match or exceed Switch performance. This suggests MoE's benefits may be most pronounced for general-purpose language modeling rather than task-specific optimization. You should evaluate whether MoE is appropriate for your specific use case.

Infrastructure Complexity

Deploying MoE models requires expert parallelism, where different experts live on different devices. This introduces all-to-all communication patterns that standard data or model parallelism don't require. The infrastructure burden limited early adoption, though frameworks like DeepSpeed and Megatron have since added MoE support. Organizations considering MoE deployment should evaluate their infrastructure readiness.

Memory vs. Compute Tradeoffs

While FLOPs per token remain constant, the total parameter memory scales with expert count. A 128-expert model requires 128x the FFN memory, even though only 1/128th activates per token. This memory overhead matters for inference, where batch sizes may be small and memory dominates cost. The memory-compute tradeoff differs between training and inference, requiring careful capacity planning.

Token Dropping Impact

Though residual connections preserve dropped tokens' information, consistent dropping hurts performance. For tasks requiring precise token-level processing, the capacity factor must be tuned carefully. In practice, this means MoE models often need larger batch sizes to ensure adequate expert utilization. Applications with strict latency requirements and small batch sizes may find dense models more suitable.

Despite these limitations, Switch Transformer established that simplified MoE architectures could scale efficiently. The design choices (top-1 routing, capacity factor, and training stabilization strategies) became foundational for subsequent work. Mixtral, which we'll explore in the next chapter, builds directly on these foundations while introducing innovations for open-source deployment.

Summary

Switch Transformer demonstrated that radical simplification could unlock MoE scalability. The key innovations include:

  • Top-1 routing: Each token goes to exactly one expert, halving communication cost and simplifying gradients compared to top-2 routing
  • Capacity factor: A tunable buffer (CC, typically 1.25) determines how many tokens each expert can handle, with overflow tokens skipping to the residual path
  • Selective precision: Keeping the router in float32 while using lower precision elsewhere improves training stability
  • Load balancing: The auxiliary loss from prior MoE work combines with router z-loss to maintain balanced expert utilization

The scaling results were compelling: Switch-Base achieved T5-Base quality in one-seventh the training time, and the architecture scaled to 1.6 trillion parameters. These results established MoE as a practical path toward larger language models, showing that parameter count and computational cost can be decoupled.

The capacity factor mechanism deserves particular attention. By setting a hard limit on tokens per expert and gracefully dropping overflow to residual connections, Switch Transformer handled the inherent imperfection of routing without catastrophic failures. This pragmatic engineering choice, combined with aggressive load balancing, made MoE training reliable enough for production scale.

Quiz

Ready to test your understanding? Take this quick quiz to reinforce what you've learned about Switch Transformer architecture and its innovations.

Loading component...

Reference

BIBTEXAcademic
@misc{switchtransformertop1routingtrillionparameterscaling, author = {Michael Brenndoerfer}, title = {Switch Transformer: Top-1 Routing & Trillion-Parameter Scaling}, year = {2025}, url = {https://mbrenndoerfer.com/writing/switch-transformer-top-1-routing-trillion-parameter-scaling}, organization = {mbrenndoerfer.com}, note = {Accessed: 2025-01-01} }
APAAcademic
Michael Brenndoerfer (2025). Switch Transformer: Top-1 Routing & Trillion-Parameter Scaling. Retrieved from https://mbrenndoerfer.com/writing/switch-transformer-top-1-routing-trillion-parameter-scaling
MLAAcademic
Michael Brenndoerfer. "Switch Transformer: Top-1 Routing & Trillion-Parameter Scaling." 2026. Web. today. <https://mbrenndoerfer.com/writing/switch-transformer-top-1-routing-trillion-parameter-scaling>.
CHICAGOAcademic
Michael Brenndoerfer. "Switch Transformer: Top-1 Routing & Trillion-Parameter Scaling." Accessed today. https://mbrenndoerfer.com/writing/switch-transformer-top-1-routing-trillion-parameter-scaling.
HARVARDAcademic
Michael Brenndoerfer (2025) 'Switch Transformer: Top-1 Routing & Trillion-Parameter Scaling'. Available at: https://mbrenndoerfer.com/writing/switch-transformer-top-1-routing-trillion-parameter-scaling (Accessed: today).
SimpleBasic
Michael Brenndoerfer (2025). Switch Transformer: Top-1 Routing & Trillion-Parameter Scaling. https://mbrenndoerfer.com/writing/switch-transformer-top-1-routing-trillion-parameter-scaling