Search

Search articles

Bahdanau Attention: Dynamic Context for Neural Machine Translation

Michael BrenndoerferDecember 16, 202553 min read

Learn how Bahdanau attention solves the encoder-decoder bottleneck with dynamic context vectors, softmax alignment, and interpretable attention weights for sequence-to-sequence models.

Reading Level

Choose your expertise level to adjust how many terms are explained. Beginners see more tooltips, experts see fewer to maintain reading flow. Hover over underlined terms for instant definitions.

Bahdanau Attention

In the previous chapter, we developed an intuition for attention as a mechanism that allows models to selectively focus on different parts of the input when generating each output. We saw how attention acts as a soft lookup, assigning weights to encoder hidden states based on their relevance to the current decoding step. Now it's time to formalize this intuition into a concrete algorithm.

Bahdanau attention, introduced in the 2014 paper "Neural Machine Translation by Jointly Learning to Align and Translate," was the first attention mechanism to achieve broad success in sequence-to-sequence models. The key innovation was replacing the fixed context vector bottleneck with a dynamic, position-dependent context that changes at each decoder timestep. This change led to significant improvements in translation quality, especially for longer sentences.

The Alignment Problem

Before diving into the mathematics, let's understand the problem that Bahdanau attention solves. In a standard encoder-decoder architecture, the encoder processes the entire input sequence and compresses it into a single fixed-length vector. The decoder then generates the output sequence using only this compressed representation.

Out[3]:
Visualization
Diagram showing encoder states being funneled into a single context vector that feeds the decoder.
The bottleneck problem in standard encoder-decoder models. The entire input sequence must be compressed into a single fixed-length context vector, which becomes the sole source of information for the decoder. For long sequences, this compression loses critical details.

This bottleneck creates a fundamental problem: the context vector must somehow encode everything about the input that the decoder might need. For short sentences, this works reasonably well. But as input length increases, the fixed-size context vector becomes increasingly inadequate. Information gets lost, compressed, or muddled together.

The insight behind attention is that different parts of the output depend on different parts of the input. When translating "The cat sat on the mat" to French, generating "Le" depends mostly on "The," while generating "chat" depends mostly on "cat." Instead of forcing all this information through a single bottleneck, why not let the decoder directly access the relevant encoder states at each step?

Alignment

In the context of attention, alignment refers to the correspondence between positions in the input and output sequences. Good alignment means the model correctly identifies which input positions are relevant for generating each output position. Bahdanau attention learns this alignment jointly with translation, rather than relying on external alignment tools.

The Attention Mechanism: From Intuition to Formulas

With the alignment problem clearly defined, we can now develop the mathematical machinery that Bahdanau attention uses to solve it. Our journey will take us through three interconnected concepts: computing relevance scores, converting those scores to probabilities, and aggregating information based on those probabilities. Each piece builds naturally on the previous one, culminating in a mechanism that allows the decoder to "look back" at exactly the right parts of the input at each generation step.

The Challenge of Measuring Relevance

Attention addresses a simple question: when the decoder is about to generate output token ii, how relevant is each encoder position jj to that decision? If we could answer this question with a single number for each encoder position, we'd have a way to rank and weight the encoder states.

But what should this "relevance score" capture? Consider translating "The cat sat on the mat" to French. When generating "chat" (cat), the decoder needs to know:

  • What the decoder is currently looking for (encoded in its hidden state si1s_{i-1})
  • What information each encoder position offers (encoded in hidden states h1,h2,,hTxh_1, h_2, \ldots, h_{T_x})

The relevance of position jj depends on both of these: it's not just about what hjh_j contains, but whether that content matches what the decoder currently needs. This suggests we need a function that takes both the decoder state and an encoder state as inputs and produces a scalar score.

The Additive Score Function

Bahdanau and colleagues designed a solution called the additive (or concatenative) alignment model. Rather than directly comparing the decoder and encoder states, which might have different dimensions and represent different types of information, they project both into a shared "alignment space" where comparison becomes meaningful.

The alignment score between decoder position ii and encoder position jj is computed as:

eij=vatanh(Wasi1+Uahj)e_{ij} = v_a^\top \tanh(W_a s_{i-1} + U_a h_j)

where:

  • eije_{ij}: the alignment score, a scalar indicating how relevant encoder position jj is for generating output ii
  • si1Rdss_{i-1} \in \mathbb{R}^{d_s}: the decoder hidden state at the previous timestep, encoding what the decoder "knows" and "needs"
  • hjRdhh_j \in \mathbb{R}^{d_h}: the encoder hidden state at position jj, encoding information about input token jj and its context
  • WaRda×dsW_a \in \mathbb{R}^{d_a \times d_s}: a learnable weight matrix that projects the decoder state into the alignment space
  • UaRda×dhU_a \in \mathbb{R}^{d_a \times d_h}: a learnable weight matrix that projects encoder states into the same alignment space
  • vaRdav_a \in \mathbb{R}^{d_a}: a learnable vector that reduces the combined representation to a scalar
  • dad_a: the dimension of the alignment space (a hyperparameter, typically matching the hidden dimensions)

This formula might look complex at first glance, but it follows a clear logic. Let's trace through it step by step to understand why each component is necessary:

Step 1: Project into alignment space. The decoder state si1s_{i-1} and encoder state hjh_j live in potentially different vector spaces with different dimensions. Before we can meaningfully compare them, we need to transform them into a common representation. The matrices WaW_a and UaU_a learn these transformations during training:

  • Wasi1W_a s_{i-1} produces a dad_a-dimensional vector representing "what the decoder is looking for"
  • UahjU_a h_j produces a dad_a-dimensional vector representing "what encoder position jj offers"

Step 2: Combine additively. With both vectors in the same space, we add them element-wise: Wasi1+UahjW_a s_{i-1} + U_a h_j. This additive combination allows the model to detect when the decoder's query and the encoder's content are compatible. If certain dimensions align well (both positive or both negative), they reinforce each other; if they conflict, they cancel out.

Step 3: Apply nonlinearity. The tanh\tanh function serves two purposes. First, it bounds the values to [1,1][-1, 1], preventing any single dimension from dominating. Second, it introduces nonlinearity, enabling the model to learn complex, non-linear relationships between decoder queries and encoder contents. Without this nonlinearity, the entire alignment model would collapse to a simple linear function.

Out[4]:
Visualization
Plot of the tanh function showing bounds at plus and minus one with shaded saturation regions.
The tanh function compresses extreme values while preserving values near zero. Saturation regions (shaded) show where inputs are bounded to ±1.
Histogram comparing score distributions before and after tanh transformation.
Distribution of combined scores before and after tanh transformation. The tanh function prevents extreme values from dominating the alignment scores.

Step 4: Reduce to scalar. Finally, the dot product with vav_a compresses the dad_a-dimensional result into a single scalar score. The vector vav_a learns which dimensions of the alignment space are most important for determining relevance. Some dimensions might capture semantic similarity, others syntactic compatibility; vav_a learns to weight these appropriately.

Out[5]:
Visualization
Flow diagram showing decoder state and encoder state being projected, added, passed through tanh, and multiplied by v to produce score.
The additive score function in Bahdanau attention. The decoder state and each encoder state are projected into a shared alignment space, combined additively, passed through tanh, and reduced to a scalar score.

Why "additive" rather than some other approach? Additive combination with separate projection matrices gives the model maximum flexibility. The decoder and encoder might have different hidden dimensions, and even if they're the same size, they encode different types of information. By learning separate projections WaW_a and UaU_a, the model can transform each representation into whatever form makes comparison most effective. Bahdanau and colleagues found this approach worked well in practice, and it became the foundation for subsequent attention mechanisms.

From Scores to Probabilities: The Softmax Transformation

We now have a way to compute alignment scores eije_{ij} for each encoder position. But these raw scores can be any real number: positive, negative, or zero. To use them for weighting encoder states, we need to transform them into something more interpretable: a probability distribution over encoder positions.

This is where the softmax function enters the picture. Softmax takes a vector of arbitrary real numbers and converts it into a valid probability distribution where all values are positive and sum to 1:

αij=exp(eij)k=1Txexp(eik)\alpha_{ij} = \frac{\exp(e_{ij})}{\sum_{k=1}^{T_x} \exp(e_{ik})}

where:

  • αij\alpha_{ij}: the attention weight for encoder position jj when generating output ii
  • eije_{ij}: the raw alignment score from our scoring function
  • exp()\exp(\cdot): the exponential function (exe^x)
  • TxT_x: the length of the source sequence
  • k=1Txexp(eik)\sum_{k=1}^{T_x} \exp(e_{ik}): the normalizing constant that ensures weights sum to 1

The exponential function is the key to softmax's behavior. It has two essential properties that make it ideal for attention:

  1. Positivity: exp(x)>0\exp(x) > 0 for all real xx, guaranteeing non-negative weights regardless of the input scores
  2. Monotonicity with amplification: larger scores produce exponentially larger values, which amplifies differences between scores

This second property is particularly important. Consider two scores: ei1=1.2e_{i1} = 1.2 and ei2=3.5e_{i2} = 3.5. The difference looks modest, but after exponentiation, exp(3.5)33\exp(3.5) \approx 33 while exp(1.2)3.3\exp(1.2) \approx 3.3, a 10x ratio. Softmax creates a "winner-take-most" dynamic where the highest-scoring positions dominate, while still allowing lower-scoring positions to contribute.

The attention weight αij\alpha_{ij} has a clear interpretation: it represents the probability that encoder position jj contains the information most relevant for generating output token ii. When we compute attention, we're asking: "Given what the decoder currently needs, where in the input should it look?"

Out[6]:
Visualization
Raw alignment scores for seven encoder positions. Position 4 has the highest score (2.8), but the differences appear modest.
Raw alignment scores for seven encoder positions. Position 4 has the highest score (2.8), but the differences appear modest.
After softmax normalization, position 4 receives disproportionately more weight, demonstrating the 'winner-take-most' amplification effect.
After softmax normalization, position 4 receives disproportionately more weight, demonstrating the 'winner-take-most' amplification effect.

The visualization above shows this amplification in action. Position 4 has a raw score of 2.8, only about 2× higher than position 1's score of 0.8. But after softmax, position 4 receives about 4× the attention weight of position 1. This amplification helps the model make decisive choices about where to focus.

Beyond the mathematical properties, softmax has two important advantages for learning:

  1. Differentiability: Gradients flow smoothly through softmax, enabling end-to-end training with backpropagation.
  2. Soft selection: Unlike "hard" attention (which would pick exactly one position), soft attention maintains gradients to all positions, making optimization much easier.

The "softness" of attention can be controlled by scaling the scores before applying softmax. Dividing scores by a temperature parameter τ\tau changes how sharply the attention focuses:

Out[7]:
Visualization
Low temperature (τ=0.3) produces nearly hard attention, with 97% weight on the highest-scoring position.
Low temperature (τ=0.3) produces nearly hard attention, with 97% weight on the highest-scoring position.
Standard temperature (τ=1.0) balances focus with coverage. This is the Bahdanau default.
Standard temperature (τ=1.0) balances focus with coverage. This is the Bahdanau default.
High temperature (τ=3.0) produces nearly uniform attention, spreading weight across all positions.
High temperature (τ=3.0) produces nearly uniform attention, spreading weight across all positions.

The entropy values in the figure quantify the "spread" of attention: lower entropy means more focused attention, higher entropy means more distributed. Standard Bahdanau attention uses τ=1\tau = 1, which provides a good balance between focusing on relevant positions and maintaining gradient flow to all positions during training.

The "sharpness" of the attention distribution depends on the magnitude of score differences. When scores are similar, attention spreads across many positions; when one score is much higher, attention concentrates sharply:

Out[8]:
Visualization
Spread attention when scores are similar (low variance). Weights distribute relatively evenly across positions (H = 1.91).
Spread attention when scores are similar (low variance). Weights distribute relatively evenly across positions (H = 1.91).
Focused attention when one score dominates (high variance). Position 4 captures most of the attention (H = 0.98).
Focused attention when one score dominates (high variance). Position 4 captures most of the attention (H = 0.98).

The entropy values quantify this difference: higher entropy indicates more spread-out attention, while lower entropy indicates more concentrated attention. A well-trained model learns to produce focused attention when one input position is clearly relevant (e.g., translating a content word) and spread attention when multiple positions contribute (e.g., translating a phrase that spans multiple source words).

Aggregating Information: The Context Vector

We've now computed attention weights αij\alpha_{ij} that tell us how much the decoder should attend to each encoder position. The final step is to use these weights to create a single vector that summarizes the relevant information from the entire input sequence. This is the context vector.

The context vector for decoder position ii is simply a weighted sum of the encoder hidden states:

ci=j=1Txαijhjc_i = \sum_{j=1}^{T_x} \alpha_{ij} h_j

where:

  • ciRdhc_i \in \mathbb{R}^{d_h}: the context vector, a dhd_h-dimensional summary of the relevant input information
  • αij\alpha_{ij}: the attention weight for encoder position jj (how much to attend to that position)
  • hjRdhh_j \in \mathbb{R}^{d_h}: the encoder hidden state at position jj
  • TxT_x: the length of the source sequence

This formula is simple, but its implications are significant. Because the attention weights sum to 1, the context vector is a convex combination of the encoder states. It lies somewhere "between" them in the vector space, closer to the states with higher weights.

Consider what happens in different scenarios:

  • Focused attention: If αi2=0.95\alpha_{i2} = 0.95 and all other weights are near zero, then cih2c_i \approx h_2. The context vector copies the encoder state at position 2.
  • Distributed attention: If weights are spread across multiple positions, the context vector blends information from all of them. This is useful when the decoder needs information from multiple parts of the input.
  • Uniform attention: If all weights are equal (αij=1/Tx\alpha_{ij} = 1/T_x), the context vector is the simple average of all encoder states, similar to a mean-pooling operation.
Out[9]:
Visualization
Font 'default' does not have a glyph for '\u2248' [U+2248], substituting with a dummy symbol.
Focused attention places the context vector (star) near $h_3$ which receives 85% of the attention weight.
Font 'default' does not have a glyph for '\u2248' [U+2248], substituting with a dummy symbol.
Distributed attention places the context vector between multiple encoder states, blending information from all positions.
Notebook output
Notebook output

This geometric interpretation helps build intuition: the context vector "moves" through the encoder state space based on attention weights, getting pulled toward whichever states receive more attention. The model learns to position the context vector where it captures the most relevant information for the current generation step.

The key insight is that each decoder step gets its own context vector cic_i, computed fresh based on the current decoder state si1s_{i-1}. This is the main improvement over standard encoder-decoder models: instead of forcing all information through a single fixed-size bottleneck, we dynamically select and aggregate the relevant parts of the input at each generation step.

Out[10]:
Visualization
Diagram showing encoder states being multiplied by attention weights and summed to produce context vector.
The context vector is computed as a weighted sum of encoder hidden states: $c_i = \sum_{j} \alpha_{ij} h_j$. Positions with higher attention weights contribute more to the final context. This dynamic weighting allows the decoder to focus on relevant parts of the input at each step.

Putting It All Together: The Complete Attention Mechanism

We've now developed all three components of Bahdanau attention:

  1. Scoring: eij=vatanh(Wasi1+Uahj)e_{ij} = v_a^\top \tanh(W_a s_{i-1} + U_a h_j) measures relevance
  2. Normalization: αij=softmax(eij)\alpha_{ij} = \text{softmax}(e_{ij}) converts scores to probabilities
  3. Aggregation: ci=jαijhjc_i = \sum_j \alpha_{ij} h_j creates a weighted summary

But how does this attention mechanism integrate with the decoder RNN? The context vector cic_i contains a summary of the relevant input information, but the decoder also needs to know what it has generated so far and maintain its own internal state.

In Bahdanau's formulation, these three sources of information are combined by feeding them into the decoder RNN:

si=f(si1,yi1,ci)s_i = f(s_{i-1}, y_{i-1}, c_i)

where:

  • siRdss_i \in \mathbb{R}^{d_s}: the new decoder hidden state after processing step ii
  • si1Rdss_{i-1} \in \mathbb{R}^{d_s}: the previous decoder hidden state (the decoder's "memory")
  • yi1Rdey_{i-1} \in \mathbb{R}^{d_e}: the embedding of the previous output token
  • ciRdhc_i \in \mathbb{R}^{d_h}: the context vector from attention
  • ff: a recurrent cell function (typically GRU or LSTM)

In practice, yi1y_{i-1} and cic_i are concatenated into a single vector of dimension de+dhd_e + d_h before being fed to the RNN. This combines information about the previous prediction with the attention-weighted summary of the input.

The complete decoding process at step ii follows this sequence:

  1. Compute alignment scores: For each encoder position jj, calculate eij=vatanh(Wasi1+Uahj)e_{ij} = v_a^\top \tanh(W_a s_{i-1} + U_a h_j)
  2. Normalize to attention weights: Apply softmax to get αij\alpha_{ij}
  3. Compute context vector: Calculate ci=jαijhjc_i = \sum_j \alpha_{ij} h_j
  4. Update decoder state: Compute si=f(si1,yi1,ci)s_i = f(s_{i-1}, y_{i-1}, c_i)
  5. Generate output: Predict the next token using sis_i (and often cic_i as well)

One subtle but important detail: attention is computed using the previous decoder state si1s_{i-1}, not the current state sis_i. This makes sense because we need to know what the decoder is looking for before we can compute the context, and we need the context before we can update the decoder state.

Out[11]:
Visualization
Flow diagram showing attention computation integrated into decoder RNN at each timestep.
The complete attention mechanism in the decoder. At each step, the decoder uses its previous state to compute attention over encoder states, produces a context vector, and combines this with the previous output to generate the next hidden state and output prediction.

A Worked Example: Attention in Action

The formulas we've developed might seem abstract, so let's ground them with a concrete numerical example. We'll trace through every step of the attention computation, watching the numbers flow through each transformation.

Imagine we're building a translation system for English to French. We're translating "I love cats" and have just generated "J'" (the French equivalent of "I"). Now we need to generate the next word, which should be "aime" (love). The question is: which part of the input should the decoder focus on?

For this example, we'll use deliberately small dimensions so we can see every number:

  • Hidden dimension: 4 (real systems use 256-1024)
  • Alignment dimension: 3
  • Source sequence: 3 words ("I", "love", "cats")
In[12]:
Code
# Example dimensions for illustration
hidden_dim = 4
align_dim = 3
num_encoder_positions = 3

# Simulated encoder hidden states (normally from bidirectional RNN)
np.random.seed(42)
h = np.array(
    [
        [0.2, -0.5, 0.8, 0.1],  # h_1: "I"
        [0.9, 0.3, -0.2, 0.7],  # h_2: "love"
        [-0.1, 0.6, 0.4, -0.3],  # h_3: "cats"
    ]
)

# Previous decoder state (looking for "love" -> "aime")
s_prev = np.array([0.5, 0.1, -0.3, 0.8])

# Learnable parameters (normally trained)
W_a = np.array(
    [
        [0.3, -0.2, 0.5, 0.1],
        [0.4, 0.6, -0.1, 0.2],
        [-0.2, 0.3, 0.4, -0.5],
    ]
)  # Shape: (align_dim, hidden_dim)

U_a = np.array(
    [
        [0.2, 0.4, -0.3, 0.1],
        [-0.1, 0.5, 0.2, 0.3],
        [0.6, -0.2, 0.1, 0.4],
    ]
)  # Shape: (align_dim, hidden_dim)

v_a = np.array([0.5, -0.3, 0.4])  # Shape: (align_dim,)

Each encoder hidden state hjh_j captures information about word jj and its surrounding context. The decoder state si1s_{i-1} encodes what the decoder has generated so far and, implicitly, what it's looking for next. The weight matrices WaW_a, UaU_a, and vector vav_a would normally be learned during training; here we use fixed values for illustration.

Step 1: Project the Decoder State

The first step is to transform the decoder state into the alignment space. This projection happens once per decoder step and will be reused when comparing against each encoder position:

In[13]:
Code
# Project decoder state: W_a @ s_prev
W_s = W_a @ s_prev
Out[14]:
Console
W_a @ s_prev = [ 0.06  0.45 -0.59]

This 3-dimensional vector represents "what the decoder is looking for" in the alignment space. Think of it as a query that we'll compare against each encoder position.

Step 2: Compute Alignment Scores

Now we compute a score for each encoder position. For each position jj, we project hjh_j into the alignment space, add it to the projected decoder state, apply tanh, and compute the final score with vav_a:

In[15]:
Code
scores = []
for j in range(num_encoder_positions):
    # Project encoder state
    U_h = U_a @ h[j]

    # Combine and apply tanh
    combined = np.tanh(W_s + U_h)

    # Compute scalar score
    e_j = v_a @ combined
    scores.append(e_j)

scores = np.array(scores)
Out[16]:
Console
Alignment scores:
  e_1 = -0.3634
  e_2 = 0.1092
  e_3 = -0.4023

The scores reveal that position 2 ("love") has the highest alignment with what the decoder is looking for. This makes intuitive sense: to generate "aime" (the French word for love), the decoder should focus on the English word "love." Position 3 ("cats") has the lowest score, indicating it's least relevant for this particular generation step.

Step 3: Normalize with Softmax

Raw scores can be any real number, but we need probabilities. Softmax transforms these scores into a proper probability distribution:

In[17]:
Code
# Softmax normalization
exp_scores = np.exp(scores)
attention_weights = exp_scores / exp_scores.sum()
Out[18]:
Console
Attention weights:
  α_1 = 0.2804
  α_2 = 0.4499
  α_3 = 0.2697
Sum: 1.0000

Position 2 receives about 45% of the attention, while positions 1 and 3 receive roughly 28% and 27% respectively. Notice that even though position 2 had the highest score, the attention isn't completely focused on it. This "soft" attention allows the model to hedge its bets, pulling in information from multiple positions when useful. The weights sum to exactly 1.0, confirming we have a valid probability distribution.

Step 4: Compute the Context Vector

Finally, we use these attention weights to create a weighted combination of the encoder hidden states:

In[19]:
Code
# Weighted sum of encoder states
context = np.zeros(hidden_dim)
for j in range(num_encoder_positions):
    context += attention_weights[j] * h[j]
Out[20]:
Console
Context vector c_i = [0.43398358 0.15657815 0.24225493 0.26202583]

For comparison:
  h_2 ("love") = [ 0.9  0.3 -0.2  0.7]
  Difference  = 0.3724 (mean abs)

The context vector is a blend of all three encoder states, weighted by their relevance. It's closest to h2h_2 (the representation of "love") since that position received the highest attention weight, but it also incorporates information from the other positions. This context vector now gets fed into the decoder RNN along with the previous output embedding to generate the next token.

We can measure how the context vector relates to the encoder states using cosine similarity:

Context vector similarity to encoder states. The context vector is geometrically closest to ("love"), which received the highest attention weight (0.45).
Encoder StateWordAttention WeightCosine Similarity
h1h_1"I"0.280.404
h2h_2"love"0.450.821
h3h_3"cats"0.270.150

The similarity analysis confirms that the context vector is closest to the encoder state that received the highest attention weight. However, the context isn't identical to any single encoder state; it's a weighted blend that can capture information from multiple positions when needed.

This worked example illustrates the key insight of attention: instead of forcing the decoder to work with a single fixed representation of the entire input, we dynamically construct a context that's tailored to each generation step. When generating "aime," the context emphasizes "love"; when generating "chats" later, the context would shift to emphasize "cats."

From Theory to Code: Implementing Bahdanau Attention

Having traced through the mathematics by hand, let's now implement Bahdanau attention as a PyTorch module. The code will mirror the formulas exactly, making it easy to see how each mathematical operation translates to tensor operations.

In[21]:
Code
class BahdanauAttention(nn.Module):
    """
    Bahdanau (additive) attention mechanism.

    Args:
        encoder_dim: Dimension of encoder hidden states
        decoder_dim: Dimension of decoder hidden states
        attention_dim: Dimension of the alignment space
    """

    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        super().__init__()

        # Projection matrices
        self.W_a = nn.Linear(decoder_dim, attention_dim, bias=False)
        self.U_a = nn.Linear(encoder_dim, attention_dim, bias=False)
        self.v_a = nn.Linear(attention_dim, 1, bias=False)

    def forward(self, decoder_state, encoder_outputs):
        """
        Compute attention weights and context vector.

        Args:
            decoder_state: Previous decoder hidden state (batch, decoder_dim)
            encoder_outputs: All encoder hidden states (batch, seq_len, encoder_dim)

        Returns:
            context: Context vector (batch, encoder_dim)
            attention_weights: Attention distribution (batch, seq_len)
        """
        # Project decoder state: (batch, attention_dim)
        decoder_proj = self.W_a(decoder_state)

        # Project encoder outputs: (batch, seq_len, attention_dim)
        encoder_proj = self.U_a(encoder_outputs)

        # Add projections (broadcast decoder over sequence)
        # decoder_proj: (batch, 1, attention_dim) after unsqueeze
        # encoder_proj: (batch, seq_len, attention_dim)
        combined = torch.tanh(decoder_proj.unsqueeze(1) + encoder_proj)

        # Compute scores: (batch, seq_len, 1) -> (batch, seq_len)
        scores = self.v_a(combined).squeeze(-1)

        # Normalize to attention weights
        attention_weights = F.softmax(scores, dim=-1)

        # Compute context vector as weighted sum
        # attention_weights: (batch, seq_len) -> (batch, 1, seq_len)
        # encoder_outputs: (batch, seq_len, encoder_dim)
        # Result: (batch, encoder_dim)
        context = torch.bmm(
            attention_weights.unsqueeze(1), encoder_outputs
        ).squeeze(1)

        return context, attention_weights

The implementation follows our mathematical formulation exactly. Let's trace through the key operations:

  1. Projection layers: W_a, U_a, and v_a are implemented as nn.Linear layers without bias terms, matching our weight matrices
  2. Broadcasting: When we add decoder_proj.unsqueeze(1) to encoder_proj, PyTorch broadcasts the decoder projection across all sequence positions
  3. Batch processing: The code handles batches of sequences simultaneously, which is essential for efficient training
  4. Efficient aggregation: torch.bmm (batch matrix multiplication) computes the weighted sum jαijhj\sum_j \alpha_{ij} h_j efficiently as a matrix operation

Let's verify that our implementation produces the expected outputs:

In[22]:
Code
# Test the attention module
batch_size = 2
seq_len = 5
encoder_dim = 64
decoder_dim = 64
attention_dim = 32

attention = BahdanauAttention(encoder_dim, decoder_dim, attention_dim)

# Create dummy inputs
encoder_outputs = torch.randn(batch_size, seq_len, encoder_dim)
decoder_state = torch.randn(batch_size, decoder_dim)

# Compute attention
context, weights = attention(decoder_state, encoder_outputs)
Out[23]:
Console
Encoder outputs shape: torch.Size([2, 5, 64])
Decoder state shape: torch.Size([2, 64])
Context vector shape: torch.Size([2, 64])
Attention weights shape: torch.Size([2, 5])

Attention weights sum to 1: tensor([1.0000, 1.0000], grad_fn=<SumBackward1>)

The attention weights sum to 1 for each batch element, confirming that softmax normalization is working correctly. The context vector has the same dimension as the encoder hidden states, ready to be fed into the decoder RNN.

Building the Complete Decoder

The attention module computes context vectors, but we still need to integrate it into a full decoder that can generate output sequences. The decoder must coordinate several components: embedding the previous output token, computing attention, updating its hidden state, and predicting the next token.

In[24]:
Code
class AttentionDecoder(nn.Module):
    """
    Decoder with Bahdanau attention for sequence-to-sequence models.

    Args:
        vocab_size: Size of output vocabulary
        embed_dim: Dimension of word embeddings
        encoder_dim: Dimension of encoder hidden states
        decoder_dim: Dimension of decoder hidden states
        attention_dim: Dimension of attention alignment space
    """

    def __init__(
        self, vocab_size, embed_dim, encoder_dim, decoder_dim, attention_dim
    ):
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.attention = BahdanauAttention(
            encoder_dim, decoder_dim, attention_dim
        )

        # GRU input: embedded token + context vector
        self.gru = nn.GRU(
            embed_dim + encoder_dim, decoder_dim, batch_first=True
        )

        # Output projection
        self.output_proj = nn.Linear(decoder_dim, vocab_size)

        self.decoder_dim = decoder_dim

    def forward_step(self, prev_token, prev_hidden, encoder_outputs):
        """
        Single decoding step with attention.

        Args:
            prev_token: Previous output token indices (batch,)
            prev_hidden: Previous decoder hidden state (1, batch, decoder_dim)
            encoder_outputs: Encoder hidden states (batch, seq_len, encoder_dim)

        Returns:
            output: Log probabilities over vocabulary (batch, vocab_size)
            hidden: New decoder hidden state (1, batch, decoder_dim)
            attention_weights: Attention distribution (batch, seq_len)
        """
        # Embed previous token: (batch, embed_dim)
        embedded = self.embedding(prev_token)

        # Compute attention using previous hidden state
        # prev_hidden shape: (1, batch, decoder_dim) -> (batch, decoder_dim)
        context, attention_weights = self.attention(
            prev_hidden.squeeze(0), encoder_outputs
        )

        # Concatenate embedding and context: (batch, 1, embed_dim + encoder_dim)
        gru_input = torch.cat([embedded, context], dim=-1).unsqueeze(1)

        # Update decoder state
        output, hidden = self.gru(gru_input, prev_hidden)

        # Project to vocabulary: (batch, vocab_size)
        logits = self.output_proj(output.squeeze(1))

        return F.log_softmax(logits, dim=-1), hidden, attention_weights

This decoder implements the complete Bahdanau architecture. The forward_step method executes one decoding step, following the exact sequence we outlined earlier:

  1. Embed: Convert the previous token ID to a dense vector
  2. Attend: Use the previous hidden state to compute attention over encoder outputs
  3. Concatenate: Combine the embedding and context vector
  4. Update: Pass through the GRU to get the new hidden state
  5. Project: Transform the hidden state to vocabulary logits

Let's verify that the decoder produces outputs of the expected shapes:

In[25]:
Code
# Test the decoder
vocab_size = 1000
embed_dim = 128
encoder_dim = 256
decoder_dim = 256
attention_dim = 128

decoder = AttentionDecoder(
    vocab_size, embed_dim, encoder_dim, decoder_dim, attention_dim
)

# Simulate encoder outputs (from a bidirectional LSTM, for example)
batch_size = 4
src_len = 10
encoder_outputs = torch.randn(batch_size, src_len, encoder_dim)

# Initialize decoder
prev_token = torch.zeros(batch_size, dtype=torch.long)  # Start token
prev_hidden = torch.zeros(1, batch_size, decoder_dim)

# One decoding step
log_probs, hidden, attn_weights = decoder.forward_step(
    prev_token, prev_hidden, encoder_outputs
)
Out[26]:
Console
Output log probabilities shape: torch.Size([4, 1000])
New hidden state shape: torch.Size([1, 4, 256])
Attention weights shape: torch.Size([4, 10])

Sample attention distribution (first batch element):
  [0.138 0.123 0.084 0.109 0.112 0.054 0.066 0.083 0.14  0.093]

Visualizing Attention Alignments

Attention mechanisms are interpretable. We can visualize the attention weights to see what the model is "looking at" when generating each output token.

Out[27]:
Visualization
Attention alignment visualization for English to French translation. Each row shows the attention distribution when generating a particular French word. The model learns to align 'Le' with 'The', 'chat' with 'cat', and so on, without explicit supervision.
Attention alignment visualization for English to French translation. Each row shows the attention distribution when generating a particular French word. The model learns to align 'Le' with 'The', 'chat' with 'cat', and so on, without explicit supervision.

The attention heatmap reveals several interesting patterns:

  • Monotonic alignment: For this simple sentence, attention roughly follows the diagonal, reflecting the similar word order between English and French.
  • Word-to-word correspondence: "Le" attends strongly to "The," "chat" to "cat," and "tapis" to "mat."
  • Many-to-one mapping: Both "était" and "assis" attend primarily to "sat," since the single English word requires two French words.
  • Soft attention: Even when one position dominates, other positions receive small but non-zero weights.

We can also visualize how the attention distribution evolves step by step during decoding. Each row in the heatmap above represents a single decoding step, but viewing them as individual distributions makes the shifting focus more apparent:

Out[28]:
Visualization
Generating 'Le': attention focuses on 'The' (α=0.80).
Generating 'Le': attention focuses on 'The' (α=0.80).
Generating 'chat': attention focuses on 'cat' (α=0.85).
Generating 'chat': attention focuses on 'cat' (α=0.85).
Generating 'était': attention focuses on 'sat' (α=0.75).
Generating 'était': attention focuses on 'sat' (α=0.75).
Generating 'assis': attention focuses on 'sat' (α=0.70).
Generating 'assis': attention focuses on 'sat' (α=0.70).
Generating 'sur': attention focuses on 'on' (α=0.80).
Generating 'sur': attention focuses on 'on' (α=0.80).
Generating 'le': attention focuses on 'the' (α=0.80).
Generating 'le': attention focuses on 'the' (α=0.80).
Generating 'tapis': attention focuses on 'mat' (α=0.82).
Generating 'tapis': attention focuses on 'mat' (α=0.82).

This interpretability is valuable for debugging and understanding model behavior. If a translation is wrong, we can inspect the attention weights to see if the model was looking at the right parts of the input.

Attention Entropy: Measuring Focus

A useful metric for understanding attention behavior is entropy, which quantifies how spread out or focused the attention distribution is. Low entropy indicates concentrated attention on few positions; high entropy indicates diffuse attention across many positions.

Out[29]:
Visualization
Bar chart showing attention entropy for each decoding step, with bars colored by entropy level.
Attention entropy varies across decoding steps. Content words (nouns, verbs) typically receive focused attention (low entropy), while function words and multi-word expressions may require distributed attention (higher entropy). Most words in this translation have entropy below 1.0, indicating clear word-to-word correspondences.

The entropy analysis reveals that most decoding steps in this translation have relatively focused attention (entropy below 1.0), indicating clear word-to-word correspondences. Steps with higher entropy might indicate cases where the model needs to aggregate information from multiple source positions, such as translating idiomatic expressions or handling word reordering.

Non-Monotonic Alignment Patterns

While English-French translation often exhibits roughly monotonic alignment (words appear in similar order), attention can handle more complex patterns. Languages with different word orders, or sentences with long-distance dependencies, produce non-monotonic attention patterns.

Out[30]:
Visualization
Heatmap showing attention alignment for English to German translation with non-diagonal attention pattern due to verb placement differences.
Non-monotonic attention patterns in English-German translation. German verb placement differs from English, requiring the attention mechanism to 'jump' across the source sentence. The model learns to attend to 'has' and 'eaten' when generating the German verb 'gegessen' at the end.

This example demonstrates one of attention's key strengths: it can learn arbitrary alignment patterns without explicit supervision. The model discovers that German places the past participle ("gegessen") at the end of the clause, and learns to attend to the corresponding English words ("has eaten") regardless of their position.

Computational Considerations

Bahdanau attention introduces additional computation compared to a standard encoder-decoder. Let's analyze the complexity by examining each operation required at every decoder step.

For each decoder step, we must perform the following operations:

  1. Project the decoder state: Computing Wasi1W_a s_{i-1} requires O(da×ds)O(d_a \times d_s) operations (matrix-vector multiplication)
  2. Project all encoder states: Computing UahjU_a h_j for all jj requires O(Tx×da×dh)O(T_x \times d_a \times d_h) operations
  3. Compute scores: The addition, tanh, and dot product with vav_a require O(Tx×da)O(T_x \times d_a) operations
  4. Apply softmax: Normalizing scores requires O(Tx)O(T_x) operations
  5. Compute weighted sum: Computing jαijhj\sum_j \alpha_{ij} h_j requires O(Tx×dh)O(T_x \times d_h) operations

where:

  • TxT_x: source sequence length
  • TyT_y: target sequence length
  • dsd_s: decoder hidden dimension
  • dhd_h: encoder hidden dimension
  • dad_a: alignment dimension

The dominant term is the encoder projection, which is O(Tx×da×dh)O(T_x \times d_a \times d_h) per decoder step. Over an entire output sequence of length TyT_y, the total complexity becomes O(Ty×Tx×da×dh)O(T_y \times T_x \times d_a \times d_h). This quadratic dependence on sequence lengths (Ty×TxT_y \times T_x) becomes significant for long sequences.

Out[31]:
Visualization
Computational cost of Bahdanau attention grows quadratically with sequence length. For very long sequences, this becomes a bottleneck, motivating more efficient attention variants.
Computational cost of Bahdanau attention grows quadratically with sequence length. For very long sequences, this becomes a bottleneck, motivating more efficient attention variants.

For typical machine translation tasks with sequences of 20-50 tokens, this overhead is manageable. However, for very long sequences (documents, conversations), the quadratic scaling becomes problematic. This motivated later work on efficient attention variants, which we'll explore in subsequent chapters.

A practical optimization addresses the redundant computation in step 2. Notice that the encoder projections UahjU_a h_j depend only on the encoder outputs, which remain constant throughout decoding. By precomputing these projections once after encoding and caching them, we avoid repeating this O(Tx×da×dh)O(T_x \times d_a \times d_h) computation at every decoder step:

In[32]:
Code
class EfficientBahdanauAttention(nn.Module):
    """Bahdanau attention with precomputed encoder projections."""

    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        super().__init__()
        self.W_a = nn.Linear(decoder_dim, attention_dim, bias=False)
        self.U_a = nn.Linear(encoder_dim, attention_dim, bias=False)
        self.v_a = nn.Linear(attention_dim, 1, bias=False)

        # Cache for precomputed encoder projections
        self.encoder_proj_cache = None

    def precompute_encoder_proj(self, encoder_outputs):
        """Precompute encoder projections once per sequence."""
        self.encoder_proj_cache = self.U_a(encoder_outputs)

    def forward(self, decoder_state, encoder_outputs):
        """Compute attention using cached encoder projections."""
        decoder_proj = self.W_a(decoder_state)

        # Use cached projection if available
        if self.encoder_proj_cache is not None:
            encoder_proj = self.encoder_proj_cache
        else:
            encoder_proj = self.U_a(encoder_outputs)

        combined = torch.tanh(decoder_proj.unsqueeze(1) + encoder_proj)
        scores = self.v_a(combined).squeeze(-1)
        attention_weights = F.softmax(scores, dim=-1)
        context = torch.bmm(
            attention_weights.unsqueeze(1), encoder_outputs
        ).squeeze(1)

        return context, attention_weights

Limitations and Impact

Bahdanau attention changed how we think about sequence-to-sequence models. However, it has several limitations that motivated subsequent research.

The most significant limitation is the sequential nature of decoding. Because attention at step ii depends on the decoder state si1s_{i-1}, and si1s_{i-1} depends on the attention at step i1i-1, we cannot parallelize across decoder timesteps. Each step must wait for the previous step to complete. This sequential bottleneck limits training throughput on modern parallel hardware like GPUs, where we'd prefer to process many positions simultaneously.

Another limitation is the additive score function's computational cost. Computing tanh(Wasi1+Uahj)\tanh(W_a s_{i-1} + U_a h_j) requires more operations than simpler alternatives like dot-product attention. While the difference is small for individual computations, it adds up over millions of training examples. Luong attention, which we'll cover in the next chapter, addresses this with more efficient score functions.

The attention mechanism also introduces additional hyperparameters (the alignment dimension dad_a) and learnable parameters (WaW_a, UaU_a, vav_a). While these provide flexibility, they also increase the risk of overfitting on small datasets and require careful tuning.

Despite these limitations, Bahdanau attention had a major impact on the field. It showed that attention mechanisms could improve performance on sequence-to-sequence tasks, reducing the BLEU score gap between neural and phrase-based machine translation systems. It also introduced the core ideas that would evolve into the self-attention mechanism at the heart of Transformers.

The key insights that carried forward include:

  • Dynamic context: Computing a fresh context vector at each decoding step, rather than using a fixed representation
  • Soft alignment: Learning to align input and output positions without explicit supervision
  • Interpretability: Attention weights provide insight into model behavior
  • Differentiable lookup: Treating attention as a soft, differentiable lookup table

These ideas laid the groundwork for the attention-based models that would reshape NLP over the following years.

How Attention Learns: From Random to Meaningful

A natural question is: how does attention learn to produce meaningful alignments? At initialization, the attention weights are random. Through training, the model gradually discovers which source positions are relevant for each target position.

Out[33]:
Visualization
Heatmap showing nearly uniform attention weights at early training.
Early training (Epoch 1): Attention is nearly uniform and uninformative, with weights spread randomly across positions.
Heatmap showing emerging diagonal structure at mid training.
Mid training (Epoch 10): The model begins discovering word correspondences, with diagonal structure emerging.
Heatmap showing strong diagonal alignment at convergence.
Converged (Epoch 50): Clear, focused alignments between corresponding words. The diagonal pattern shows learned word mappings.

The learning progression illustrates a key property of attention: it's trained end-to-end with the translation objective. The model receives no explicit supervision about which source words correspond to which target words. Instead, it discovers these alignments because they help minimize the translation loss. This emergent alignment is one of the most interesting aspects of attention mechanisms.

Summary

Bahdanau attention solves the bottleneck problem in encoder-decoder models by allowing the decoder to dynamically focus on relevant parts of the input at each generation step. The mechanism works through a learned alignment model that scores how well each encoder position matches the current decoder state.

The key components are:

  • Additive score function: eij=vatanh(Wasi1+Uahj)e_{ij} = v_a^\top \tanh(W_a s_{i-1} + U_a h_j) computes alignment scores by projecting the decoder state si1s_{i-1} and encoder states hjh_j into a shared alignment space, combining them additively, and reducing to a scalar
  • Softmax normalization: αij=exp(eij)kexp(eik)\alpha_{ij} = \frac{\exp(e_{ij})}{\sum_k \exp(e_{ik})} converts raw scores to attention weights that form a probability distribution over encoder positions
  • Context vector: ci=jαijhjc_i = \sum_j \alpha_{ij} h_j computes a weighted sum of encoder states, where the weights αij\alpha_{ij} determine how much each position contributes to the context for generating output ii

The attention weights are interpretable, showing which input positions the model considers relevant for each output position. This interpretability, combined with strong empirical performance, made Bahdanau attention a foundational technique in neural machine translation and sequence-to-sequence modeling more broadly.

In the next chapter, we'll explore Luong attention, which offers alternative score functions with different computational trade-offs.

Key Parameters

When implementing Bahdanau attention, several parameters significantly impact model performance:

  • attention_dim (d_a): The dimension of the alignment space where encoder and decoder states are projected before comparison. Typical values range from 64 to 512. Larger values increase model capacity but also computation cost. A common heuristic is to set this equal to the encoder or decoder hidden dimension.

  • encoder_dim (d_h): The dimension of encoder hidden states. For bidirectional encoders, this is typically twice the base RNN hidden size (forward + backward concatenated). Values of 256-1024 are common in practice.

  • decoder_dim (d_s): The dimension of decoder hidden states. Often set equal to the encoder dimension for simplicity, though they can differ. The decoder dimension affects both the attention computation and the output projection.

  • embed_dim: The dimension of word embeddings fed to the decoder. Typical values are 128-512. Smaller embeddings reduce parameters but may limit expressiveness for large vocabularies.

  • vocab_size: The size of the output vocabulary. Larger vocabularies require more parameters in the output projection layer and can slow down training due to the softmax computation over all tokens.

Quiz

Ready to test your understanding? Take this quick quiz to reinforce what you've learned about Bahdanau attention and its role in neural machine translation.

Loading component...

Comments

Reference

BIBTEXAcademic
@misc{bahdanauattentiondynamiccontextforneuralmachinetranslation, author = {Michael Brenndoerfer}, title = {Bahdanau Attention: Dynamic Context for Neural Machine Translation}, year = {2025}, url = {https://mbrenndoerfer.com/writing/bahdanau-attention-neural-machine-translation}, organization = {mbrenndoerfer.com}, note = {Accessed: 2025-12-16} }
APAAcademic
Michael Brenndoerfer (2025). Bahdanau Attention: Dynamic Context for Neural Machine Translation. Retrieved from https://mbrenndoerfer.com/writing/bahdanau-attention-neural-machine-translation
MLAAcademic
Michael Brenndoerfer. "Bahdanau Attention: Dynamic Context for Neural Machine Translation." 2025. Web. 12/16/2025. <https://mbrenndoerfer.com/writing/bahdanau-attention-neural-machine-translation>.
CHICAGOAcademic
Michael Brenndoerfer. "Bahdanau Attention: Dynamic Context for Neural Machine Translation." Accessed 12/16/2025. https://mbrenndoerfer.com/writing/bahdanau-attention-neural-machine-translation.
HARVARDAcademic
Michael Brenndoerfer (2025) 'Bahdanau Attention: Dynamic Context for Neural Machine Translation'. Available at: https://mbrenndoerfer.com/writing/bahdanau-attention-neural-machine-translation (Accessed: 12/16/2025).
SimpleBasic
Michael Brenndoerfer (2025). Bahdanau Attention: Dynamic Context for Neural Machine Translation. https://mbrenndoerfer.com/writing/bahdanau-attention-neural-machine-translation
Michael Brenndoerfer

About the author: Michael Brenndoerfer

All opinions expressed here are my own and do not reflect the views of my employer.

Michael currently works as an Associate Director of Data Science at EQT Partners in Singapore, leading AI and data initiatives across private capital investments.

With over a decade of experience spanning private equity, management consulting, and software engineering, he specializes in building and scaling analytics capabilities from the ground up. He has published research in leading AI conferences and holds expertise in machine learning, natural language processing, and value creation through data.

Stay updated

Get notified when I publish new articles on data and AI, private equity, technology, and more.

No spam, unsubscribe anytime.

or

Create a free account to unlock exclusive features, track your progress, and join the conversation.

No popupsUnobstructed readingCommenting100% Free