Auxiliary Balancing Loss: Preventing Expert Collapse in MoE

Michael BrenndoerferNovember 17, 202535 min read

Learn how auxiliary balancing loss prevents expert collapse in MoE models. Covers loss formulations, coefficient tuning, and PyTorch implementation.

Reading Level

Choose your expertise level to adjust how many terms are explained. Beginners see more tooltips, experts see fewer to maintain reading flow. Hover over underlined terms for instant definitions.

Auxiliary Balancing Loss

In the previous chapter, we explored why load balancing matters for Mixture of Experts models: without it, a few experts monopolize computation while others sit idle. But understanding the problem doesn't solve it. The gating network, trained purely to minimize task loss, has no incentive to spread tokens evenly. We need to explicitly tell the model that balance matters.

The solution is an auxiliary loss: a secondary objective added to the main task loss that penalizes imbalanced expert usage. This chapter covers the mathematical formulation of balancing losses, how to tune the coefficient that controls their strength, the fundamental tension between balancing and task performance and how to implement these losses in practice.

Why Gating Networks Cause Imbalance

Before diving into loss formulations, let's understand why the problem exists in the first place. The imbalance we observe in MoE models isn't a bug in the gating mechanism; rather, it's an emergent property of how neural networks learn through gradient descent. To see this clearly, consider what happens during early training when all experts and the router begin with random weights.

By chance, some experts receive slightly more tokens than others. This initial asymmetry might seem insignificant, perhaps one expert receives 27% of tokens while another receives 23%. However, this small difference sets in motion a cascade of effects. The experts that receive more tokens also receive more gradient updates, since each processed token generates gradients that flow back through the expert network. With more updates, these experts improve faster. Their weights adjust more quickly to the training distribution, and they become genuinely better at processing the types of tokens they see.

The gating network, meanwhile, is doing its job diligently. It observes the reconstruction quality or task performance when tokens are routed to different experts, and it adjusts its weights to favor better-performing routes. When it notices that certain experts produce better outputs, it increases the routing probabilities toward those experts. This is exactly the behavior we want from a gating network when experts have different competencies, but during early training, this feedback creates a self-reinforcing cycle.

This creates a feedback loop: success breeds more success, failure breeds abandonment. The experts that received slightly more tokens initially become genuinely better, which makes them receive even more tokens, which makes them improve even further. Meanwhile, the experts that received fewer tokens fall behind. They update less frequently, improve more slowly, and become comparatively worse choices. The gating network, observing this growing gap, routes even fewer tokens to the struggling experts.

The gating network is doing exactly what we trained it to do: minimize task loss by routing tokens to the best experts. The problem is that "best" becomes self-fulfilling. Experts that receive no tokens never improve, so they remain poor choices, so they continue receiving no tokens. Left unchecked, this process converges to a degenerate state where one or two experts handle nearly all computation while the remaining experts contribute nothing. This represents a catastrophic waste of model capacity, as we've paid the memory cost of multiple expert networks but receive the computational benefit of only a few.

Out[3]:
Visualization
Line plot showing four expert usage fractions over 50 training steps. One expert grows to dominate while others decline.
Expert collapse dynamics without auxiliary loss. Starting from near-uniform usage, a feedback loop causes token distribution to concentrate on fewer experts. Expert 2 dominates while Expert 3 becomes nearly unused, wasting model capacity.

An auxiliary loss breaks this cycle by adding a cost for imbalance. The core insight is that we can modify what "success" means during training. Instead of optimizing purely for task performance, we add a secondary objective that penalizes concentration of expert usage. The total training objective becomes:

Ltotal=Ltask+αLbalance\mathcal{L}_{\text{total}} = \mathcal{L}_{\text{task}} + \alpha \cdot \mathcal{L}_{\text{balance}}

where:

  • Ltotal\mathcal{L}_{\text{total}}: the combined objective function used for training
  • Ltask\mathcal{L}_{\text{task}}: the primary objective (e.g., cross-entropy loss)
  • α\alpha: a scalar coefficient controlling the weight of the balancing penalty
  • Lbalance\mathcal{L}_{\text{balance}}: the auxiliary loss term that penalizes imbalanced expert usage

The balancing loss increases when expert usage becomes skewed, forcing the optimizer to consider both task performance and expert utilization. This formulation allows us to tune how much we care about balance relative to task performance through the coefficient α\alpha. A small α\alpha gently encourages balance while prioritizing task quality; a larger α\alpha enforces stricter balance at the potential cost of task performance.

Importance-Based Auxiliary Loss

The original MoE paper by Shazeer et al. (2017) introduced an importance-based auxiliary loss. The intuition behind this approach is straightforward: if we can measure how much attention the router pays to each expert, we can penalize situations where some experts receive vastly more attention than others. The key insight is that router probabilities, before the hard selection of which expert to use, provide a soft measure of expert "importance" that we can aggregate and analyze.

For each token xx in a batch, the gating network produces probabilities g(x)ig(x)_i for routing to expert ii. These probabilities sum to one across experts and represent the router's assessment of how suitable each expert is for processing that particular token. A high probability indicates the router considers that expert a good match; a low probability indicates a poor match. The importance of expert ii across the batch is the sum of these probabilities:

Importancei=xbatchg(x)i\text{Importance}_i = \sum_{x \in \text{batch}} g(x)_i

where:

  • Importancei\text{Importance}_i: the total importance score for expert ii across the batch
  • xx: a token in the current batch
  • g(x)ig(x)_i: the probability assigned to expert ii by the gating network for token xx
  • \sum: the summation over all tokens in the batch

To understand what this measures, consider a concrete example. If our batch contains 100 tokens and the router assigns an average probability of 0.4 to expert 3 across all tokens, then the importance of expert 3 would be 40. If all tokens strongly prefer expert 3, then Importance3\text{Importance}_3 will be large while other experts have small importance values. In a perfectly balanced scenario with 4 experts, each expert would have an importance of 25 (since probabilities must sum to 1 for each token, the total importance across all experts equals the batch size).

The importance measure captures something subtle but important: it reflects the router's preferences before any hard decisions are made. Even if a token ultimately gets routed to expert 1 because it has probability 0.35 compared to expert 2's 0.30, both experts contributed to the importance calculation. This provides a smoother signal than counting actual routing decisions.

To quantify the imbalance in importance scores, we need a metric that captures how spread out or concentrated the distribution is. We use the squared coefficient of variation (CV):

Limportance=CV(Importance)2=(σ(Importance)μ(Importance))2\mathcal{L}_{\text{importance}} = \text{CV}(\text{Importance})^2 = \left(\frac{\sigma(\text{Importance})}{\mu(\text{Importance})}\right)^2

where:

  • Limportance\mathcal{L}_{\text{importance}}: the importance-based auxiliary loss
  • σ(Importance)\sigma(\text{Importance}): the standard deviation of importance scores across all experts
  • μ(Importance)\mu(\text{Importance}): the mean importance score across all experts
  • CV\text{CV}: the coefficient of variation, measuring relative dispersion

The coefficient of variation measures spread relative to the mean. This relative measure is important because we want the loss to be meaningful regardless of batch size. A standard deviation of 10 means something very different when the mean is 25 versus when the mean is 1000. By dividing by the mean, we get a dimensionless ratio that measures proportional variation.

Squaring the coefficient of variation serves two purposes. First, it ensures the loss is differentiable at zero, which matters for gradient-based optimization. Second, it penalizes large deviations more strongly than small ones, creating a steeper gradient when imbalance is severe and a gentler gradient when balance is nearly achieved. When all experts have equal importance, σ=0\sigma = 0 and the loss is zero. When importance is concentrated on one expert, the loss becomes large.

This formulation has an elegant property: it's batch-level, not token-level. The model can route individual tokens wherever it wants, as long as the aggregate distribution remains balanced. This flexibility is crucial because different tokens genuinely benefit from different experts. The loss doesn't micromanage individual routing decisions; instead, it sets a constraint on the overall outcome.

Load Balancing Loss Formulation

The Switch Transformer paper introduced a refined formulation that became the standard in modern MoE architectures. This newer approach addresses some limitations of the importance-based loss by more directly measuring actual routing behavior rather than just router preferences. Instead of using coefficient of variation, it directly measures two quantities and penalizes their correlation.

For a batch of NN tokens and EE experts, we define two complementary measures of expert usage. The first captures what actually happens during routing, while the second captures what the router intended to happen.

Fraction of tokens routed to expert ii:

fi=1Nxbatch1[argmax(g(x))=i]f_i = \frac{1}{N} \textstyle \sum_{x \in \text{batch}} \mathbf{1}[\text{argmax}(g(x)) = i]

where:

  • fif_i: the fraction of tokens routed to expert ii
  • NN: the total number of tokens in the batch
  • 1[]\mathbf{1}[\cdot]: the indicator function, which is 1 if the condition is true and 0 otherwise
  • argmax(g(x))\text{argmax}(g(x)): the index of the expert with the highest probability for token xx
  • xbatch1[]\textstyle \sum_{x \in \text{batch}} \mathbf{1}[\ldots]: the total count of tokens in the batch assigned to expert ii

This quantity measures the actual routing outcome. When the router makes its final decision for each token, selecting the expert with the highest probability, some experts will be chosen more frequently than others. The fraction fif_i counts how many tokens were ultimately sent to expert ii and normalizes by the total token count. This is the fraction of tokens where expert ii was selected (had the highest gate probability). It represents the actual load on each expert.

The crucial aspect of fif_i is that it reflects discrete decisions. A token is either routed to an expert or it isn't, and this binary nature means fif_i captures the ground truth of computational load. If expert 3 processes 40% of tokens, then f3=0.4f_3 = 0.4, regardless of whether those routing decisions were made with high confidence (probability 0.9) or narrow margins (probability 0.26 versus 0.25 for the runner-up).

Fraction of router probability assigned to expert ii:

Pi=1Nxbatchg(x)iP_i = \frac{1}{N} \textstyle \sum_{x \in \text{batch}} g(x)_i

where:

  • PiP_i: the fraction of router probability assigned to expert ii
  • g(x)ig(x)_i: the routing probability assigned to expert ii for token xx
  • NN: the total number of tokens in the batch
  • xbatchg(x)i\textstyle \sum_{x \in \text{batch}} g(x)_i: the sum of routing probabilities for expert ii across all tokens

This quantity measures the router's overall preference for each expert. Unlike fif_i, which counts discrete decisions, PiP_i aggregates the continuous probability values. This is the average probability the router assigns to expert ii across all tokens. It represents the router's "intention" to use each expert, capturing not just which expert wins the argmax competition, but also how confident the router is in its choices.

The distinction between fif_i and PiP_i is subtle but important. Consider a batch where half the tokens give expert 1 probability 0.51 and expert 2 probability 0.49, while the other half gives expert 2 probability 0.51 and expert 1 probability 0.49. The fractions P1P_1 and P2P_2 would both be approximately 0.5, reflecting balanced intentions. But if by chance the argmax consistently favors expert 1, then f1f_1 would be much larger than f2f_2, revealing the imbalance in actual routing.

The load balancing loss is the scaled dot product of these vectors:

Lbalance=Ei=1EfiPi\mathcal{L}_{\text{balance}} = E \cdot \sum_{i=1}^{E} f_i \cdot P_i

where:

  • Lbalance\mathcal{L}_{\text{balance}}: the auxiliary balancing loss
  • EE: the total number of experts
  • fif_i: the fraction of tokens routed to expert ii (actual load)
  • PiP_i: the average probability assigned to expert ii (intended load)
  • i=1E\sum_{i=1}^{E}: the summation over all EE experts

The factor EE normalizes the loss so that perfect balance gives Lbalance=1\mathcal{L}_{\text{balance}} = 1. This normalization is a thoughtful design choice that makes the loss interpretable. A value of 1.0 means perfect balance, values above 1.0 indicate imbalance, and the magnitude tells us how severe the imbalance is. Under perfect balance, each expert receives fraction 1/E1/E of both tokens and probability, so fi=Pi=1/Ef_i = P_i = 1/E for all ii. We can verify the loss value:

Lbalance=Ei=1E(1E1E)(substitute balanced values)=Ei=1E1E2(simplify terms)=E(E1E2)(sum constant E times)=1\begin{aligned} \mathcal{L}_{\text{balance}} &= E \cdot \sum_{i=1}^{E} \left(\frac{1}{E} \cdot \frac{1}{E}\right) && \text{(substitute balanced values)} \\ &= E \cdot \sum_{i=1}^{E} \frac{1}{E^2} && \text{(simplify terms)} \\ &= E \cdot \left( E \cdot \frac{1}{E^2} \right) && \text{(sum constant $E$ times)} \\ &= 1 \end{aligned}

When load is imbalanced, both fif_i and PiP_i grow for popular experts and shrink for unpopular ones. The correlation between these quantities drives the loss increase. Since both terms are larger for popular experts, their product is even larger, and the sum exceeds 1.

Why This Formulation Works

This loss formulation works because of what it penalizes. The product structure creates a natural feedback mechanism that targets exactly the behavior we want to discourage. Consider what happens when expert ii is overused:

  1. More tokens are routed to it, so fif_i increases
  2. The router assigns higher probabilities to it, so PiP_i increases
  3. The product fiPif_i \cdot P_i grows quadratically

This quadratic growth is the key insight. When an expert is slightly overused, both fif_i and PiP_i are slightly elevated, and their product reflects this doubly. When an expert is heavily overused, both quantities are substantially elevated, and their product amplifies the signal even further. The loss creates a gradient that pushes the router to reduce probabilities for overused experts.

Critically, only PiP_i contributes gradients (since fif_i involves a discrete argmax, which is non-differentiable). But fif_i amplifies the gradient signal: popular experts get stronger pushback. This asymmetry is actually beneficial. The non-differentiable fif_i acts as a weighting factor, telling the optimization process where to focus its attention. Experts with high actual load receive stronger gradient signals on their probability terms, while experts with low actual load receive weaker signals. This adaptive weighting helps the optimizer prioritize fixing the most severe imbalances.

The formulation also avoids a subtle failure mode. If we only penalized PiP_i variance, the router could game the system by assigning nearly-uniform probabilities while still routing most tokens to one expert (through tiny differences that all favor the same expert). Imagine a scenario where the router learns to give expert 1 probability 0.251 and all other experts probability 0.2497. The probability distribution looks nearly uniform, so a variance penalty would be small. Yet the argmax would consistently select expert 1, creating severe imbalance. By including fif_i, we measure actual routing decisions, not just intentions.

Out[5]:
Visualization
Load balancing loss as a function of expert concentration. The loss equals 1.0 under perfect balance (concentration=0) and increases quadratically as tokens concentrate on fewer experts, providing a strong gradient signal to correct imbalance.
Load balancing loss as a function of expert concentration. The loss equals 1.0 under perfect balance (concentration=0) and increases quadratically as tokens concentrate on fewer experts, providing a strong gradient signal to correct imbalance.
Out[6]:
Visualization
Token and probability fractions under different imbalance levels. In the balanced scenario (left), distributions are uniform. In the severe imbalance scenario (right), one expert dominates both actual routing (f) and router probability (P).
Token and probability fractions under different imbalance levels. In the balanced scenario (left), distributions are uniform. In the severe imbalance scenario (right), one expert dominates both actual routing (f) and router probability (P).

Loss Coefficient Tuning

The coefficient α\alpha in Ltotal=Ltask+αLbalance\mathcal{L}_{\text{total}} = \mathcal{L}_{\text{task}} + \alpha \cdot \mathcal{L}_{\text{balance}} controls the strength of the balancing incentive. This single hyperparameter has outsized importance in MoE training, as it determines how the model allocates its optimization effort between learning useful representations and maintaining computational balance. Setting α\alpha correctly is crucial: too small and experts collapse; too large and the model sacrifices task performance to achieve perfect balance.

The key considerations for tuning α\alpha are:

  • Typical values: Most implementations use α\alpha between 0.01 and 0.1. Switch Transformer used 0.01; Mixtral uses 0.01 by default
  • Scale dependence: The appropriate α\alpha depends on how large Ltask\mathcal{L}_{\text{task}} typically is. If task loss is around 2.0 and balancing loss is around 1.0, α=0.01\alpha = 0.01 means balancing contributes about 0.5% of the total gradient signal
  • Number of experts: More experts require careful tuning, as the minimum achievable imbalance increases with expert count
  • Top-k value: With top-2 routing, each token contributes to two experts' load, which naturally improves balance compared to top-1

Effects of Extreme Coefficients

When α\alpha is too small, the model effectively ignores the balancing signal. The gradient contribution from the auxiliary loss becomes negligible compared to the task loss gradients. Expert collapse proceeds as if no auxiliary loss existed. You'll observe a few experts receiving most tokens while others remain unused throughout training. The training curves may look healthy in terms of task loss, but inspection of expert utilization reveals the underlying pathology.

When α\alpha is too large, the model prioritizes balance over task performance. The optimizer, faced with a strong gradient signal to equalize expert usage, adjusts the router weights primarily to reduce the balancing loss. In the extreme, the router learns to assign exactly uniform probabilities to all experts, achieving perfect balance but terrible task performance. The model loses its ability to specialize experts for different types of inputs. Every expert becomes equally mediocre at handling all tokens, squandering the capacity benefits that MoE architectures are designed to provide.

The sweet spot achieves "soft" balance: experts receive roughly similar token counts, but the router retains freedom to assign non-uniform probabilities based on input characteristics. In this regime, the auxiliary loss prevents pathological collapse without overly constraining the router's flexibility.

Out[8]:
Visualization
Line plot of task loss vs training steps for 4 alpha values.
Task loss evolution over training steps. Large coefficients (alpha=0.5) significantly degrade task performance, while smaller values (alpha=0.01) maintain performance comparable to the baseline.
Line plot of expert imbalance ratio vs training steps for 4 alpha values.
Expert imbalance ratio (max/min usage) over training. An alpha of 0.01 achieves stability with a ratio near 1.0 without forcing rigid uniformity, whereas alpha=0.001 fails to prevent expert collapse.

Practical Tuning Strategy

A reliable approach is to start with α=0.01\alpha = 0.01 and monitor both metrics during training:

  1. Track expert utilization (fraction of tokens per expert) across training
  2. Track the ratio Lbalance/Ltask\mathcal{L}_{\text{balance}} / \mathcal{L}_{\text{task}}
  3. If experts collapse (one expert gets >50% of tokens), increase α\alpha by 2-3x
  4. If task loss plateaus while balance loss keeps dropping, decrease α\alpha

We'll explore a related technique called Router Z-Loss in the next chapter, which provides additional stabilization with less impact on task performance.

Balancing vs Task Loss Tradeoffs

There's a fundamental tension between balancing and task performance that no formulation can eliminate. This tension isn't a flaw in the auxiliary loss approach; it reflects a genuine tradeoff inherent to MoE architectures. Understanding this tradeoff helps set realistic expectations and informs practical decisions about coefficient values.

Why perfect balance hurts performance: Different types of inputs genuinely benefit from different experts. A language model might naturally develop experts for code, dialogue, technical writing, and creative text. This specialization emerges because experts can become particularly good at certain patterns when they see them repeatedly. If code represents only 10% of training data, the code expert should receive only 10% of tokens. Forcing equal distribution means either routing non-code tokens to the code expert (hurting their performance) or routing code tokens away from the code expert (hurting code performance). In both cases, we sacrifice task quality for an arbitrary notion of fairness.

Why some imbalance is tolerable: In practice, moderate imbalance (say, 2:1 ratio between most and least popular experts) has minimal impact on inference efficiency. Modern accelerators handle this gracefully through techniques like dynamic batching and load balancing at the infrastructure level. The goal is preventing pathological collapse, not achieving perfect uniformity. A model where all eight experts receive between 10% and 15% of tokens is functioning well, even though the distribution isn't perfectly uniform.

Empirical findings: The Switch Transformer paper found that α=0.01\alpha = 0.01 achieved a good balance between utilization and task performance. At this setting, expert utilization variance was significantly reduced compared to no auxiliary loss, while perplexity degradation was minimal (less than 0.5%). Higher values of α\alpha continued improving balance but with diminishing returns and increasing task performance cost.

The Capacity Factor Interaction

As we discussed in the Load Balancing chapter, the capacity factor CC limits how many tokens each expert can process. This creates an important interaction with the auxiliary loss, as both mechanisms influence expert utilization but through different means. The auxiliary loss and capacity factor work together:

  • The auxiliary loss encourages the router to spread tokens evenly
  • The capacity factor enforces a hard cap, dropping tokens when an expert is overloaded

These mechanisms are complementary. The auxiliary loss provides a soft incentive through gradients, nudging the router toward balance without forcing specific outcomes. The capacity factor provides a hard constraint that takes effect when soft incentives are insufficient. With a well-tuned auxiliary loss, fewer tokens hit the capacity cap, reducing wasted computation. But if the auxiliary loss is too weak, the capacity factor does most of the work, causing token dropping and information loss.

The interplay between these mechanisms suggests a practical principle: the auxiliary loss should be strong enough that the capacity factor rarely needs to drop tokens. Token dropping represents a failure mode where computation is wasted and information is lost. By tuning α\alpha to achieve reasonable balance, we can keep most tokens below the capacity threshold while still allowing meaningful expert specialization.

Implementation

Let's implement the auxiliary balancing loss step by step. We'll create a module that can be integrated into any MoE training loop.

In[9]:
Code
import torch
import torch.nn.functional as F


def compute_load_balancing_loss(
    router_probs: torch.Tensor, expert_indices: torch.Tensor, num_experts: int
) -> torch.Tensor:
    """
    Compute the auxiliary load balancing loss.

    Args:
        router_probs: Gate probabilities, shape (batch_size, seq_len, num_experts)
        expert_indices: Selected expert indices, shape (batch_size, seq_len)
        num_experts: Number of expert networks

    Returns:
        Scalar auxiliary loss
    """
    # Flatten batch and sequence dimensions
    batch_size, seq_len, _ = router_probs.shape
    num_tokens = batch_size * seq_len

    router_probs_flat = router_probs.view(-1, num_experts)  # (N, E)
    expert_indices_flat = expert_indices.view(-1)  # (N,)

    # Compute fraction of tokens routed to each expert (f_i)
    # One-hot encode the expert selections and average
    expert_mask = F.one_hot(expert_indices_flat, num_experts).float()  # (N, E)
    tokens_per_expert = expert_mask.sum(dim=0)  # (E,)
    f = tokens_per_expert / num_tokens  # (E,)

    # Compute fraction of router probability per expert (P_i)
    P = router_probs_flat.mean(dim=0)  # (E,)

    # Load balancing loss: scaled dot product
    load_balancing_loss = num_experts * (f * P).sum()

    return load_balancing_loss

This function computes the exact formulation from the Switch Transformer paper. Let's verify it works correctly with a simple example.

In[10]:
Code
import torch
import torch.nn.functional as F

## Create example data: 4 tokens, 3 experts
batch_size, seq_len, num_experts = 2, 2, 3
num_tokens = batch_size * seq_len

## Simulate router probabilities (softmax output)
torch.manual_seed(42)
router_logits = torch.randn(batch_size, seq_len, num_experts)
router_probs = F.softmax(router_logits, dim=-1)

## Select top expert for each token
expert_indices = router_probs.argmax(dim=-1)

## Compute loss and breakdown for inspection
loss = compute_load_balancing_loss(router_probs, expert_indices, num_experts)

router_probs_flat = router_probs.view(-1, num_experts)
expert_indices_flat = expert_indices.view(-1)
expert_mask = F.one_hot(expert_indices_flat, num_experts).float()
f = expert_mask.sum(dim=0) / num_tokens
P = router_probs_flat.mean(dim=0)
f_P_product = f * P
Out[11]:
Console
Router probabilities:
tensor([[0.3683, 0.2992, 0.3325],
        [0.5215, 0.1348, 0.3438],
        [0.8114, 0.0471, 0.1415],
        [0.2484, 0.3246, 0.4271]])

Selected experts: [0, 0, 0, 2]

Load balancing loss: 1.3300

Fraction of tokens per expert (f): ['0.7500', '0.0000', '0.2500']
Average probability per expert (P): ['0.4874', '0.2014', '0.3112']
f * P products: ['0.3655', '0.0000', '0.0778']

The loss is above 1.0 because expert 2 receives most tokens (high f2f_2) and also has high average probability (high P2P_2). With perfect balance, we'd expect f=P=[0.333,0.333,0.333]f = P = [0.333, 0.333, 0.333] and loss =1.0= 1.0.

Handling Top-K Routing

When using top-2 or top-k routing, each token contributes to multiple experts' load. We need to adjust the computation:

In[12]:
Code
import torch


def compute_load_balancing_loss_topk(
    router_probs: torch.Tensor, expert_indices: torch.Tensor, num_experts: int
) -> torch.Tensor:
    """
    Compute auxiliary loss for top-k routing.

    Args:
        router_probs: Gate probabilities, shape (batch_size, seq_len, num_experts)
        expert_indices: Selected expert indices, shape (batch_size, seq_len, k)
        num_experts: Number of expert networks

    Returns:
        Scalar auxiliary loss
    """
    batch_size, seq_len, k = expert_indices.shape
    num_tokens = batch_size * seq_len

    router_probs_flat = router_probs.view(-1, num_experts)  # (N, E)
    expert_indices_flat = expert_indices.view(-1, k)  # (N, k)

    # Count how many times each expert is selected across all tokens
    # Each token contributes k selections total
    tokens_per_expert = torch.zeros(num_experts, device=router_probs.device)
    for i in range(k):
        expert_counts = torch.bincount(
            expert_indices_flat[:, i], minlength=num_experts
        ).float()
        tokens_per_expert += expert_counts

    # Normalize by total selections (N * k)
    f = tokens_per_expert / (num_tokens * k)

    # Average router probability per expert
    P = router_probs_flat.mean(dim=0)

    return num_experts * (f * P).sum()
In[13]:
Code
## Example with top-2 routing
k = 2
_, topk_indices = router_probs.topk(k, dim=-1)

loss_topk = compute_load_balancing_loss_topk(
    router_probs, topk_indices, num_experts
)
Out[14]:
Console
Top-2 expert indices:
tensor([[[0, 2],
         [0, 2]],

        [[0, 2],
         [2, 1]]])

Top-k load balancing loss: 1.0907

The top-k loss is typically lower than top-1 because selecting two experts per token naturally spreads load more evenly.

Integrating with Training

Here's how to incorporate the auxiliary loss into a training loop:

In[15]:
Code
import torch
import torch.nn as nn
import torch.nn.functional as F


class MoELayer(nn.Module):
    """Simplified MoE layer with auxiliary loss tracking."""

    def __init__(self, hidden_size: int, num_experts: int, expert_size: int):
        super().__init__()
        self.num_experts = num_experts

        # Router network
        self.router = nn.Linear(hidden_size, num_experts)

        # Expert networks (simplified as linear layers)
        self.experts = nn.ModuleList(
            [
                nn.Sequential(
                    nn.Linear(hidden_size, expert_size),
                    nn.GELU(),
                    nn.Linear(expert_size, hidden_size),
                )
                for _ in range(num_experts)
            ]
        )

        # Store auxiliary loss
        self.aux_loss = None

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size, seq_len, hidden_size = x.shape

        # Compute router probabilities
        router_logits = self.router(x)  # (B, S, E)
        router_probs = F.softmax(router_logits, dim=-1)

        # Select top-1 expert
        expert_indices = router_probs.argmax(dim=-1)  # (B, S)

        # Compute auxiliary loss
        self.aux_loss = compute_load_balancing_loss(
            router_probs, expert_indices, self.num_experts
        )

        # Route tokens to experts (simplified dense computation for clarity)
        output = torch.zeros_like(x)
        x_flat = x.view(-1, hidden_size)
        output_flat = output.view(-1, hidden_size)
        indices_flat = expert_indices.view(-1)

        for i in range(self.num_experts):
            mask = indices_flat == i
            if mask.any():
                output_flat[mask] = self.experts[i](x_flat[mask])

        return output.view(batch_size, seq_len, hidden_size)
In[16]:
Code
import torch
import torch.nn.functional as F


def train_step(model, x, y, optimizer, aux_loss_coef=0.01):
    """Single training step with auxiliary loss."""
    optimizer.zero_grad()

    # Forward pass
    output = model(x)

    # Task loss (e.g., cross-entropy for language modeling)
    task_loss = F.mse_loss(output, y)  # Simplified example

    # Collect auxiliary losses from all MoE layers
    total_aux_loss = 0
    for module in model.modules():
        if isinstance(module, MoELayer) and module.aux_loss is not None:
            total_aux_loss += module.aux_loss

    # Combined loss
    total_loss = task_loss + aux_loss_coef * total_aux_loss

    # Backward and update
    total_loss.backward()
    optimizer.step()

    return {
        "task_loss": task_loss.item(),
        "aux_loss": total_aux_loss.item()
        if isinstance(total_aux_loss, torch.Tensor)
        else total_aux_loss,
        "total_loss": total_loss.item(),
    }
In[17]:
Code
import torch

## Create model and run training step
hidden_size, num_experts, expert_size = 64, 4, 128
model = MoELayer(hidden_size, num_experts, expert_size)

## Random input and target
x = torch.randn(2, 8, hidden_size)
y = torch.randn(2, 8, hidden_size)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
metrics = train_step(model, x, y, optimizer, aux_loss_coef=0.01)
Out[18]:
Console
Training metrics:
  task_loss: 1.0804
  aux_loss: 1.0224
  total_loss: 1.0906

The total loss is the sum of task loss and weighted auxiliary loss. The auxiliary loss is initially high (>1.0), reflecting the random initialization of the router.

Monitoring Expert Utilization

Tracking expert utilization during training helps diagnose load balancing issues:

In[19]:
Code
import torch


def compute_expert_utilization(
    router_probs: torch.Tensor, expert_indices: torch.Tensor, num_experts: int
) -> dict:
    """Compute expert utilization statistics."""
    num_tokens = expert_indices.numel()
    expert_indices_flat = expert_indices.view(-1)

    # Count tokens per expert
    counts = torch.bincount(expert_indices_flat, minlength=num_experts).float()
    fractions = counts / num_tokens

    # Compute statistics
    mean_fraction = 1.0 / num_experts  # Expected under perfect balance
    max_fraction = fractions.max().item()
    min_fraction = fractions.min().item()

    # Imbalance ratio: max/min (1.0 = perfect balance)
    imbalance_ratio = max_fraction / max(min_fraction, 1e-8)

    return {
        "tokens_per_expert": counts.tolist(),
        "fraction_per_expert": fractions.tolist(),
        "max_fraction": max_fraction,
        "min_fraction": min_fraction,
        "imbalance_ratio": imbalance_ratio,
    }
In[20]:
Code
import torch
import torch.nn.functional as F

## Generate more data to see utilization patterns
torch.manual_seed(123)
batch_size, seq_len, num_experts = 8, 32, 4
router_logits = torch.randn(batch_size, seq_len, num_experts)
router_probs = F.softmax(router_logits, dim=-1)
expert_indices = router_probs.argmax(dim=-1)

utilization = compute_expert_utilization(
    router_probs, expert_indices, num_experts
)
Out[21]:
Console
Expert utilization statistics:
  Tokens per expert: [77.0, 65.0, 62.0, 52.0]
  Fraction per expert: ['0.301', '0.254', '0.242', '0.203']
  Max/min fractions: 0.301 / 0.203
  Imbalance ratio: 1.48x

With random initialization, some imbalance naturally occurs. The auxiliary loss should push this ratio closer to 1.0 during training.

Visualizing the Balancing Effect

Let's visualize how the auxiliary loss affects routing over training iterations:

In[22]:
Code
import torch
import torch.nn.functional as F
import numpy as np

## Simulate training with auxiliary loss
torch.manual_seed(42)
np.random.seed(42)

hidden_size, num_experts, expert_size = 32, 4, 64
model = MoELayer(hidden_size, num_experts, expert_size)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

aux_losses = []
imbalance_ratios = []

for step in range(100):
    x = torch.randn(4, 16, hidden_size)
    y = torch.randn(4, 16, hidden_size)

    optimizer.zero_grad()
    output = model(x)

    task_loss = F.mse_loss(output, y)
    total_loss = (
        task_loss + 0.1 * model.aux_loss
    )  # Higher coef for faster visualization

    total_loss.backward()
    optimizer.step()

    # Track metrics
    aux_losses.append(model.aux_loss.item())

    # Compute utilization
    with torch.no_grad():
        router_probs = F.softmax(model.router(x), dim=-1)
        expert_indices = router_probs.argmax(dim=-1)
        util = compute_expert_utilization(
            router_probs, expert_indices, num_experts
        )
        imbalance_ratios.append(util["imbalance_ratio"])
Out[23]:
Visualization
Line plot of auxiliary loss vs training step.
Auxiliary loss evolution during training. The loss decreases from its initial random state toward the target value of 1.0, indicating the router is learning to balance probabilities and routing.
Line plot of max/min expert ratio vs training step.
Expert imbalance ratio evolution. The ratio of maximum to minimum expert usage improves concurrently with the auxiliary loss, reducing from high imbalance to a more stable distribution.

The auxiliary loss pushes the model toward more balanced routing. As training progresses, the auxiliary loss approaches 1.0 (perfect balance), and the imbalance ratio decreases.

Key Parameters

The key parameters for the auxiliary balancing loss implementation are:

  • num_experts: Number of experts in the MoE layer.
  • aux_loss_coef: Coefficient α\alpha scaling the auxiliary loss (typically 0.01-0.1).
  • k: Number of experts selected per token (top-k routing).
  • expert_size: Hidden dimension size of each expert network.

Limitations and Practical Considerations

While the auxiliary balancing loss is essential for stable MoE training, it has important limitations.

Batch-level balancing only: The loss encourages balance within each batch but doesn't guarantee global balance across the entire dataset. If certain input types cluster in specific batches, expert specialization patterns may still emerge unevenly. Larger batch sizes help mitigate this issue by providing more representative samples.

Gradient signal limitations: The fraction of tokens routed (fif_i) doesn't contribute gradients because it involves a non-differentiable argmax operation. Only the router probabilities (PiP_i) are differentiable. This means the loss influences the router through probability adjustments, not through direct feedback about routing decisions. Some tokens may still be misrouted if their probability distribution is nearly uniform.

Tension with expert specialization: An overly strong auxiliary loss can prevent meaningful expert specialization. If the model is forced to route all input types equally across experts, each expert becomes a generalist rather than developing specific competencies. The optimal balance point depends on the data distribution and task requirements.

Sensitivity to expert count: The loss formulation scales with the number of experts, but optimal α\alpha values may still need adjustment as expert count changes. With 8 experts versus 64 experts, the same α\alpha may have different practical effects on routing behavior.

Interaction with capacity factor: When combined with capacity-based token dropping, the auxiliary loss and capacity factor can work at cross purposes. The loss pushes for balance while the capacity factor enforces hard limits. If α\alpha is too weak, the capacity factor does most of the balancing work by dropping tokens, which wastes computation and loses information.

These limitations motivate additional techniques like Router Z-Loss (covered in the next chapter), which provides complementary stabilization by penalizing extreme router logits.

Summary

The auxiliary balancing loss is a critical component of MoE training that prevents expert collapse and ensures efficient computation. The key concepts are:

  • Problem: Gating networks naturally create feedback loops where successful experts receive more tokens and improve further, while unused experts stagnate
  • Solution: Add an auxiliary loss term Lbalance=EifiPi\mathcal{L}_{\text{balance}} = E \cdot \sum_i f_i \cdot P_i that penalizes imbalanced routing
  • Components: The loss combines token fraction (fif_i, actual load) with probability fraction (PiP_i, intended load), achieving perfect balance loss of 1.0
  • Coefficient tuning: Typical values of α=0.01\alpha = 0.01 to 0.10.1 balance load distribution against task performance; too low causes collapse, too high forces unproductive uniformity
  • Tradeoff: Some imbalance is natural and desirable when data distributions are non-uniform; the goal is preventing pathological collapse, not perfect uniformity
  • Implementation: Track both auxiliary loss and expert utilization during training to diagnose load balancing issues

With this foundation, the next chapter explores Router Z-Loss, which addresses router training stability by penalizing extreme logit values, complementing the auxiliary balancing loss in modern MoE architectures.

Quiz

Ready to test your understanding? Take this quick quiz to reinforce what you've learned about auxiliary balancing loss in Mixture of Experts models.

Loading component...

Reference

BIBTEXAcademic
@misc{auxiliarybalancinglosspreventingexpertcollapseinmoe, author = {Michael Brenndoerfer}, title = {Auxiliary Balancing Loss: Preventing Expert Collapse in MoE}, year = {2025}, url = {https://mbrenndoerfer.com/writing/auxiliary-balancing-loss-mixture-of-experts-moe}, organization = {mbrenndoerfer.com}, note = {Accessed: 2025-01-01} }
APAAcademic
Michael Brenndoerfer (2025). Auxiliary Balancing Loss: Preventing Expert Collapse in MoE. Retrieved from https://mbrenndoerfer.com/writing/auxiliary-balancing-loss-mixture-of-experts-moe
MLAAcademic
Michael Brenndoerfer. "Auxiliary Balancing Loss: Preventing Expert Collapse in MoE." 2026. Web. today. <https://mbrenndoerfer.com/writing/auxiliary-balancing-loss-mixture-of-experts-moe>.
CHICAGOAcademic
Michael Brenndoerfer. "Auxiliary Balancing Loss: Preventing Expert Collapse in MoE." Accessed today. https://mbrenndoerfer.com/writing/auxiliary-balancing-loss-mixture-of-experts-moe.
HARVARDAcademic
Michael Brenndoerfer (2025) 'Auxiliary Balancing Loss: Preventing Expert Collapse in MoE'. Available at: https://mbrenndoerfer.com/writing/auxiliary-balancing-loss-mixture-of-experts-moe (Accessed: today).
SimpleBasic
Michael Brenndoerfer (2025). Auxiliary Balancing Loss: Preventing Expert Collapse in MoE. https://mbrenndoerfer.com/writing/auxiliary-balancing-loss-mixture-of-experts-moe