Search

Search articles

Cross-Attention: Connecting Encoder and Decoder in Transformers

Michael BrenndoerferUpdated June 18, 202536 min read

Master cross-attention, the mechanism that bridges encoder and decoder in sequence-to-sequence transformers. Learn how queries from the decoder attend to encoder keys and values for translation and summarization.

Track your reading progress

Sign in to mark chapters as read and track your learning journey

Sign in →
Reading Level

Choose your expertise level to adjust how many terms are explained. Beginners see more tooltips, experts see fewer to maintain reading flow. Hover over underlined terms for instant definitions.

Cross-Attention

In the previous chapters, we explored encoder-only and decoder-only transformer architectures, each using self-attention to let tokens within a sequence attend to each other. But what happens when you need to connect two different sequences, such as when translating from English to French or summarizing a document into a shorter form? This is where cross-attention comes in, the mechanism that lets the decoder "look at" the encoder's output while generating new tokens.

Cross-attention is the bridge between encoder and decoder in sequence-to-sequence transformers. While self-attention computes relationships within a single sequence, cross-attention computes relationships between two sequences: the decoder queries the encoder's representations to extract relevant information for generating each output token.

Self-Attention vs. Cross-Attention

Self-attention and cross-attention share the same mathematical formulation: queries, keys, and values combined with scaled dot-product attention. The critical difference lies in where these components come from.

In self-attention, all three projections derive from the same sequence. Given an input matrix XX containing token representations, we compute queries, keys, and values as:

Q=XWQ,K=XWK,V=XWVQ = X W^Q, \quad K = X W^K, \quad V = X W^V

where:

  • XRn×dX \in \mathbb{R}^{n \times d}: the input sequence matrix, with nn tokens each represented as a dd-dimensional vector
  • WQRd×dkW^Q \in \mathbb{R}^{d \times d_k}: the learned query projection matrix
  • WKRd×dkW^K \in \mathbb{R}^{d \times d_k}: the learned key projection matrix
  • WVRd×dvW^V \in \mathbb{R}^{d \times d_v}: the learned value projection matrix
  • Q,KRn×dkQ, K \in \mathbb{R}^{n \times d_k}: the resulting query and key matrices
  • VRn×dvV \in \mathbb{R}^{n \times d_v}: the resulting value matrix

Every token attends to every other token (and itself) within the same sequence.

In cross-attention, the queries come from one sequence while keys and values come from a different sequence. This separation is the key distinction: the decoder generates queries to ask "what do I need?" while the encoder provides keys and values to answer "here's what I have."

Q=XdecWQ,K=XencWK,V=XencWVQ = X_{\text{dec}} W^Q, \quad K = X_{\text{enc}} W^K, \quad V = X_{\text{enc}} W^V

where:

  • XdecRndec×dX_{\text{dec}} \in \mathbb{R}^{n_{\text{dec}} \times d}: the decoder's current representations (what's been generated so far)
  • XencRnenc×dX_{\text{enc}} \in \mathbb{R}^{n_{\text{enc}} \times d}: the encoder's output representations (the processed source sequence)
  • WQ,WKRd×dkW^Q, W^K \in \mathbb{R}^{d \times d_k}: the learned projection matrices for queries and keys
  • WVRd×dvW^V \in \mathbb{R}^{d \times d_v}: the learned projection matrix for values
  • ndecn_{\text{dec}}: the number of tokens in the decoder sequence (target length so far)
  • nencn_{\text{enc}}: the number of tokens in the encoder sequence (source length)
  • dd: the model dimension (size of each token's representation)
Cross-Attention

Cross-attention allows tokens in one sequence (the decoder) to attend to tokens in a different sequence (the encoder). Queries come from the decoder, while keys and values come from the encoder, enabling information to flow from the source sequence to the target sequence during generation.

This asymmetry is what makes cross-attention powerful for tasks like translation. The decoder asks questions (queries) based on what it has generated so far, and the encoder's representations provide the answers (keys and values).

The Cross-Attention Formulation

To understand cross-attention deeply, we need to think about what problem it solves. Imagine you're translating "The cat sat on the mat" into French. You've already generated "Le chat" (the cat), and now you need to produce the next word. Which part of the English sentence should you focus on? The answer is "sat," because that's the verb that follows the subject.

This is precisely what cross-attention computes: for each position in the target sequence, it determines which positions in the source sequence are most relevant, then gathers information from those positions. The mechanism needs to answer two questions simultaneously:

  1. Where should I look? Each decoder position needs to identify which encoder positions contain relevant information.
  2. What should I extract? Once the relevant positions are identified, the decoder needs to pull out the appropriate information.

The query-key-value framework elegantly separates these concerns. Queries encode what the decoder is looking for, keys encode what each encoder position offers, and values encode the actual information to transmit.

Building the Formula Step by Step

Let's construct the cross-attention formula piece by piece, understanding why each component is necessary.

Step 1: Create queries from the decoder. Each decoder position needs to express what information it's seeking. We project the decoder representations into a "query space":

Q=XdecWQQ = X_{\text{dec}} W^Q

where XdecRndec×dX_{\text{dec}} \in \mathbb{R}^{n_{\text{dec}} \times d} contains the decoder's current representations and WQRd×dkW^Q \in \mathbb{R}^{d \times d_k} is a learned projection matrix. The resulting QRndec×dkQ \in \mathbb{R}^{n_{\text{dec}} \times d_k} has one row per decoder position, each encoding "what am I looking for?"

Step 2: Create keys from the encoder. Each encoder position needs to advertise what information it contains. We project the encoder output into a "key space" using a different projection:

K=XencWKK = X_{\text{enc}} W^K

where XencRnenc×dX_{\text{enc}} \in \mathbb{R}^{n_{\text{enc}} \times d} is the encoder output and WKRd×dkW^K \in \mathbb{R}^{d \times d_k} is another learned matrix. The resulting KRnenc×dkK \in \mathbb{R}^{n_{\text{enc}} \times d_k} has one row per encoder position, each encoding "here's what I offer."

Step 3: Create values from the encoder. When attention flows to an encoder position, we need to specify what information actually transfers. The value projection captures this:

V=XencWVV = X_{\text{enc}} W^V

where WVRd×dvW^V \in \mathbb{R}^{d \times d_v} projects into a "value space." The resulting VRnenc×dvV \in \mathbb{R}^{n_{\text{enc}} \times d_v} contains the content that will be aggregated.

Step 4: Compute similarity scores. Now we need to measure how well each query matches each key. The dot product is ideal for this: when two vectors point in similar directions, their dot product is large. We compute all pairwise scores with a single matrix multiplication:

S=QKTS = Q K^T

The resulting score matrix SRndec×nencS \in \mathbb{R}^{n_{\text{dec}} \times n_{\text{enc}}} has entry (i,j)(i, j) equal to the dot product between decoder query ii and encoder key jj. This is where the rectangular shape emerges: we're comparing ndecn_{\text{dec}} queries against nencn_{\text{enc}} keys.

Step 5: Scale to prevent saturation. In high dimensions, dot products tend to have large magnitudes, which would push softmax into saturation (producing near-one-hot outputs with vanishing gradients). Dividing by dk\sqrt{d_k} normalizes the scores:

Sscaled=QKTdkS_{\text{scaled}} = \frac{Q K^T}{\sqrt{d_k}}
Out[2]:
Visualization
Histogram of raw attention scores showing wide spread from -4 to +4.
Raw dot product scores before scaling. In high dimensions, scores can have large magnitudes (here from roughly -4 to +4 with d=64).
Histogram of scaled attention scores showing compressed spread from -0.5 to +0.5.
Scores after dividing by sqrt(d_k). Scaling compresses the range, keeping softmax in its sensitive region where gradients flow well.

The histograms illustrate why scaling matters. With dk=64d_k = 64, raw dot products have a standard deviation around 8, producing scores that can easily reach ±20\pm 20 or more. After scaling by 64=8\sqrt{64} = 8, the standard deviation drops to approximately 1, keeping scores in a range where softmax produces meaningful gradients.

Step 6: Convert scores to attention weights. We apply softmax row-wise to convert each row of scores into a probability distribution over encoder positions:

A=softmax(Sscaled)A = \text{softmax}(S_{\text{scaled}})

Each row of AA sums to 1.0, representing how one decoder position distributes its attention across all encoder positions.

Step 7: Aggregate values. Finally, we use the attention weights to compute a weighted sum of encoder values for each decoder position:

Output=AV\text{Output} = A V

The output has shape (ndec×dv)(n_{\text{dec}} \times d_v): one row per decoder position, each containing information gathered from the encoder.

The Complete Formula

Combining all steps, we arrive at the cross-attention formula:

CrossAttention(Q,K,V)=softmax(QKTdk)V\text{CrossAttention}(Q, K, V) = \text{softmax}\left(\frac{Q K^T}{\sqrt{d_k}}\right) V

where:

  • Q=XdecWQRndec×dkQ = X_{\text{dec}} W^Q \in \mathbb{R}^{n_{\text{dec}} \times d_k}: queries from the decoder, encoding "what information am I looking for?"
  • K=XencWKRnenc×dkK = X_{\text{enc}} W^K \in \mathbb{R}^{n_{\text{enc}} \times d_k}: keys from the encoder, encoding "what information do I have to offer?"
  • V=XencWVRnenc×dvV = X_{\text{enc}} W^V \in \mathbb{R}^{n_{\text{enc}} \times d_v}: values from the encoder, encoding "what content should I contribute?"
  • QKTRndec×nencQ K^T \in \mathbb{R}^{n_{\text{dec}} \times n_{\text{enc}}}: raw similarity scores measuring alignment between decoder queries and encoder keys
  • dk\sqrt{d_k}: scaling factor that maintains healthy gradients during training
  • softmax()\text{softmax}(\cdot): applied row-wise to produce probability distributions
  • dkd_k: query/key dimension (must match for the dot product)
  • dvd_v: value dimension (determines output size)

The attention weight matrix has shape (ndec×nenc)(n_{\text{dec}} \times n_{\text{enc}}), fundamentally different from self-attention's square (n×n)(n \times n) matrix. This rectangular shape reflects the asymmetry of sequence-to-sequence tasks: we have ndecn_{\text{dec}} positions that need to attend to nencn_{\text{enc}} positions, and these lengths are typically different.

Tracing Through the Computation

Let's make this concrete by tracing the shapes through a realistic example. We'll simulate translating a 6-word English sentence into French, where we've generated 4 words so far.

In[3]:
Code
import numpy as np

# Example dimensions for translation
n_enc = 6  # Source: "The cat sat on the mat" (6 tokens)
n_dec = 4  # Target so far: "Le chat s'assit sur" (4 tokens)
d_model = 8  # Model dimension
d_k = 8  # Query/key dimension
d_v = 8  # Value dimension

# Encoder output (fixed representations of the source sentence)
np.random.seed(42)
encoder_output = np.random.randn(n_enc, d_model)

# Decoder state (evolving representations of generated target)
decoder_state = np.random.randn(n_dec, d_model)

# Learned projection matrices
W_Q = np.random.randn(d_model, d_k) * 0.1
W_K = np.random.randn(d_model, d_k) * 0.1
W_V = np.random.randn(d_model, d_v) * 0.1
Out[4]:
Console
Input shapes:
  Encoder output: (6, 8) (source_len × d_model)
  Decoder state:  (4, 8) (target_len × d_model)

The encoder has processed all 6 source tokens, producing a (6×8)(6 \times 8) matrix. The decoder has generated 4 tokens so far, giving us a (4×8)(4 \times 8) state matrix. These different sequence lengths are the essence of cross-attention.

Now we project into the query, key, and value spaces:

In[5]:
Code
# Q from decoder, K and V from encoder
Q = decoder_state @ W_Q  # (n_dec, d_k) = (4, 8)
K = encoder_output @ W_K  # (n_enc, d_k) = (6, 8)
V = encoder_output @ W_V  # (n_enc, d_v) = (6, 8)
Out[6]:
Console
Projection shapes:
  Q (from decoder): (4, 8)
  K (from encoder): (6, 8)
  V (from encoder): (6, 8)

Notice the asymmetry: QQ has 4 rows (one per target token) while KK and VV have 6 rows (one per source token). When we compute QKTQ K^T, we multiply a (4×8)(4 \times 8) matrix by a (8×6)(8 \times 6) matrix, yielding a (4×6)(4 \times 6) score matrix. Each of the 4 decoder positions gets a similarity score against each of the 6 encoder positions.

In[7]:
Code
def softmax(x):
    """Numerically stable softmax."""
    exp_x = np.exp(x - x.max(axis=-1, keepdims=True))
    return exp_x / exp_x.sum(axis=-1, keepdims=True)


# Step-by-step attention computation
scores = Q @ K.T  # (4, 6) raw similarity scores
scores_scaled = scores / np.sqrt(d_k)  # Scale to prevent saturation
attention_weights = softmax(scores_scaled)  # (4, 6) probability distributions
output = attention_weights @ V  # (4, 8) gathered information
Out[8]:
Console
Attention computation shapes:
  Raw scores:        (4, 6) (target_len × source_len)
  Scaled scores:     (4, 6)
  Attention weights: (4, 6)
  Output:            (4, 8) (target_len × d_v)

Each of the 4 decoder positions attends to all 6 encoder positions
Row sums (should be 1.0): [1. 1. 1. 1.]

The output has shape (4×8)(4 \times 8), matching the decoder's sequence length but with the value dimension. Each of the 4 decoder positions now contains a weighted mixture of information from the encoder, with the weights determined by query-key similarity. This enriched representation flows to the next stage of the decoder, helping predict the next target token.

Out[9]:
Visualization
Heatmap showing a 4x6 attention weight matrix with decoder tokens on y-axis and encoder tokens on x-axis.
The rectangular shape of cross-attention weights. Unlike self-attention's square matrix where every token attends to every other token in the same sequence, cross-attention produces a rectangular matrix where each decoder position (rows) attends to all encoder positions (columns).

The heatmap reveals how each decoder token distributes its attention across the source sequence. Row sums equal 1.0 (each row is a probability distribution), but column sums can vary: some source tokens receive more total attention than others.

Why Queries from Decoder, Keys and Values from Encoder?

The choice of where QQ, KK, and VV come from is not arbitrary. It reflects the fundamental asymmetry of sequence-to-sequence tasks.

Queries represent what you're looking for. At each decoder position, the model is trying to generate the next token. The query encodes "what information do I need from the source to make this prediction?" The decoder has access to its own context (previous tokens, positional information) and uses this to formulate questions.

Keys represent what's available. The encoder has processed the entire source sequence and built representations that capture its meaning. Keys advertise "here's what I know about this position in the source." The encoder output is fixed once computed, so keys remain constant throughout decoding.

Values represent what gets transmitted. When attention flows from decoder to encoder, the values determine what information actually transfers. If the decoder strongly attends to a particular encoder position, that position's value vector contributes heavily to the decoder's output.

Consider translation from English to French. When generating the French word for "cat," the decoder's query might encode "I need information about an animal noun." The encoder's keys for the position containing "cat" would encode "animal, noun, subject." The high dot product between these creates a strong attention weight, and the encoder's value for "cat" (containing semantic features about cats) flows into the decoder's representation.

Out[10]:
Visualization
Diagram showing arrows from decoder positions to encoder positions representing attention flow, with Q from decoder and K,V from encoder.
Information flow in cross-attention. Queries (Q) from the decoder attend to keys (K) from the encoder to determine attention weights. These weights then aggregate values (V) from the encoder, allowing each decoder position to gather relevant source information.

The visualization shows "chat" (French for "cat") attending strongly to "cat" in the encoder. This alignment emerges naturally from the learned query and key projections, which encode semantic relationships between source and target tokens.

Cross-Attention Masking

Unlike causal self-attention in decoders, cross-attention typically does not require causal masking. The decoder can attend to any position in the encoder output because the encoder sequence is fully processed before decoding begins. There's no "future information" in the encoder to hide.

However, cross-attention does require padding masking when processing batches with variable-length source sequences. If the encoder sequences have different lengths, shorter sequences are padded to match the longest. The decoder should not attend to these padding positions, which carry no meaningful information.

In[11]:
Code
def cross_attention_with_mask(Q, K, V, mask=None):
    """
    Cross-attention with optional encoder padding mask.

    Args:
        Q: Queries from decoder (n_dec, d_k)
        K: Keys from encoder (n_enc, d_k)
        V: Values from encoder (n_enc, d_v)
        mask: Boolean mask (n_enc,) where True = valid, False = padding

    Returns:
        output: Attended representations (n_dec, d_v)
        weights: Attention weights (n_dec, n_enc)
    """
    d_k = Q.shape[-1]
    scores = Q @ K.T / np.sqrt(d_k)

    # Apply padding mask
    if mask is not None:
        # Expand mask for broadcasting: (1, n_enc)
        mask = mask.reshape(1, -1)
        # Set padded positions to large negative value
        scores = np.where(mask, scores, -1e9)

    attention_weights = softmax(scores)
    output = attention_weights @ V

    return output, attention_weights

Let's see how masking affects attention:

In[12]:
Code
# Simulate a padded encoder sequence
# Original: "The cat sat" (3 tokens), padded to length 6
encoder_mask = np.array([True, True, True, False, False, False])

# Compute masked cross-attention
output_masked, weights_masked = cross_attention_with_mask(
    Q, K, V, mask=encoder_mask
)
Out[13]:
Console
Attention weights with padding mask:
(Columns 3-5 are padding, should have zero weight)

Weights shape: (4, 6)

Weight matrix (rounded):
[[0.357 0.342 0.301 0.    0.    0.   ]
 [0.316 0.331 0.353 0.    0.    0.   ]
 [0.324 0.327 0.349 0.    0.    0.   ]
 [0.31  0.33  0.36  0.    0.    0.   ]]

Row sums: [1. 1. 1. 1.]
Columns 3-5 sum: 0.0

The masking ensures that padded positions receive zero attention weight. The softmax is applied only over valid encoder positions (columns 0-2), and the probability mass distributes across those positions. This prevents the decoder from incorporating garbage information from padding tokens.

Out[14]:
Visualization
Heatmap showing attention weights spread across all 6 encoder positions.
Without masking, attention distributes across all 6 positions, including padding tokens that contain garbage information.
Heatmap showing attention weights concentrated on first 3 positions with zeros for padded positions.
With masking, attention concentrates entirely on valid tokens. Padded positions receive exactly zero weight.

The comparison makes the effect of masking clear. Without masking, attention bleeds into padding positions, corrupting the decoder's representations with meaningless information. With masking, all attention concentrates on the valid tokens, and the model ignores padding entirely.

Placement in the Decoder

In the original transformer architecture from "Attention Is All You Need," each decoder layer contains three sub-layers:

  1. Masked self-attention: Decoder tokens attend to previous decoder tokens
  2. Cross-attention: Decoder tokens attend to encoder output
  3. Feed-forward network: Position-wise transformation

The order matters. Self-attention comes first, allowing each decoder position to incorporate information from previously generated tokens. Then cross-attention brings in information from the source sequence. Finally, the feed-forward network transforms the combined representation.

Out[15]:
Visualization
Vertical diagram showing three sub-layers in a decoder block: masked self-attention, cross-attention, and feed-forward, each with residual connections and layer normalization.
Structure of a transformer decoder layer. Each position first attends to previous decoder positions (masked self-attention), then attends to encoder outputs (cross-attention), and finally passes through a feed-forward network. Add & Norm layers wrap each sub-layer.

The cross-attention layer is the only point where information flows from encoder to decoder. The encoder output is computed once and then used identically in every decoder layer. This means the decoder doesn't need to recompute the encoder at each step, making inference efficient.

Information Flow During Generation

During autoregressive generation, the cross-attention mechanism operates as follows:

  1. The encoder processes the entire source sequence once, producing encoder outputs
  2. For each decoder step tt:
    • The decoder's self-attention sees positions 1,2,,t1, 2, \ldots, t (previous outputs plus current)
    • Cross-attention queries the encoder using the current decoder state
    • The gathered information helps predict the next token

The encoder representations stay fixed, while the decoder's queries evolve as it generates more tokens. Early in generation, the decoder might ask broad questions about the source. Later, it might ask more specific questions as it refines its translation.

Implementation

Let's implement a complete cross-attention module following the patterns used in production transformers:

In[16]:
Code
class CrossAttention:
    """
    Cross-attention module for encoder-decoder transformers.

    Queries come from the decoder, keys and values come from the encoder.
    """

    def __init__(self, d_model, d_k, d_v):
        """
        Initialize cross-attention with projection matrices.

        Args:
            d_model: Model dimension (size of input representations)
            d_k: Query/key dimension
            d_v: Value dimension
        """
        self.d_k = d_k
        self.scale = 1.0 / np.sqrt(d_k)

        # Query projection (applied to decoder state)
        self.W_Q = np.random.randn(d_model, d_k) * np.sqrt(
            2.0 / (d_model + d_k)
        )

        # Key and value projections (applied to encoder output)
        self.W_K = np.random.randn(d_model, d_k) * np.sqrt(
            2.0 / (d_model + d_k)
        )
        self.W_V = np.random.randn(d_model, d_v) * np.sqrt(
            2.0 / (d_model + d_v)
        )

    def __call__(self, decoder_state, encoder_output, encoder_mask=None):
        """
        Apply cross-attention.

        Args:
            decoder_state: Current decoder representations (n_dec, d_model)
            encoder_output: Encoder output representations (n_enc, d_model)
            encoder_mask: Boolean mask for valid encoder positions (n_enc,)

        Returns:
            output: Contextualized decoder representations (n_dec, d_v)
            attention_weights: Cross-attention weights (n_dec, n_enc)
        """
        # Q from decoder, K and V from encoder
        Q = decoder_state @ self.W_Q
        K = encoder_output @ self.W_K
        V = encoder_output @ self.W_V

        # Compute attention scores
        scores = Q @ K.T * self.scale

        # Apply encoder padding mask if provided
        if encoder_mask is not None:
            mask = encoder_mask.reshape(1, -1)
            scores = np.where(mask, scores, -1e9)

        # Softmax and aggregate
        attention_weights = softmax(scores)
        output = attention_weights @ V

        return output, attention_weights

Let's test the module with a translation-like example:

In[17]:
Code
# Simulate encoding "The quick brown fox"
np.random.seed(123)
source_tokens = ["The", "quick", "brown", "fox"]
n_source = len(source_tokens)

# Target tokens generated so far: "Le renard"
target_tokens = ["Le", "renard"]
n_target = len(target_tokens)

d_model = 16
d_k = d_v = 16

# Simulated representations
encoder_output = np.random.randn(n_source, d_model)
decoder_state = np.random.randn(n_target, d_model)

# Create and apply cross-attention
cross_attn = CrossAttention(d_model, d_k, d_v)
output, weights = cross_attn(decoder_state, encoder_output)
Out[18]:
Console
Cross-attention example: English → French
Source: ['The', 'quick', 'brown', 'fox']
Target so far: ['Le', 'renard']

Encoder output shape: (4, 16)
Decoder state shape:  (2, 16)
Cross-attention output shape: (2, 16)

Attention weights (target × source):
             ['The', 'quick', 'brown', 'fox']
  Le       [0.403  0.142  0.109  0.346]
  renard   [0.075  0.303  0.408  0.215]

Each target token distributes its attention across all source tokens. The weights indicate which parts of the source are most relevant for each target position. In a trained model, "renard" (French for "fox") would attend strongly to "fox" in the English source.

Out[19]:
Visualization
Heatmap showing attention weights between French target tokens and English source tokens, with darker cells indicating stronger attention.
Cross-attention weights for a translation example. Each row shows how a target token (French) attends to source tokens (English). Strong weights indicate which source words are most relevant for generating each target word.

Multi-Head Cross-Attention

Just like self-attention, cross-attention benefits from multiple attention heads. Each head can learn to attend to different aspects of the source sequence: one head might focus on syntactic alignment, another on semantic similarity, a third on positional patterns.

In[20]:
Code
class MultiHeadCrossAttention:
    """
    Multi-head cross-attention for encoder-decoder transformers.
    """

    def __init__(self, d_model, n_heads):
        """
        Initialize multi-head cross-attention.

        Args:
            d_model: Model dimension (must be divisible by n_heads)
            n_heads: Number of attention heads
        """
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"

        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.scale = 1.0 / np.sqrt(self.d_k)

        # Projections for all heads combined
        self.W_Q = np.random.randn(d_model, d_model) * np.sqrt(
            2.0 / (2 * d_model)
        )
        self.W_K = np.random.randn(d_model, d_model) * np.sqrt(
            2.0 / (2 * d_model)
        )
        self.W_V = np.random.randn(d_model, d_model) * np.sqrt(
            2.0 / (2 * d_model)
        )
        self.W_O = np.random.randn(d_model, d_model) * np.sqrt(
            2.0 / (2 * d_model)
        )

    def __call__(self, decoder_state, encoder_output, encoder_mask=None):
        """
        Apply multi-head cross-attention.

        Args:
            decoder_state: (n_dec, d_model)
            encoder_output: (n_enc, d_model)
            encoder_mask: Optional (n_enc,)

        Returns:
            output: (n_dec, d_model)
            attention_weights: (n_heads, n_dec, n_enc)
        """
        n_dec = decoder_state.shape[0]
        n_enc = encoder_output.shape[0]

        # Project and reshape to (n_heads, n, d_k)
        Q = (
            (decoder_state @ self.W_Q)
            .reshape(n_dec, self.n_heads, self.d_k)
            .transpose(1, 0, 2)
        )
        K = (
            (encoder_output @ self.W_K)
            .reshape(n_enc, self.n_heads, self.d_k)
            .transpose(1, 0, 2)
        )
        V = (
            (encoder_output @ self.W_V)
            .reshape(n_enc, self.n_heads, self.d_k)
            .transpose(1, 0, 2)
        )

        # Compute attention for all heads: (n_heads, n_dec, n_enc)
        scores = np.matmul(Q, K.transpose(0, 2, 1)) * self.scale

        # Apply mask if provided
        if encoder_mask is not None:
            mask = encoder_mask.reshape(1, 1, -1)
            scores = np.where(mask, scores, -1e9)

        attention_weights = softmax(scores)

        # Aggregate values: (n_heads, n_dec, d_k)
        head_outputs = np.matmul(attention_weights, V)

        # Concatenate heads and project: (n_dec, d_model)
        concatenated = head_outputs.transpose(1, 0, 2).reshape(n_dec, -1)
        output = concatenated @ self.W_O

        return output, attention_weights
In[21]:
Code
# Test multi-head cross-attention
n_heads = 4
d_model_mh = 32

# Create larger representations
encoder_output_mh = np.random.randn(n_source, d_model_mh)
decoder_state_mh = np.random.randn(n_target, d_model_mh)

mh_cross_attn = MultiHeadCrossAttention(d_model_mh, n_heads)
output_mh, weights_mh = mh_cross_attn(decoder_state_mh, encoder_output_mh)
Out[22]:
Console
Multi-head cross-attention with 4 heads:
  Input encoder shape:  (4, 32)
  Input decoder shape:  (2, 32)
  Output shape:         (2, 32)
  Attention weights:    (4, 2, 4) (heads × target × source)

The output maintains the same shape as the decoder input (2 tokens, 32 dimensions), but now each position has gathered information from the encoder through 4 independent attention computations. The attention weights tensor has shape (4, 2, 4), meaning each of the 4 heads produces its own 2×4 attention pattern.

Out[23]:
Visualization
Heatmap showing attention weights for head 1.
Head 1 attention pattern.
Heatmap showing attention weights for head 2.
Head 2 attention pattern.
Heatmap showing attention weights for head 3.
Head 3 attention pattern.
Heatmap showing attention weights for head 4.
Head 4 attention pattern.

Each head develops its own attention pattern. In trained translation models, researchers have observed heads that focus on positional alignment (source position ii attends to target position ii), heads that track syntactic relationships, and heads that capture semantic similarities. The diversity across heads allows the model to capture multiple aspects of source-target alignment simultaneously.

KV Caching in Cross-Attention

During autoregressive generation, cross-attention has a computational advantage over decoder self-attention. The encoder output is fixed, meaning the keys and values from the encoder can be computed once and reused for every decoding step.

In[24]:
Code
class CachedCrossAttention:
    """
    Cross-attention with KV caching for efficient inference.

    The encoder's keys and values are computed once and reused
    for all decoder steps.
    """

    def __init__(self, d_model, d_k, d_v):
        self.d_k = d_k
        self.scale = 1.0 / np.sqrt(d_k)

        self.W_Q = np.random.randn(d_model, d_k) * np.sqrt(
            2.0 / (d_model + d_k)
        )
        self.W_K = np.random.randn(d_model, d_k) * np.sqrt(
            2.0 / (d_model + d_k)
        )
        self.W_V = np.random.randn(d_model, d_v) * np.sqrt(
            2.0 / (d_model + d_v)
        )

        # Cache for encoder K and V
        self.cached_K = None
        self.cached_V = None

    def cache_encoder(self, encoder_output):
        """
        Compute and cache encoder keys and values.
        Called once after encoding.
        """
        self.cached_K = encoder_output @ self.W_K
        self.cached_V = encoder_output @ self.W_V

    def forward(self, decoder_state, encoder_mask=None):
        """
        Apply cross-attention using cached encoder KV.

        Args:
            decoder_state: Current decoder position(s) (n_new, d_model)

        Returns:
            output: (n_new, d_v)
        """
        if self.cached_K is None:
            raise ValueError("Must call cache_encoder() first")

        # Only compute Q for new decoder positions
        Q = decoder_state @ self.W_Q

        # Use cached K and V
        scores = Q @ self.cached_K.T * self.scale

        if encoder_mask is not None:
            mask = encoder_mask.reshape(1, -1)
            scores = np.where(mask, scores, -1e9)

        attention_weights = softmax(scores)
        output = attention_weights @ self.cached_V

        return output, attention_weights
In[25]:
Code
# Simulate incremental decoding
np.random.seed(456)
d_model = 16
d_k = d_v = 16

# Encoder processes source once
source_len = 8
encoder_output = np.random.randn(source_len, d_model)

# Create cached cross-attention and cache encoder KV
cached_cross_attn = CachedCrossAttention(d_model, d_k, d_v)
cached_cross_attn.cache_encoder(encoder_output)

# Generate tokens one at a time
generated_tokens = []
for step in range(4):
    # Get representation for current position only
    current_state = np.random.randn(1, d_model)

    # Cross-attention uses cached K, V from encoder
    output, weights = cached_cross_attn.forward(current_state)
    generated_tokens.append(f"token_{step}")
Out[26]:
Console
Incremental generation with KV caching:
  Encoder length: 8
  Cached K shape: (8, 16)
  Cached V shape: (8, 16)

At each step, only Q is computed for the new token.
K and V are reused from cache, avoiding redundant computation.

This caching is particularly important for long source sequences. Without caching, generating 100 tokens from a 1000-token source would require recomputing the encoder's KV projections 100 times. With caching, we compute them once.

A Worked Example: Translation Step by Step

Let's trace through cross-attention during a translation example to see how all the pieces fit together:

In[27]:
Code
# Translation example: "I love cats" → "J'aime les chats"
source_sentence = ["I", "love", "cats"]
target_prefix = ["J'", "aime"]  # Already generated

np.random.seed(789)
n_src = len(source_sentence)
n_tgt = len(target_prefix)
d_model = 8

# Simulated encoder output (in practice, from transformer encoder)
encoder_out = np.random.randn(n_src, d_model)

# Simulated decoder state after self-attention
decoder_state = np.random.randn(n_tgt, d_model)

# Cross-attention
cross_attn = CrossAttention(d_model, d_k=8, d_v=8)
output, weights = cross_attn(decoder_state, encoder_out)
Out[28]:
Console
Translation: 'I love cats' → 'J'aime les chats'

Source tokens: ['I', 'love', 'cats']
Target prefix: ["J'", 'aime']

Cross-attention weights:
              I    love    cats
  J'       [0.216  0.099  0.685]
  aime     [0.082  0.044  0.874]

In this example, "aime" (love) should ideally attend strongly to "love" in the source. While our random weights don't show this (since we haven't trained the model), a trained model would learn to align related words across languages.

The cross-attention output for each target position now contains a weighted mixture of encoder information. The next step in the decoder would combine this with the self-attention output and pass through the feed-forward network, ultimately producing logits for predicting the next token ("les").

Limitations and Impact

Cross-attention is the mechanism that makes encoder-decoder transformers work. It provides a direct, differentiable connection between source and target sequences, enabling end-to-end training of translation, summarization, and other sequence-to-sequence models.

Several characteristics of cross-attention deserve consideration when designing or deploying encoder-decoder models:

The computational complexity of cross-attention is O(ndec×nenc)O(n_{\text{dec}} \times n_{\text{enc}}), where ndecn_{\text{dec}} is the number of decoder tokens and nencn_{\text{enc}} is the number of encoder tokens. This arises because each of the ndecn_{\text{dec}} decoder positions must compute attention scores against all nencn_{\text{enc}} encoder positions. While this is typically less problematic than the O(n2)O(n^2) self-attention in decoders (since source sequences are often shorter than the total generated output), the cost can become significant for very long source documents. Techniques like sparse cross-attention or retrieval-augmented approaches address this by attending to only a subset of encoder positions.

Cross-attention assumes the entire encoder output is available before decoding begins. This makes it unsuitable for streaming applications where source and target are produced simultaneously. For such cases, architectures like streaming transformers or monotonic attention provide alternatives.

The fixed encoder output means the decoder cannot "ask follow-up questions" that change how the source is encoded. Each decoder layer sees the same encoder representations. Some architectures address this by adding encoder layers that receive decoder feedback, though this increases complexity.

Despite these considerations, cross-attention has proven remarkably effective. It underlies the success of models like T5, BART, and mBART in translation, summarization, and question answering. The mechanism's simplicity, matching the same scaled dot-product attention used in self-attention, makes it easy to implement and optimize.

Summary

Cross-attention bridges encoder and decoder in sequence-to-sequence transformers, enabling the decoder to gather information from the encoded source sequence while generating output tokens.

Key takeaways from this chapter:

  • Q from decoder, K and V from encoder: This asymmetry defines cross-attention. Queries represent what the decoder is looking for; keys and values represent what the encoder offers.

  • Rectangular attention matrix: Unlike self-attention's (n×n)(n \times n) square matrix, cross-attention produces a (ndec×nenc)(n_{\text{dec}} \times n_{\text{enc}}) rectangular weight matrix, where ndecn_{\text{dec}} is the target sequence length and nencn_{\text{enc}} is the source sequence length. Each row represents how one decoder position distributes attention across all encoder positions.

  • Padding masks, not causal masks: Cross-attention masks padding tokens in the encoder but doesn't need causal masking since the encoder sequence is fully available.

  • Fixed encoder output: The encoder is computed once, and its keys and values can be cached for efficient inference across all decoder steps.

  • Placement in decoder layers: Cross-attention appears between masked self-attention and the feed-forward network in each decoder layer, gathering source information after the decoder has processed its own context.

  • Multi-head diversity: Like self-attention, cross-attention benefits from multiple heads that can specialize in different alignment patterns.

The next chapter explores weight tying, a technique for sharing parameters between embedding layers and output projections, reducing model size while maintaining performance.

Key Parameters

When implementing cross-attention in encoder-decoder transformers, several parameters control the mechanism's behavior and capacity:

  • d_model: The model dimension, which is the size of input token representations. Both encoder outputs and decoder states should have this dimension. Common values range from 256 to 1024, with larger models using 2048 or more.

  • d_k (query/key dimension): The dimension of the projected queries and keys. This controls the capacity of the attention scoring mechanism. Typically set to d_model // n_heads in multi-head attention, giving each head a portion of the full dimensionality.

  • d_v (value dimension): The dimension of the projected values. Often equal to d_k, but can differ. This determines the size of information transmitted when attention flows from encoder to decoder.

  • n_heads: The number of parallel attention heads in multi-head cross-attention. More heads allow the model to attend to different aspects of the source sequence simultaneously. Typical values are 8, 12, or 16 heads.

  • encoder_mask: A boolean mask indicating which encoder positions are valid (True) versus padding (False). Essential for batched processing of variable-length source sequences to prevent attending to meaningless padding tokens.

  • scale factor: The scaling factor 1/dk1/\sqrt{d_k} applied before softmax. This is automatically determined by d_k and prevents attention scores from becoming too large in high dimensions, which would cause softmax to saturate.

Quiz

Ready to test your understanding of cross-attention? Take this quick quiz to reinforce what you've learned about connecting encoder and decoder in sequence-to-sequence transformers.

Loading component...
Track your reading progress

Sign in to mark chapters as read and track your learning journey

Sign in →

Comments

Reference

BIBTEXAcademic
@misc{crossattentionconnectingencoderanddecoderintransformers, author = {Michael Brenndoerfer}, title = {Cross-Attention: Connecting Encoder and Decoder in Transformers}, year = {2025}, url = {https://mbrenndoerfer.com/writing/cross-attention-encoder-decoder-transformers}, organization = {mbrenndoerfer.com}, note = {Accessed: 2025-12-19} }
APAAcademic
Michael Brenndoerfer (2025). Cross-Attention: Connecting Encoder and Decoder in Transformers. Retrieved from https://mbrenndoerfer.com/writing/cross-attention-encoder-decoder-transformers
MLAAcademic
Michael Brenndoerfer. "Cross-Attention: Connecting Encoder and Decoder in Transformers." 2025. Web. 12/19/2025. <https://mbrenndoerfer.com/writing/cross-attention-encoder-decoder-transformers>.
CHICAGOAcademic
Michael Brenndoerfer. "Cross-Attention: Connecting Encoder and Decoder in Transformers." Accessed 12/19/2025. https://mbrenndoerfer.com/writing/cross-attention-encoder-decoder-transformers.
HARVARDAcademic
Michael Brenndoerfer (2025) 'Cross-Attention: Connecting Encoder and Decoder in Transformers'. Available at: https://mbrenndoerfer.com/writing/cross-attention-encoder-decoder-transformers (Accessed: 12/19/2025).
SimpleBasic
Michael Brenndoerfer (2025). Cross-Attention: Connecting Encoder and Decoder in Transformers. https://mbrenndoerfer.com/writing/cross-attention-encoder-decoder-transformers
Michael Brenndoerfer

About the author: Michael Brenndoerfer

All opinions expressed here are my own and do not reflect the views of my employer.

Michael currently works as an Associate Director of Data Science at EQT Partners in Singapore, leading AI and data initiatives across private capital investments.

With over a decade of experience spanning private equity, management consulting, and software engineering, he specializes in building and scaling analytics capabilities from the ground up. He has published research in leading AI conferences and holds expertise in machine learning, natural language processing, and value creation through data.

Stay updated

Get notified when I publish new articles on data and AI, private equity, technology, and more.

No spam, unsubscribe anytime.

or

Create a free account to unlock exclusive features, track your progress, and join the conversation.

No popupsUnobstructed readingCommenting100% Free