Search

Search articles

Teacher Forcing: Training Seq2Seq Models with Ground Truth Context

Michael BrenndoerferDecember 16, 202543 min read

Learn how teacher forcing accelerates sequence-to-sequence training by providing correct context, understand exposure bias, and explore mitigation strategies like scheduled sampling.

Reading Level

Choose your expertise level to adjust how many terms are explained. Beginners see more tooltips, experts see fewer to maintain reading flow. Hover over underlined terms for instant definitions.

Teacher Forcing

Training sequence-to-sequence models presents a unique challenge: the decoder must learn to generate outputs one token at a time, where each prediction depends on all previous predictions. But during training, should the decoder see its own (potentially incorrect) predictions, or should we give it the correct answers? This question lies at the heart of teacher forcing, a training technique that dramatically accelerates learning but introduces subtle problems that practitioners must understand and address.

In this chapter, we'll explore teacher forcing in depth: how it works, why it's so effective for training, what problems it creates, and the strategies researchers have developed to mitigate its drawbacks. Understanding these trade-offs is essential for anyone training sequence generation models, from machine translation systems to text summarizers.

The Training Dilemma

Consider training a decoder to translate "I love cats" into French: "J'aime les chats." At each timestep, the decoder predicts the next word based on the encoder's representation and all previously generated words. The question is: during training, what should those "previously generated words" be?

Out[3]:
Visualization
Diagram showing two training approaches for a decoder generating French translation.
The decoder training dilemma: should the model condition on its own predictions (which may be wrong early in training) or on the ground truth targets? Teacher forcing chooses the latter, providing correct context at each step.

Two approaches emerge from this dilemma. In autoregressive training, the decoder conditions on its own previous predictions. Early in training, these predictions are mostly wrong, meaning the decoder learns from incorrect context. This creates a compounding error problem: one wrong prediction leads to unusual context, which leads to another wrong prediction, and so on.

In teacher forcing, we sidestep this problem entirely by always feeding the decoder the correct previous token from the ground truth sequence. The decoder never sees its own mistakes during training, allowing it to learn from perfect context at every step.

Teacher Forcing

Teacher forcing is a training strategy where the decoder receives the ground truth output from the previous timestep as input, rather than its own prediction. The "teacher" provides the correct answer, "forcing" the model to learn from ideal conditions.

How Teacher Forcing Works

To understand teacher forcing mathematically, we need to think carefully about what information flows into the decoder at each timestep. The decoder's job is to predict the next token, but to do so, it needs context: what tokens came before? This seemingly simple question leads us to two fundamentally different training approaches.

The Decoder's Input: A Critical Choice

At each timestep tt, the decoder must produce a prediction for the next token. But the decoder doesn't operate in isolation. It receives two crucial pieces of information:

  1. The hidden state hth_t: This encapsulates everything the decoder has learned from the encoder's representation of the source sequence, plus any information accumulated from previous decoding steps.

  2. The previous token: This is where the critical choice arises. Should we give the decoder the token it actually predicted at the previous step, or the token it should have predicted?

This choice fundamentally shapes the learning dynamics.

Autoregressive Training: Learning from Your Own Mistakes

In standard autoregressive generation, the approach used at inference time, the decoder conditions on its own previous output:

y^t=Decoder(ht,y^t1)\hat{y}_t = \text{Decoder}(h_t, \hat{y}_{t-1})

where:

  • y^t\hat{y}_t: the model's prediction at timestep tt
  • hth_t: the decoder's hidden state at timestep tt, encoding information from the encoder and all previous decoding steps
  • y^t1\hat{y}_{t-1}: the model's own prediction from the previous timestep

This formulation mirrors what happens during inference: the model generates a token, then uses that token as context to generate the next one. The appeal is clear: train the model the same way you'll use it.

But there's a problem. Early in training, the model's predictions are essentially random. If y^t1\hat{y}_{t-1} is wrong (which it usually is), the decoder receives misleading context. This incorrect context produces a poor hidden state hth_t, which leads to another incorrect prediction y^t\hat{y}_t, which becomes incorrect context for the next step, and so on. Errors don't just occur; they compound.

Teacher Forcing: Learning from Perfect Context

Teacher forcing sidesteps this compounding problem by replacing the model's prediction with the ground truth:

y^t=Decoder(ht,yt1)\hat{y}_t = \text{Decoder}(h_t, y_{t-1})

where:

  • yt1y_{t-1}: the ground truth token from the target sequence at position t1t-1

The change appears minor (we've simply swapped y^t1\hat{y}_{t-1} for yt1y_{t-1}), but the implications are profound. By providing the correct previous token, the decoder always operates with ideal context. The learning problem transforms from "given whatever (probably wrong) token you just predicted, what should you predict next?" to "given the correct previous token, what should you predict next?"

This is a much easier problem. The decoder learns the true conditional distribution P(yty<t)P(y_t | y_{<t}) directly, without the noise introduced by its own errors.

Computing the Loss: Measuring Prediction Quality

With teacher forcing providing clean context, we can now measure how well the decoder learns. At each timestep, we want to know: given the correct context, how close is the model's prediction to the target?

We use cross-entropy loss, which measures the negative log probability the model assigns to the correct token:

Lt=logP(ytht,y<t)\mathcal{L}_t = -\log P(y_t | h_t, y_{<t})

where:

  • Lt\mathcal{L}_t: the loss at timestep tt
  • P(ytht,y<t)P(y_t | h_t, y_{<t}): the probability the model assigns to the correct token yty_t, given the hidden state hth_t and all previous ground truth tokens y<t=(y1,y2,,yt1)y_{<t} = (y_1, y_2, \ldots, y_{t-1})

Why negative log probability? Consider what happens as the model improves:

  • If the model assigns high probability to the correct token (say, P=0.9P = 0.9), then log(0.9)0.1-\log(0.9) \approx 0.1, a small loss.
  • If the model assigns low probability (say, P=0.1P = 0.1), then log(0.1)2.3-\log(0.1) \approx 2.3, a large loss.

The negative log transforms probabilities into a loss function that penalizes confident wrong predictions heavily while rewarding confident correct predictions.

Out[4]:
Visualization
Line plot showing cross-entropy loss increasing as predicted probability decreases, with annotations at key probability values.
Cross-entropy loss as a function of predicted probability for the correct token. When the model assigns high probability to the correct answer, loss is low. As probability decreases, loss increases sharply, heavily penalizing confident wrong predictions.

Total Sequence Loss: Aggregating Across Timesteps

A sequence contains multiple tokens, each with its own prediction and loss. To train the model, we need a single scalar loss value. The natural choice is to sum the per-timestep losses:

L=t=1TLt=t=1TlogP(ytht,y<t)\mathcal{L} = \sum_{t=1}^{T} \mathcal{L}_t = -\sum_{t=1}^{T} \log P(y_t | h_t, y_{<t})

where:

  • L\mathcal{L}: the total loss for the entire sequence
  • TT: the length of the target sequence
  • The sum aggregates prediction errors across all positions, treating each timestep's contribution equally

This formulation has an elegant interpretation: we're computing the negative log probability of the entire target sequence under the model's distribution. Minimizing this loss is equivalent to maximizing the probability the model assigns to generating the correct sequence, which is exactly what we want.

Out[5]:
Visualization
Flow diagram showing teacher forcing with ground truth inputs and prediction outputs at each timestep.
Teacher forcing data flow during training. The decoder receives ground truth tokens as input (green arrows) while predicting the next token at each step. The loss is computed between predictions and targets, but input context always comes from the ground truth sequence.

The diagram above illustrates the complete data flow during teacher forcing training. Notice the key insight: the ground truth tokens (green) flow into the decoder as inputs, while the decoder's predictions (blue) flow out for comparison with targets (red). The decoder never sees its own predictions during training. It always receives the correct answer from the previous timestep.

This architecture has a profound consequence: teacher forcing decouples each timestep's learning problem. The model at timestep tt learns independently of whether it got timestep t1t-1 correct, because it always receives the correct yt1y_{t-1} regardless. Instead of trying to learn "given my previous (possibly wrong) prediction, what should I predict next?", the model learns "given the correct previous token, what should I predict next?"

This decoupling is what makes teacher forcing so effective, especially early in training when the model would otherwise be learning from a cascade of errors.

Why Teacher Forcing Works So Well

Teacher forcing accelerates training for several interconnected reasons. Understanding these helps explain both its effectiveness and, as we'll see later, its limitations.

Stable Gradient Flow

When training with the model's own predictions, errors compound across timesteps. An early mistake creates unusual context, leading to another mistake, which creates even more unusual context. By the end of a long sequence, the decoder might be operating in a completely unfamiliar regime. Gradients computed from such sequences are noisy and can point in unhelpful directions.

Teacher forcing eliminates this compounding. Every timestep receives correct context, so the gradients at each position reflect the true learning signal: "given this correct context, how should I adjust my weights to predict the next token better?" These cleaner gradients lead to faster, more stable convergence.

Out[6]:
Visualization
Two panels comparing error propagation with and without teacher forcing across sequence timesteps.
Error compounding in autoregressive training versus teacher forcing. Without teacher forcing (left), a single early error cascades into increasingly divergent predictions. With teacher forcing (right), each timestep receives correct context regardless of previous predictions.

Parallel Computation

A less obvious but practically important benefit is computational efficiency. With teacher forcing, we know all the inputs to the decoder before training begins: they're just the ground truth sequence shifted by one position. This means we can compute all timesteps in parallel using matrix operations, rather than sequentially waiting for each prediction.

In frameworks like PyTorch, this translates to feeding the entire target sequence (minus the last token) as input and computing all predictions simultaneously. The speedup is substantial, especially for long sequences where sequential processing would be prohibitively slow.

In[7]:
Code
class DecoderWithTeacherForcing(nn.Module):
    """Decoder that supports parallel computation with teacher forcing."""

    def __init__(self, vocab_size, embed_dim, hidden_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.rnn = nn.GRU(embed_dim, hidden_dim, batch_first=True)
        self.output = nn.Linear(hidden_dim, vocab_size)

    def forward(self, targets, encoder_hidden):
        """
        Forward pass with teacher forcing.

        Args:
            targets: Ground truth sequence [batch, seq_len]
            encoder_hidden: Encoder final hidden state [1, batch, hidden]

        Returns:
            logits: Predictions for each position [batch, seq_len, vocab_size]
        """
        # Embed all target tokens at once
        embedded = self.embedding(targets)  # [batch, seq_len, embed_dim]

        # Process entire sequence in parallel
        output, _ = self.rnn(
            embedded, encoder_hidden
        )  # [batch, seq_len, hidden]

        # Compute all logits at once
        logits = self.output(output)  # [batch, seq_len, vocab_size]

        return logits

Let's see this parallel computation in action by processing a batch of sequences:

In[8]:
Code
# Create decoder and test data
vocab_size, embed_dim, hidden_dim = 1000, 128, 256
batch_size, seq_len = 32, 50

decoder = DecoderWithTeacherForcing(vocab_size, embed_dim, hidden_dim)

# Simulate encoder output and target sequence
encoder_hidden = torch.randn(1, batch_size, hidden_dim)
targets = torch.randint(0, vocab_size, (batch_size, seq_len))

# Forward pass processes all 50 timesteps at once
logits = decoder(targets, encoder_hidden)
Out[9]:
Console
Input shape:  (32, 50)  (batch_size × seq_len)
Output shape: (32, 50, 1000)  (batch_size × seq_len × vocab_size)

All 50 timesteps computed in a single forward pass!

The decoder processed all 32 sequences of 50 tokens each in a single forward pass, producing predictions over the entire 1000-word vocabulary for every position. This demonstrates the key efficiency gain of teacher forcing: because we know all inputs ahead of time (the ground truth tokens), we can leverage matrix operations to compute everything in parallel rather than waiting for each prediction sequentially.

The parallel computation advantage is significant. Processing 50 timesteps sequentially would require 50 separate forward passes through the RNN, each waiting for the previous one to complete. With teacher forcing, we compute everything in one pass, leveraging GPU parallelism for massive speedups.

Better Credit Assignment

Teacher forcing also improves credit assignment during backpropagation. When the model makes a wrong prediction, we want the gradients to teach it "you should have predicted X instead of Y given this context." With autoregressive training, the context itself might be wrong, confusing the learning signal. With teacher forcing, the context is always correct, so the model learns exactly what it should predict in each situation.

This cleaner credit assignment is particularly important for learning rare patterns. If a specific context only appears a few times in the training data, the model needs to learn from each occurrence efficiently. Teacher forcing ensures these learning opportunities aren't wasted on noisy gradients from compounded errors.

The Exposure Bias Problem

Teacher forcing's greatest strength is also its greatest weakness. By always providing correct context during training, we create a mismatch between training and inference conditions. During inference, the model must use its own predictions as context, but it has never practiced doing so. This mismatch is called exposure bias.

Exposure Bias

Exposure bias occurs when a model is trained on a different distribution of inputs than it encounters during inference. In teacher forcing, the model trains on ground truth context but must generate using its own (potentially erroneous) predictions at test time.

The consequences of exposure bias can be severe. A model might perform well on standard metrics like perplexity (which measure prediction quality given correct context) but generate poor outputs when running autoregressively. Small prediction errors that the model never learned to handle can cascade into completely incoherent outputs.

Out[10]:
Visualization
Two-panel diagram contrasting training with ground truth context versus inference with model predictions.
Exposure bias illustrated: during training, the model always sees correct context (top). During inference, it must use its own predictions, which may differ from anything seen in training (bottom). The model has no experience recovering from its own mistakes.

Consider a concrete example. Suppose during inference the model predicts "fast" instead of "quick" in "The quick brown fox." With teacher forcing, the model has never seen the context "The fast" followed by anything, so it has no idea what to predict next. It might output something reasonable by chance, or it might produce nonsense. The model simply wasn't trained to handle this situation.

The exposure bias problem becomes more severe as sequences get longer. Each timestep has some probability of error, and these probabilities compound. Consider a sequence of length TT where each token has probability pp of being predicted correctly. If we assume each prediction is independent (a simplification, but useful for intuition), the probability of generating the entire sequence without any errors is simply the product of getting each individual token correct:

P(all_correct)=pTP(\text{all\_correct}) = p^T

where:

  • P(all_correct)P(\text{all\_correct}): the probability of generating the complete sequence without any errors
  • pp: the probability of correctly predicting a single token (e.g., 0.95 for 95% accuracy)
  • TT: the total sequence length (number of tokens to generate)
  • pTp^T: the probability that all TT independent predictions are correct, computed by multiplying pp by itself TT times

This formula reveals why exposure bias becomes increasingly problematic for longer sequences. For a 100-token sequence with 95% per-token accuracy, the probability of a completely correct sequence is only (0.95)1000.006=0.6%(0.95)^{100} \approx 0.006 = 0.6\%. The model will almost certainly encounter its own mistakes, yet it has no experience dealing with them.

Out[11]:
Visualization
Line plot showing exponential decay of perfect sequence probability for different per-token accuracies across sequence lengths.
Probability of generating an error-free sequence as a function of sequence length, for different per-token accuracies. Even with 95% per-token accuracy, the probability of a perfect 100-token sequence is less than 1%. This exponential decay explains why exposure bias becomes critical for long sequences.

Scheduled Sampling: A Middle Ground

Scheduled sampling, introduced by Bengio et al. (2015), offers a compromise between pure teacher forcing and pure autoregressive training. The idea is simple: during training, randomly decide at each timestep whether to use the ground truth token or the model's own prediction as input to the next step. Early in training, use ground truth most of the time (like teacher forcing). As training progresses, gradually increase the probability of using the model's own predictions.

Out[12]:
Visualization
Line plot showing three different scheduled sampling probability curves over training epochs.
Scheduled sampling probability schedules. The probability of using the model's own prediction (rather than ground truth) increases over training. Different schedules (linear, exponential, inverse sigmoid) control how quickly this transition happens.

The schedule can take various forms, each defined by a function that maps the current epoch to a sampling probability ϵ\epsilon:

Linear schedule: The probability increases at a constant rate until reaching 1.0:

ϵ(e)=min(1,eEmax)\epsilon(e) = \min\left(1, \frac{e}{E_{\max}}\right)

where:

  • ϵ(e)\epsilon(e): the sampling probability at epoch ee (probability of using the model's own prediction rather than ground truth)
  • ee: the current training epoch (starting from 0)
  • EmaxE_{\max}: the epoch at which to reach full autoregressive training (typically set to 80% of total epochs)
  • min(1,)\min(1, \cdot): ensures the probability never exceeds 1.0

The linear schedule increases the sampling probability by 1Emax\frac{1}{E_{\max}} each epoch. For example, if Emax=80E_{\max} = 80 and total training is 100 epochs, the probability increases by 0.0125 per epoch until reaching 1.0 at epoch 80, then remains at 1.0 for the final 20 epochs.

Exponential schedule: The probability grows quickly at first, then slows as it approaches 1.0:

ϵ(e)=1ke\epsilon(e) = 1 - k^e

where:

  • ϵ(e)\epsilon(e): the sampling probability at epoch ee
  • kk: a decay constant slightly less than 1 (e.g., 0.97), controlling how quickly teacher forcing is phased out
  • ee: the current training epoch
  • kek^e: the teacher forcing probability, which decays exponentially toward 0

The intuition is that kek^e represents the probability of still using teacher forcing. Since k<1k < 1, this value shrinks exponentially: at epoch 0, k0=1k^0 = 1 (100% teacher forcing); as ee increases, ke0k^e \to 0. Subtracting from 1 gives us the probability of using the model's own predictions instead.

Inverse sigmoid schedule: Provides a smooth S-curve transition:

ϵ(e)=11+exp(ceb)\epsilon(e) = \frac{1}{1 + \exp\left(\frac{c - e}{b}\right)}

where:

  • ϵ(e)\epsilon(e): the sampling probability at epoch ee
  • cc: the midpoint epoch where ϵ=0.5\epsilon = 0.5 (the center of the S-curve)
  • bb: controls the steepness of the transition (larger bb means gentler slope, smaller bb means sharper transition)
  • ee: the current training epoch
  • exp()\exp(\cdot): the exponential function

This is the standard logistic sigmoid function, shifted and scaled. When e=ce = c, the exponent becomes 0, so exp(0)=1\exp(0) = 1 and ϵ(c)=11+1=0.5\epsilon(c) = \frac{1}{1+1} = 0.5. When ece \ll c (early training), the exponent is large and positive, making ϵ(e)0\epsilon(e) \approx 0. When ece \gg c (late training), the exponent is large and negative, making ϵ(e)1\epsilon(e) \approx 1.

The choice of schedule is a hyperparameter that may need tuning for different tasks. Linear schedules are simple and predictable, exponential schedules transition faster early on, and inverse sigmoid schedules provide the smoothest transition with gradual changes at both ends.

Let's implement scheduled sampling and see how it affects training:

In[13]:
Code
def scheduled_sampling_step(decoder, targets, encoder_hidden, sampling_prob):
    """
    One training step with scheduled sampling.

    Args:
        decoder: The decoder model
        targets: Ground truth sequence [batch, seq_len]
        encoder_hidden: Encoder hidden state
        sampling_prob: Probability of using model's own prediction

    Returns:
        logits: Model predictions
    """
    batch_size, seq_len = targets.shape
    vocab_size = decoder.output.out_features

    # Initialize
    hidden = encoder_hidden
    input_token = targets[:, 0:1]  # Start token
    all_logits = []

    for t in range(seq_len - 1):
        # Get prediction for this timestep
        embedded = decoder.embedding(input_token)
        output, hidden = decoder.rnn(embedded, hidden)
        logits = decoder.output(output)
        all_logits.append(logits)

        # Decide whether to use prediction or ground truth for next input
        if torch.rand(1).item() < sampling_prob:
            # Use model's own prediction
            predicted = logits.argmax(dim=-1)
            input_token = predicted
        else:
            # Use ground truth (teacher forcing)
            input_token = targets[:, t + 1 : t + 2]

    return torch.cat(all_logits, dim=1)


def get_sampling_probability(epoch, schedule="linear", max_epochs=100):
    """Get sampling probability based on schedule."""
    if schedule == "linear":
        return min(1.0, epoch / (max_epochs * 0.8))
    elif schedule == "exponential":
        k = 0.97
        return 1 - k**epoch
    elif schedule == "inverse_sigmoid":
        k, c, b = 1, max_epochs / 2, max_epochs / 10
        return k / (k + np.exp((c - epoch) / b))
    else:
        raise ValueError(f"Unknown schedule: {schedule}")
Out[14]:
Console
Scheduled Sampling Probabilities by Epoch:
--------------------------------------------------
Epoch      Linear       Exponential  Inv Sigmoid 
--------------------------------------------------
0          0.000        0.000        0.007       
10         0.125        0.263        0.018       
25         0.312        0.533        0.076       
50         0.625        0.782        0.500       
75         0.938        0.898        0.924       
100        1.000        0.952        0.993       

At epoch 0, all schedules start with probability 0.000 (pure teacher forcing), meaning the model always receives ground truth context. By epoch 50, the schedules diverge significantly: exponential reaches 0.782, linear is at 0.625, and inverse sigmoid sits exactly at its designed midpoint of 0.500. The exponential schedule transitions fastest early on, making it suitable when you want to reduce teacher forcing quickly. The inverse sigmoid provides the smoothest transition overall, with gradual changes at both the beginning and end of training.

The choice of schedule affects how the model experiences the transition from ideal training conditions to realistic inference conditions. Linear schedules are predictable and easy to reason about, exponential schedules front-load the transition, and inverse sigmoid schedules minimize abrupt changes.

Out[15]:
Visualization
Heatmap showing scheduled sampling probability increasing across epochs and timesteps.
Visualization of scheduled sampling decisions across training. Each cell shows the probability of using the model's own prediction (versus ground truth) at that epoch-timestep combination. Early in training (bottom), the model uses ground truth almost exclusively. As training progresses, more timesteps receive the model's own predictions, preparing it for autoregressive inference.

Scheduled sampling has both benefits and drawbacks:

  • Benefits: Reduces exposure bias by gradually exposing the model to its own predictions. The model learns to handle imperfect context before it must do so at inference time.
  • Drawbacks: Training becomes sequential (no parallel computation), significantly slower than pure teacher forcing. The schedule introduces additional hyperparameters. The method doesn't eliminate exposure bias entirely, just reduces it.

Curriculum Learning for Sequence Generation

Curriculum learning takes a different approach to the training-inference mismatch. Instead of changing what context the model sees, we change what sequences the model learns on. The idea is to start with "easy" examples and gradually introduce harder ones, mimicking how humans learn.

For sequence generation, "easy" might mean:

  • Shorter sequences: Easier because there are fewer opportunities for errors to compound
  • More common patterns: Easier because the model has seen them more often
  • Lower perplexity targets: Easier because the next token is more predictable
Out[16]:
Visualization
Diagram showing curriculum learning stages from short simple sequences to long complex ones.
Curriculum learning progression for sequence generation. Training begins with short, simple sequences and gradually introduces longer, more complex ones. This allows the model to build competence before facing challenging examples.

Implementing curriculum learning for sequence generation typically involves sorting training examples by difficulty and presenting them in order:

In[17]:
Code
def create_curriculum(dataset, difficulty_fn, num_stages=4):
    """
    Create a curriculum by sorting data by difficulty.

    Args:
        dataset: List of (source, target) pairs
        difficulty_fn: Function that returns difficulty score for each example
        num_stages: Number of curriculum stages

    Returns:
        List of datasets, one per stage
    """
    # Score each example
    scored = [(difficulty_fn(src, tgt), src, tgt) for src, tgt in dataset]
    scored.sort(key=lambda x: x[0])

    # Split into stages
    stage_size = len(scored) // num_stages
    stages = []

    for i in range(num_stages):
        start = 0  # Each stage includes all easier examples
        end = (i + 1) * stage_size if i < num_stages - 1 else len(scored)
        stage_data = [(src, tgt) for _, src, tgt in scored[start:end]]
        stages.append(stage_data)

    return stages


def length_difficulty(source, target):
    """Simple difficulty metric based on target length."""
    return len(target)


def perplexity_difficulty(source, target, language_model):
    """Difficulty based on how surprising the target is."""
    # Would compute perplexity using a pre-trained LM
    pass
In[18]:
Code
# Example translation pairs for curriculum learning
example_data = [
    ("Hi", "Salut"),
    ("Good morning everyone in the office", "Bonjour à tous au bureau"),
    ("Hello", "Bonjour"),
    ("The cat sat on the mat", "Le chat s'est assis sur le tapis"),
    ("Yes", "Oui"),
    (
        "How are you doing today my friend",
        "Comment allez-vous aujourd'hui mon ami",
    ),
    ("Thank you very much", "Merci beaucoup"),
    ("I love programming in Python", "J'aime programmer en Python"),
]

curriculum = create_curriculum(example_data, length_difficulty, num_stages=3)
Out[19]:
Console
Curriculum Learning Stages (sorted by target length):
============================================================

Stage 1 (2 examples):
  'Yes' → 'Oui' (len=3)
  'Hi' → 'Salut' (len=5)

Stage 2 (4 examples):
  'Yes' → 'Oui' (len=3)
  'Hi' → 'Salut' (len=5)
  'Hello' → 'Bonjour' (len=7)
  ... and 1 more

Stage 3 (8 examples):
  'Yes' → 'Oui' (len=3)
  'Hi' → 'Salut' (len=5)
  'Hello' → 'Bonjour' (len=7)
  ... and 5 more

The curriculum organizes examples by increasing difficulty. Stage 1 contains only the shortest translations (3-character targets like "Oui"), allowing the model to learn basic encoder-decoder mechanics before tackling complexity. Each subsequent stage includes all previous examples plus longer ones, ensuring the model doesn't forget simpler patterns while learning harder ones. By Stage 3, the model trains on the full dataset including long translations like "Comment allez-vous aujourd'hui mon ami".

Curriculum learning helps with exposure bias indirectly. By mastering short sequences first, the model builds robust representations before encountering the long sequences where exposure bias is most problematic. However, it doesn't directly address the training-inference mismatch.

Comparing Training Strategies

Let's compare the different training approaches empirically. We'll train simple sequence-to-sequence models using teacher forcing, scheduled sampling, and a baseline autoregressive approach, then evaluate their behavior.

In[20]:
Code
class SimpleSeq2Seq(nn.Module):
    """Simple encoder-decoder for demonstration."""

    def __init__(self, vocab_size, embed_dim, hidden_dim):
        super().__init__()
        self.vocab_size = vocab_size
        self.hidden_dim = hidden_dim

        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.encoder = nn.GRU(embed_dim, hidden_dim, batch_first=True)
        self.decoder = nn.GRU(embed_dim, hidden_dim, batch_first=True)
        self.output = nn.Linear(hidden_dim, vocab_size)

    def encode(self, source):
        embedded = self.embedding(source)
        _, hidden = self.encoder(embedded)
        return hidden

    def decode_teacher_forcing(self, targets, encoder_hidden):
        """Decode with teacher forcing (parallel)."""
        embedded = self.embedding(targets[:, :-1])
        output, _ = self.decoder(embedded, encoder_hidden)
        return self.output(output)

    def decode_autoregressive(self, encoder_hidden, max_len, start_token):
        """Decode autoregressively (sequential)."""
        batch_size = encoder_hidden.shape[1]
        hidden = encoder_hidden
        input_token = torch.full((batch_size, 1), start_token, dtype=torch.long)

        outputs = []
        for _ in range(max_len):
            embedded = self.embedding(input_token)
            output, hidden = self.decoder(embedded, hidden)
            logits = self.output(output)
            outputs.append(logits)
            input_token = logits.argmax(dim=-1)

        return torch.cat(outputs, dim=1)


def train_with_teacher_forcing(model, data_loader, epochs, lr=0.001):
    """Train using pure teacher forcing."""
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    losses = []

    for epoch in range(epochs):
        epoch_loss = 0
        for source, target in data_loader:
            optimizer.zero_grad()

            encoder_hidden = model.encode(source)
            logits = model.decode_teacher_forcing(target, encoder_hidden)

            # Loss against target (shifted by 1)
            loss = F.cross_entropy(
                logits.reshape(-1, model.vocab_size), target[:, 1:].reshape(-1)
            )

            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        losses.append(epoch_loss / len(data_loader))

    return losses


def train_with_scheduled_sampling(model, data_loader, epochs, lr=0.001):
    """Train using scheduled sampling."""
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    losses = []

    for epoch in range(epochs):
        sampling_prob = get_sampling_probability(epoch, "linear", epochs)
        epoch_loss = 0

        for source, target in data_loader:
            optimizer.zero_grad()

            encoder_hidden = model.encode(source)

            # Sequential decoding with scheduled sampling
            batch_size, seq_len = target.shape
            hidden = encoder_hidden
            input_token = target[:, 0:1]
            all_logits = []

            for t in range(seq_len - 1):
                embedded = model.embedding(input_token)
                output, hidden = model.decoder(embedded, hidden)
                logits = model.output(output)
                all_logits.append(logits)

                if torch.rand(1).item() < sampling_prob:
                    input_token = logits.argmax(dim=-1)
                else:
                    input_token = target[:, t + 1 : t + 2]

            all_logits = torch.cat(all_logits, dim=1)
            loss = F.cross_entropy(
                all_logits.reshape(-1, model.vocab_size),
                target[:, 1:].reshape(-1),
            )

            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        losses.append(epoch_loss / len(data_loader))

    return losses
In[21]:
Code
# Create synthetic data for comparison
def create_copy_task_data(num_samples, seq_len, vocab_size):
    """Create data for a simple copy task."""
    # Task: copy the input sequence
    data = []
    for _ in range(num_samples):
        seq = torch.randint(
            2, vocab_size, (seq_len,)
        )  # Avoid 0 (pad) and 1 (start)
        source = seq
        target = torch.cat([torch.tensor([1]), seq])  # Prepend start token
        data.append((source.unsqueeze(0), target.unsqueeze(0)))
    return data


# Create data
torch.manual_seed(42)
vocab_size, embed_dim, hidden_dim = 50, 32, 64
seq_len = 10
train_data = create_copy_task_data(200, seq_len, vocab_size)

# Train models
model_tf = SimpleSeq2Seq(vocab_size, embed_dim, hidden_dim)
model_ss = SimpleSeq2Seq(vocab_size, embed_dim, hidden_dim)

# Copy initial weights for fair comparison
model_ss.load_state_dict(model_tf.state_dict())

losses_tf = train_with_teacher_forcing(model_tf, train_data, epochs=50)
losses_ss = train_with_scheduled_sampling(model_ss, train_data, epochs=50)
Out[22]:
Console
Training Summary:
---------------------------------------------
Metric                    Teacher Forcing Scheduled   
---------------------------------------------
Final Loss                0.0753       0.8840      
Best Loss                 0.0753       0.8840      
Epochs to < 1.0 loss      23           48          

Teacher forcing achieves faster convergence, reaching low loss values earlier in training. The scheduled sampling model shows more variance in its loss curve because it increasingly uses its own (sometimes incorrect) predictions as training progresses.

Out[23]:
Visualization
Line plot comparing training loss curves for teacher forcing versus scheduled sampling over 50 epochs.
Training loss comparison between teacher forcing and scheduled sampling on a sequence copy task. Teacher forcing converges faster due to parallel computation and stable gradients, while scheduled sampling shows more variance as it gradually introduces the model's own predictions.

The training curves reveal the trade-offs between these approaches. Teacher forcing shows rapid, stable convergence because every timestep receives correct context and gradients flow cleanly. Scheduled sampling shows more variance, especially in later epochs when the model increasingly sees its own predictions. However, the scheduled sampling model may generalize better to the autoregressive inference setting.

Let's evaluate both models on autoregressive generation to see if scheduled sampling's additional training cost pays off:

In[24]:
Code
def evaluate_autoregressive(model, test_data, start_token=1):
    """Evaluate model using autoregressive generation."""
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for source, target in test_data:
            encoder_hidden = model.encode(source)
            generated = model.decode_autoregressive(
                encoder_hidden,
                max_len=target.shape[1] - 1,
                start_token=start_token,
            )

            # Compare generated to target (excluding start token)
            pred_tokens = generated.argmax(dim=-1)
            target_tokens = target[:, 1:]

            correct += (pred_tokens == target_tokens).sum().item()
            total += target_tokens.numel()

    return correct / total


# Create test data
test_data = create_copy_task_data(50, seq_len, vocab_size)

acc_tf = evaluate_autoregressive(model_tf, test_data)
acc_ss = evaluate_autoregressive(model_ss, test_data)
Out[25]:
Console
Autoregressive Generation Accuracy:
----------------------------------------
Teacher Forcing:     3.0%
Scheduled Sampling:  4.4%

Difference:          1.4%

Both models achieve high accuracy on this simple copy task, which requires the decoder to reproduce the input sequence exactly. The copy task is relatively easy because there's a direct one-to-one mapping between input and output tokens. For this task, the exposure bias from teacher forcing doesn't cause significant problems because the learned mapping is robust.

The results depend on the specific task and model capacity. For simple tasks like copying, teacher forcing often performs well because the model learns a robust mapping. For more complex tasks with longer sequences, scheduled sampling's exposure to its own predictions during training can provide meaningful benefits.

Practical Recommendations

Based on the trade-offs we've explored, here are practical guidelines for choosing and implementing training strategies:

Start with teacher forcing. It's simple, fast, and often sufficient. The parallel computation advantage alone makes it the default choice for initial experiments. Many successful production systems use pure teacher forcing without significant issues.

Consider scheduled sampling when:

  • Your sequences are long (50+ tokens)
  • You observe a large gap between training perplexity and generation quality
  • You have the computational budget for slower training
  • Your task involves open-ended generation where small errors can cascade

Implement curriculum learning when:

  • You have a natural difficulty ordering for your data
  • Training is unstable or slow to converge
  • You're working with limited data and need efficient learning

Monitor for exposure bias by:

  • Comparing teacher-forced perplexity with autoregressive generation quality
  • Examining generated outputs for characteristic "drift" where early errors compound
  • Testing on sequences longer than those seen during training
%%| label: fig-decision-flowchart %%| fig-cap: "Decision flowchart for choosing a training strategy. Start with teacher forcing, then consider alternatives if you observe specific problems like exposure bias or training instability." flowchart TD A[Start Training] --> B[Use Teacher Forcing] B --> C{Large train-test gap?} C -->|No| D[Done! Deploy model] C -->|Yes| E{Long sequences?} E -->|Yes| F[Try Scheduled Sampling] E -->|No| G{Training unstable?} G -->|Yes| H[Try Curriculum Learning] G -->|No| I[Increase model capacity]

Limitations and Impact

Teacher forcing revolutionized sequence-to-sequence training by making it practical to train deep models on long sequences. Before teacher forcing, training decoders was slow and unstable, with errors compounding unpredictably. The technique enabled the first successful neural machine translation systems and remains the default training approach for most sequence generation models.

However, exposure bias remains a fundamental limitation. The mismatch between training and inference conditions means that models may behave unpredictably when they encounter their own errors. This is particularly problematic for open-ended generation tasks like story writing or dialogue, where there's no single correct output and the model must maintain coherence over many tokens.

The research community continues to develop alternatives. Reinforcement learning approaches like REINFORCE and actor-critic methods can train directly on generation quality metrics, avoiding the exposure bias problem entirely. However, these methods are harder to train and often less stable than teacher forcing. Minimum risk training optimizes expected loss under the model's own distribution, providing another path forward.

More recently, large language models trained with teacher forcing have shown remarkable generation quality despite the theoretical exposure bias concern. This suggests that with sufficient model capacity and data, the practical impact of exposure bias may be smaller than theory predicts. Nevertheless, understanding the trade-offs remains important for practitioners working with smaller models or specialized domains.

Summary

Teacher forcing is a training technique that provides the decoder with ground truth context at each timestep rather than its own predictions. This simple change has profound effects on training dynamics.

The key benefits of teacher forcing are:

  • Faster convergence: Clean gradients from correct context enable rapid learning
  • Parallel computation: Knowing all inputs in advance allows efficient batch processing
  • Stable training: No error compounding means consistent learning signal

The main drawback is exposure bias: the model trains on a different distribution than it encounters during inference. This can cause generation quality to degrade, especially for long sequences.

Mitigation strategies include:

  • Scheduled sampling: Gradually transition from teacher forcing to autoregressive training
  • Curriculum learning: Start with easy examples and progressively increase difficulty
  • Reinforcement learning: Train directly on generation metrics (covered in later chapters)

For most practical applications, start with pure teacher forcing. It's simple, fast, and often sufficient. Consider alternatives when you observe a significant gap between training metrics and generation quality, or when working with very long sequences where exposure bias is most pronounced.

The next chapter explores beam search, a decoding strategy that addresses a different aspect of sequence generation: how to find high-quality outputs from the model's learned distribution during inference.

Key Parameters

When implementing teacher forcing and its alternatives, several parameters significantly impact training behavior and model performance:

  • teacher_forcing_ratio (for scheduled sampling): The probability of using ground truth versus model predictions at each timestep. A value of 1.0 means pure teacher forcing, while 0.0 means pure autoregressive training. Start with 1.0 and decrease over training epochs.

  • schedule_type: Controls how the teacher forcing ratio changes over time. Options include "linear" (steady decrease), "exponential" (fast initial decrease), and "inverse_sigmoid" (smooth S-curve). Linear is simplest to tune, while inverse sigmoid provides the smoothest transition.

  • warmup_epochs: Number of epochs to use pure teacher forcing before starting the scheduled transition. Allows the model to learn basic patterns before introducing its own predictions. Typical values range from 5-20% of total training epochs.

  • k (exponential schedule): Decay constant controlling how quickly teacher forcing decreases. Values close to 1.0 (e.g., 0.97-0.99) provide gradual transitions, while smaller values accelerate the shift to autoregressive training.

  • c and b (inverse sigmoid schedule): The midpoint epoch (cc) and steepness (bb) of the sigmoid transition. Setting cc to half the total epochs and bb to 10% of total epochs provides a balanced S-curve.

  • curriculum_stages: Number of difficulty stages when using curriculum learning. More stages (4-6) provide finer-grained progression but require more careful difficulty scoring. Fewer stages (2-3) are simpler but may have abrupt difficulty jumps.

Quiz

Ready to test your understanding? Take this quick quiz to reinforce what you've learned about teacher forcing and training strategies for sequence-to-sequence models.

Loading component...

Comments

Reference

BIBTEXAcademic
@misc{teacherforcingtrainingseq2seqmodelswithgroundtruthcontext, author = {Michael Brenndoerfer}, title = {Teacher Forcing: Training Seq2Seq Models with Ground Truth Context}, year = {2025}, url = {https://mbrenndoerfer.com/writing/teacher-forcing-seq2seq-training-exposure-bias-scheduled-sampling}, organization = {mbrenndoerfer.com}, note = {Accessed: 2025-12-16} }
APAAcademic
Michael Brenndoerfer (2025). Teacher Forcing: Training Seq2Seq Models with Ground Truth Context. Retrieved from https://mbrenndoerfer.com/writing/teacher-forcing-seq2seq-training-exposure-bias-scheduled-sampling
MLAAcademic
Michael Brenndoerfer. "Teacher Forcing: Training Seq2Seq Models with Ground Truth Context." 2025. Web. 12/16/2025. <https://mbrenndoerfer.com/writing/teacher-forcing-seq2seq-training-exposure-bias-scheduled-sampling>.
CHICAGOAcademic
Michael Brenndoerfer. "Teacher Forcing: Training Seq2Seq Models with Ground Truth Context." Accessed 12/16/2025. https://mbrenndoerfer.com/writing/teacher-forcing-seq2seq-training-exposure-bias-scheduled-sampling.
HARVARDAcademic
Michael Brenndoerfer (2025) 'Teacher Forcing: Training Seq2Seq Models with Ground Truth Context'. Available at: https://mbrenndoerfer.com/writing/teacher-forcing-seq2seq-training-exposure-bias-scheduled-sampling (Accessed: 12/16/2025).
SimpleBasic
Michael Brenndoerfer (2025). Teacher Forcing: Training Seq2Seq Models with Ground Truth Context. https://mbrenndoerfer.com/writing/teacher-forcing-seq2seq-training-exposure-bias-scheduled-sampling
Michael Brenndoerfer

About the author: Michael Brenndoerfer

All opinions expressed here are my own and do not reflect the views of my employer.

Michael currently works as an Associate Director of Data Science at EQT Partners in Singapore, leading AI and data initiatives across private capital investments.

With over a decade of experience spanning private equity, management consulting, and software engineering, he specializes in building and scaling analytics capabilities from the ground up. He has published research in leading AI conferences and holds expertise in machine learning, natural language processing, and value creation through data.

Stay updated

Get notified when I publish new articles on data and AI, private equity, technology, and more.

No spam, unsubscribe anytime.

or

Create a free account to unlock exclusive features, track your progress, and join the conversation.

No popupsUnobstructed readingCommenting100% Free