Sparse Models: Conditional Computation & Efficiency

Michael BrenndoerferUpdated December 31, 202544 min read

Discover how sparse models decouple capacity from compute using conditional computation and mixture of experts to achieve efficient scaling.

Reading Level

Choose your expertise level to adjust how many terms are explained. Beginners see more tooltips, experts see fewer to maintain reading flow. Hover over underlined terms for instant definitions.

Sparse Models

The scaling laws we explored in Part XXI revealed a fundamental insight: model performance improves predictably with more compute, more data, and more parameters. This created an arms race toward ever-larger models, with GPT-3's 175 billion parameters quickly becoming a stepping stone rather than a ceiling. But this scaling approach faces a harsh economic reality: every parameter must be processed for every token, making the cost of inference grow linearly with model size.

What if we could have the capacity of a massive model while only paying the computational cost of a much smaller one? This is the promise of sparse models. They break the assumption that bigger models must be proportionally more expensive to run. Instead of activating all parameters for every input, sparse models selectively engage different subsets of their parameters based on what the input requires.

This chapter introduces the foundational concepts behind sparse computation in neural networks. We'll examine why dense models hit efficiency walls, how conditional computation offers an escape, and what challenges arise when we abandon the all-parameters-all-the-time paradigm.

Dense Models and Their Limitations

Every transformer architecture we've studied so far follows the same pattern. When a token enters the model, it flows through every layer, every attention head, and every feed-forward network. The model's full parameter count participates in every computation. This uniformity of computation, where the same operations apply regardless of the input's content, has been both a strength and a limitation of the transformer architecture.

Dense Model

A model where all parameters are activated and participate in the forward pass for every input. The computational cost scales directly with the total parameter count.

The "dense" property means every parameter contributes to every prediction, creating a direct relationship between model size and computational cost. If we double the parameters, we double the compute required per token. This tight coupling between capacity and computation may seem inevitable, but as we'll see, it's actually a design choice rather than a fundamental constraint of neural networks.

Consider the feed-forward network we examined in Part XII. For a model with hidden dimension dmodeld_{\text{model}} and FFN intermediate dimension dffd_{\text{ff}}, each token passes through two linear transformations. The feed-forward network is where sparse models make their intervention.

The feed-forward network transforms each token through an expansion to a larger intermediate dimension, applies a non-linearity, then projects back to the original dimension. This two-stage transformation allows the network to learn complex non-linear mappings. The network first projects into a higher-dimensional space where patterns are easier to separate, then projects back to the model dimension. This expansion and contraction pattern is remarkably effective at capturing complex relationships in the data. Given an input token xx, the FFN computes:

FFN(x)=W2σ(W1x+b1)+b2\text{FFN}(x) = W_2 \cdot \sigma(W_1 x + b_1) + b_2

where:

  • xx: input token representation with dimension dmodeld_{\text{model}}
  • W1Rdff×dmodelW_1 \in \mathbb{R}^{d_{\text{ff}} \times d_{\text{model}}}: first weight matrix that projects to the intermediate dimension
  • b1Rdffb_1 \in \mathbb{R}^{d_{\text{ff}}}: bias vector for the first transformation
  • σ\sigma: activation function (typically ReLU or GELU) applied element-wise
  • W2Rdmodel×dffW_2 \in \mathbb{R}^{d_{\text{model}} \times d_{\text{ff}}}: second weight matrix that projects back to dmodeld_{\text{model}}
  • b2Rdmodelb_2 \in \mathbb{R}^{d_{\text{model}}}: bias vector for the output transformation

The activation function has two key properties that make it ideal for the transformation. It ensures non-linearity in the transformation and allows the network to learn complex patterns by projecting into a space where features can be more easily separated before projecting back to the model dimension. Without this non-linearity, the composition of two linear transformations would simply collapse into a single linear transformation, severely limiting the network's expressive power.

We can derive the total parameter count for these two weight matrices as follows. This derivation helps us understand exactly where the computational burden lies in a transformer layer:

Total FFN parameters=W1+W2(sum of both weight matrices)=(dff×dmodel)+(dmodel×dff)(dimensions of each matrix)=2×dff×dmodel(combine terms)\begin{aligned} \text{Total FFN parameters} &= |W_1| + |W_2| && \text{(sum of both weight matrices)} \\ &= (d_{\text{ff}} \times d_{\text{model}}) + (d_{\text{model}} \times d_{\text{ff}}) && \text{(dimensions of each matrix)} \\ &= 2 \times d_{\text{ff}} \times d_{\text{model}} && \text{(combine terms)} \end{aligned}

In a typical transformer where dff=4×dmodeld_{\text{ff}} = 4 \times d_{\text{model}}, substituting into our formula gives 2×dff×dmodel=2×(4×dmodel)×dmodel=8×dmodel22 \times d_{\text{ff}} \times d_{\text{model}} = 2 \times (4 \times d_{\text{model}}) \times d_{\text{model}} = 8 \times d_{\text{model}}^2 parameters per FFN layer. To put this in concrete terms, for GPT-3 with dmodel=12288d_{\text{model}} = 12288, that's over 1.2 billion parameters per layer, and every single parameter gets used for every single token. This enormous parameter count per layer, multiplied across all layers, explains why large language models require such substantial hardware resources.

This creates a coupling between three quantities that we might prefer to scale independently:

  • Model capacity: The ability to store and process complex patterns, which is tied to parameter count
  • Training compute: The FLOPs required to update parameters during training
  • Inference compute: The FLOPs required to process each token at runtime

In dense models, these three quantities move together in lockstep. Want more capacity? Add parameters. More parameters means more compute for both training and inference. The Chinchilla scaling laws help us balance compute and data efficiently, but they don't break this fundamental coupling. This coupling represents a significant constraint: we cannot increase the model's knowledge capacity without simultaneously increasing the cost of using that knowledge.

The practical impact is severe. A 175 billion parameter dense model requires roughly 350 billion FLOPs per token (using the approximation that FLOPs = 2 times parameters per forward pass). At scale, this determines your hardware requirements, inference latency, and operating costs. For organizations deploying these models to millions of users, even small inefficiencies translate into substantial infrastructure expenses.

In[2]:
Code
def estimate_dense_model_cost(
    params_billions: float, tokens_per_second: float = 1000
) -> dict:
    """Estimate computational costs for a dense model."""
    params = params_billions * 1e9

    # FLOPs per token: forward pass only, approximately 2x parameters
    flops_per_token = 2 * params

    # Total FLOPs per second
    flops_per_second = flops_per_token * tokens_per_second

    # A100 GPU theoretical peak: 312 TFLOPS (FP16)
    a100_tflops = 312e12

    # Realistic utilization typically ranges from 30-50%
    utilization = 0.4
    effective_tflops = a100_tflops * utilization

    # GPUs needed for real-time inference
    gpus_needed = flops_per_second / effective_tflops

    return {
        "params_billions": params_billions,
        "flops_per_token": flops_per_token,
        "tflops_per_second": flops_per_second / 1e12,
        "a100_gpus_needed": gpus_needed,
    }


model_sizes = [7, 70, 175, 540]
Out[3]:
Console
7B parameters: 1.40e+10 FLOPs/token, 0.1 A100 GPUs needed
70B parameters: 1.40e+11 FLOPs/token, 1.1 A100 GPUs needed
175B parameters: 3.50e+11 FLOPs/token, 2.8 A100 GPUs needed
540B parameters: 1.08e+12 FLOPs/token, 8.7 A100 GPUs needed
Out[4]:
Visualization
FLOPs per token scaling with model size, reaching approximately 1.1 TFLOPs for 540B parameter models, demonstrating the tight coupling between capacity and compute.
FLOPs per token scaling with model size, reaching approximately 1.1 TFLOPs for 540B parameter models, demonstrating the tight coupling between capacity and compute.
Hardware requirements scaling with model size, requiring nearly 9 A100 GPUs for real-time inference at 1,000 tokens per second.
Hardware requirements scaling with model size, requiring nearly 9 A100 GPUs for real-time inference at 1,000 tokens per second.

The results reveal the scaling problem. As model size increases from 7B to 540B parameters, computational requirements grow proportionally. The 7B model processes 1000 tokens per second with approximately 1 A100 GPU, while the 540B model requires nearly 9 GPUs for the same throughput. This linear scaling between parameters and compute creates a fundamental efficiency barrier. Even with optimized hardware running at 40% utilization, the infrastructure costs become prohibitive at large scales. This growth in hardware requirements motivates the search for sparse architectures that can decouple model capacity from computational cost.

The Conditional Computation Paradigm

The insight behind sparse models is elegant: not all parameters need to contribute to every prediction. Different inputs may benefit from different computational patterns. A question about chemistry might activate different knowledge than a question about history, even within the same model. This observation suggests that the dense model's uniform treatment of all inputs might be wasteful: we're applying computational resources indiscriminately rather than directing them where they're most needed.

Conditional Computation

A computational paradigm where the operations performed depend on the input. Different inputs follow different computational paths through the network, activating different subsets of parameters based on the input's characteristics.

Human experts provide an intuitive analogy. When a hospital faces a complex case, they don't have every specialist examine the patient. Instead, a routing process directs the case to relevant experts: perhaps a cardiologist and neurologist for certain symptoms, or an oncologist and radiologist for others. Each expert has deep knowledge in their domain, but only the relevant experts participate in each case. The hospital maintains a large total capacity (many specialists) while each patient receives focused attention from only the most relevant subset.

Sparse models implement this principle mathematically. Instead of a single large FFN layer, imagine dividing those parameters into multiple separate "expert" networks. A learned routing mechanism examines each input and decides which experts should process it. This routing decision is the key innovation: rather than treating all parameters uniformly, we learn to selectively activate the most relevant subset for each input.

The sparse layer computes its output as a weighted combination of expert outputs. This combination allows the model to blend specialized knowledge from different experts based on the input's characteristics. The weighting scheme ensures that more relevant experts contribute more strongly to the final output, while less relevant experts contribute proportionally less or not at all. For an input xx, the output is computed as:

y=i=1Ngi(x)Ei(x)y = \sum_{i=1}^{N} g_i(x) \cdot E_i(x)

where:

  • yy: the output of the sparse layer for input xx
  • NN: the total number of expert networks available in the layer
  • Ei(x)E_i(x): the output of expert network ii when processing input xx
  • gi(x)g_i(x): the gating weight for expert ii given input xx, which determines how much expert ii contributes to the final output
  • i=1N\sum_{i=1}^{N}: summation operator that sums over all NN experts (from expert 1 to expert NN)
  • gi(x)Ei(x)g_i(x) \cdot E_i(x): the weighted contribution of expert ii, scaling the expert's output by how relevant it is to this input

This weighted sum allows the model to combine specialized knowledge from different experts in a flexible, input-dependent manner. The key property is that gi(x)g_i(x) is sparse: most values are zero, meaning most experts don't participate for any given input. This sparsity is what makes the computation efficient: we only compute Ei(x)E_i(x) for experts where gi(x)>0g_i(x) > 0. Typically, only the top-k experts with highest gating weights are activated, with gi(x)=0g_i(x) = 0 for all others. This selection mechanism transforms a potentially expensive computation over all experts into a much cheaper computation over just a few.

To ensure computational efficiency while maintaining proper probability distributions, we impose two constraints on the gating weights. The first constraint enforces sparsity, ensuring that only k experts are active for any given input. The second ensures the weights form a valid probability distribution, so the output is a proper weighted average. These constraints can be formalized as:

gi(x)>0 only for the k experts with highest router scoresgi(x)=0 for all other expertsi=1Ngi(x)=1(weights sum to 1, forming a proper mixture)\begin{aligned} g_i(x) &> 0 \text{ only for the } k \text{ experts with highest router scores} \\ g_i(x) &= 0 \text{ for all other experts} \\ \sum_{i=1}^{N} g_i(x) &= 1 && \text{(weights sum to 1, forming a proper mixture)} \end{aligned}

These two constraints work together to enable efficient sparse computation. The first constraint (sparsity) ensures that most gi(x)g_i(x) values are exactly zero, so we can skip computing Ei(x)E_i(x) for those experts entirely. This is where the computational savings originate: if only 2 out of 64 experts are active, we perform approximately 1/32 of the computation we would need for a dense layer with equivalent total parameters. The second constraint (normalization) ensures that the non-zero weights form a proper probability distribution, making the output a valid weighted average of expert outputs. This normalization preserves the mathematical properties that allow the model to be trained effectively with standard gradient-based optimization.

The routing mechanism implements these constraints through a three-step process that transforms continuous router scores into sparse, normalized gating weights. First, compute unnormalized router scores ri(x)r_i(x) for all experts, typically via a learned linear projection followed by softmax. This linear projection learns which features of the input are relevant for selecting each expert. Second, select the top-k experts with highest scores and set all other weights to zero, creating a sparse mask. This hard selection step is what enables the computational efficiency of sparse models. Third, renormalize only the selected weights so they sum to 1:

gi(x)={ri(x)jtop-krj(x)if itop-k experts0otherwiseg_i(x) = \begin{cases} \frac{r_i(x)}{\sum_{j \in \text{top-k}} r_j(x)} & \text{if } i \in \text{top-k experts} \\ 0 & \text{otherwise} \end{cases}

where:

  • ri(x)r_i(x): the unnormalized router score for expert ii on input xx
  • top-k\text{top-k}: the set of kk experts with the highest router scores
  • jtop-krj(x)\sum_{j \in \text{top-k}} r_j(x): normalization constant that ensures selected weights sum to 1. This process ensures that only k experts perform forward passes, dramatically reducing computation while maintaining the mathematical properties of a weighted mixture.

This mechanism fundamentally breaks the dense model's coupling between capacity and compute. We can have many experts (large total parameter count), but only activate a few per token (small computational cost). The model gains capacity to store diverse knowledge while maintaining efficient inference. The key insight is that this doesn't sacrifice quality: different tokens genuinely benefit from different experts, so selective activation isn't just an approximation but rather a better match to how knowledge should be applied.

In[5]:
Code
%pip install torch
import torch
import torch.nn as nn
import torch.nn.functional as F


class ConceptualSparseLayer(nn.Module):
    """Demonstrates the sparse computation concept (simplified)."""

    def __init__(self, d_model: int, d_ff: int, n_experts: int, top_k: int):
        super().__init__()
        self.n_experts = n_experts
        self.top_k = top_k

        # Multiple expert FFN networks
        self.experts = nn.ModuleList(
            [
                nn.Sequential(
                    nn.Linear(d_model, d_ff),
                    nn.ReLU(),
                    nn.Linear(d_ff, d_model),
                )
                for _ in range(n_experts)
            ]
        )

        # Router that decides which experts to use
        self.router = nn.Linear(d_model, n_experts)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size, seq_len, d_model = x.shape

        # Compute routing scores for each token
        router_logits = self.router(x)  # [batch, seq, n_experts]

        # Select top-k experts per token
        top_k_logits, top_k_indices = torch.topk(
            router_logits, self.top_k, dim=-1
        )
        top_k_weights = F.softmax(top_k_logits, dim=-1)

        # Initialize output
        output = torch.zeros_like(x)

        # Only compute outputs for selected experts (sparse computation)
        for i in range(self.top_k):
            expert_idx = top_k_indices[:, :, i]  # [batch, seq]
            weight = top_k_weights[:, :, i : i + 1]  # [batch, seq, 1]

            # Process each expert (simplified - real implementations batch this)
            for e in range(self.n_experts):
                mask = expert_idx == e
                if mask.any():
                    tokens_for_expert = x[mask]
                    expert_output = self.experts[e](tokens_for_expert)
                    output[mask] += (
                        weight[mask].squeeze(-1).unsqueeze(-1) * expert_output
                    )

        return output
Out[6]:
Console
Configuration: 8 experts total, 2 active per token (25.0% sparsity)

Sample routing decisions:
  Batch 0, Token 0: Experts [3, 7]
  Batch 0, Token 1: Experts [5, 0]
  Batch 0, Token 2: Experts [2, 3]
  Batch 0, Token 3: Experts [0, 7]
  Batch 1, Token 0: Experts [6, 2]
  Batch 1, Token 1: Experts [4, 3]
  Batch 1, Token 2: Experts [7, 0]
  Batch 1, Token 3: Experts [4, 2]
Out[7]:
Visualization
Routing decisions for a batch of 8 tokens using top-2 expert selection. The heatmap displays active experts (dark blue) versus inactive ones (light blue) for each token (rows). The sparse and varied activation pattern illustrates how different tokens utilize different subsets of the available experts.
Routing decisions for a batch of 8 tokens using top-2 expert selection. The heatmap displays active experts (dark blue) versus inactive ones (light blue) for each token (rows). The sparse and varied activation pattern illustrates how different tokens utilize different subsets of the available experts.

The configuration demonstrates the core efficiency principle of sparse models. Each token activates only 2 out of 8 experts, using just 25% of available capacity. The routing decisions show that different tokens select different experts, with each making independent routing choices based on the learned router weights. While the layer contains 8 experts worth of parameters, any single token only triggers computation in 2 of them. This achieves the key sparse model trade-off: maintaining high total capacity (all 8 experts available) while keeping per-token compute cost low (only 2 experts activate per forward pass).

Efficiency Analysis: Decoupling Parameters from Compute

The power of sparse models becomes clear when we quantify the efficiency gains. We can compare a dense model to a sparse model with equivalent computational cost to see the capacity advantage. This analysis reveals the mathematical foundation for why sparse models can achieve better performance per unit of compute than their dense counterparts.

Consider a dense FFN with intermediate dimension dffd_{\text{ff}}. This layer has two weight matrices: W1Rdff×dmodelW_1 \in \mathbb{R}^{d_{\text{ff}} \times d_{\text{model}}} and W2Rdmodel×dffW_2 \in \mathbb{R}^{d_{\text{model}} \times d_{\text{ff}}}. These matrices define the expansion into a higher-dimensional space and the subsequent projection back to the model dimension.

The dense FFN has the following computational characteristics:

  • Total parameters: 2×dmodel×dff2 \times d_{\text{model}} \times d_{\text{ff}}. This count comes from two weight matrices: W1Rdff×dmodelW_1 \in \mathbb{R}^{d_{\text{ff}} \times d_{\text{model}}} (containing dff×dmodeld_{\text{ff}} \times d_{\text{model}} parameters) and W2Rdmodel×dffW_2 \in \mathbb{R}^{d_{\text{model}} \times d_{\text{ff}}} (containing dmodel×dffd_{\text{model}} \times d_{\text{ff}} parameters)
  • FLOPs per token: 4×dmodel×dff4 \times d_{\text{model}} \times d_{\text{ff}}, because each of the two matrix multiplications requires approximately 2×dmodel×dff2 \times d_{\text{model}} \times d_{\text{ff}} FLOPs (standard approximation that a matrix multiplication of dimensions m×nm \times n by n×pn \times p requires 2mnp2mnp FLOPs).

where:

  • dmodeld_{\text{model}}: the dimension of the model's hidden representations (the size of token embeddings throughout the network)
  • dffd_{\text{ff}}: the intermediate dimension of the feed-forward network, typically set to 4×dmodel4 \times d_{\text{model}}
  • 2mnp2mnp: the FLOP count for matrix multiplication comes from computing m×pm \times p output elements, each requiring nn multiplications and nn additions, giving approximately 2mnp2mnp FLOPs

Now consider a sparse layer with NN experts, where each expert has intermediate dimension dff/Nd_{\text{ff}}/N, using top-2 routing. This configuration divides the total parameters across multiple smaller experts, each specialized for different types of inputs. The top-2 routing means each token will be processed by exactly two experts, combining their outputs with learned weights.

For this sparse configuration, we can compute the total parameter count, active parameters per token, and FLOPs per token. These derivations illuminate exactly where the efficiency gains originate.

Total parameters: Each expert is a smaller FFN with intermediate dimension dff/Nd_{\text{ff}}/N instead of the full dffd_{\text{ff}}. Since each expert has two weight matrices, it contains 2×dmodel×(dff/N)2 \times d_{\text{model}} \times (d_{\text{ff}}/N) parameters. The expert's reduced size means it specializes in a narrower range of computations, but collectively the experts cover the same representational space as the original dense layer. Summing across all NN experts:

Total sparse parameters=N×2×dmodel×(dff/N)(sum over all experts)=2×dmodel×dff×NN(rearrange)=2×dmodel×dff(cancel N)\begin{aligned} \text{Total sparse parameters} &= N \times 2 \times d_{\text{model}} \times (d_{\text{ff}}/N) && \text{(sum over all experts)} \\ &= 2 \times d_{\text{model}} \times d_{\text{ff}} \times \frac{N}{N} && \text{(rearrange)} \\ &= 2 \times d_{\text{model}} \times d_{\text{ff}} && (\text{cancel } N) \end{aligned}

This matches the dense model exactly, since the NN terms cancel. This result is important because it shows that we haven't reduced capacity: the total number of learnable parameters remains the same. The parameters are simply distributed across multiple experts rather than concentrated in a single network.

Active parameters per token: The key efficiency gain comes from activating only a subset of experts. With top-2 routing, only 2 out of NN experts activate for any given token. Each active expert contributes its 2×dmodel×(dff/N)2 \times d_{\text{model}} \times (d_{\text{ff}}/N) parameters to the computation. This selective activation is the heart of sparse efficiency:

Active parameters=top-k×2×dmodel×(dff/N)(k active experts)=2×2×dmodel×(dff/N)(substitute top-k = 2)=4×dmodel×dffN(simplify)\begin{aligned} \text{Active parameters} &= \text{top-k} \times 2 \times d_{\text{model}} \times (d_{\text{ff}}/N) && \text{(k active experts)} \\ &= 2 \times 2 \times d_{\text{model}} \times (d_{\text{ff}}/N) && \text{(substitute top-k = 2)} \\ &= \frac{4 \times d_{\text{model}} \times d_{\text{ff}}}{N} && \text{(simplify)} \end{aligned}

FLOPs per token: The computational cost follows the active parameters, since we only perform computations for the experts that are activated. Each active expert must perform two matrix multiplications (expansion and contraction), costing 4×dmodel×(dff/N)4 \times d_{\text{model}} \times (d_{\text{ff}}/N) FLOPs. With 2 active experts:

Sparse FLOPs=top-k×4×dmodel×(dff/N)(k active experts, each costs 4×dmodel×(dff/N) FLOPs)=2×4×dmodel×(dff/N)(substitute top-k = 2)=8×dmodel×dffN(simplify)\begin{aligned} \text{Sparse FLOPs} &= \text{top-k} \times 4 \times d_{\text{model}} \times (d_{\text{ff}}/N) && \text{(k active experts, each costs } 4 \times d_{\text{model}} \times (d_{\text{ff}}/N) \text{ FLOPs)} \\ &= 2 \times 4 \times d_{\text{model}} \times (d_{\text{ff}}/N) && \text{(substitute top-k = 2)} \\ &= \frac{8 \times d_{\text{model}} \times d_{\text{ff}}}{N} && \text{(simplify)} \end{aligned}

With N=8N = 8 experts, the sparse model uses only 1/41/4 the FLOPs per token while maintaining the same total parameter count. This is a remarkable result: we get the same capacity at a fraction of the computational cost. We can verify this by computing the ratio:

Sparse FLOPs ratio=Sparse FLOPs per tokenDense FLOPs per token(definition)=8×dmodel×dff/N4×dmodel×dff(substitute formulas)=8/N4(cancel dmodel×dff)=2N(simplify)=28=14(substitute N=8)\begin{aligned} \text{Sparse FLOPs ratio} &= \frac{\text{Sparse FLOPs per token}}{\text{Dense FLOPs per token}} && \text{(definition)} \\ &= \frac{8 \times d_{\text{model}} \times d_{\text{ff}}/N}{4 \times d_{\text{model}} \times d_{\text{ff}}} && \text{(substitute formulas)} \\ &= \frac{8/N}{4} && (\text{cancel } d_{\text{model}} \times d_{\text{ff}}) \\ &= \frac{2}{N} && \text{(simplify)} \\ &= \frac{2}{8} = \frac{1}{4} && \text{(substitute } N = 8\text{)} \end{aligned}

This 1/41/4 ratio reveals the key insight: with 8 experts and top-2 routing, we use only 25% of the dense model's compute per token. The general formula 2/N2/N shows that computational cost decreases linearly with the number of experts, while total model capacity (parameter count) increases linearly with NN. This inverse relationship between experts and compute cost is the mathematical foundation for sparse model efficiency.

A significant advantage of sparse models is the ability to increase total parameters without increasing compute. Rather than keeping the same total parameter count as the dense model, we can make each expert the same size as the original dense FFN, giving us NN times more parameters while only paying k/Nk/N times the compute. This is the real power of sparse architectures: they let us scale capacity independently of computational cost.

In[8]:
Code
def compare_dense_sparse(
    d_model: int = 4096,
    d_ff_dense: int = 16384,
    n_experts: int = 8,
    top_k: int = 2,
):
    """Compare dense and sparse model efficiency."""

    # Dense model statistics
    dense_params = 2 * d_model * d_ff_dense
    dense_flops = 4 * d_model * d_ff_dense  # 2 matmuls, each costs 2*in*out

    # Sparse model: each expert has same size as dense FFN
    # This gives us N times more total parameters
    expert_ff_dim = d_ff_dense  # Each expert is full-sized
    sparse_total_params = n_experts * 2 * d_model * expert_ff_dim
    sparse_active_params = top_k * 2 * d_model * expert_ff_dim
    sparse_flops = top_k * 4 * d_model * expert_ff_dim

    return {
        "dense_params": dense_params,
        "dense_flops": dense_flops,
        "sparse_total_params": sparse_total_params,
        "sparse_active_params": sparse_active_params,
        "sparse_flops": sparse_flops,
        "param_ratio": sparse_total_params / dense_params,
        "flops_ratio": sparse_flops / dense_flops,
        "efficiency_gain": (sparse_total_params / dense_params)
        / (sparse_flops / dense_flops),
    }
Out[9]:
Console
Dense Model: 134,217,728 params, 268,435,456 FLOPs/token
Sparse Model (8 experts, top-2): 1,073,741,824 total params, 268,435,456 active/token, 536,870,912 FLOPs/token

Efficiency: 8.0x parameters, 2.0x FLOPs, 4.0x capacity per FLOP
Out[10]:
Visualization
Comparison of parameter counts and computational costs between a dense model and a sparse model. While the sparse model (orange) has 8x the total parameters of the dense model (blue), its active parameters and FLOPs per token are much lower relative to total capacity, achieving a 4x efficiency gain in capacity per FLOP.
Comparison of parameter counts and computational costs between a dense model and a sparse model. While the sparse model (orange) has 8x the total parameters of the dense model (blue), its active parameters and FLOPs per token are much lower relative to total capacity, achieving a 4x efficiency gain in capacity per FLOP.

The sparse model achieves significantly higher parameter count while maintaining reasonable computational cost. With 8 experts and top-2 routing, the total parameter count is 8x larger than the dense baseline, but FLOPs per token only double because each token activates just 2 of the 8 experts. This creates a 4x improvement in capacity per unit of compute. The dense model's single FFN processes every token, while the sparse model distributes capacity across multiple experts and activates only a subset per token. This fundamental trade-off allows sparse models to scale capacity beyond what would be computationally feasible with dense architectures.

In practice, this enables a new training and deployment paradigm:

Comparison of parameter counts and effective capacity between Dense 7B and Sparse 47B models.
MetricDense 7BSparse 47B (8 experts, top-2)
Total Parameters7B47B
Active Parameters7B~12B
FLOPs per Token~14 GFLOP~24 GFLOP
Effective Capacity1x~4-6x

The sparse model processes each token with similar cost to a ~12B dense model but has the knowledge capacity of a much larger model. This has significant implications for the scaling laws we discussed in Part XXI: we can push the Pareto frontier of performance vs. compute. The sparse architecture essentially lets us move along a different efficiency curve, one where we can trade increased memory (for storing more expert parameters) for improved quality at fixed compute.

Why Replace the Feed-Forward Network?

You might wonder why sparse architectures typically replace the FFN rather than the attention mechanism. As we explored in Part X, attention's computational cost comes from the quadratic relationship between sequence length and compute. But there's another consideration: the role each component plays in the model. Understanding these distinct roles helps explain why the FFN is the natural candidate for sparsification.

Attention mechanisms handle information routing, determining which tokens should attend to which other tokens. This function is inherently about relationships within the input and benefits from seeing all tokens simultaneously. Making attention sparse (as in Longformer and BigBird from Part XIV) requires careful design to ensure important relationships aren't missed. The attention mechanism needs to maintain a global view of the sequence to effectively aggregate information across positions.

Feed-forward networks, by contrast, operate on each token position independently. They're often interpreted as key-value memories that store factual knowledge and learned transformations. This independence is crucial: since each token's FFN computation doesn't depend on other tokens in the sequence, we can freely route different tokens to different experts without disrupting information flow. This makes them natural candidates for specialization: different experts can store different types of knowledge without interfering with the information routing that attention provides.

The FFN also dominates the parameter count in typical transformers. In a standard architecture with dff=4×dmodeld_{ff} = 4 \times d_{model}, the FFN contains roughly two-thirds of each layer's parameters. This makes the FFN the natural target for sparsification. Converting it to a mixture of experts yields the largest efficiency gains. To see why the FFN dominates, let's calculate the parameter counts step by step, starting with the FFN:

FFN parameters=2×dmodel×dff(two weight matrices)=2×dmodel×(4×dmodel)(substitute dff=4×dmodel)=8×dmodel2(simplify)\begin{aligned} \text{FFN parameters} &= 2 \times d_{\text{model}} \times d_{\text{ff}} && \text{(two weight matrices)} \\ &= 2 \times d_{\text{model}} \times (4 \times d_{\text{model}}) && \text{(substitute } d_{\text{ff}} = 4 \times d_{\text{model}}\text{)} \\ &= 8 \times d_{\text{model}}^2 && \text{(simplify)} \end{aligned}

For comparison, the attention mechanism contains four projection matrices (Q, K, V, and output projection), each with dimension dmodel×dmodeld_{\text{model}} \times d_{\text{model}}. These projections transform the input into queries, keys, and values for the attention computation, then project the attention output back to the model dimension. This gives:

Attention parameters=4×dmodel×dmodel=4×dmodel2\text{Attention parameters} = 4 \times d_{\text{model}} \times d_{\text{model}} = 4 \times d_{\text{model}}^2

where:

  • 44: the number of projection matrices in multi-head attention (query, key, value, and output projections)
  • dmodel×dmodeld_{\text{model}} \times d_{\text{model}}: the dimensions of each projection matrix, mapping from the model dimension to itself

To compute the FFN's share of total layer parameters, we divide FFN parameters by the sum of FFN and attention parameters. This ratio tells us what fraction of each layer's capacity resides in the feed-forward network:

FFN fraction=FFN parametersFFN parameters+Attention parameters(definition)=8dmodel28dmodel2+4dmodel2(substitute parameter counts)=8dmodel212dmodel2(combine denominators)=812(cancel dmodel2)=23(simplify fraction)\begin{aligned} \text{FFN fraction} &= \frac{\text{FFN parameters}}{\text{FFN parameters} + \text{Attention parameters}} && \text{(definition)} \\ &= \frac{8d_{\text{model}}^2}{8d_{\text{model}}^2 + 4d_{\text{model}}^2} && \text{(substitute parameter counts)} \\ &= \frac{8d_{\text{model}}^2}{12d_{\text{model}}^2} && \text{(combine denominators)} \\ &= \frac{8}{12} && (\text{cancel } d_{\text{model}}^2) \\ &= \frac{2}{3} && \text{(simplify fraction)} \end{aligned}

This means the FFN accounts for approximately 67% of each layer's parameters. This two-thirds dominance makes the FFN the natural target for sparsification: converting it to a mixture of experts allows us to scale the parameter count (by adding more experts) while keeping computational cost proportional to the number of active experts rather than total experts. Since the FFN already holds most of the parameters, this is where we get the biggest capacity gains from sparse computation. If we were to sparsify the attention mechanism instead, we would affect only one-third of the layer's parameters, yielding proportionally smaller efficiency gains.

In[11]:
Code
def analyze_layer_parameters(
    d_model: int = 4096, n_heads: int = 32, d_ff: int = 16384
):
    """Break down parameter distribution in a transformer layer."""

    d_head = d_model // n_heads

    # Self-attention parameters: Q, K, V projections and output projection
    attention_params = 4 * d_model * d_model  # 4 matrices of d_model × d_model

    # FFN parameters: two matrices d_model by d_ff and d_ff by d_model
    ffn_params = 2 * d_model * d_ff

    # Layer norms: small contribution, included for completeness
    norm_params = 4 * d_model  # 2 layer norms with scale and bias

    total = attention_params + ffn_params + norm_params

    return {
        "attention": attention_params,
        "ffn": ffn_params,
        "norms": norm_params,
        "total": total,
        "ffn_fraction": ffn_params / total,
    }
Out[12]:
Console
Attention: 67,108,864 params (33.3%)
FFN: 134,217,728 params (66.7%)
Total: 201,342,976 params

FFN accounts for 66.7% of layer parameters
Out[13]:
Visualization
Proportional breakdown of parameters in a standard transformer layer. The Feed-Forward Network (FFN) accounts for approximately two-thirds (66.7%) of total parameters, identifying it as the high-leverage target for sparsification strategies.
Proportional breakdown of parameters in a standard transformer layer. The Feed-Forward Network (FFN) accounts for approximately two-thirds (66.7%) of total parameters, identifying it as the high-leverage target for sparsification strategies.

The FFN's dominance at approximately two-thirds of layer parameters makes it the natural target for sparsification. In typical transformer configurations, the FFN contains roughly twice as many parameters as the attention mechanism. This 2:1 ratio means that converting the FFN to a mixture of experts yields the largest efficiency gains. When we add more experts to the FFN, we scale the parameter count significantly while keeping computational cost proportional to the number of active experts. Since the FFN already holds most parameters, this is where sparse computation provides the biggest capacity increase per unit of additional compute. Replacing the dense FFN with a sparse mixture of experts while keeping attention dense preserves the model's ability to route information effectively while dramatically increasing its knowledge capacity.

Challenges of Sparse Architectures

Sparse models introduce complexity that dense models don't face. Sparse models offer decoupled capacity and compute, but they also introduce several engineering challenges.

Load Imbalance

The most immediate challenge is load imbalance. If the router learns to send most tokens to a small subset of experts, the model degrades to an expensive dense network where popular experts become bottlenecks while unused experts contribute nothing. This scenario represents a failure to achieve the sparse model's promise: we pay for all the expert parameters in memory but only use a fraction of them effectively.

This happens naturally during training without intervention. Early in training, some experts may randomly produce slightly better outputs for common inputs. The router learns to prefer these experts, which then receive more gradient updates and improve further. Meanwhile, neglected experts fall behind, creating a rich-get-richer dynamic that can collapse the model to using just one or two experts. This collapse is self-reinforcing: once an expert dominates, it receives most training signal and becomes even more dominant, while underused experts stagnate.

The load balance coefficient quantifies this imbalance mathematically, providing a single number that summarizes how evenly tokens are distributed across experts. For NN experts and TT total tokens, perfect balance would send T/NT/N tokens to each expert. The load balance coefficient is defined as:

Load balance coefficient=N×mini(tokensi)T\text{Load balance coefficient} = \frac{N \times \min_i(\text{tokens}_i)}{T}

where:

  • NN: the total number of experts in the layer
  • TT: the total number of tokens being processed
  • mini(tokensi)\min_i(\text{tokens}_i): the minimum number of tokens sent to any single expert (the most underutilized expert)
  • T/NT/N: the ideal load per expert under perfect balance

This coefficient equals 1.0 when all experts receive exactly T/NT/N tokens (perfect balance), and approaches 0 as the distribution becomes more skewed. A coefficient near 0 indicates severe imbalance where some experts are heavily overused while others sit idle. The focus on the minimum load is intentional: it identifies the weakest link, the expert receiving the least training signal, which determines how effectively the model uses its full capacity.

In[14]:
Code
import numpy as np


def simulate_load_imbalance(
    n_tokens: int = 1000, n_experts: int = 8, imbalance_factor: float = 0.5
):
    """Simulate expert load distribution with varying imbalance."""

    # Create biased routing probabilities
    # Higher imbalance_factor means more uniform distribution
    base_probs = np.random.exponential(scale=imbalance_factor, size=n_experts)
    base_probs = base_probs / base_probs.sum()

    # Sample routing decisions
    expert_choices = np.random.choice(n_experts, size=n_tokens, p=base_probs)

    # Count tokens per expert
    counts = np.bincount(expert_choices, minlength=n_experts)

    # Compute load balance metrics
    ideal_load = n_tokens / n_experts
    max_load = counts.max()
    min_load = counts.min()

    # Load balance coefficient: 1.0 means perfectly balanced
    load_balance_coef = (n_experts * counts.min()) / n_tokens

    return {
        "counts": counts,
        "ideal_load": ideal_load,
        "max_load": max_load,
        "min_load": min_load,
        "load_balance_coef": load_balance_coef,
        "max_overload_ratio": max_load / ideal_load,
    }
Out[15]:
Visualization
Imbalanced routing showing a collapsed state where one expert processes the majority of tokens (low balance coefficient), creating a bottleneck.
Imbalanced routing showing a collapsed state where one expert processes the majority of tokens (low balance coefficient), creating a bottleneck.
Balanced routing demonstrating effective load distribution where tokens are distributed nearly evenly across all experts, maximizing parallel processing efficiency.
Balanced routing demonstrating effective load distribution where tokens are distributed nearly evenly across all experts, maximizing parallel processing efficiency.

The visualization demonstrates two load distribution scenarios. The imbalanced case (left) shows a coefficient near 0, where Expert 0 receives the majority of tokens while other experts are severely underutilized. This creates a bottleneck where one expert must process most inputs sequentially. The balanced case (right) achieves a coefficient near 1.0, with all experts receiving roughly equal loads near the ideal line of 125 tokens per expert. This even distribution enables true parallel processing across all experts. Real sparse models require auxiliary loss functions during training to prevent collapse toward the imbalanced state, as the routing mechanism naturally tends toward specialization that can create bottlenecks.

Communication Overhead in Distributed Settings

Sparse models shine for large-scale systems, but they introduce unique distributed computing challenges. In a dense model, each GPU holds a copy of the full layer and processes its assigned batch of tokens independently. The communication pattern is regular and predictable: gradients aggregate across devices during backward passes, but forward passes require no cross-device communication within a layer. In a sparse model with experts distributed across GPUs, tokens must travel to whichever GPU holds their assigned expert, fundamentally changing the communication pattern.

This creates an all-to-all communication pattern: every GPU may need to send tokens to every other GPU and receive tokens back. The communication volume depends on how tokens route, which depends on the input data. This is fundamentally different from the predictable, structured communication patterns in tensor parallelism or pipeline parallelism for dense models. Network bandwidth becomes a potential bottleneck, and communication latency adds to the overall processing time.

In[16]:
Code
def estimate_communication_cost(
    n_tokens: int, n_experts: int, n_gpus: int, d_model: int, top_k: int = 2
):
    """Estimate all-to-all communication for expert routing."""

    # Assume experts evenly distributed across GPUs
    experts_per_gpu = n_experts // n_gpus

    # Each token goes to top_k experts
    # In worst case, all tokens go to experts on different GPUs

    # Bytes per token, assuming bfloat16 precision
    bytes_per_token = d_model * 2  # Each element is 2 bytes in bfloat16 format

    # Each token must be sent to top_k experts, potentially on different GPUs,
    # and results must be returned
    send_bytes = n_tokens * top_k * bytes_per_token
    receive_bytes = n_tokens * top_k * bytes_per_token

    # With good load balancing, communication is more local
    # Estimate: fraction of tokens going to non-local experts
    non_local_fraction = 1 - (experts_per_gpu * top_k / n_experts)
    non_local_fraction = max(0, non_local_fraction)

    actual_send = send_bytes * non_local_fraction

    return {
        "worst_case_gb": (send_bytes + receive_bytes) / 1e9,
        "estimated_gb": actual_send * 2 / 1e9,
        "non_local_fraction": non_local_fraction,
    }
Out[17]:
Console
8E/8GPU, top-2: 0.13 GB worst case, 0.10 GB estimated
64E/64GPU, top-2: 0.13 GB worst case, 0.13 GB estimated
8E/4GPU, top-2: 0.13 GB worst case, 0.07 GB estimated
Out[18]:
Visualization
Estimated all-to-all communication volume per forward pass for varying expert and GPU configurations. Deployments with fewer GPUs than experts (e.g., 8E/4GPU) incur lower communication costs because local routing is possible, whereas one-expert-per-GPU setups (64E/64GPU) face significant bandwidth requirements due to cross-device token movement.
Estimated all-to-all communication volume per forward pass for varying expert and GPU configurations. Deployments with fewer GPUs than experts (e.g., 8E/4GPU) incur lower communication costs because local routing is possible, whereas one-expert-per-GPU setups (64E/64GPU) face significant bandwidth requirements due to cross-device token movement.

This reveals an important trade-off in sparse model deployment. Configurations with fewer GPUs relative to experts show lower communication costs because multiple experts can reside on the same GPU, reducing cross-GPU traffic. Conversely, configurations with one expert per GPU or many distributed GPUs experience higher communication overhead as tokens must be routed across the network. The communication cost per layer can range from negligible to several gigabytes depending on the deployment configuration. These results demonstrate that sparse models trade reduced computation for increased communication, and this trade-off becomes favorable only when the computational savings outweigh the communication overhead.

Training Instability

Sparse models introduce non-differentiable routing decisions into the computational graph. The choice of which experts to activate is discrete: tokens either go to an expert or they don't. This discreteness complicates gradient computation because the standard backpropagation algorithm requires differentiable operations to compute gradients.

Several techniques address this challenge, each with its own trade-offs:

  • Soft routing: Instead of hard selection, use weighted combinations of experts (though this partially defeats the computational savings)
  • Straight-through estimators: Approximate gradients through discrete choices
  • Auxiliary losses: Add terms to the training objective that encourage desirable routing behavior

The router also creates a circular dependency: the router should learn to send tokens to experts that process them well, but experts can only learn to process tokens they receive. Early training can be unstable as this chicken-and-egg problem resolves. This interplay between router learning and expert learning requires careful initialization and sometimes curriculum-based training strategies.

Inference Complexity

Dense model inference is straightforward: run the forward pass, and you're done. Sparse model inference requires additional bookkeeping that adds engineering complexity. For each layer, the system must:

  1. Run the router to determine expert assignments
  2. Group tokens by their assigned experts
  3. Process each expert's assigned tokens
  4. Reassemble results in the correct order

When serving requests with different inputs, the routing patterns vary, making batching less efficient than in dense models. Some inputs might activate experts heavily while others need different experts entirely. This variability complicates capacity planning and can lead to uneven GPU utilization across a serving cluster.

In[19]:
Code
def analyze_batch_efficiency(
    batch_router_decisions: np.ndarray, n_experts: int
):
    """Analyze how well a batch of inputs can be processed together."""

    # batch_router_decisions: shape [batch_size, seq_len, top_k] containing expert indices
    batch_size, seq_len, top_k = batch_router_decisions.shape

    # Total number of tokens across the entire batch
    total_tokens = batch_size * seq_len

    # Count tokens per expert
    flat_decisions = batch_router_decisions.reshape(-1, top_k)
    expert_counts = np.zeros(n_experts)

    for k in range(top_k):
        for expert in range(n_experts):
            expert_counts[expert] += (flat_decisions[:, k] == expert).sum()

    # Compute padding overhead
    # In a batched implementation, we pad to the maximum load
    max_load = expert_counts.max()
    total_with_padding = max_load * n_experts
    total_actual = expert_counts.sum()

    padding_overhead = (total_with_padding - total_actual) / total_actual

    return {
        "expert_counts": expert_counts,
        "max_load": max_load,
        "mean_load": expert_counts.mean(),
        "padding_overhead": padding_overhead,
        "utilization": total_actual / total_with_padding,
    }
Out[20]:
Console
Configuration: batch_size=16, seq_len=128, 8 experts, top-2

Expert load distribution:
  Expert 0:  1143 ██████████████████████
  Expert 1:   734 ██████████████
  Expert 2:   524 ██████████
  Expert 3:   483 █████████
  Expert 4:   354 ███████
  Expert 5:   340 ██████
  Expert 6:   267 █████
  Expert 7:   251 █████

Max load: 1143, Mean load: 512
Padding overhead: 123.2%, Compute utilization: 44.8%
Out[21]:
Visualization
Expert load distribution showing varying queue lengths per expert, where the maximum load determines the batch processing time.
Expert load distribution showing varying queue lengths per expert, where the maximum load determines the batch processing time.
Compute utilization breakdown showing that a significant portion of hardware cycles are wasted on padding overhead to synchronize the parallel expert computations.
Compute utilization breakdown showing that a significant portion of hardware cycles are wasted on padding overhead to synchronize the parallel expert computations.

The analysis reveals significant inefficiency from uneven expert loads. The expert receiving the most tokens processes roughly twice as many as the average expert, creating substantial load imbalance. This imbalance translates to high padding overhead: when we batch process all experts in parallel, we must pad shorter queues to match the longest queue, wasting computational cycles. The compute utilization metric quantifies this waste by showing what fraction of hardware cycles perform useful computation versus processing padding. This demonstrates why load balancing is critical in production systems. Without it, the theoretical efficiency gains from sparsity disappear as we waste compute matching uneven loads across experts.

The Historical Arc

The idea of conditional computation predates the transformer era. Early work in the 1990s explored mixture of experts for various machine learning tasks. However, these early systems used small numbers of experts and faced challenges with training stability and load balancing.

The key enabler for modern sparse models was scale. Small models don't benefit much from sparsity because the routing overhead dominates. But as models grew to billions of parameters, the economics shifted. Google's research on sparse transformers, including the Switch Transformer, demonstrated that sparse models could match dense model quality while using a fraction of the compute.

Recent architectures like Mixtral, which we'll examine later in this part, have brought sparse models into the mainstream. By carefully engineering the routing mechanism and balancing losses, these models achieve compelling quality-to-compute ratios that make them practical for both training and deployment.

Summary

Sparse models represent a fundamental shift in how we think about neural network scaling. Rather than accepting that more parameters always means more compute, sparse architectures decouple these quantities through conditional computation.

The core concepts we've covered lay the groundwork for understanding mixture of experts:

  • Dense models activate all parameters for every input, creating a tight coupling between capacity and cost
  • Conditional computation activates different parameter subsets based on the input, breaking this coupling
  • Sparse efficiency comes from having many experts (high capacity) while using few per token (low compute)
  • Feed-forward networks are natural targets for sparsification due to their per-token independence and large parameter share

The challenges of sparse models, including load imbalance, communication overhead, training instability, and inference complexity, have driven a rich body of research into routing mechanisms and training techniques. The next chapter introduces expert networks in detail, followed by chapters on gating mechanisms, load balancing strategies, and landmark architectures that have made sparse models practical.

Sparse models don't replace dense models entirely. They represent a new point on the Pareto frontier of capability vs. compute, particularly valuable when you need more capacity than you can afford to run densely. Understanding when and how to leverage sparsity is becoming an essential skill as models continue to scale.

Key Parameters

The key parameters for sparse models are:

  • n_experts: Number of expert networks in the sparse layer. More experts increase total model capacity but also increase memory requirements and communication overhead.
  • top_k: Number of experts to activate per token. Lower values reduce compute cost but may limit the model's ability to combine diverse knowledge.
  • d_ff (or expert_dim): Hidden dimension of each expert's feed-forward network. Larger dimensions increase expert capacity but also increase per-expert compute cost.
  • load_balance_coefficient: Metric measuring how evenly tokens distribute across experts. Values near 1.0 indicate good balance, while values near 0 suggest severe imbalance requiring intervention.

Quiz

Ready to test your understanding? Take this quick quiz to reinforce what you've learned about Sparse Models.

Loading component...

Reference

BIBTEXAcademic
@misc{sparsemodelsconditionalcomputationefficiency, author = {Michael Brenndoerfer}, title = {Sparse Models: Conditional Computation & Efficiency}, year = {2025}, url = {https://mbrenndoerfer.com/writing/sparse-models-conditional-computation-efficiency}, organization = {mbrenndoerfer.com}, note = {Accessed: 2025-12-31} }
APAAcademic
Michael Brenndoerfer (2025). Sparse Models: Conditional Computation & Efficiency. Retrieved from https://mbrenndoerfer.com/writing/sparse-models-conditional-computation-efficiency
MLAAcademic
Michael Brenndoerfer. "Sparse Models: Conditional Computation & Efficiency." 2025. Web. 12/31/2025. <https://mbrenndoerfer.com/writing/sparse-models-conditional-computation-efficiency>.
CHICAGOAcademic
Michael Brenndoerfer. "Sparse Models: Conditional Computation & Efficiency." Accessed 12/31/2025. https://mbrenndoerfer.com/writing/sparse-models-conditional-computation-efficiency.
HARVARDAcademic
Michael Brenndoerfer (2025) 'Sparse Models: Conditional Computation & Efficiency'. Available at: https://mbrenndoerfer.com/writing/sparse-models-conditional-computation-efficiency (Accessed: 12/31/2025).
SimpleBasic
Michael Brenndoerfer (2025). Sparse Models: Conditional Computation & Efficiency. https://mbrenndoerfer.com/writing/sparse-models-conditional-computation-efficiency