Top-K Routing: Expert Selection in Mixture of Experts Models

Michael BrenndoerferNovember 15, 202535 min read

Learn how top-K routing selects experts in MoE architectures. Understand top-1 vs top-2 trade-offs, implementation details, and weighted output combination.

Reading Level

Choose your expertise level to adjust how many terms are explained. Beginners see more tooltips, experts see fewer to maintain reading flow. Hover over underlined terms for instant definitions.

Top-K Routing

In the previous chapter, we explored gating networks that compute a probability distribution over all experts for each input token. Mixture of Experts is computationally tractable because we only need the top K experts with the highest routing scores. This selective activation allows MoE to scale to trillions of parameters while maintaining low per-token computational costs.

Top-K routing answers a fundamental question: given a token and its routing scores across all experts, which experts should actually process that token, and how should we combine their outputs? The choice of K affects model capacity, stability, and inference efficiency. This chapter explains how top-K routing works, why different values of K lead to different trade-offs, and how to implement the routing and combination process correctly.

Top-1 Routing

The simplest form of expert selection routes each token to exactly one expert, the one with the highest gating score. This approach, called top-1 routing or hard routing, maximizes computational efficiency by ensuring each token passes through only a single expert network. This approach is simple and increases model capacity without raising per-token computation costs. When each token visits only one expert, we can scale the number of experts arbitrarily without changing the computational cost for processing any individual token.

To understand how top-1 routing works, we will walk through the process step by step. First, we compute the routing scores (logits) for all experts by projecting the token representation x\mathbf{x} through the router's learnable weights WrW_r:

h=Wrx\mathbf{h} = W_r \mathbf{x}

where:

  • h\mathbf{h}: the vector of routing logits (scores) for all experts
  • WrW_r: the router weight matrix
  • x\mathbf{x}: the input token representation vector

This linear transformation serves as a compatibility function. Each row of the weight matrix WrW_r can be thought of as a "query vector" for one expert, and the dot product between this query and the token representation measures how well-suited that expert is for processing this particular token. Higher scores indicate stronger affinity between the token and the expert.

Then, top-1 routing selects the single expert with the highest score:

k=argmaxihik^* = \arg\max_i h_i

where:

  • kk^*: the index of the selected expert
  • argmaxi\arg\max_i: the operation that finds the index ii maximizing the value
  • hih_i: the routing logit (score) for the ii-th expert

The argmax operation identifies the expert with the highest score. This expert then processes the token.

The output for that token becomes simply the output of the selected expert:

y=Ek(x)\mathbf{y} = E_{k^*}(\mathbf{x})

where:

  • y\mathbf{y}: the output vector for the token
  • EkE_{k^*}: the feed-forward network function of the selected expert kk^*
  • x\mathbf{x}: the input token representation vector

This is efficient. If we have 64 experts and use top-1 routing, each token uses exactly 1/64th of the expert capacity. A model with 64 experts, each containing 1 billion parameters in their feed-forward networks, would have 64 billion expert parameters total but only activate 1 billion of them for any given token. Separating total capacity from per-token costs allows MoE architectures to scale efficiently.

However, top-1 routing introduces a significant challenge: the argmax\arg\max operation is not differentiable. How do gradients flow back through a hard selection? The standard solution uses the straight-through estimator. During the forward pass, we use the hard argmax\arg\max selection, meaning the token genuinely goes through only one expert. During the backward pass, we pretend the operation was actually the softmax probabilities allowing gradients to flow to all experts proportionally to their routing probabilities. This creates a mismatch between forward and backward computations but works surprisingly well in practice. This approach uses gradients from soft probabilities to update experts based on their selection likelihood, even though the forward pass uses a hard decision.

Straight-Through Estimator

A technique for backpropagating through discrete operations by using the discrete value in the forward pass but treating the operation as continuous typically using softmax probabilities in the backward pass.

Top-1 routing has another characteristic that can be either a feature or a bug depending on your perspective: it forces complete specialization. Each token must commit fully to one expert's computation, with no blending of perspectives. This can lead to sharper expert specialization but also means the model cannot hedge its bets when a token might benefit from multiple experts' knowledge. Consider a token that sits at the boundary between two semantic categories. With top-1 routing, the model must make a definitive choice, potentially losing information that a second expert could have contributed.

Out[2]:
Visualization
Router logits for 8 experts. E3 (2.1) and E7 (1.8) show the highest affinity scores for the current token.
Router logits for 8 experts. E3 (2.1) and E7 (1.8) show the highest affinity scores for the current token.
Top-1 routing weights. The entire probability mass is assigned to the single highest-scoring expert (E3), resulting in a sparse binary vector.
Top-1 routing weights. The entire probability mass is assigned to the single highest-scoring expert (E3), resulting in a sparse binary vector.
Top-2 routing weights. Probability mass is distributed between the two highest-scoring experts (E3 and E7) via softmax, allowing for gradient flow to both experts.
Top-2 routing weights. Probability mass is distributed between the two highest-scoring experts (E3 and E7) via softmax, allowing for gradient flow to both experts.

Top-2 Routing

The most common choice in modern MoE architectures is top-2 routing, which selects the two experts with the highest routing scores and combines their outputs. This change from top-1 to top-2 significantly affects model behavior and training. By allowing each token to benefit from two different experts, we introduce a form of ensemble learning at the token level, where multiple specialized perspectives contribute to the final representation.

With top-2 routing, given routing logits h\mathbf{h}, we select indices k1k_1 and k2k_2 corresponding to the two highest values. The process begins identically to top-1 routing: we compute the same routing logits using the same linear transformation. The difference lies in what we do with these logits. Instead of selecting just the maximum, we identify both the highest and second-highest scoring experts.

The gating weights for these experts are computed by applying softmax only over the selected experts:

g1=exp(hk1)exp(hk1)+exp(hk2)g2=exp(hk2)exp(hk1)+exp(hk2)\begin{aligned} g_1 &= \frac{\exp(h_{k_1})}{\exp(h_{k_1}) + \exp(h_{k_2})} \\ g_2 &= \frac{\exp(h_{k_2})}{\exp(h_{k_1}) + \exp(h_{k_2})} \end{aligned}

where:

  • g1,g2g_1, g_2: the normalized gating weights for the first and second selected experts
  • hk1,hk2h_{k_1}, h_{k_2}: the raw routing logits for the selected experts
  • exp()\exp(\cdot): the exponential function, which ensures all weights are positive and amplifies differences between values
  • exp(hk1)+exp(hk2)\exp(h_{k_1}) + \exp(h_{k_2}): the sum of exponentials, serving as a normalizing constant to ensure weights sum to 1

Notice that we apply softmax only over the two selected logits, not over all expert logits. This is a crucial design choice. If we had applied softmax over all experts first and then taken the top-2 probabilities, those probabilities might not sum to 1, because the mass would be distributed across all experts. By computing softmax over just the selected pair, we ensure that the two weights always sum exactly to 1, providing a proper convex combination of the two expert outputs.

Note that g1+g2=1g_1 + g_2 = 1. The final output combines both expert outputs weighted by these normalized scores:

y=g1Ek1(x)+g2Ek2(x)\mathbf{y} = g_1 \cdot E_{k_1}(\mathbf{x}) + g_2 \cdot E_{k_2}(\mathbf{x})

where:

  • y\mathbf{y}: the combined output vector for the token
  • g1,g2g_1, g_2: the normalized scalar weights for the two experts
  • Ek1,Ek2E_{k_1}, E_{k_2}: the functions computed by the selected experts
  • x\mathbf{x}: the input token representation

This weighted combination means that the final representation is a blend of two expert perspectives. If expert k1k_1 had a much higher routing score than expert k2k_2, then g1g_1 will be close to 1 and g2g_2 will be close to 0, making the output dominated by the first expert. Conversely, if both experts had similar scores, the weights will be closer to 0.5 each, giving both experts roughly equal influence. This adaptive weighting allows the model to smoothly interpolate between relying on a single dominant expert and equally blending two perspectives.

Out[3]:
Visualization
Softmax weights for the top two experts as a function of their logit difference. As the gap between the top two logits increases, the weight of the top expert (blue) approaches 1.0, while the second expert's weight decreases toward zero.
Softmax weights for the top two experts as a function of their logit difference. As the gap between the top two logits increases, the weight of the top expert (blue) approaches 1.0, while the second expert's weight decreases toward zero.
Out[4]:
Visualization
Weight distribution examples for three specific logit gaps. A zero gap results in equal weighting between experts, while larger gaps lead to a winner-take-all behavior where one expert receives nearly all the probability mass.
Weight distribution examples for three specific logit gaps. A zero gap results in equal weighting between experts, while larger gaps lead to a winner-take-all behavior where one expert receives nearly all the probability mass.

Why does top-2 work better than top-1 in many settings? Several factors contribute:

The first is gradient flow. With two active experts per token, gradients reach twice as many expert parameters during each training step. This improves training efficiency and helps experts learn faster, particularly in the early stages of training when routing decisions are noisy. During early training, the router has not yet learned which experts are best for which tokens, so routing decisions are essentially random. With top-2 routing, more experts receive gradient signal during this critical period, helping the entire expert ensemble learn meaningful specializations more quickly.

The second is representation flexibility. Tokens often don't fit neatly into single categories. A token representing a technical term in a legal document might benefit from both a "technical/scientific" expert and a "legal language" expert. Top-2 routing allows this blending. Language is ambiguous, so forcing tokens into one category can be restrictive. The ability to combine two expert perspectives provides a richer representational palette.

The third is training stability. When only one expert processes each token, small changes in routing can cause dramatic shifts in which expert learns from which data. With two experts active, there's inherent smoothing that stabilizes training, as we'll see when we discuss load balancing in the next chapter. If the router makes a slightly different decision, the token might still visit at least one of the same experts, providing continuity in which experts receive which training signal.

The computational cost doubles compared to top-1, since each token now passes through two expert networks instead of one. For a model with 64 experts where each expert's feed-forward network has 1 billion parameters, top-2 routing activates 2 billion parameters per token instead of 1 billion. This is still a dramatic reduction from the 64 billion total expert parameters. Doubling the computation is usually worth the gains in stability and model quality.

Selecting K: Trade-offs

Choosing KK involves several trade-offs. Let's examine what happens as KK increases from 11 toward the total number of experts NN. Understanding these trade-offs helps you choose an architecture based on performance and capacity needs.

Computational cost scales linearly with KK. If each expert's feed-forward network requires FF FLOPs per token, and you select KK experts, the expert computation costs K×FK \times F FLOPs per token. This is the most direct trade-off: higher K means more computation. For a fixed computational budget, choosing a larger K means you can afford fewer total experts, which reduces the model's total capacity advantage over dense architectures.

Model capacity utilization increases with KK. With K=1K=1, each token sees 1/N1/N of your expert capacity. With K=2K=2, it sees 2/N2/N. At K=NK=N (all experts), you've effectively built a very expensive dense model with NN times the feed-forward computation. The power of MoE comes from keeping KNK \ll N. The ratio K/NK/N represents what fraction of the expert capacity any given token can access. Keeping this ratio small is what enables the favorable scaling properties of MoE architectures.

Training stability generally improves with moderate KK. The original Shazeer et al. MoE work used K=4K=4. The GShard paper found K=2K=2 worked well. Switch Transformer pushed to K=1K=1 but required careful auxiliary losses to maintain stability. Lower KK values are more prone to training instability because routing decisions have larger effects. When K is small, the decision of which expert processes a token becomes more consequential. A mistake in routing has larger downstream effects, and the router receives noisier gradient signals because fewer experts are active for each token.

Specialization sharpness decreases with higher KK. When KK is small, experts must specialize to win the routing competition. When KK is large, experts can be more generalist since tokens see many of them anyway. This affects what kind of knowledge structure emerges in the expert networks. With aggressive specialization (low K), experts may develop distinct "personalities" focused on narrow domains. With milder specialization (higher K), experts may develop more overlapping capabilities.

Out[5]:
Visualization
Computational cost and capacity utilization versus K. Both metrics scale linearly with K, as selecting more experts increases both inference FLOPs and parameter access.
Computational cost and capacity utilization versus K. Both metrics scale linearly with K, as selecting more experts increases both inference FLOPs and parameter access.
Training stability and specialization sharpness versus K. Moderate values like K=2 provide a balance where training is stable (green) while experts still maintain distinct specializations (purple).
Training stability and specialization sharpness versus K. Moderate values like K=2 provide a balance where training is stable (green) while experts still maintain distinct specializations (purple).

Here's a summary of common choices:

Common K values and their trade-offs in MoE architectures.
K ValueUse CaseTrade-off
K = 1Maximum efficiency (Switch Transformer)Requires careful balancing; training can be unstable
K = 2Standard choice (GShard, Mixtral)Good balance of efficiency and stability
K = 4Early MoE workMore computation but smoother training
K ≥ 8Rare in practiceDiminishing returns; approaches dense computation

The Mixtral model uses K=2K=2 with N=8N=8 experts, meaning each token activates 2 out of 8 experts, or 25% of expert capacity. This keeps inference costs comparable to a model with roughly 2×2\times the feed-forward parameters of a dense model, while having 8×8\times the total expert parameters available. This configuration works well, increasing capacity while keeping computation manageable.

Routing Implementation

Implementing top-K routing correctly requires handling several details: selecting the top K indices, computing normalized weights, dealing with numerical stability, and enabling gradient flow. Let's build this step by step. A robust implementation must be efficient and handle training edge cases.

The core routing operation takes router logits and produces both the selected expert indices and their corresponding weights:

In[6]:
Code
import torch
import torch.nn.functional as F


def top_k_routing(router_logits, k=2):
    """
    Select top-k experts and compute their normalized weights.

    Args:
        router_logits: Tensor of shape (batch_size, seq_len, num_experts)
        k: Number of experts to select per token

    Returns:
        expert_weights: Normalized weights for selected experts (batch, seq, k)
        expert_indices: Indices of selected experts (batch, seq, k)
    """
    # Get top-k logits and their indices
    top_k_logits, expert_indices = torch.topk(router_logits, k, dim=-1)

    # Normalize weights using softmax over only the selected experts
    expert_weights = F.softmax(top_k_logits, dim=-1)

    return expert_weights, expert_indices

The torch.topk function efficiently finds the K largest values along the expert dimension. We then apply softmax only to these selected logits, ensuring the weights sum to 1 across the K active experts. This two-step process, first selecting then normalizing, is essential for obtaining proper convex combination weights.

Let's verify this works as expected:

In[7]:
Code
## Simulate routing for a small batch
batch_size, seq_len, num_experts = 2, 4, 8
router_logits = torch.randn(batch_size, seq_len, num_experts)

## Apply top-2 routing
weights, indices = top_k_routing(router_logits, k=2)
Out[8]:
Console
Router logits shape: torch.Size([2, 4, 8])
Expert weights shape: torch.Size([2, 4, 2])
Expert indices shape: torch.Size([2, 4, 2])

For first token in batch:
  All router logits: [1.9759124517440796, -0.20113429427146912, 0.8982502222061157, 0.38522207736968994, -2.1745169162750244, -0.168917715549469, -0.31404799222946167, -0.6442866921424866]
  Selected expert indices: [0, 2]
  Normalized weights: [0.7460513710975647, 0.2539486885070801]
  Weights sum to: 1.0000

The output shows that for each token, we get exactly K = 2 expert indices and their corresponding normalized weights that sum to 1.

For top-1 routing specifically, we often want to preserve the routing probability for use in auxiliary losses while still making a hard selection. The challenge is that we need the discrete selection for the forward pass but the full probability distribution to compute load balancing penalties:

In[9]:
Code
def top_1_routing_with_probs(router_logits):
    """
    Top-1 routing that preserves full probability distribution.

    Returns:
        expert_index: Index of selected expert (batch, seq)
        routing_probs: Full softmax probabilities (batch, seq, num_experts)
    """
    # Full probability distribution for auxiliary losses
    routing_probs = F.softmax(router_logits, dim=-1)

    # Select the argmax expert
    expert_index = torch.argmax(router_logits, dim=-1)

    return expert_index, routing_probs

We keep the full routing probability distribution because we'll need it for load balancing losses, which we cover in the next chapter. This design pattern, where the routing function returns auxiliary information beyond just the selected experts, is common in practical MoE implementations.

Combining Expert Outputs

Once we've selected experts and computed their weights, we need to actually route tokens through the selected experts and combine their outputs. This is where implementation complexity increases, because we're dealing with different experts for different tokens. Unlike a standard feed-forward layer where all tokens pass through the same parameters, MoE requires dynamically dispatching different tokens to different experts and then gathering and combining the results.

The simplest approach processes tokens individually:

In[10]:
Code
def combine_expert_outputs_simple(
    hidden_states, expert_weights, expert_indices, experts
):
    """
    Simple (but slow) expert combination by iterating over tokens.

    Args:
        hidden_states: Input tensor (batch, seq, hidden_dim)
        expert_weights: Weights for selected experts (batch, seq, k)
        expert_indices: Indices of selected experts (batch, seq, k)
        experts: List of expert modules

    Returns:
        combined_output: Weighted combination of expert outputs (batch, seq, hidden_dim)
    """
    batch_size, seq_len, hidden_dim = hidden_states.shape
    k = expert_weights.shape[-1]

    output = torch.zeros_like(hidden_states)

    for b in range(batch_size):
        for s in range(seq_len):
            for i in range(k):
                expert_idx = expert_indices[b, s, i].item()
                weight = expert_weights[b, s, i]
                expert_output = experts[expert_idx](
                    hidden_states[b, s].unsqueeze(0)
                )
                output[b, s] += weight * expert_output.squeeze(0)

    return output

This explicit loop makes the logic clear: for each token, iterate through its K selected experts, compute each expert's output, and add them together weighted by the routing weights. However, this approach is extremely slow because it processes tokens sequentially and can't leverage GPU parallelism. Modern GPUs are designed to process thousands of operations in parallel, and this nested loop structure forces sequential execution that drastically underutilizes the hardware.

The efficient approach groups tokens by their assigned experts:

In[11]:
Code
def combine_expert_outputs_efficient(
    hidden_states, expert_weights, expert_indices, experts
):
    """
    Efficient expert combination using batched operations.

    Groups tokens by expert and processes each expert's tokens in parallel.
    """
    batch_size, seq_len, hidden_dim = hidden_states.shape
    num_experts = len(experts)
    k = expert_weights.shape[-1]

    # Flatten batch and sequence dimensions
    flat_hidden = hidden_states.view(-1, hidden_dim)  # (batch*seq, hidden)
    flat_weights = expert_weights.view(-1, k)  # (batch*seq, k)
    flat_indices = expert_indices.view(-1, k)  # (batch*seq, k)

    # Initialize output
    output = torch.zeros_like(flat_hidden)

    # Process each expert position (first selected, second selected, etc.)
    for pos in range(k):
        pos_indices = flat_indices[:, pos]  # Which expert for this position
        pos_weights = flat_weights[:, pos]  # Weight for this position

        # Process each expert's assigned tokens
        for expert_idx in range(num_experts):
            # Find tokens assigned to this expert at this position
            mask = pos_indices == expert_idx
            if mask.any():
                expert_input = flat_hidden[mask]
                expert_output = experts[expert_idx](expert_input)
                output[mask] += pos_weights[mask].unsqueeze(-1) * expert_output

    return output.view(batch_size, seq_len, hidden_dim)

This version groups all tokens assigned to each expert and processes them in a single batched forward pass through that expert. The key insight is that even though different tokens go to different experts, all tokens assigned to the same expert can be processed together. The outer loop over KK expert positions and the inner loop over experts means we make exactly K×NK \times N forward passes through expert networks, but each pass processes all relevant tokens in parallel. This batching strategy transforms the problem from per-token sequential processing to per-expert parallel processing.

Let's create a complete working example:

In[12]:
Code
import torch.nn as nn
import torch.nn.functional as F


class Expert(nn.Module):
    """Simple feed-forward expert network."""

    def __init__(self, hidden_dim, intermediate_dim):
        super().__init__()
        self.w1 = nn.Linear(hidden_dim, intermediate_dim)
        self.w2 = nn.Linear(intermediate_dim, hidden_dim)

    def forward(self, x):
        return self.w2(F.gelu(self.w1(x)))


class Router(nn.Module):
    """Gating network that produces routing logits."""

    def __init__(self, hidden_dim, num_experts):
        super().__init__()
        self.gate = nn.Linear(hidden_dim, num_experts, bias=False)

    def forward(self, x):
        return self.gate(x)
In[13]:
Code
import torch.nn as nn


class TopKMoELayer(nn.Module):
    """Complete top-K MoE layer combining routing and expert computation."""

    def __init__(self, hidden_dim, intermediate_dim, num_experts, k=2):
        super().__init__()
        self.num_experts = num_experts
        self.k = k

        self.router = Router(hidden_dim, num_experts)
        self.experts = nn.ModuleList(
            [Expert(hidden_dim, intermediate_dim) for _ in range(num_experts)]
        )

    def forward(self, hidden_states):
        batch_size, seq_len, hidden_dim = hidden_states.shape

        # Get routing decisions
        router_logits = self.router(hidden_states)
        expert_weights, expert_indices = top_k_routing(router_logits, self.k)

        # Combine expert outputs
        output = combine_expert_outputs_efficient(
            hidden_states, expert_weights, expert_indices, self.experts
        )

        return output, router_logits  # Return logits for auxiliary losses

Let's test the complete layer:

In[14]:
Code
## Create MoE layer
hidden_dim = 256
intermediate_dim = 512
num_experts = 8
k = 2

moe_layer = TopKMoELayer(hidden_dim, intermediate_dim, num_experts, k)

## Create sample input
batch_size, seq_len = 4, 16
x = torch.randn(batch_size, seq_len, hidden_dim)

## Forward pass
output, router_logits = moe_layer(x)
Out[15]:
Console
Input shape: torch.Size([4, 16, 256])
Output shape: torch.Size([4, 16, 256])
Router logits shape: torch.Size([4, 16, 8])

Total expert parameters: 2,103,296
Parameters used per token: 525,824

The output confirms that our MoE layer maintains the expected dimensions while activating only K experts per token. The total expert parameters scale with the number of experts, but each token only uses K experts worth of computation. Separating total parameters from per-token computation is the main advantage of MoE architectures.

Worked Example

Let's trace through a concrete example of top-2 routing to solidify understanding. Consider a small MoE layer with 4 experts processing a single token. This example shows how routing scores translate into expert selection and weight assignment.

The token has hidden representation x\mathbf{x} with dimension 4. The router weight matrix is WrR4×4W_r \in \mathbb{R}^{4 \times 4}:

In[16]:
Code
## Token representation
x = torch.tensor([[0.5, -0.3, 0.8, 0.1]])

## Router weights (4 experts, hidden dim 4)
W_r = torch.tensor(
    [
        [0.2, -0.1, 0.4, 0.3],  # Expert 0 weights
        [-0.3, 0.5, 0.1, -0.2],  # Expert 1 weights
        [0.1, 0.2, -0.3, 0.6],  # Expert 2 weights
        [0.4, -0.4, 0.2, 0.1],  # Expert 3 weights
    ]
)

## Compute router logits: h = x @ W_r.T
router_logits = x @ W_r.T
Out[17]:
Console
Token representation x: [0.5, -0.30000001192092896, 0.800000011920929, 0.10000000149011612]

Router logits h = x @ W_r.T:
  Expert 0: 0.4800
  Expert 1: -0.2400
  Expert 2: -0.1900
  Expert 3: 0.4900

Now we select the top-2 experts:

In[18]:
Code
## Select top-2
top_k_vals, top_k_idx = torch.topk(router_logits, k=2, dim=-1)
Out[19]:
Console
Top-2 expert indices: [3, 0]
Top-2 logit values: [0.49000000953674316, 0.48000001907348633]

We see that experts 3 (0.490.49) and 0 (0.480.48) have the highest logits. These two experts "won" the routing competition for this token. Notice that the margin between the winners and the losers (experts 1 and 2) is relatively small in this case, which means slightly different input representations could have resulted in different routing decisions. Now we compute normalized weights using softmax over only these two selected logits:

In[20]:
Code
## Normalize weights over selected experts only
normalized_weights = F.softmax(top_k_vals, dim=-1)
Out[21]:
Console
Normalized weights: [0.5024999976158142, 0.4975000321865082]
Sum of weights: 1.0000

Interpretation:
  Expert 3 gets 50.2% weight
  Expert 0 gets 49.8% weight
  Experts 1 and 2 are not used (0% weight)
Out[22]:
Visualization
Router logits for the example token. Expert 3 (0.49) and Expert 0 (0.48) have the highest scores.
Router logits for the example token. Expert 3 (0.49) and Expert 0 (0.48) have the highest scores.
Top-2 expert selection. Experts 0 and 3 are selected (green) to process the token, while others are bypassed.
Top-2 expert selection. Experts 0 and 3 are selected (green) to process the token, while others are bypassed.
Final normalized weights. Softmax applied to the selected logits yields weights of roughly 0.50 and 0.50 due to the small logit difference.
Final normalized weights. Softmax applied to the selected logits yields weights of roughly 0.50 and 0.50 due to the small logit difference.

The routing has decided that expert 3 should contribute roughly 50.2% and expert 0 should contribute roughly 49.8% to this token's output. Experts 1 and 2 are completely bypassed; they don't compute anything for this token. Because the two selected logits were very close in value, the resulting weights are also close to equal. If one expert had dominated with a much higher logit, its weight would be closer to 1.0.

If expert 3 produces output e3\mathbf{e}_3 and expert 0 produces output e0\mathbf{e}_0, the final output would be:

y0.502e3+0.498e0\mathbf{y} \approx 0.502 \cdot \mathbf{e}_3 + 0.498 \cdot \mathbf{e}_0

where:

  • y\mathbf{y}: the final combined output vector
  • e3,e0\mathbf{e}_3, \mathbf{e}_0: the output vectors computed by expert 3 and expert 0

This nearly equal weighting means both experts contribute almost equally to the final representation. In contrast, if the logits had been more differentiated (say, 0.9 for expert 3 and 0.2 for expert 0), the softmax would produce weights closer to 0.67 and 0.33, giving expert 3 twice as much influence.

Key Parameters

The key parameters for the Top-K Routing implementation are:

  • num_experts: Total number of experts in the MoE layer. This parameter determines the potential capacity of the MoE layer. More experts means more total parameters and potentially more specialized knowledge, but also more memory requirements and complexity in managing load balance
  • k: Number of experts to select for each token (typically 1 or 2). This controls the trade-off between computational cost and model expressiveness. Lower values of K yield better efficiency, while higher values provide smoother training and richer token representations.
  • hidden_dim: Dimensionality of the input and output representations. This must match the hidden dimension of the surrounding transformer architecture. The router uses this dimension to compute compatibility scores between tokens and experts.
  • intermediate_dim: Dimensionality of the expert feed-forward networks. Larger intermediate dimensions increase the capacity of each individual expert but also increase the per-expert computational cost.

Visualizing Routing Patterns

To understand how tokens get distributed across experts, let's visualize the routing decisions for a sequence:

Out[23]:
Visualization
Heatmap showing sparse routing patterns where each token activates exactly 2 of 8 experts.
Heatmap of top-2 routing weights for a 16-token sequence across 8 experts. The sparse activation pattern reveals that each token (row) activates exactly two experts (colored cells), while unselected experts remain white.

The heatmap reveals the sparsity pattern of top-K routing. Each token (row) activates exactly 2 experts (colored cells), with the color intensity indicating the routing weight. White cells represent experts that receive zero weight for that token. This sparse activation pattern is what makes MoE computationally efficient: instead of all 8 experts processing every token, only 2 do. Looking at the pattern across tokens, we can also observe how different tokens prefer different expert combinations, reflecting the router's learned notion of which experts are appropriate for which input representations.

Let's also examine how tokens are distributed across experts:

Out[24]:
Visualization
Bar chart showing uneven token distribution across 8 experts.
Token assignment counts across 8 experts for a sample sequence. Deviations from the uniform distribution line (red dashed) indicate load imbalance, where some experts are utilized more than others.

This distribution shows a common challenge with top-K routing: without explicit encouragement, tokens may not distribute evenly across experts. Some experts receive many tokens while others receive few. This load imbalance is a critical issue that we'll address in the next chapter on load balancing. The red dashed line shows what perfectly uniform distribution would look like; deviations from this line represent inefficiency in how we're using our expert capacity.

Out[25]:
Visualization
Histogram of normalized routing weights. The distribution peaks near 0.5 and 1.0, indicating that the router alternates between splitting attention equally and assigning a dominant expert.
Histogram of normalized routing weights. The distribution peaks near 0.5 and 1.0, indicating that the router alternates between splitting attention equally and assigning a dominant expert.

Limitations and Impact

Top-K routing enables the core promise of MoE architectures: scaling model capacity without proportionally scaling computation. However, this selective activation introduces several challenges that you must address.

The load imbalance problem. As our visualization showed, naive top-K routing can lead to severe load imbalance where some experts are overwhelmed with tokens while others sit idle. In the extreme case, a phenomenon called "expert collapse" can occur where the router learns to send all tokens to just one or two experts, effectively wasting the capacity of the other experts. This is particularly problematic during distributed training, where each expert typically resides on a different accelerator. If expert 0 receives 50% of tokens while expert 7 receives 5%, you've created a massive bottleneck. The next chapter introduces auxiliary losses that encourage balanced routing.

Discrete selection and gradient approximation. The top-K selection is fundamentally discrete, meaning we use approximations like the straight-through estimator to enable gradient flow. While these approximations work well in practice, they create a mismatch between the forward pass (hard selection) and backward pass (soft gradient flow). This can occasionally cause training instabilities, particularly with K=1K=1 routing where the approximation is most severe.

Token dropping under capacity constraints. In distributed settings, each expert can only process a limited number of tokens per batch due to memory constraints. If the router sends more tokens to an expert than it can handle, some tokens must be dropped. Their representations pass through unchanged or via a simple residual connection. Dropped tokens miss the expert processing entirely, which can degrade model quality. This creates tension between model quality (wanting to route tokens optimally) and computational constraints (needing balanced, bounded expert loads).

Inference complexity. During inference, top-K routing requires dynamic batching where tokens are grouped by their selected experts. This is more complex than standard dense model inference and can be harder to optimize. When K>1K > 1, you also need to gather and combine outputs from multiple experts per token, adding overhead.

Despite these challenges, top-K routing has proven remarkably effective. The Switch Transformer demonstrated that K=1K=1 can work with proper auxiliary losses. Mixtral showed that K=2K=2 with only 8 experts achieves excellent performance. The key insight is that the routing mechanism itself is not the whole story. It must be combined with load balancing techniques to realize the full potential of sparse expert architectures. We'll explore these techniques in detail in the upcoming chapters on load balancing and auxiliary losses.

Summary

Top-K routing makes Mixture of Experts computationally practical. Rather than using all experts for every token, we select only the K experts with the highest routing scores and ignore the rest.

The choice of KK involves fundamental trade-offs. Top-1 routing maximizes efficiency (only one expert per token) but can be unstable during training and requires careful load balancing. Top-2 routing, used by models like Mixtral, doubles the computation but provides smoother training dynamics and allows tokens to benefit from multiple experts' perspectives. Values of KK beyond 2 show diminishing returns and approach dense model computation.

The implementation involves two key steps: selecting the top-K experts using argmax operations on routing logits, and combining expert outputs using normalized weights computed by applying softmax only over the selected logits. Efficient implementations group tokens by their assigned experts to enable batched processing through each expert network.

The main limitation of top-K routing is load imbalance. Without additional mechanisms, some experts may receive far more tokens than others, wasting capacity and creating computational bottlenecks. The next chapter addresses this directly with load balancing techniques and auxiliary losses that encourage more uniform token distribution across experts.

Quiz

Ready to test your understanding? Take this quick quiz to reinforce what you've learned about top-K routing in Mixture of Experts architectures.

Loading component...

Reference

BIBTEXAcademic
@misc{topkroutingexpertselectioninmixtureofexpertsmodels, author = {Michael Brenndoerfer}, title = {Top-K Routing: Expert Selection in Mixture of Experts Models}, year = {2025}, url = {https://mbrenndoerfer.com/writing/top-k-routing-mixture-of-experts-expert-selection}, organization = {mbrenndoerfer.com}, note = {Accessed: 2025-01-01} }
APAAcademic
Michael Brenndoerfer (2025). Top-K Routing: Expert Selection in Mixture of Experts Models. Retrieved from https://mbrenndoerfer.com/writing/top-k-routing-mixture-of-experts-expert-selection
MLAAcademic
Michael Brenndoerfer. "Top-K Routing: Expert Selection in Mixture of Experts Models." 2026. Web. today. <https://mbrenndoerfer.com/writing/top-k-routing-mixture-of-experts-expert-selection>.
CHICAGOAcademic
Michael Brenndoerfer. "Top-K Routing: Expert Selection in Mixture of Experts Models." Accessed today. https://mbrenndoerfer.com/writing/top-k-routing-mixture-of-experts-expert-selection.
HARVARDAcademic
Michael Brenndoerfer (2025) 'Top-K Routing: Expert Selection in Mixture of Experts Models'. Available at: https://mbrenndoerfer.com/writing/top-k-routing-mixture-of-experts-expert-selection (Accessed: today).
SimpleBasic
Michael Brenndoerfer (2025). Top-K Routing: Expert Selection in Mixture of Experts Models. https://mbrenndoerfer.com/writing/top-k-routing-mixture-of-experts-expert-selection