Expert Parallelism: Distributed Computing for MoE Models

Michael BrenndoerferNovember 19, 202537 min read

Learn how expert parallelism distributes MoE experts across devices using all-to-all communication, enabling efficient training of trillion-parameter models.

Reading Level

Choose your expertise level to adjust how many terms are explained. Beginners see more tooltips, experts see fewer to maintain reading flow. Hover over underlined terms for instant definitions.

Expert Parallelism

Mixture of Experts models achieve their remarkable efficiency by activating only a subset of parameters for each token. As we explored in previous chapters, a model with 64 experts might only use 2 experts per token, giving you the capacity of a massive model with the compute cost of a much smaller one. But this creates a fundamental systems challenge: where do all those experts physically reside, and how do tokens get to the right experts when they're spread across multiple devices?

Expert parallelism is the distributed computing strategy designed specifically for MoE architectures. Unlike data parallelism (which replicates the entire model) or tensor parallelism (which shards individual layers), expert parallelism distributes the expert networks themselves across devices while keeping tokens flowing to whichever experts they need. This chapter examines how this works in practice, from expert placement decisions to the all-to-all communication patterns that make it possible.

The Distribution Challenge

To understand why expert parallelism is necessary, consider a concrete scenario. You're training an MoE model with 64 experts, where each expert is a two-layer feed-forward network with hidden dimension 4096. This is a common configuration where each expert acts as an independent specialist. Each expert contains roughly 134 million parameters (for a model dimension of 4096). With 64 experts, the expert layers alone require storing 8.6 billion parameters, not counting the attention layers, embeddings, and other components.

This parameter count comes from a simple calculation. Each expert consists of two linear transformations: one that projects from the model dimension to the feed-forward hidden dimension, and another that projects back. When you multiply the model dimension by the expanded hidden dimension for each of these two layers, and then multiply by the number of experts, the total parameter count grows rapidly. The mathematics is simple, but the implications for hardware are significant.

No single GPU can efficiently hold all of this. Even if memory weren't a constraint, we'd want to leverage multiple devices to parallelize computation. The question becomes: how do we distribute experts across devices while maintaining efficient training?

In[2]:
Code
## Example MoE configuration
model_dim = 4096
ffn_hidden = 4096 * 4  # Standard 4x expansion
num_experts = 64
num_gpus = 8

## Parameters per expert (two linear layers)
params_per_expert = model_dim * ffn_hidden + ffn_hidden * model_dim
total_expert_params = params_per_expert * num_experts
Out[3]:
Console
Parameters per expert: 134.2M
Total expert parameters: 8.59B
Experts per GPU (with expert parallelism): 8

With 8 GPUs, expert parallelism places 8 experts on each device. This is the core principle: partition the expert pool across the available devices rather than replicating it. This approach is simple. Instead of asking each device to maintain a copy of every expert (which would be memory-prohibitive), we ask each device to host a specific subset of experts. Tokens that need those experts must travel to that device, but the memory burden is distributed evenly across the cluster.

Expert Placement Strategies

Once we commit to distributing experts across devices, we face a design decision: which experts should live on which devices? This seemingly simple question has meaningful implications for system performance, communication patterns, and load balance.

The simplest and most common strategy is uniform placement, where experts are evenly divided across devices. With EE experts and NN devices, each device holds E/NE/N experts. This approach requires no prior knowledge about how experts will be used and creates predictable, symmetric communication patterns.

In[4]:
Code
def uniform_expert_placement(num_experts: int, num_devices: int) -> dict:
    """
    Assign experts to devices uniformly.
    Returns mapping from expert_id to device_id.
    """
    assert num_experts % num_devices == 0, "Experts must divide evenly"
    experts_per_device = num_experts // num_devices

    placement = {}
    for expert_id in range(num_experts):
        device_id = expert_id // experts_per_device
        placement[expert_id] = device_id

    return placement


## Create placement for our example
placement = uniform_expert_placement(num_experts=64, num_devices=8)
Out[5]:
Console
Expert placement (expert_id -> device_id):
  Device 0: experts 0-7
  Device 1: experts 8-15
  Device 2: experts 16-23
  Device 3: experts 24-31
  Device 4: experts 32-39
  Device 5: experts 40-47
  Device 6: experts 48-55
  Device 7: experts 56-63
Out[6]:
Visualization
Expert placement grid for 64 experts distributed across 8 devices. Experts are assigned in contiguous blocks (0-7 to Device 0, 8-15 to Device 1), creating a deterministic mapping that requires no lookup table. This symmetric arrangement simplifies routing logic but ignores potential load imbalances from uneven expert popularity.
Expert placement grid for 64 experts distributed across 8 devices. Experts are assigned in contiguous blocks (0-7 to Device 0, 8-15 to Device 1), creating a deterministic mapping that requires no lookup table. This symmetric arrangement simplifies routing logic but ignores potential load imbalances from uneven expert popularity.

The placement mapping shown above illustrates the contiguous assignment strategy: experts 0 through 7 reside on device 0, experts 8 through 15 on device 1, and so forth. This contiguous assignment simplifies the mental model and makes it easy to compute which device holds any given expert through simple integer division.

Uniform placement works well when load balancing is effective and all experts receive roughly equal traffic. However, as we discussed in the chapters on load balancing and auxiliary losses, achieving perfect balance is challenging. Some experts inevitably become more popular than others, leading to computational hotspots where certain devices must process more tokens than their peers.

Capacity-Aware Placement

An alternative strategy accounts for expected expert utilization. If certain experts consistently receive more tokens, placing them on separate devices can better balance computation. The intuition is simple: if experts 0 and 1 are both heavily used, placing them on the same device would overload it while leaving others underused. By spreading popular experts across different devices, we can achieve better computational balance.

In[7]:
Code
import torch
from typing import Tuple


def capacity_aware_placement(
    num_experts: int,
    num_devices: int,
    expert_loads: torch.Tensor,  # Expected load per expert
) -> Tuple[dict, torch.Tensor]:
    """
    Place experts to balance computational load across devices.
    Greedy algorithm: assign each expert to the least-loaded device.
    """
    device_loads = torch.zeros(num_devices)
    placement = {}

    # Sort experts by load (descending) for better packing
    sorted_experts = torch.argsort(expert_loads, descending=True)

    for expert_id in sorted_experts.tolist():
        # Assign to device with minimum current load
        target_device = torch.argmin(device_loads).item()
        placement[expert_id] = target_device
        device_loads[target_device] += expert_loads[expert_id]

    return placement, device_loads


## Simulate non-uniform expert loads
torch.manual_seed(42)
expert_loads = torch.softmax(torch.randn(64), dim=0) * 64  # Normalized loads

placement_aware, final_loads = capacity_aware_placement(64, 8, expert_loads)
Out[8]:
Console
Capacity-aware placement results:
Load variance across devices: 0.0004
Uniform placement load variance: 5.0085
Out[9]:
Visualization
Uniform placement load distribution across 8 devices. The red dashed line indicates the mean load. High variance is observed due to uneven expert popularity.
Uniform placement load distribution across 8 devices. The red dashed line indicates the mean load. High variance is observed due to uneven expert popularity.
Capacity-aware placement load distribution. By assigning popular experts to underutilized devices, load variance is significantly reduced compared to the uniform baseline.
Capacity-aware placement load distribution. By assigning popular experts to underutilized devices, load variance is significantly reduced compared to the uniform baseline.

The greedy algorithm presented here processes experts in order of their expected load, always assigning each expert to the device with the lowest current total load. This bin-packing approach, while not optimal in all cases, achieves a meaningful reduction in load variance compared to uniform placement. The variance reduction shown in the output demonstrates that even this simple heuristic can significantly improve balance when expert utilization patterns are non-uniform.

Capacity-aware placement reduces load imbalance but requires knowledge of expert utilization patterns, which may change during training. This creates a practical challenge: the routing patterns that determine expert popularity evolve as the model learns, meaning that a placement decision made at initialization may become suboptimal later in training. In practice, most implementations use uniform placement combined with strong load balancing losses, preferring to shape the routing behavior rather than adapt the physical placement.

All-to-All Communication

Expert parallelism relies on a communication primitive called all-to-all. This primitive is fundamentally different from the collective operations commonly used in dense model training, and understanding its behavior is essential for reasoning about MoE system performance.

Unlike all-reduce (which aggregates values across devices) or all-gather (which collects values from all devices), all-to-all performs a complete exchange where each device sends different data to each other device. Think of it as a postal system where every post office simultaneously sends letters to every other post office, and all the letters arrive at their destinations at roughly the same time. Each device has a different message for each recipient, and all messages are exchanged in a single coordinated operation.

All-to-All Collective

An all-to-all operation takes input tensors partitioned across devices and redistributes them so that each device receives the portions destined for it from all other devices. If device ii has data for device jj, that data moves from ii to jj, and this happens simultaneously for all device pairs. The total volume of data in the system remains constant; only its distribution changes.

Here's how all-to-all fits into the MoE forward pass. The process unfolds in four distinct phases, each essential to the overall computation:

  1. Routing decision: Each device computes which experts each of its tokens should visit. The gating network produces probabilities over all experts, and the top-k selection determines which experts will process each token.

  2. First all-to-all (dispatch): Tokens are sent to the devices holding their target experts. A token that was generated on device 0 but needs expert 42 (located on device 5) must travel across the network to device 5.

  3. Expert computation: Each device processes tokens using its local experts. At this point, each device has received all tokens that need its experts, regardless of where those tokens originated.

  4. Second all-to-all (combine): Results are sent back to the original devices. The processed token representations must return to the devices where they originated so that the model can continue processing them through subsequent layers.

In[10]:
Code
def simulate_all_to_all_routing(
    tokens_per_device: int, num_devices: int, num_experts: int, top_k: int = 2
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Simulate the routing decisions in expert parallelism.
    Returns dispatch counts (tokens sent from device i to device j).
    """
    experts_per_device = num_experts // num_devices

    # Simulate random expert assignments for each token on each device
    total_tokens = tokens_per_device * num_devices

    # Each token selects top_k experts
    expert_indices = torch.randint(
        0, num_experts, (num_devices, tokens_per_device, top_k)
    )

    # Count how many tokens each device sends to each other device
    dispatch_counts = torch.zeros(num_devices, num_devices, dtype=torch.long)

    for src_device in range(num_devices):
        for token_idx in range(tokens_per_device):
            for k in range(top_k):
                expert_id = expert_indices[src_device, token_idx, k].item()
                dst_device = expert_id // experts_per_device
                dispatch_counts[src_device, dst_device] += 1

    return dispatch_counts, expert_indices


## Simulate routing for a batch
sim_tokens = 1024
sim_devices = 8
sim_experts = 64
sim_top_k = 2

dispatch_counts, _ = simulate_all_to_all_routing(
    tokens_per_device=sim_tokens,
    num_devices=sim_devices,
    num_experts=sim_experts,
    top_k=sim_top_k,
)
Out[11]:
Console
Dispatch matrix (rows=source device, cols=destination device):
Each entry shows token-expert pairs sent from source to destination

[[250 224 248 272 286 230 270 268]
 [257 262 267 252 256 245 250 259]
 [253 234 275 249 264 262 261 250]
 [225 264 263 257 235 257 256 291]
 [246 256 284 252 228 264 268 250]
 [256 258 257 258 269 237 259 254]
 [243 293 264 252 258 257 228 253]
 [262 257 254 266 230 247 259 273]]

Total token-expert pairs: 16384
Expected per device: 2048 (tokens × top_k)
Out[12]:
Visualization
Heatmap of token-expert pairs exchanged between devices during the all-to-all dispatch phase. Rows represent source devices and columns represent destinations; the diagonal indicates local routing. Uniform random routing results in a balanced distribution of traffic (~32 pairs per cell), demonstrating how expert parallelism requires extensive network communication.
Heatmap of token-expert pairs exchanged between devices during the all-to-all dispatch phase. Rows represent source devices and columns represent destinations; the diagonal indicates local routing. Uniform random routing results in a balanced distribution of traffic (~32 pairs per cell), demonstrating how expert parallelism requires extensive network communication.

The dispatch matrix reveals the communication pattern in concrete terms. Each row represents a source device, and each column represents a destination device. The value at position (i, j) tells us how many token-expert pairs device i sends to device j during the dispatch phase. With uniform random routing and 8 devices, each device sends roughly equal traffic to all devices (including itself). The values along the diagonal are particularly interesting: these represent tokens that happen to route to experts on their original device, requiring no network communication. These "local" routings are essentially free from a communication perspective, though they still require computation.

The All-to-All Operation

PyTorch's distributed library provides all_to_all for this communication pattern. The operation coordinates all devices to simultaneously exchange data according to their individual send and receive specifications. This coordination is complex at scale, as it requires synchronization across all participating devices.

In[13]:
Code
def expert_parallel_dispatch(
    hidden_states: torch.Tensor,  # [batch_size, seq_len, hidden_dim]
    expert_indices: torch.Tensor,  # [batch_size, seq_len, top_k]
    expert_weights: torch.Tensor,  # [batch_size, seq_len, top_k]
    num_experts: int,
    world_size: int,
    rank: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Dispatch tokens to their target devices via all-to-all.

    This is a simplified version showing the key operations.
    Production implementations handle variable-length messages.
    """
    batch_size, seq_len, hidden_dim = hidden_states.shape
    experts_per_device = num_experts // world_size

    # Flatten to [num_tokens, hidden_dim]
    tokens = hidden_states.view(-1, hidden_dim)
    num_tokens = tokens.shape[0]
    top_k = expert_indices.shape[-1]

    # Determine destination device for each expert selection
    flat_indices = expert_indices.view(-1)  # [num_tokens * top_k]
    dest_devices = flat_indices // experts_per_device

    # Sort tokens by destination device for efficient all-to-all
    sort_indices = torch.argsort(dest_devices)

    # Count tokens going to each device
    send_counts = torch.bincount(dest_devices, minlength=world_size)

    # In actual implementation, we'd call:
    # recv_tokens = torch.empty(...)
    # dist.all_to_all_single(recv_tokens, send_tokens, recv_counts, send_counts)

    return tokens, send_counts, sort_indices


## Demonstrate dispatch logic
demo_states = torch.randn(1, 4, 4)
## Experts mapped to devices: 0-1->Dev0, 2-3->Dev1
demo_experts = torch.tensor([[[0, 2], [1, 3], [0, 1], [2, 3]]])
demo_weights = torch.ones_like(demo_experts, dtype=torch.float)

_, demo_counts, _ = expert_parallel_dispatch(
    hidden_states=demo_states,
    expert_indices=demo_experts,
    expert_weights=demo_weights,
    num_experts=4,
    world_size=2,
    rank=0,
)
Out[14]:
Console
Dispatch counts (2 devices): [4, 4]

The dispatch counts demonstrate the core accounting that drives the all-to-all operation. In this small example with 4 tokens and 2 experts per token, we see that 4 token-expert pairs are destined for device 0 and 4 for device 1, reflecting a balanced routing pattern across the two devices.

The critical insight is that all-to-all enables dynamic routing: each token goes exactly where it needs to go based on the gating network's decision. This differs fundamentally from static parallelism strategies where data movement patterns are fixed. In tensor parallelism, for example, the communication pattern is determined entirely by the model architecture and is identical for every input. In expert parallelism, the communication pattern depends on the content of the input itself, as mediated by the learned gating network. This dynamic behavior is both the source of MoE's flexibility and the root of its systems complexity.

Communication Overhead Analysis

Expert parallelism introduces communication costs that don't exist in dense models. Understanding these costs is essential for system design and performance optimization. While dense model training also involves communication (for gradient synchronization, for example), the all-to-all operations in MoE occur in the forward pass itself, directly adding latency to every training step.

Communication Volume

For each MoE layer, we perform two all-to-all operations. The first operation dispatches tokens from their originating devices to the devices holding their target experts. The second operation returns the processed results back to the originating devices. Each operation moves the same amount of data, as every token that goes out must come back.

The total communication volume VlayerV_{\text{layer}} (in elements) per MoE layer is:

Vlayer=2×B×S×K×DV_{\text{layer}} = 2 \times B \times S \times K \times D

Let's examine each component of this formula to understand what drives communication cost:

  • BB: the batch size (number of sequences). Larger batches increase throughput but also increase the total number of tokens that must be routed.
  • SS: the sequence length (number of tokens per sequence). Longer sequences mean more tokens per batch, directly scaling communication.
  • KK: the top-kk parameter (number of experts per token). When k=2k=2, each token creates two communication events rather than one.
  • DD: the hidden dimension (size of each token vector). Larger models with wider hidden dimensions must send more data per token.
  • 22: accounts for the two all-to-all operations (dispatch and combine). This factor is fixed by the algorithm structure.

The multiplicative relationship between these factors means that communication volume can grow quickly. Doubling any single factor doubles the communication cost. Doubling all of them together would increase communication by a factor of 16.

In[15]:
Code
def calculate_communication_volume(
    batch_size: int,
    seq_len: int,
    hidden_dim: int,
    top_k: int,
    num_moe_layers: int,
    bytes_per_element: int = 2,  # bf16
) -> dict:
    """
    Calculate total communication volume for MoE forward pass.
    """
    # Tokens communicated per MoE layer
    tokens_per_layer = batch_size * seq_len * top_k
    elements_per_layer = tokens_per_layer * hidden_dim

    # Two all-to-all operations per layer (dispatch + combine)
    elements_per_layer_total = elements_per_layer * 2

    # Total across all MoE layers
    total_elements = elements_per_layer_total * num_moe_layers
    total_bytes = total_elements * bytes_per_element

    return {
        "tokens_per_layer": tokens_per_layer,
        "elements_per_all_to_all": elements_per_layer,
        "total_elements": total_elements,
        "total_bytes": total_bytes,
        "total_gb": total_bytes / (1024**3),
    }


## Typical training configuration
comm_volume = calculate_communication_volume(
    batch_size=32,
    seq_len=2048,
    hidden_dim=4096,
    top_k=2,
    num_moe_layers=16,  # MoE in every other layer of 32-layer model
)
Out[16]:
Console
Communication volume analysis:
  Tokens per MoE layer: 131,072
  Elements per all-to-all: 536,870,912
  Total elements (forward pass): 17,179,869,184
  Total communication: 32.00 GB

This substantial volume of data movement, roughly 32 GB per forward pass for a single batch, highlights why network bandwidth is often the bottleneck in MoE training. To put this in perspective, even with high-bandwidth interconnects capable of 200 GB/s, transferring 32 GB takes around 160 milliseconds if the transfers were purely sequential. The communication cost scales linearly with the number of MoE layers, making efficient routing essential. This is one reason why some architectures use MoE layers only in alternating positions rather than in every layer.

Communication vs Computation Trade-off

The key metric for understanding expert parallelism efficiency is the ratio of communication time to computation time. If communication dominates, adding more devices provides diminishing returns because each device spends more time waiting for data than actually computing. Conversely, if computation dominates, the system can efficiently utilize additional devices.

In[17]:
Code
def analyze_scaling_efficiency(
    hidden_dim: int,
    ffn_hidden: int,
    tokens_per_device: int,
    top_k: int,
    num_devices_list: list,
    bandwidth_gb_per_sec: float = 200,  # NVLink bandwidth
    tflops: float = 150,  # GPU compute throughput
) -> dict:
    """
    Analyze how communication overhead scales with device count.
    """
    results = []

    for num_devices in num_devices_list:
        # Communication: tokens × hidden_dim × 2 (dispatch + combine)
        # Each token sends/receives from other devices
        # Fraction of traffic that crosses device boundaries
        cross_device_fraction = (num_devices - 1) / num_devices

        comm_elements = (
            tokens_per_device * top_k * hidden_dim * 2 * cross_device_fraction
        )
        comm_bytes = comm_elements * 2  # bf16
        comm_time = comm_bytes / (bandwidth_gb_per_sec * 1e9)

        # Computation: expert forward pass
        # FLOPs = 2 * tokens * hidden_dim * ffn_hidden * 2 (two linear layers)
        flops = 2 * tokens_per_device * top_k * hidden_dim * ffn_hidden * 2
        compute_time = flops / (tflops * 1e12)

        efficiency = compute_time / (compute_time + comm_time)

        results.append(
            {
                "num_devices": num_devices,
                "comm_time_ms": comm_time * 1000,
                "compute_time_ms": compute_time * 1000,
                "efficiency": efficiency,
            }
        )

    return results


scaling_results = analyze_scaling_efficiency(
    hidden_dim=4096,
    ffn_hidden=4096 * 4,
    tokens_per_device=4096,
    top_k=2,
    num_devices_list=[2, 4, 8, 16, 32, 64],
)
Out[18]:
Console
Expert parallelism scaling efficiency:
 Devices    Comm (ms)   Compute (ms)   Efficiency
--------------------------------------------------
       2        0.336         14.660        97.8%
       4        0.503         14.660        96.7%
       8        0.587         14.660        96.1%
      16        0.629         14.660        95.9%
      32        0.650         14.660        95.8%
      64        0.661         14.660        95.7%
Out[19]:
Visualization
Time breakdown per MoE layer as device count increases. Communication overhead (orange) grows with the number of devices while computation time (blue) remains constant, as cross-device traffic increases.
Time breakdown per MoE layer as device count increases. Communication overhead (orange) grows with the number of devices while computation time (blue) remains constant, as cross-device traffic increases.
Out[20]:
Visualization
Expert parallelism scaling efficiency. Efficiency degrades below 80% at 32 devices as the communication overhead from all-to-all operations begins to dominate the fixed computation time.
Expert parallelism scaling efficiency. Efficiency degrades below 80% at 32 devices as the communication overhead from all-to-all operations begins to dominate the fixed computation time.

Several insights emerge from this analysis. First, communication overhead grows with the number of devices because more tokens must cross device boundaries. With 2 devices, half the tokens (on average) route to local experts and require no network transfer. With 64 devices, only 1/64 of tokens stay local, meaning the vast majority must travel across the network.

Second, efficiency remains high when computation dominates, but degrades as we scale to many devices. The efficiency percentages in the rightmost column show this degradation: at 2 devices we might maintain 95% efficiency, but at 64 devices this drops significantly. This suggests a practical limit to how far expert parallelism alone can scale.

Third, high-bandwidth interconnects like NVLink are essential for maintaining efficiency at scale. The analysis above assumes 200 GB/s bandwidth, which is achievable with modern GPU interconnects. Slower connections, such as those between nodes in a cluster connected by InfiniBand or Ethernet, would shift these efficiency curves dramatically downward.

Expert Parallelism Implementation

A complete expert parallelism implementation requires coordinating routing decisions, all-to-all communication, and expert computation. The implementation must handle several subtle challenges: variable-length messages between device pairs, preservation of token ordering for correct gradient computation, and efficient batching of expert computations. Here's a simplified but functional implementation that demonstrates the core concepts:

In[21]:
Code
import torch
import torch.nn as nn
from typing import Tuple


class ExpertParallelMoE(nn.Module):
    """
    MoE layer with expert parallelism.
    Each device holds a subset of experts.
    """

    def __init__(
        self,
        hidden_dim: int,
        ffn_hidden: int,
        num_experts: int,
        num_local_experts: int,
        top_k: int = 2,
        capacity_factor: float = 1.25,
    ):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_experts = num_experts
        self.num_local_experts = num_local_experts
        self.top_k = top_k
        self.capacity_factor = capacity_factor

        # Router (replicated on all devices)
        self.gate = nn.Linear(hidden_dim, num_experts, bias=False)

        # Local experts only
        self.experts = nn.ModuleList(
            [
                nn.Sequential(
                    nn.Linear(hidden_dim, ffn_hidden, bias=False),
                    nn.GELU(),
                    nn.Linear(ffn_hidden, hidden_dim, bias=False),
                )
                for _ in range(num_local_experts)
            ]
        )

    def forward(
        self, hidden_states: torch.Tensor, expert_parallel_group=None
    ) -> Tuple[torch.Tensor, dict]:
        """
        Forward pass with expert parallelism.

        Args:
            hidden_states: [batch, seq_len, hidden_dim]
            expert_parallel_group: Distributed process group for all-to-all

        Returns:
            output: [batch, seq_len, hidden_dim]
            aux_info: Dictionary with auxiliary information
        """
        batch_size, seq_len, hidden_dim = hidden_states.shape
        num_tokens = batch_size * seq_len

        # Compute routing probabilities
        router_logits = self.gate(
            hidden_states
        )  # [batch, seq_len, num_experts]
        router_probs = torch.softmax(router_logits, dim=-1)

        # Select top-k experts
        top_k_probs, top_k_indices = torch.topk(
            router_probs, self.top_k, dim=-1
        )  # [batch, seq_len, top_k]

        # Normalize weights
        top_k_weights = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)

        # Flatten for processing
        flat_hidden = hidden_states.view(num_tokens, hidden_dim)
        flat_indices = top_k_indices.view(num_tokens, self.top_k)
        flat_weights = top_k_weights.view(num_tokens, self.top_k)

        # Process through experts (simplified: no actual distributed communication)
        # In production, this would involve all-to-all operations
        output = self._local_expert_forward(
            flat_hidden, flat_indices, flat_weights
        )

        return output.view(batch_size, seq_len, hidden_dim), {
            "router_probs": router_probs,
            "top_k_indices": top_k_indices,
        }

    def _local_expert_forward(
        self,
        hidden_states: torch.Tensor,
        expert_indices: torch.Tensor,
        expert_weights: torch.Tensor,
    ) -> torch.Tensor:
        """
        Process tokens through local experts.
        Simplified version without capacity limits.
        """
        num_tokens = hidden_states.shape[0]
        output = torch.zeros_like(hidden_states)

        # Determine which experts are local
        world_size = self.num_experts // self.num_local_experts
        rank = 0  # Would be dist.get_rank() in distributed setting
        local_expert_start = rank * self.num_local_experts
        local_expert_end = local_expert_start + self.num_local_experts

        for k in range(self.top_k):
            for local_idx, expert in enumerate(self.experts):
                global_idx = local_expert_start + local_idx

                # Find tokens routed to this expert
                mask = expert_indices[:, k] == global_idx
                if mask.sum() == 0:
                    continue

                # Process tokens
                expert_input = hidden_states[mask]
                expert_output = expert(expert_input)

                # Weight and accumulate
                weights = expert_weights[mask, k : k + 1]
                output[mask] += weights * expert_output

        return output


## Create model instance
moe_layer = ExpertParallelMoE(
    hidden_dim=512,
    ffn_hidden=2048,
    num_experts=8,
    num_local_experts=8,  # All experts local for this demo
    top_k=2,
)
In[22]:
Code
## Test forward pass
test_input = torch.randn(2, 16, 512)  # batch=2, seq_len=16
output, aux_info = moe_layer(test_input)
Out[23]:
Console
Input shape: torch.Size([2, 16, 512])
Output shape: torch.Size([2, 16, 512])
Router probabilities shape: torch.Size([2, 16, 8])
Selected experts shape: torch.Size([2, 16, 2])

The output shapes confirm that the model processes the input sequence and produces an output of the same dimensionality. The router probabilities tensor has shape corresponding to the batch size, sequence length, and number of experts, providing a complete probability distribution over experts for each token. The selected indices tensor (top-k) has the same batch and sequence dimensions but with the last dimension equal to k, indicating which experts were chosen for each token. These auxiliary outputs provide visibility into the gating mechanism, allowing us to verify routing behavior, compute load balancing metrics, and debug potential issues.

Out[24]:
Visualization
Router probabilities (gating weights) for a single sequence across 8 experts. The heatmap visualizes the soft assignment of tokens (x-axis) to experts (y-axis), where brighter colors indicate higher probability. The distinct vertical patterns show how different tokens activate specific experts, illustrating the router's selective specialization.
Router probabilities (gating weights) for a single sequence across 8 experts. The heatmap visualizes the soft assignment of tokens (x-axis) to experts (y-axis), where brighter colors indicate higher probability. The distinct vertical patterns show how different tokens activate specific experts, illustrating the router's selective specialization.
Out[25]:
Visualization
Distribution of token counts assigned to each expert across the batch. The bars show the actual number of tokens routed to each expert based on top-k selection, while the red dashed line indicates the theoretical uniform load. Deviations from the mean illustrate the natural load imbalance that arises from content-based routing decisions.
Distribution of token counts assigned to each expert across the batch. The bars show the actual number of tokens routed to each expert based on top-k selection, while the red dashed line indicates the theoretical uniform load. Deviations from the mean illustrate the natural load imbalance that arises from content-based routing decisions.

Handling Variable-Length Communication

A subtlety of expert parallelism is that all-to-all messages have variable lengths. One device might send 1000 tokens to device 0 but only 500 to device 1, depending on routing decisions. This variability arises naturally from the learned gating function: different inputs activate different experts, and the distribution of selected experts changes from batch to batch. This requires careful buffer management to ensure that receiving devices allocate sufficient space without wasting memory:

In[26]:
Code
def prepare_all_to_all_buffers(
    hidden_states: torch.Tensor,
    expert_indices: torch.Tensor,
    num_devices: int,
    experts_per_device: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Prepare send/receive buffers for variable-length all-to-all.

    Returns:
        send_buffer: Sorted tokens ready for dispatch
        send_counts: Number of elements to send to each device
        permutation: Indices to restore original order after combine
    """
    num_tokens, hidden_dim = hidden_states.shape
    top_k = expert_indices.shape[1]

    # Expand hidden states for each expert selection
    expanded_hidden = hidden_states.unsqueeze(1).expand(-1, top_k, -1)
    expanded_hidden = expanded_hidden.reshape(-1, hidden_dim)

    # Determine destination device for each token-expert pair
    flat_indices = expert_indices.view(-1)
    dest_devices = flat_indices // experts_per_device

    # Sort by destination device
    sort_indices = torch.argsort(dest_devices, stable=True)
    sorted_hidden = expanded_hidden[sort_indices]

    # Count elements for each destination
    send_counts = torch.bincount(dest_devices, minlength=num_devices)

    # Create inverse permutation for restoration
    inverse_indices = torch.empty_like(sort_indices)
    inverse_indices[sort_indices] = torch.arange(len(sort_indices))

    return sorted_hidden, send_counts, inverse_indices


## Test buffer preparation
num_test_tokens = 32
test_top_k = 2
test_hidden = torch.randn(num_test_tokens, 512)
test_experts = torch.randint(0, 64, (num_test_tokens, test_top_k))

send_buffer, send_counts, permutation = prepare_all_to_all_buffers(
    test_hidden, test_experts, num_devices=8, experts_per_device=8
)
Out[27]:
Console
Original hidden states: torch.Size([32, 512])
Send buffer (sorted): torch.Size([64, 512])
Send counts per device: [8, 3, 8, 8, 8, 9, 10, 10]
Total dispatched: 64 (expected: 64)
Permutation indices shape: torch.Size([64])

The buffer preparation function performs several important transformations. First, it expands the hidden states to account for the top-k selection, creating a separate copy of each token for each of its selected experts. Then, it sorts these expanded tokens by their destination device, grouping all tokens bound for device 0 together, followed by all tokens bound for device 1, and so on. This sorting is essential because the all-to-all operation expects contiguous chunks of data for each destination.

The permutation indices are crucial: after receiving results back via the second all-to-all, we use them to restore the original token ordering. Without this restoration step, the processed representations would be scrambled, breaking the correspondence between tokens and their positions in the sequence. The inverse permutation ensures that after the round-trip through expert computation, each token's result ends up in the correct position for subsequent layers.

Key Parameters

The key parameters for the Expert Parallel MoE implementation are:

  • num_experts: The total number of experts in the model. This determines the potential for specialization and affects both model capacity and communication patterns. Common values range from 8 (as in Mixtral) to thousands (as in Switch Transformer).

  • num_local_experts: The number of experts stored on the current device (typically num_experts / world_size). This value determines the memory footprint per device and must divide evenly into the total expert count for uniform placement.

  • top_k: The number of experts selected for each token (usually 1 or 2). Higher values increase computational cost but may improve model quality by allowing more experts to contribute to each token's representation.

  • capacity_factor: A multiplier to reserve extra buffer space for load imbalances (e.g., 1.25x expected load). This parameter provides headroom for the natural variation in routing patterns, preventing token dropping when some experts receive more traffic than expected.

Combining Expert Parallelism with Other Strategies

In practice, expert parallelism is combined with other parallelism strategies to scale MoE models to hundreds of GPUs. No single parallelism strategy can address all the challenges of training massive models: expert parallelism handles the distribution of experts, but other strategies are needed to handle large batch sizes, oversized individual layers, and the sequential dependencies between layers.

  • Expert + Data Parallelism: Replicate the expert parallel setup across multiple groups of devices. Each group has a complete set of experts, and different groups process different batches. This combination multiplies throughput by the number of data parallel replicas while keeping the communication complexity of expert parallelism contained within each replica group.

  • Expert + Tensor Parallelism: Shard individual experts across devices in addition to distributing experts. This handles cases where individual experts are too large for a single device. For example, if each expert has 7 billion parameters (as in Mixtral-scale models), tensor parallelism can split each expert across 2 or more devices, with expert parallelism then distributing these sharded experts across additional device groups.

  • Expert + Pipeline Parallelism: Distribute different layers across different pipeline stages, with expert parallelism within each MoE layer. This approach addresses the sequential nature of transformer computation, where each layer depends on the output of the previous layer. Pipeline parallelism overlaps computation across stages, while expert parallelism handles the within-layer distribution of experts.

In[28]:
Code
import math


def calculate_device_requirements(
    num_experts: int,
    expert_params: int,
    model_params: int,  # Non-expert parameters
    memory_per_device_gb: float,
    bytes_per_param: int = 2,  # bf16
) -> dict:
    """
    Calculate minimum devices needed for different parallelism strategies.
    """
    total_expert_memory_gb = (
        num_experts * expert_params * bytes_per_param
    ) / 1e9
    total_model_memory_gb = (model_params * bytes_per_param) / 1e9
    total_memory_gb = total_expert_memory_gb + total_model_memory_gb

    # Account for optimizer states and activations (rough 3x multiplier)
    training_memory_gb = total_memory_gb * 3

    # Expert parallelism only
    min_ep_devices = max(
        1, math.ceil(total_expert_memory_gb / memory_per_device_gb)
    )

    # With tensor parallelism for non-expert params
    min_combined_devices = max(
        math.ceil(total_expert_memory_gb / memory_per_device_gb),
        math.ceil(total_model_memory_gb / memory_per_device_gb),
    )

    return {
        "total_expert_memory_gb": total_expert_memory_gb,
        "total_model_memory_gb": total_model_memory_gb,
        "training_memory_gb": training_memory_gb,
        "min_expert_parallel_devices": min_ep_devices,
        "min_combined_devices": min_combined_devices,
    }


## Mixtral-8x7B scale model
requirements = calculate_device_requirements(
    num_experts=8,
    expert_params=7_000_000_000,  # 7B params per expert
    model_params=7_000_000_000,  # Shared params (attention, embeddings)
    memory_per_device_gb=80,  # A100 80GB
)
Out[29]:
Console
Device requirements analysis (Mixtral-8x7B scale):
  Expert memory: 112.0 GB
  Other model memory: 14.0 GB
  Training memory (with optimizer): 378.0 GB
  Minimum devices (expert parallel): 2
Out[30]:
Visualization
Memory requirements for a Mixtral-8x7B scale model compared to A100 GPU capacity. Expert parameters (blue) and shared parameters (orange) consume significant memory, but training overheads (green), including optimizer states and activations, triple the requirement to over 350 GB, necessitating multi-device distribution.
Memory requirements for a Mixtral-8x7B scale model compared to A100 GPU capacity. Expert parameters (blue) and shared parameters (orange) consume significant memory, but training overheads (green), including optimizer states and activations, triple the requirement to over 350 GB, necessitating multi-device distribution.

The analysis shows that for a model of this scale, the expert parameters alone (112 GB) exceed the memory of a single 80GB device. We require at least 2 devices just to store the experts, and significantly more when training due to optimizer states and activations. The 3x multiplier for training memory accounts for the Adam optimizer (which maintains two additional tensors per parameter) and the activation checkpointing overhead. In practice, teams training models at this scale often use 8, 16, or more devices, combining expert parallelism with data parallelism to achieve both memory capacity and training throughput.

Limitations and Impact

Expert parallelism enables scaling MoE models that would otherwise be impossible to train, but it comes with significant challenges that you must understand and address.

The communication overhead from all-to-all operations can bottleneck training, especially when scaling to many devices or using slower interconnects. Models like Switch Transformer (which we'll explore in the next chapter) sometimes use expert parallelism across hundreds of devices, where communication costs become substantial. Techniques like capacity factors and dropping tokens help manage this, but at the cost of some model quality. The dropped tokens represent information that the model never processes, creating a subtle but real degradation in the model's ability to learn from its training data.

Load imbalance creates another persistent challenge. Even with auxiliary losses encouraging balanced routing, some experts inevitably receive more traffic than others. This means some devices finish their computation earlier and must wait, reducing overall efficiency. The problem worsens with more devices, as the probability of imbalance across any device pair increases. A system with 64 devices has 64 opportunities for one device to become a bottleneck, compared to just 2 opportunities with 2 devices.

Despite these challenges, expert parallelism has been transformative for scaling language models. It allows models to have many more parameters than they could with dense architectures while maintaining reasonable compute costs. The Switch Transformer demonstrated training models with trillions of parameters using expert parallelism. More recent models like Mixtral use expert parallelism more conservatively, with 8 experts providing a balance between capacity and communication efficiency.

The impact extends beyond just scale. Expert parallelism enables a form of model specialization that dense models cannot achieve. Different experts can learn different capabilities, and the routing mechanism determines which capabilities to apply for each input. This architectural choice influences not just training efficiency but the kinds of representations these models learn. Experts tend to specialize in different ways: some focus on particular syntactic patterns, others on specific domains or languages. This emergent specialization is only possible because expert parallelism makes it practical to maintain many separate expert networks within a single model.

Summary

Expert parallelism is the distributed computing strategy that makes large-scale MoE models practical. Rather than replicating all experts on every device (expensive) or sharding experts arbitrarily (complex), expert parallelism assigns complete experts to devices and uses all-to-all communication to route tokens to wherever their selected experts reside.

The core concepts include:

  • Expert placement assigns experts to devices, typically uniformly, though capacity-aware placement can improve load balance when expert utilization patterns are known or can be estimated
  • All-to-all communication is the primitive that enables dynamic routing, where each device sends different data to each other device based on the gating network's decisions
  • Communication overhead grows with device count and can become a bottleneck at scale, making high-bandwidth interconnects essential for maintaining training efficiency
  • Implementation requires careful buffer management for variable-length messages and coordination with routing decisions to preserve token ordering

Expert parallelism combines naturally with other parallelism strategies. Production systems typically use expert parallelism for the MoE layers alongside data parallelism for throughput and potentially tensor parallelism for very large expert dimensions. The choice of which strategies to combine, and in what proportions, depends on the specific model architecture, available hardware, and performance requirements.

Quiz

Ready to test your understanding? Take this quick quiz to reinforce what you've learned about expert parallelism in Mixture of Experts models.

Loading component...

Reference

BIBTEXAcademic
@misc{expertparallelismdistributedcomputingformoemodels, author = {Michael Brenndoerfer}, title = {Expert Parallelism: Distributed Computing for MoE Models}, year = {2025}, url = {https://mbrenndoerfer.com/writing/expert-parallelism-distributed-moe-training}, organization = {mbrenndoerfer.com}, note = {Accessed: 2025-01-01} }
APAAcademic
Michael Brenndoerfer (2025). Expert Parallelism: Distributed Computing for MoE Models. Retrieved from https://mbrenndoerfer.com/writing/expert-parallelism-distributed-moe-training
MLAAcademic
Michael Brenndoerfer. "Expert Parallelism: Distributed Computing for MoE Models." 2026. Web. today. <https://mbrenndoerfer.com/writing/expert-parallelism-distributed-moe-training>.
CHICAGOAcademic
Michael Brenndoerfer. "Expert Parallelism: Distributed Computing for MoE Models." Accessed today. https://mbrenndoerfer.com/writing/expert-parallelism-distributed-moe-training.
HARVARDAcademic
Michael Brenndoerfer (2025) 'Expert Parallelism: Distributed Computing for MoE Models'. Available at: https://mbrenndoerfer.com/writing/expert-parallelism-distributed-moe-training (Accessed: today).
SimpleBasic
Michael Brenndoerfer (2025). Expert Parallelism: Distributed Computing for MoE Models. https://mbrenndoerfer.com/writing/expert-parallelism-distributed-moe-training