Search

Search articles

Multi-Head Attention: Parallel Attention for Richer Representations

Michael BrenndoerferUpdated May 30, 202536 min read

Learn how multi-head attention runs multiple attention operations in parallel, enabling transformers to capture diverse relationships like syntax, semantics, and coreference simultaneously.

Track your reading progress

Sign in to mark chapters as read and track your learning journey

Sign in →
Reading Level

Choose your expertise level to adjust how many terms are explained. Beginners see more tooltips, experts see fewer to maintain reading flow. Hover over underlined terms for instant definitions.

Multi-Head Attention

In the previous chapters, we built up the complete picture of scaled dot-product attention: how queries match with keys, how values get aggregated, and how masking controls information flow. A single attention mechanism can capture one type of relationship between tokens. But language is rich with multiple simultaneous relationships. When reading "The cat sat on the mat because it was tired," you simultaneously track syntactic structure (subject-verb agreement), semantic relationships (what "it" refers to), and positional patterns (nearby words). Can we design attention to capture all of these at once?

Multi-head attention answers this question by running multiple attention operations in parallel, each with its own learned projections. Instead of one attention mechanism looking for one type of pattern, we have several "heads," each specializing in different aspects of the input. This architectural choice is central to transformer performance and appears in every major language model from BERT to GPT-4.

The Limitation of Single-Head Attention

Single-head attention computes one set of attention weights per position. Each token decides how to weight all other tokens, producing a single blended representation. This works well for capturing one dominant relationship, but it forces a choice: should "sat" attend mostly to its subject "cat," its location "mat," or its tense indicators?

Consider a more complex example: "The lawyer who the witness saw left." To parse this sentence correctly, a model must simultaneously track multiple dependencies. "Who" relates back to "lawyer" (as an antecedent). "Saw" connects to "witness" (as its subject) and to "who" (as its object). "Left" connects back to "lawyer" (not "witness") as its subject. A single attention head must somehow encode all these relationships in one set of weights.

Multi-Head Attention

Multi-head attention runs several attention operations in parallel, each with independent learned projections. The outputs are concatenated and projected to produce a final representation that captures multiple types of relationships simultaneously.

The solution is straightforward: instead of computing attention once, compute it multiple times with different projections. Each "head" learns to focus on different aspects of the input. One head might specialize in syntactic relationships, another in coreference, and another in local context. By combining their outputs, the model builds a richer representation than any single head could achieve.

How Multi-Head Attention Works

Multi-head attention divides the model's capacity across hh parallel attention heads. Each head operates on a reduced dimension, so the total computation remains comparable to single-head attention with the full dimension.

The process unfolds in four steps:

  1. Project the input into hh separate query, key, and value representations
  2. Compute scaled dot-product attention independently for each head
  3. Concatenate all head outputs into a single tensor
  4. Project the concatenated output back to the model dimension

Let's trace through each step with concrete mathematics and then implement it in code.

Step 1: Parallel Projections

Given an input sequence XRn×dmodel\mathbf{X} \in \mathbb{R}^{n \times d_{\text{model}}} with nn tokens and model dimension dmodeld_{\text{model}}, we project it into queries, keys, and values for each head. If we have hh heads, we typically set the head dimension to dk=dmodel/hd_k = d_{\text{model}} / h. This division ensures that the total computation across all heads remains comparable to a single attention operation over the full dimension.

For head ii, we compute the query, key, and value matrices through learned linear projections:

Qi=XWiQ,Ki=XWiK,Vi=XWiV\mathbf{Q}_i = \mathbf{X} \mathbf{W}^Q_i, \quad \mathbf{K}_i = \mathbf{X} \mathbf{W}^K_i, \quad \mathbf{V}_i = \mathbf{X} \mathbf{W}^V_i

where:

  • XRn×dmodel\mathbf{X} \in \mathbb{R}^{n \times d_{\text{model}}}: the input sequence with nn tokens, each represented as a dmodeld_{\text{model}}-dimensional vector
  • WiQRdmodel×dk\mathbf{W}^Q_i \in \mathbb{R}^{d_{\text{model}} \times d_k}: the learned query projection matrix for head ii, transforming each token into a dkd_k-dimensional query vector
  • WiKRdmodel×dk\mathbf{W}^K_i \in \mathbb{R}^{d_{\text{model}} \times d_k}: the learned key projection matrix for head ii, transforming each token into a dkd_k-dimensional key vector
  • WiVRdmodel×dv\mathbf{W}^V_i \in \mathbb{R}^{d_{\text{model}} \times d_v}: the learned value projection matrix for head ii (often dv=dkd_v = d_k)
  • Qi,KiRn×dk\mathbf{Q}_i, \mathbf{K}_i \in \mathbb{R}^{n \times d_k}: the resulting query and key matrices for head ii
  • ViRn×dv\mathbf{V}_i \in \mathbb{R}^{n \times d_v}: the resulting value matrix for head ii
  • dkd_k: the per-head dimension, typically dmodel/hd_{\text{model}} / h

Each head gets its own projection matrices, allowing it to learn different transformations of the input. Head 1 might project tokens to emphasize syntactic features, while head 2 emphasizes semantic similarity. The key insight is that different projection matrices create different "views" of the same input, each potentially highlighting different aspects of the token relationships.

Step 2: Independent Attention Computation

Each head computes standard scaled dot-product attention using its own projections. The attention mechanism compares queries against keys to determine relevance weights, then uses those weights to aggregate values:

headi=Attention(Qi,Ki,Vi)=softmax(QiKiTdk)Vi\text{head}_i = \text{Attention}(\mathbf{Q}_i, \mathbf{K}_i, \mathbf{V}_i) = \text{softmax}\left(\frac{\mathbf{Q}_i \mathbf{K}_i^T}{\sqrt{d_k}}\right) \mathbf{V}_i

where:

  • headiRn×dv\text{head}_i \in \mathbb{R}^{n \times d_v}: the output of attention head ii, containing contextual representations for all nn tokens
  • QiKiTRn×n\mathbf{Q}_i \mathbf{K}_i^T \in \mathbb{R}^{n \times n}: the matrix of raw attention scores, where entry (j,k)(j, k) measures how much query at position jj matches key at position kk
  • dk\sqrt{d_k}: the scaling factor that prevents dot products from growing too large as dkd_k increases (covered in the previous chapter on scaled dot-product attention)
  • softmax()\text{softmax}(\cdot): applied row-wise to convert scores into attention weights that sum to 1 for each query position
  • Vi\mathbf{V}_i: the value matrix whose rows are aggregated according to the attention weights

This produces hh separate outputs, each of shape Rn×dv\mathbb{R}^{n \times d_v}. The heads operate completely independently at this stage; there is no information sharing between them. This independence is what allows different heads to specialize in detecting different types of relationships.

Step 3: Concatenation

After computing all heads, we concatenate their outputs along the feature dimension. Each head produces an n×dvn \times d_v matrix, and stacking these side-by-side creates a larger matrix:

Concat(head1,head2,,headh)Rn×(hdv)\text{Concat}(\text{head}_1, \text{head}_2, \ldots, \text{head}_h) \in \mathbb{R}^{n \times (h \cdot d_v)}

where:

  • head1,,headh\text{head}_1, \ldots, \text{head}_h: the hh individual head outputs, each of shape Rn×dv\mathbb{R}^{n \times d_v}
  • hdvh \cdot d_v: the concatenated feature dimension, combining all heads' outputs

If dv=dk=dmodel/hd_v = d_k = d_{\text{model}} / h, then hdv=dmodelh \cdot d_v = d_{\text{model}}, and the concatenated output has the same dimension as the original input. This dimensional consistency is intentional: it allows multi-head attention to be used as a drop-in replacement for single-head attention. The concatenation preserves each head's contribution separately, letting the subsequent projection learn how to combine them.

Step 4: Output Projection

Finally, we apply a learned linear projection to the concatenated output. This projection maps the combined head outputs back to the model dimension:

MultiHead(X)=Concat(head1,,headh)WO\text{MultiHead}(\mathbf{X}) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) \mathbf{W}^O

where:

  • Concat(head1,,headh)Rn×(hdv)\text{Concat}(\text{head}_1, \ldots, \text{head}_h) \in \mathbb{R}^{n \times (h \cdot d_v)}: the concatenated outputs from all hh attention heads
  • WOR(hdv)×dmodel\mathbf{W}^O \in \mathbb{R}^{(h \cdot d_v) \times d_{\text{model}}}: the learned output projection matrix
  • MultiHead(X)Rn×dmodel\text{MultiHead}(\mathbf{X}) \in \mathbb{R}^{n \times d_{\text{model}}}: the final output, matching the input dimension

This projection serves two purposes. First, it combines information from all heads, allowing the model to learn how to mix their contributions. A head specializing in syntax might contribute more heavily for certain tokens, while a semantic head dominates for others. Second, it ensures the output dimension matches the input dimension, enabling residual connections in the transformer architecture.

The Complete Formula

Now that we've walked through each step conceptually, let's see how they combine into a complete formulation. Understanding this picture helps clarify how multi-head attention works: by running multiple specialized attention operations in parallel and learning to combine their insights.

The outer structure of multi-head attention takes the outputs from all heads and merges them:

MultiHead(Q,K,V)=Concat(head1,,headh)WO\text{MultiHead}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) \mathbf{W}^O

This formula reads from inside out: first compute each head, then concatenate their outputs, then apply the output projection. But what exactly is each head doing? Each head applies its own learned projections before computing attention:

headi=Attention(QWiQ,KWiK,VWiV)\text{head}_i = \text{Attention}(\mathbf{Q} \mathbf{W}^Q_i, \mathbf{K} \mathbf{W}^K_i, \mathbf{V} \mathbf{W}^V_i)

where:

  • Q,K,V\mathbf{Q}, \mathbf{K}, \mathbf{V}: the query, key, and value inputs (in self-attention, all three equal the input X\mathbf{X})
  • WiQ,WiK,WiV\mathbf{W}^Q_i, \mathbf{W}^K_i, \mathbf{W}^V_i: the per-head projection matrices for head ii, each transforming the full dmodeld_{\text{model}}-dimensional input into a smaller dkd_k-dimensional subspace
  • Attention()\text{Attention}(\cdot): the scaled dot-product attention function from the previous chapter
  • WO\mathbf{W}^O: the output projection that learns how to combine all heads' contributions

The design is flexible. In self-attention, Q=K=V=X\mathbf{Q} = \mathbf{K} = \mathbf{V} = \mathbf{X}, so the same input sequence projects to all three roles. But the same multi-head attention mechanism also supports cross-attention in encoder-decoder models, where queries come from the decoder and keys/values come from the encoder. The mathematical structure remains identical; only the inputs differ.

Why This Works: The Subspace Intuition

Why should projecting to lower-dimensional subspaces and computing attention there be useful? Think of it this way: the full dmodeld_{\text{model}}-dimensional embedding space contains many different types of information about each token. Its syntactic role, semantic meaning, position in the sentence, relationship to neighboring words. A single attention operation in the full space must somehow balance all of these considerations when deciding what to attend to.

By projecting to a lower-dimensional subspace, each head can focus on a particular "slice" of the information. One head's projection matrices might learn to emphasize dimensions that capture syntactic structure, causing that head to attend based on grammatical relationships. Another head's projections might emphasize semantic dimensions, causing it to attend based on meaning similarity. The output projection then learns to weigh and combine these different perspectives.

This is analogous to how ensemble methods work in machine learning: multiple weak learners, each capturing part of the pattern, combine to form a stronger predictor than any single model could achieve.

Parameter Count Analysis

A natural question arises: if we're running hh separate attention operations, doesn't that multiply our parameter count by hh? Actually, no. The dimension splitting ensures that multi-head attention uses approximately the same number of parameters as single-head attention with the full dimension.

Let's trace through the calculation:

  1. Per head: Each head has three projection matrices (WiQ\mathbf{W}^Q_i, WiK\mathbf{W}^K_i, WiV\mathbf{W}^V_i), each of size dmodel×dkd_{\text{model}} \times d_k. That's 3×dmodel×dk3 \times d_{\text{model}} \times d_k parameters per head.

  2. All heads: With hh heads, we multiply by hh: total QKV parameters are h×3×dmodel×dkh \times 3 \times d_{\text{model}} \times d_k. Here's the key insight: since dk=dmodel/hd_k = d_{\text{model}} / h, we have h×dk=dmodelh \times d_k = d_{\text{model}}. This simplifies our expression to 3×dmodel23 \times d_{\text{model}}^2.

  3. Output projection: The matrix WO\mathbf{W}^O has size (hdv)×dmodel(h \cdot d_v) \times d_{\text{model}}. With dv=dk=dmodel/hd_v = d_k = d_{\text{model}} / h, the input dimension is h(dmodel/h)=dmodelh \cdot (d_{\text{model}} / h) = d_{\text{model}}, giving us dmodel×dmodel=dmodel2d_{\text{model}} \times d_{\text{model}} = d_{\text{model}}^2 parameters.

  4. Grand total: 3×dmodel2+dmodel2=4×dmodel23 \times d_{\text{model}}^2 + d_{\text{model}}^2 = 4 \times d_{\text{model}}^2 parameters.

This is exactly what we'd have with single-head attention using the full dimension, plus one additional output projection. The key benefit of multi-head attention is that we get hh different attention patterns, each potentially specializing in different relationships, for essentially the same parameter budget as a single larger head.

Implementation

With the mathematics clear, let's translate these ideas into working code. We'll build multi-head attention from scratch using NumPy, making every operation explicit so you can see exactly how the formulas map to implementation. This hands-on approach will solidify your understanding and prepare you to work with production implementations in PyTorch or TensorFlow.

Our implementation needs two building blocks: first, the scaled dot-product attention we covered in the previous chapter; second, the multi-head wrapper that orchestrates multiple attention operations in parallel.

Building Block 1: Scaled Dot-Product Attention

We start with a function that computes attention for tensors that already include a head dimension. This allows us to process all heads simultaneously through batched matrix multiplication:

In[2]:
Code
import numpy as np


def softmax(x, axis=-1):
    """Numerically stable softmax."""
    exp_x = np.exp(x - x.max(axis=axis, keepdims=True))
    return exp_x / exp_x.sum(axis=axis, keepdims=True)


def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Compute scaled dot-product attention.

    Args:
        Q: Queries, shape (batch, heads, seq_len, d_k)
        K: Keys, shape (batch, heads, seq_len, d_k)
        V: Values, shape (batch, heads, seq_len, d_v)
        mask: Optional mask, shape broadcastable to (batch, heads, seq_len, seq_len)

    Returns:
        output: Attention output, shape (batch, heads, seq_len, d_v)
        weights: Attention weights, shape (batch, heads, seq_len, seq_len)
    """
    d_k = Q.shape[-1]

    # Compute attention scores
    scores = Q @ K.transpose(0, 1, 3, 2) / np.sqrt(d_k)

    # Apply mask if provided
    if mask is not None:
        scores = np.where(mask, scores, -1e9)

    # Softmax to get attention weights
    weights = softmax(scores, axis=-1)

    # Weighted sum of values
    output = weights @ V

    return output, weights

Notice the tensor shape: (batch, heads, seq_len, d_k). By placing the head dimension second, we can compute attention scores for all heads at once. The matrix multiplication Q @ K.transpose(...) produces a (batch, heads, seq_len, seq_len) tensor containing hh separate attention matrices, one per head. This is the key to efficient parallel computation.

Building Block 2: The Multi-Head Attention Module

Now we build the main module. The MultiHeadAttention class encapsulates all the projection matrices and orchestrates the four-step process we described earlier.

A key implementation detail: instead of storing separate WiQ\mathbf{W}^Q_i, WiK\mathbf{W}^K_i, WiV\mathbf{W}^V_i matrices for each head, we use a single combined projection matrix W_qkv of shape (d_model, 3 * d_model). This computes all queries, keys, and values for all heads in one matrix multiplication. The result is then split and reshaped to separate the heads. This approach is more efficient on GPUs because it maximizes parallelism and minimizes memory transfers.

In[3]:
Code
class MultiHeadAttention:
    """
    Multi-head attention implementation.

    Args:
        d_model: Model dimension
        num_heads: Number of attention heads
    """

    def __init__(self, d_model, num_heads, seed=42):
        assert d_model % num_heads == 0, (
            "d_model must be divisible by num_heads"
        )

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        # Initialize projection matrices with small random values
        np.random.seed(seed)
        scale = np.sqrt(2.0 / d_model)

        # Combined QKV projections: (d_model, 3 * d_model)
        self.W_qkv = np.random.randn(d_model, 3 * d_model) * scale

        # Output projection: (d_model, d_model)
        self.W_o = np.random.randn(d_model, d_model) * scale

    def split_heads(self, x):
        """
        Split the last dimension into (num_heads, d_k).

        Args:
            x: Input tensor, shape (batch, seq_len, d_model)

        Returns:
            Reshaped tensor, shape (batch, num_heads, seq_len, d_k)
        """
        batch, seq_len, _ = x.shape
        x = x.reshape(batch, seq_len, self.num_heads, self.d_k)
        return x.transpose(0, 2, 1, 3)

    def combine_heads(self, x):
        """
        Inverse of split_heads.

        Args:
            x: Input tensor, shape (batch, num_heads, seq_len, d_k)

        Returns:
            Reshaped tensor, shape (batch, seq_len, d_model)
        """
        batch, _, seq_len, _ = x.shape
        x = x.transpose(0, 2, 1, 3)
        return x.reshape(batch, seq_len, self.d_model)

    def forward(self, x, mask=None):
        """
        Compute multi-head self-attention.

        Args:
            x: Input tensor, shape (batch, seq_len, d_model)
            mask: Optional attention mask

        Returns:
            output: Attention output, shape (batch, seq_len, d_model)
            weights: Attention weights per head, shape (batch, num_heads, seq_len, seq_len)
        """
        batch, seq_len, _ = x.shape

        # Step 1: Project to Q, K, V (combined for efficiency)
        qkv = x @ self.W_qkv  # (batch, seq_len, 3 * d_model)

        # Split into Q, K, V
        Q, K, V = np.split(qkv, 3, axis=-1)

        # Step 2: Split into multiple heads
        Q = self.split_heads(Q)  # (batch, num_heads, seq_len, d_k)
        K = self.split_heads(K)
        V = self.split_heads(V)

        # Step 3: Compute attention for all heads in parallel
        attn_output, attn_weights = scaled_dot_product_attention(Q, K, V, mask)

        # Step 4: Combine heads
        combined = self.combine_heads(attn_output)  # (batch, seq_len, d_model)

        # Step 5: Output projection
        output = combined @ self.W_o

        return output, attn_weights

The split_heads and combine_heads methods handle the tensor reshaping that makes parallel head computation possible. split_heads takes a tensor of shape (batch, seq_len, d_model) and reorganizes it to (batch, num_heads, seq_len, d_k), separating the head dimension. combine_heads reverses this transformation after attention is computed.

Testing the Implementation

Let's verify our implementation produces the expected output shapes and that the attention weights are properly normalized:

In[4]:
Code
# Create a small example
batch_size = 1
seq_len = 6
d_model = 64
num_heads = 4

# Random input sequence
np.random.seed(123)
x = np.random.randn(batch_size, seq_len, d_model)

# Create multi-head attention module
mha = MultiHeadAttention(d_model, num_heads)

# Forward pass
output, attention_weights = mha.forward(x)
Out[5]:
Console
Input shape:             (1, 6, 64)
Output shape:            (1, 6, 64)
Attention weights shape: (1, 4, 6, 6)

Head dimension (d_k):    16
Number of heads:         4

The output tensor has shape (1, 6, 64), matching the input exactly. This dimensional consistency is essential for residual connections in transformer architectures. The attention weights have shape (1, 4, 6, 6), representing a separate 6×66 \times 6 attention matrix for each of the 4 heads.

Out[6]:
Console
Attention weights sum per query position (head 1):
[1. 1. 1. 1. 1. 1.]

Each row sums to exactly 1.0, confirming that softmax normalization is working correctly. This is expected: for each query position, the attention weights across all key positions must form a valid probability distribution.

How Multi-Head Attention Transforms Representations

Let's visualize how multi-head attention changes the token representations. We'll use PCA to project the 64-dimensional embeddings down to 2D, allowing us to see how the representations shift after passing through the attention layer.

In[7]:
Code
from sklearn.decomposition import PCA

# Get the input and output representations
input_embeds = x[0]  # Shape: (6, 64)
output_embeds = output[0]  # Shape: (6, 64)

# Combine for PCA fitting
combined = np.vstack([input_embeds, output_embeds])

# Fit PCA on combined data
pca = PCA(n_components=2)
combined_2d = pca.fit_transform(combined)

# Split back
input_2d = combined_2d[:6]
output_2d = combined_2d[6:]

# Token labels for this example
tokens_example = [f"Pos {i}" for i in range(6)]
Out[8]:
Visualization
2D scatter plot showing input and output embeddings with arrows showing transformation.
Embedding transformation through multi-head attention. Hollow circles show input positions, filled circles show outputs after attention. Arrows indicate how each position's representation shifts as it incorporates contextual information from other positions.

The visualization shows how multi-head attention repositions each token's representation in embedding space. The arrows indicate the direction and magnitude of change. Tokens that attended strongly to semantically related positions move toward each other, while tokens that focused on different aspects of the input shift in different directions. This is the contextual enrichment in action: each position now carries information gathered from across the sequence.

Visualizing Attention Heads

The true power of multi-head attention becomes apparent when we visualize what different heads are doing. Even with random initialization (before any training), the different projection matrices cause each head to compute different attention patterns. After training on language data, these patterns become meaningful: heads specialize to detect specific types of relationships.

Let's visualize the attention patterns for each head on a simple sentence. We'll create random embeddings and see how the four heads distribute their attention differently:

In[9]:
Code
# Create a simple sentence for visualization
sentence = ["The", "cat", "sat", "on", "the", "mat"]
seq_len = len(sentence)

# Create embeddings (using random for demonstration)
np.random.seed(42)
d_model = 32
num_heads = 4

# Simulate embeddings that might arise from word similarities
# "The" tokens should be similar, "cat" and "mat" rhyme, etc.
embeddings = np.random.randn(1, seq_len, d_model)

# Create MHA and compute attention
mha = MultiHeadAttention(d_model, num_heads, seed=42)
output, attention_weights = mha.forward(embeddings)
Out[10]:
Visualization
Heatmap showing attention weights for head 1.
Head 1 attention pattern showing how each token distributes attention across the sequence.
Heatmap showing attention weights for head 2.
Head 2 attention pattern with a different learned focus than head 1.
Heatmap showing attention weights for head 3.
Head 3 attention pattern demonstrating further specialization.
Heatmap showing attention weights for head 4.
Head 4 attention pattern completing the four-head perspective.

The heatmaps reveal that even with random initialization, each head produces a distinct attention pattern. Some heads distribute attention more uniformly, while others concentrate it on specific positions. In a trained model, these differences become meaningful and interpretable: one head might consistently attend to the previous word (useful for language modeling), another to syntactically related words (useful for parsing), and another to semantically similar words (useful for coreference resolution).

The key observation is that the same input, "The cat sat on the mat," produces four different perspectives on token relationships. This diversity is precisely what we wanted when we motivated multi-head attention: the ability to capture multiple types of dependencies simultaneously.

Quantifying Head Focus with Entropy

We can quantify how "focused" or "distributed" each head's attention is using entropy. A head with low entropy concentrates attention on a few positions, while a head with high entropy distributes attention more uniformly. This provides a numerical measure of the diversity we observed visually.

In[11]:
Code
def attention_entropy(weights):
    """
    Compute the average entropy of attention distributions.

    Higher entropy means more uniform attention (less focused).
    Lower entropy means attention is concentrated on fewer positions.
    """
    # Avoid log(0) by adding small epsilon
    eps = 1e-10
    # Compute entropy for each query position, average over positions
    entropy = -np.sum(weights * np.log(weights + eps), axis=-1)
    return entropy.mean()


# Compute entropy for each head
head_entropies = []
for h in range(num_heads):
    entropy = attention_entropy(attention_weights[0, h])
    head_entropies.append(entropy)

# Maximum possible entropy (uniform distribution over 6 positions)
max_entropy = np.log(seq_len)
Out[12]:
Visualization
Bar chart comparing entropy values for four attention heads.
Attention entropy by head. Lower entropy indicates more focused attention patterns, while higher entropy indicates more distributed attention. The dashed line shows maximum possible entropy (uniform attention across all 6 positions).
Out[13]:
Console
Entropy analysis by head:
  Maximum possible entropy (uniform over 6 positions): 1.792

  Head 1: entropy = 1.206 (32.7% focused)
  Head 2: entropy = 0.925 (48.4% focused)
  Head 3: entropy = 1.259 (29.7% focused)
  Head 4: entropy = 0.843 (53.0% focused)

Heads with lower entropy relative to the maximum are more "focused," concentrating their attention on specific positions. Heads closer to the maximum entropy distribute attention more uniformly. This variation in focus is another dimension of head diversity: some heads act as sharp selectors, while others aggregate information more broadly.

Head Specialization

Research on trained transformers reveals that attention heads develop distinct specializations. Some commonly observed patterns include:

  • Positional heads: Attend to fixed relative positions (e.g., always attend to the previous token, or always attend to the first token)
  • Syntactic heads: Track grammatical relationships like subject-verb or modifier-noun connections
  • Coreference heads: Link pronouns to their antecedents across the sequence
  • Rare word heads: Focus on infrequent tokens that carry high information content
  • Separator heads: Attend to punctuation and sentence boundaries

This specialization emerges purely from training on language modeling objectives. No explicit supervision tells heads what to focus on; they discover useful patterns through gradient descent.

In[14]:
Code
# Simulate different head specialization patterns
np.random.seed(42)
seq_len = 8
tokens = ["[CLS]", "The", "quick", "fox", "jumps", "over", "dogs", "[SEP]"]


# Create simulated attention patterns for different head types
def create_positional_head(n, offset=-1):
    """Head that attends to a fixed relative position."""
    weights = np.zeros((n, n))
    for i in range(n):
        target = max(0, min(n - 1, i + offset))
        weights[i, target] = 0.8
        # Some attention to self
        weights[i, i] = 0.1
        # Distribute remaining mass
        remaining = 0.1 / (n - 2) if n > 2 else 0
        for j in range(n):
            if j != i and j != target:
                weights[i, j] = remaining
    return weights


def create_first_token_head(n):
    """Head that always attends to [CLS] token."""
    weights = np.zeros((n, n))
    weights[:, 0] = 0.7  # Strong attention to first token
    # Distribute remaining mass
    for i in range(n):
        weights[i, i] += 0.2  # Some self-attention
        remaining = 0.1 / (n - 2) if n > 2 else 0
        for j in range(1, n):
            if j != i:
                weights[i, j] += remaining
    return weights


def create_local_context_head(n, window=2):
    """Head that attends to nearby tokens."""
    weights = np.zeros((n, n))
    for i in range(n):
        for j in range(n):
            distance = abs(i - j)
            if distance <= window:
                weights[i, j] = 1.0 / (distance + 1)
        weights[i] = weights[i] / weights[i].sum()
    return weights


# Generate example patterns
prev_token_head = create_positional_head(seq_len, offset=-1)
cls_head = create_first_token_head(seq_len)
local_head = create_local_context_head(seq_len, window=2)
Out[15]:
Visualization
Heatmap showing diagonal attention pattern for previous token head.
Previous Token Head: attends strongly to position i-1, capturing local sequential dependencies.
Heatmap showing column attention pattern for CLS aggregation head.
[CLS] Aggregation Head: all tokens attend to the first token, useful for sequence-level classification.
Heatmap showing banded attention pattern for local context head.
Local Context Head: attention concentrated within a sliding window, capturing phrase-level context.

The three patterns illustrate different specializations:

  • Previous token head: The diagonal pattern shows each token attending primarily to its predecessor. This captures local sequential dependencies, useful for language modeling where the previous word strongly predicts the next.

  • [CLS] aggregation head: The first column lights up, showing all tokens attending to the special classification token. This pattern aggregates information from the entire sequence into one position, commonly used for classification tasks.

  • Local context head: The banded pattern shows attention concentrated within a local window. This captures phrase-level relationships without being distracted by distant tokens.

Multi-Head vs Single-Head: An Empirical Comparison

How much does multi-head attention actually help? Let's compare the representational capacity of single-head versus multi-head attention on a simple task: capturing diverse pairwise relationships.

In[16]:
Code
def compute_attention_diversity(weights, num_heads):
    """
    Measure how different the attention patterns are across heads.
    Higher values indicate more diverse (specialized) heads.
    """
    if num_heads == 1:
        return 0.0

    # Compute pairwise Jensen-Shannon divergence between heads
    from scipy.spatial.distance import jensenshannon

    total_div = 0
    count = 0
    for i in range(num_heads):
        for j in range(i + 1, num_heads):
            # Average JS divergence across all query positions
            for q in range(weights.shape[2]):
                div = jensenshannon(weights[0, i, q], weights[0, j, q])
                if not np.isnan(div):
                    total_div += div
                    count += 1

    return total_div / count if count > 0 else 0.0


# Compare 1, 2, 4, 8 heads with the same total dimension
d_model = 64
head_configs = [1, 2, 4, 8]
diversity_scores = []

np.random.seed(42)
x = np.random.randn(1, 8, d_model)

for num_heads in head_configs:
    mha = MultiHeadAttention(d_model, num_heads, seed=42)
    _, weights = mha.forward(x)
    diversity = compute_attention_diversity(weights, num_heads)
    diversity_scores.append(diversity)
Out[17]:
Visualization
Bar chart showing increasing diversity scores from 1 to 8 heads.
Attention pattern diversity increases with the number of heads. With more heads, the model can simultaneously capture more distinct relationship types, though returns diminish as heads begin to overlap in function.
Out[18]:
Console
Diversity scores by number of heads:
  1 head(s): 0.0000
  2 head(s): 0.5962
  4 head(s): 0.5770
  8 head(s): 0.5774

The diversity metric reveals a clear pattern: more heads lead to more varied attention patterns. With a single head, the diversity score is 0 by definition since there's nothing to compare. With 2 heads, the model can already capture distinct patterns. As we increase to 4 and 8 heads, diversity continues to grow, though the incremental gains diminish. This suggests that beyond a certain point, additional heads may begin to learn redundant patterns.

Practical Considerations

When implementing multi-head attention in production systems, several practical considerations come into play.

Choosing the Number of Heads

The original transformer paper used 8 heads with a 512-dimensional model, giving each head 64 dimensions. Modern large models often use more heads, but the per-head dimension typically stays between 64 and 128.

In[19]:
Code
# Real model configurations
model_configs = {
    "Transformer (orig.)": {"d_model": 512, "num_heads": 8},
    "BERT-base": {"d_model": 768, "num_heads": 12},
    "GPT-2 Small": {"d_model": 768, "num_heads": 12},
    "BERT-large": {"d_model": 1024, "num_heads": 16},
    "GPT-2 Medium": {"d_model": 1024, "num_heads": 16},
    "GPT-2 Large": {"d_model": 1280, "num_heads": 20},
    "GPT-3 (175B)": {"d_model": 12288, "num_heads": 96},
}

# Compute d_k for each
for name, config in model_configs.items():
    config["d_k"] = config["d_model"] // config["num_heads"]
Out[20]:
Visualization
Scatter plot showing d_model vs num_heads for transformer models.
Multi-head attention configurations across popular transformer models. The x-axis shows model dimension, y-axis shows number of heads. Point size indicates the per-head dimension (d_k). Despite the wide range in model sizes, per-head dimensions cluster around 64-128.

The visualization reveals a consistent design pattern: as model dimension increases, the number of heads scales proportionally to maintain per-head dimensions between 64 and 128. This suggests an empirically validated "sweet spot" for head expressiveness. Too few dimensions per head limit each head's ability to capture complex patterns; too many heads with too few dimensions fragment the representation unhelpfully.

Computational Efficiency

Although multi-head attention computes hh separate attention operations, modern implementations parallelize this efficiently. The key insight is that we can stack all heads into a single tensor and use batched matrix multiplication.

In[21]:
Code
# Efficient batched implementation
def efficient_multihead_attention(x, W_qkv, W_o, num_heads):
    """
    Efficient multi-head attention using batched operations.

    All heads are computed in parallel using a single matrix multiplication,
    then reshaped to separate the head dimension.
    """
    batch, seq_len, d_model = x.shape
    d_k = d_model // num_heads

    # Single projection for all heads: (batch, seq, 3*d_model)
    qkv = x @ W_qkv

    # Reshape to (batch, seq, 3, num_heads, d_k)
    qkv = qkv.reshape(batch, seq_len, 3, num_heads, d_k)

    # Transpose to (3, batch, num_heads, seq, d_k) and unpack
    qkv = qkv.transpose(2, 0, 3, 1, 4)
    Q, K, V = qkv[0], qkv[1], qkv[2]

    # Batched attention: (batch, num_heads, seq, seq)
    scores = (Q @ K.transpose(0, 1, 3, 2)) / np.sqrt(d_k)
    weights = softmax(scores)

    # (batch, num_heads, seq, d_k)
    attn_output = weights @ V

    # Reshape back: (batch, seq, d_model)
    output = attn_output.transpose(0, 2, 1, 3).reshape(batch, seq_len, d_model)

    return output @ W_o

This batched approach processes all heads with the same matrix operations, leveraging GPU parallelism. The computational cost is essentially equivalent to single-head attention with some reshaping overhead.

Head Pruning

Research has shown that not all attention heads are equally important. Some heads can be removed after training with minimal impact on performance. This head pruning technique reduces computation and memory for inference.

Studies on BERT found that many heads could be pruned without significant accuracy loss, suggesting some redundancy in the learned representations. This has practical implications for deploying models in resource-constrained environments.

Limitations and Impact

Multi-head attention is one of the key innovations that made transformers successful. By allowing the model to jointly attend to information from different representation subspaces at different positions, it provides a flexible mechanism for capturing complex relationships in language. Each head learns its own projection matrices, creating specialized views of the input that combine to form rich contextual representations.

The primary limitation is computational. Multi-head attention still has O(n2)O(n^2) complexity in sequence length nn, and the overhead of multiple heads adds constant factors to both computation and memory. For a sequence of 4,096 tokens with 32 heads, the model must compute and store 32 separate attention matrices, each of size 4096×40964096 \times 4096. This quadratic scaling remains a bottleneck for very long documents, motivating research into efficient attention variants like sparse attention (which attends to a subset of positions), linear attention (which reformulates the computation to avoid explicit attention matrices), and methods like FlashAttention (which optimizes memory access patterns).

Another consideration is interpretability. While we can visualize attention patterns, understanding what each head has learned remains challenging. Research has shown that some heads appear redundant, attending to similar patterns, while others develop specializations that don't correspond to human-interpretable categories. This makes it difficult to debug unexpected model behavior or explain predictions. The observation that many heads can be pruned without significant accuracy loss suggests that multi-head attention may include built-in redundancy, which could be beneficial for robustness but represents inefficiency during inference.

Despite these limitations, multi-head attention works well across virtually all NLP tasks. The ability to capture multiple types of relationships simultaneously, without explicit supervision about what those relationships should be, is a key reason why transformers generalize so well. From named entity recognition to machine translation to question answering, multi-head attention provides the representational flexibility needed to handle diverse linguistic phenomena.

Key Parameters

When configuring multi-head attention, the following parameters have the most significant impact on model behavior:

  • num_heads (h): The number of parallel attention heads. Common values range from 8 to 96 depending on model size. More heads enable greater specialization but require the model dimension to be evenly divisible. Start with 8 heads for smaller models and scale proportionally with model dimension.

  • d_model: The model dimension, which determines the total representation capacity. Must be divisible by num_heads. Standard values include 512 (original transformer), 768 (BERT-base), 1024 (BERT-large), and 12288 (GPT-3).

  • d_k (head dimension): The dimension per head, computed as d_model // num_heads. Values between 64 and 128 work well in practice. Smaller values allow more heads but limit each head's expressiveness.

  • d_v (value dimension): Often set equal to d_k. Can be set independently if you want asymmetric query-key versus value dimensions, though this is uncommon.

  • Initialization scale: Projection matrices are typically initialized with variance 2dmodel\frac{2}{d_{\text{model}}} or similar schemes. Poor initialization can cause attention weights to become too uniform or too peaked early in training.

Summary

Multi-head attention extends single-head attention by running multiple attention operations in parallel, each with its own learned projections. This architectural choice significantly impacts model capacity and performance.

Key takeaways from this chapter:

  • Parallel heads: Multi-head attention computes hh independent attention operations, each learning different relationship patterns between tokens. Each head can specialize in detecting specific linguistic phenomena.
  • Dimension splitting: The model dimension dmodeld_{\text{model}} is divided among heads, with each head operating on dimension dk=dmodel/hd_k = d_{\text{model}} / h. This keeps total computation comparable to single-head attention while enabling multiple perspectives.
  • Four-step process: (1) Project input to Q, K, V for each head using learned matrices WiQ\mathbf{W}^Q_i, WiK\mathbf{W}^K_i, WiV\mathbf{W}^V_i; (2) compute scaled dot-product attention independently per head; (3) concatenate head outputs; (4) apply final output projection WO\mathbf{W}^O.
  • Head specialization: Different heads learn to focus on different aspects: positional patterns, syntactic relationships, semantic similarity, or coreference links. This emerges from training without explicit supervision.
  • Parameter efficiency: The total parameter count is 4×dmodel24 \times d_{\text{model}}^2, similar to single-head attention with an output projection. The gain is in representational diversity, not parameter count.
  • Practical implementations: Efficient implementations batch all heads into single tensor operations, leveraging GPU parallelism through combined QKV projections.

In the next chapter, we'll examine the computational complexity of attention in detail. Understanding the O(n2)O(n^2) scaling (where nn is sequence length) and its implications is essential for working with long sequences and motivates the efficient attention variants used in modern long-context models.

Quiz

Ready to test your understanding? Take this quick quiz to reinforce what you've learned about multi-head attention.

Loading component...
Track your reading progress

Sign in to mark chapters as read and track your learning journey

Sign in →

Comments

Reference

BIBTEXAcademic
@misc{multiheadattentionparallelattentionforricherrepresentations, author = {Michael Brenndoerfer}, title = {Multi-Head Attention: Parallel Attention for Richer Representations}, year = {2025}, url = {https://mbrenndoerfer.com/writing/multi-head-attention-transformers}, organization = {mbrenndoerfer.com}, note = {Accessed: 2025-12-19} }
APAAcademic
Michael Brenndoerfer (2025). Multi-Head Attention: Parallel Attention for Richer Representations. Retrieved from https://mbrenndoerfer.com/writing/multi-head-attention-transformers
MLAAcademic
Michael Brenndoerfer. "Multi-Head Attention: Parallel Attention for Richer Representations." 2025. Web. 12/19/2025. <https://mbrenndoerfer.com/writing/multi-head-attention-transformers>.
CHICAGOAcademic
Michael Brenndoerfer. "Multi-Head Attention: Parallel Attention for Richer Representations." Accessed 12/19/2025. https://mbrenndoerfer.com/writing/multi-head-attention-transformers.
HARVARDAcademic
Michael Brenndoerfer (2025) 'Multi-Head Attention: Parallel Attention for Richer Representations'. Available at: https://mbrenndoerfer.com/writing/multi-head-attention-transformers (Accessed: 12/19/2025).
SimpleBasic
Michael Brenndoerfer (2025). Multi-Head Attention: Parallel Attention for Richer Representations. https://mbrenndoerfer.com/writing/multi-head-attention-transformers
Michael Brenndoerfer

About the author: Michael Brenndoerfer

All opinions expressed here are my own and do not reflect the views of my employer.

Michael currently works as an Associate Director of Data Science at EQT Partners in Singapore, leading AI and data initiatives across private capital investments.

With over a decade of experience spanning private equity, management consulting, and software engineering, he specializes in building and scaling analytics capabilities from the ground up. He has published research in leading AI conferences and holds expertise in machine learning, natural language processing, and value creation through data.

Stay updated

Get notified when I publish new articles on data and AI, private equity, technology, and more.

No spam, unsubscribe anytime.

or

Create a free account to unlock exclusive features, track your progress, and join the conversation.

No popupsUnobstructed readingCommenting100% Free