Search

Search articles

Replaced Token Detection: ELECTRA's Efficient Pretraining Objective

Michael BrenndoerferUpdated July 12, 202535 min read

Learn how replaced token detection trains language models 4x more efficiently than masked language modeling by learning from every position, not just masked tokens.

Track your reading progress

Sign in to mark chapters as read and track your learning journey

Sign in →
Reading Level

Choose your expertise level to adjust how many terms are explained. Beginners see more tooltips, experts see fewer to maintain reading flow. Hover over underlined terms for instant definitions.

Replaced Token Detection

Masked language modeling wastes compute. When BERT masks 15% of tokens and predicts only those positions, 85% of the forward pass produces no training signal. The model processes the entire sequence but learns from a small fraction of it. For expensive pretraining runs consuming millions of GPU hours, this inefficiency is costly.

Replaced Token Detection (RTD) solves this problem with a different formulation. Instead of predicting masked tokens, the model classifies every token as either original or replaced. A small "generator" network produces plausible replacements for masked positions, and a larger "discriminator" network learns to detect which tokens have been swapped. Every position contributes to the loss, making training much more sample-efficient.

This approach, introduced in the ELECTRA paper in 2020, achieves BERT-level performance with a fraction of the compute. The key insight is that detection is easier than generation: distinguishing real from fake requires less capacity than predicting the exact original token, allowing the model to learn from every position rather than just masked ones.

The Generator-Discriminator Setup

RTD uses two networks working together: a small generator that creates replacements and a larger discriminator that detects them. This setup resembles generative adversarial networks (GANs) but with an important difference: training uses maximum likelihood, not adversarial objectives.

Replaced Token Detection

A pretraining objective where some tokens are replaced with plausible alternatives, and the model learns to classify each token as original or replaced. Unlike MLM, which predicts tokens only at masked positions, RTD produces a binary classification signal at every position, greatly improving sample efficiency.

The architecture works as follows. First, we mask some positions in the input sequence, exactly as in MLM. The generator, typically a small transformer, predicts tokens for the masked positions. We sample from the generator's output distribution to create replacement tokens. The discriminator, a larger transformer, receives the corrupted sequence and must identify which tokens were replaced.

Out[3]:
Visualization
Diagram showing two-stage RTD process with generator producing replacements and discriminator detecting them.
The ELECTRA architecture for replaced token detection. A small generator (left) predicts tokens at masked positions. Sampled replacements are inserted into the sequence. The larger discriminator (right) classifies each token as original or replaced. Both networks are trained simultaneously, but only the discriminator is used for downstream tasks.

The key point is that the generator and discriminator have different roles. The generator only needs to produce plausible replacements, not perfect ones. It can be small, perhaps one-quarter or one-third the size of the discriminator. The discriminator, which will be used for downstream tasks, receives the bulk of the parameters and learns rich representations from the detection task.

The RTD Objective

To understand how RTD trains both networks effectively, we need to think about what each network must learn and how to measure its progress. The generator must learn to produce plausible replacements for masked tokens, while the discriminator must learn to detect which tokens have been swapped. These are different tasks requiring different loss functions, yet they must work together in a single training loop.

Let's build up the complete objective by examining each component, understanding why it takes its particular form, and seeing how the pieces combine into an efficient training signal.

Generator Loss: Learning to Replace

The generator faces a familiar task: predict the original token at each masked position. This is identical to masked language modeling. We mask some positions in the input, show the generator the corrupted sequence, and ask it to predict what was hidden.

Why use MLM for the generator? The goal is to produce plausible replacements that will challenge the discriminator. A generator that simply outputs random tokens would be trivial to detect. By training on MLM, the generator learns the statistical patterns of language, producing replacements that fit grammatically and semantically into their contexts. The discriminator then must develop nuanced understanding to distinguish these plausible fakes from originals.

Formally, given a set of masked positions M\mathcal{M}, we minimize the negative log-likelihood of the original tokens:

LGen=iMlogPG(xix~)\mathcal{L}_{\text{Gen}} = -\sum_{i \in \mathcal{M}} \log P_G(x_i | \tilde{x})

where:

  • LGen\mathcal{L}_{\text{Gen}} is the generator loss, identical to standard MLM loss
  • M\mathcal{M} is the set of masked position indices (typically 15% of positions)
  • xix_i is the original token at position ii that we want to predict
  • x~\tilde{x} is the corrupted input sequence with [MASK] tokens at positions in M\mathcal{M}
  • PG(xix~)P_G(x_i | \tilde{x}) is the probability the generator assigns to the correct token xix_i, conditioned on seeing the masked sequence

To understand this formula intuitively, consider what happens as training progresses. Initially, the generator assigns roughly uniform probability across all vocabulary tokens, making PG(xix~)1/VP_G(x_i | \tilde{x}) \approx 1/V where VV is vocabulary size. The log of this small probability is a large negative number, producing high loss. As the generator learns, it concentrates probability on likely tokens, increasing PG(xix~)P_G(x_i | \tilde{x}) toward 1. The log approaches zero, and loss decreases.

Out[4]:
Visualization
Line plot showing negative log loss decreasing from high values near zero probability to near zero at probability 1.0.
Generator loss as a function of predicted probability. When the generator assigns low probability to the correct token, loss is high. As probability increases toward 1.0, loss approaches zero. The steep curve near zero creates strong gradients for learning when predictions are poor.

The summation iterates only over masked positions because those are the only positions where we have a prediction task. Unmasked positions pass through unchanged, contributing nothing to generator learning. This sparsity is the inefficiency that the discriminator's loss will address.

Discriminator Loss: Learning to Detect

Now we arrive at the key innovation. The discriminator receives a sequence where masked positions have been filled with generator samples, and must classify every token as original or replaced. This binary classification task applies to all positions, providing the dense training signal that makes RTD efficient.

Think about what the discriminator must learn. For original tokens, it needs to recognize that they fit naturally into their context. For replaced tokens, even plausible ones, it must detect subtle mismatches. Perhaps the replacement "waiter" in "The waiter barked loudly" is grammatically acceptable but semantically wrong. The discriminator learns these nuances by processing the entire sequence and making a decision at each position.

We create the corrupted sequence x^\hat{x} by sampling from the generator's output distribution at masked positions. The discriminator then outputs a probability for each position, and we compute binary cross-entropy against the ground truth labels:

LDisc=i=1n[yilogD(x^)i+(1yi)log(1D(x^)i)]\mathcal{L}_{\text{Disc}} = -\sum_{i=1}^{n} \left[ y_i \log D(\hat{x})_i + (1 - y_i) \log (1 - D(\hat{x})_i) \right]

where:

  • LDisc\mathcal{L}_{\text{Disc}} is the discriminator loss, a sum over all positions
  • nn is the total sequence length (not just masked positions)
  • yiy_i is the binary label at position ii (1 if original, 0 if replaced)
  • x^\hat{x} is the corrupted sequence with generator samples inserted at previously masked positions
  • D(x^)iD(\hat{x})_i is the discriminator's predicted probability that position ii contains an original token

This binary cross-entropy formula rewards correct predictions at both extremes. Let's trace through what each term contributes:

  1. For original tokens (yi=1y_i = 1): The second term vanishes since (1yi)=0(1 - y_i) = 0. The loss becomes logD(x^)i-\log D(\hat{x})_i. When the discriminator correctly predicts high probability (near 1.0), the log is near zero, contributing little loss. When it incorrectly predicts low probability, the log is a large negative number, producing high loss.

  2. For replaced tokens (yi=0y_i = 0): The first term vanishes. The loss becomes log(1D(x^)i)-\log(1 - D(\hat{x})_i). When the discriminator correctly predicts low probability (near 0), (1D(x^)i)(1 - D(\hat{x})_i) is near 1, and the log is near zero. When it incorrectly predicts high probability, the loss is large.

Out[5]:
Visualization
Line plot showing loss curve for original tokens decreasing as probability increases.
Discriminator loss for original tokens. Loss decreases as the predicted 'is original' probability increases toward 1.0.
Line plot showing loss curve for replaced tokens decreasing as probability decreases.
Discriminator loss for replaced tokens. Loss decreases as the predicted probability decreases toward 0.0.

The key point is that this loss applies to every position in the sequence. Most positions contain original tokens, so the discriminator primarily learns what "normal" tokens look like in context. A small fraction contain replacements, teaching the discriminator what "wrong" looks like. Both types of positions contribute to learning, giving 6-7x more training signal than MLM's masked-only approach.

Combining the Losses

With both component losses defined, we need a strategy for combining them into a single objective that trains both networks. The simplest approach is a weighted sum:

LRTD=LGen+λLDisc\mathcal{L}_{\text{RTD}} = \mathcal{L}_{\text{Gen}} + \lambda \cdot \mathcal{L}_{\text{Disc}}

where:

  • LRTD\mathcal{L}_{\text{RTD}} is the total loss used to update both networks
  • LGen\mathcal{L}_{\text{Gen}} is the generator's MLM loss (computed only at masked positions)
  • LDisc\mathcal{L}_{\text{Disc}} is the discriminator's binary classification loss (computed at all positions)
  • λ\lambda is a weighting factor that balances the two losses

Why weight the losses differently? The answer lies in what we ultimately care about. After pretraining, we discard the generator and use only the discriminator for downstream tasks. The generator exists solely to create challenging training examples. From this perspective, generator learning is a means to an end: we want it good enough to produce useful replacements, but not so good that it dominates the training budget.

The ELECTRA paper uses λ=50\lambda = 50, heavily weighting the discriminator loss. This asymmetry reflects the different roles of the two networks. With this weighting, most of the gradient signal flows to the discriminator, which receives the bulk of the learning capacity. The generator still improves through its share of the gradient, producing increasingly plausible replacements that keep the discriminator challenged.

Out[6]:
Visualization
Stacked bar chart showing generator and discriminator contributions to total loss at different lambda values.
Effect of the discriminator weight on loss contribution. With typical loss values and lambda = 50, the discriminator loss dominates the total, directing most gradient signal to the discriminator. Lower lambda values would shift more learning toward the generator.

This design creates a beneficial training dynamic. As the generator improves, its replacements become harder to detect. The discriminator must develop more sophisticated representations to keep up. But because the generator is smaller and receives less gradient signal, it cannot outpace the discriminator. The asymmetry maintains a productive difficulty level throughout training.

Out[7]:
Visualization
Bar chart showing sparse loss for MLM where only masked positions contribute signal.
MLM computes loss only at masked positions (15% of tokens). Gray positions provide no training signal.
Bar chart showing dense loss for RTD where all positions contribute signal.
RTD computes discriminator loss at every position, providing 6-7x more training signal per forward pass.

Why Detection Is Easier Than Generation

The efficiency gain comes from the difference between detection and generation. Predicting the exact token that was masked requires learning fine-grained distinctions in a vocabulary of 30,000+ tokens. Detecting whether a token was replaced requires only a binary decision at each position.

Consider a masked position in the sentence "The [MASK] barked loudly." To predict the correct token, the model must assign probability to "dog" over thousands of alternatives. But to detect a replacement, the model only needs to recognize that "cat" or "piano" in that position feels wrong, even if it cannot pinpoint exactly what should be there.

This asymmetry allows the discriminator to learn useful representations from every position. Original tokens that fit the context well should score high. Replaced tokens, even plausible ones, often have subtle mismatches with surrounding context. The model learns to detect these mismatches, developing representations that capture semantic coherence.

In[8]:
Code
def detection_vs_generation_example():
    """Illustrate the difference between detection and generation tasks."""

    sentence = "The chef cooked the meal"
    masked_position = 1  # "chef" is masked

    # Generation: must predict exact token from 30k vocabulary
    generation_task = {
        "input": "The [MASK] cooked the meal",
        "target": "chef",
        "vocabulary_size": 30522,
        "task": "Predict exactly which token was masked",
    }

    # Detection: binary classification at each position
    detection_task = {
        "input": "The waiter cooked the meal",
        "labels": ["original", "replaced", "original", "original", "original"],
        "task": "Is each token original or replaced?",
    }

    return generation_task, detection_task


gen_task, det_task = detection_vs_generation_example()
Out[9]:
Console
Generation Task (MLM):
  Input: The [MASK] cooked the meal
  Must predict: 'chef' from 30,522 options

Detection Task (RTD):
  Input: The waiter cooked the meal
  Labels: ['original', 'replaced', 'original', 'original', 'original']
  Task: Binary classification at each position

The generator deliberately produces challenging replacements. If it simply inserted random tokens, detection would be trivial. By sampling from a language model, the generator creates replacements that are semantically plausible but contextually imperfect. This forces the discriminator to develop nuanced understanding of context.

Implementing the Generator

The generator is a small transformer that performs MLM. It masks positions, predicts token distributions, and samples replacements.

In[10]:
Code
class RTDGenerator(nn.Module):
    """Small generator network for producing replacement tokens."""

    def __init__(
        self, vocab_size, d_model=64, n_heads=2, n_layers=2, max_len=128
    ):
        super().__init__()

        # Embeddings
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(max_len, d_model)
        self.layer_norm = nn.LayerNorm(d_model)

        # Small transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=d_model * 4,
            dropout=0.1,
            batch_first=True,
            activation="gelu",
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)

        # Output projection to vocabulary
        self.output_proj = nn.Linear(d_model, vocab_size)

    def forward(self, input_ids):
        batch_size, seq_len = input_ids.shape

        # Get embeddings
        positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
        x = self.token_emb(input_ids) + self.pos_emb(positions)
        x = self.layer_norm(x)

        # Encode (bidirectional attention)
        x = self.encoder(x)

        # Project to vocabulary logits
        logits = self.output_proj(x)
        return logits
In[11]:
Code
def create_masked_input(token_ids, mask_token_id, mask_prob=0.15):
    """Mask tokens and return masked input, labels, and mask positions."""
    labels = token_ids.clone()
    masked_ids = token_ids.clone()

    # Sample positions to mask
    probability_matrix = torch.full(token_ids.shape, mask_prob)
    masked_indices = torch.bernoulli(probability_matrix).bool()

    # Replace with [MASK] token
    masked_ids[masked_indices] = mask_token_id

    # Labels are -100 for non-masked positions (ignored in loss)
    labels[~masked_indices] = -100

    return masked_ids, labels, masked_indices

Let's see the generator in action:

In[12]:
Code
# Create a small vocabulary for demonstration
vocab = [
    "[PAD]",
    "[MASK]",
    "[CLS]",
    "[SEP]",
    "the",
    "chef",
    "cooked",
    "meal",
    "waiter",
    "served",
    "food",
    "ate",
    "a",
    "an",
]
vocab_size = len(vocab)
word_to_idx = {w: i for i, w in enumerate(vocab)}
idx_to_word = {i: w for w, i in word_to_idx.items()}
mask_token_id = word_to_idx["[MASK]"]

# Initialize generator
torch.manual_seed(42)
generator = RTDGenerator(vocab_size, d_model=32, n_heads=2, n_layers=1)

# Create sample input
sentence = ["the", "chef", "cooked", "the", "meal"]
token_ids = torch.tensor([[word_to_idx[w] for w in sentence]])
Out[13]:
Console
Original sentence: ['the', 'chef', 'cooked', 'the', 'meal']
Token IDs: [4, 5, 6, 4, 7]

Masked IDs: [1, 5, 1, 4, 1]
Masked sequence: ['[MASK]', 'chef', '[MASK]', 'the', '[MASK]']
Mask positions: [True, False, True, False, True]

The masking function randomly selects positions to mask. With a 40% masking probability on our 5-token sentence, we typically see 2 positions replaced with [MASK]. The labels tensor stores the original tokens at masked positions for computing the generator loss.

Now let's sample replacements from the generator:

In[14]:
Code
def sample_replacements(generator, masked_ids, labels, temperature=1.0):
    """Sample replacement tokens from the generator."""
    generator.eval()

    with torch.no_grad():
        logits = generator(masked_ids)

    # Apply temperature and sample
    probs = F.softmax(logits / temperature, dim=-1)

    # Create replaced sequence
    replaced_ids = masked_ids.clone()
    is_replaced = torch.zeros_like(masked_ids, dtype=torch.bool)

    # Sample only at masked positions
    mask_positions = labels != -100

    for batch_idx in range(masked_ids.size(0)):
        for pos in range(masked_ids.size(1)):
            if mask_positions[batch_idx, pos]:
                sampled_token = torch.multinomial(
                    probs[batch_idx, pos], num_samples=1
                )
                replaced_ids[batch_idx, pos] = sampled_token
                # Mark as replaced if sampled token differs from original
                original_token = labels[batch_idx, pos]
                is_replaced[batch_idx, pos] = sampled_token != original_token

    return replaced_ids, is_replaced
In[15]:
Code
torch.manual_seed(42)
replaced_ids, is_replaced = sample_replacements(generator, masked_ids, labels)
Out[16]:
Console
Original: ['the', 'chef', 'cooked', 'the', 'meal']
Replaced: ['[MASK]', 'chef', 'ate', 'the', '[MASK]']
Is replaced: [True, False, True, False, True]

The generator samples tokens for masked positions. Because we're using an untrained generator, the replacements are essentially random from the small vocabulary. With training, the generator would produce more plausible replacements that better challenge the discriminator.

Implementing the Discriminator

The discriminator is a larger transformer that performs binary classification at each position. Its architecture is similar to BERT, but the output layer produces a single logit per position rather than vocabulary logits.

In[17]:
Code
class RTDDiscriminator(nn.Module):
    """Discriminator network for detecting replaced tokens."""

    def __init__(
        self, vocab_size, d_model=128, n_heads=4, n_layers=4, max_len=128
    ):
        super().__init__()

        # Embeddings (shared embedding table in ELECTRA, separate here for clarity)
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(max_len, d_model)
        self.layer_norm = nn.LayerNorm(d_model)

        # Larger transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=d_model * 4,
            dropout=0.1,
            batch_first=True,
            activation="gelu",
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)

        # Binary classification head (one output per position)
        self.classifier = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.GELU(),
            nn.Linear(d_model, 1),
        )

    def forward(self, input_ids):
        batch_size, seq_len = input_ids.shape

        # Get embeddings
        positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
        x = self.token_emb(input_ids) + self.pos_emb(positions)
        x = self.layer_norm(x)

        # Encode
        x = self.encoder(x)

        # Binary classification at each position
        logits = self.classifier(x).squeeze(-1)  # (batch, seq_len)
        return logits
In[18]:
Code
# Initialize discriminator (larger than generator)
discriminator = RTDDiscriminator(vocab_size, d_model=64, n_heads=4, n_layers=2)

# Get discriminator predictions
with torch.no_grad():
    disc_logits = discriminator(replaced_ids)
    disc_probs = torch.sigmoid(disc_logits)
Out[19]:
Console
Generator parameters: 17,774
Discriminator parameters: 113,409
Ratio (disc/gen): 6.4x

Input sequence: ['[MASK]', 'chef', 'ate', 'the', '[MASK]']
Discriminator probabilities (is original): [0.49  0.546 0.56  0.53  0.491]
True labels (0=replaced, 1=original): [0, 1, 0, 1, 0]

The discriminator has roughly 2-3x more parameters than the generator, matching the ELECTRA design principle. The probabilities show the untrained discriminator's random guesses. With training, these should approach 1.0 for original tokens and 0.0 for replaced tokens, correctly identifying which positions contain generator-produced replacements.

The Complete Training Loop

Training RTD involves three steps per batch: mask inputs, generate replacements, and update both networks.

In[20]:
Code
def compute_rtd_loss(
    generator,
    discriminator,
    token_ids,
    mask_token_id,
    mask_prob=0.15,
    disc_weight=50.0,
):
    """Compute combined RTD loss for generator and discriminator."""

    # Step 1: Mask tokens
    masked_ids, gen_labels, mask_positions = create_masked_input(
        token_ids, mask_token_id, mask_prob
    )

    # Step 2: Generator forward pass and loss
    gen_logits = generator(masked_ids)
    gen_loss = F.cross_entropy(
        gen_logits.view(-1, gen_logits.size(-1)),
        gen_labels.view(-1),
        ignore_index=-100,
    )

    # Step 3: Sample replacements (detach to not backprop through sampling)
    with torch.no_grad():
        gen_probs = F.softmax(gen_logits, dim=-1)

    replaced_ids = masked_ids.clone()
    disc_labels = torch.ones_like(token_ids, dtype=torch.float)  # 1 = original

    for batch_idx in range(token_ids.size(0)):
        for pos in range(token_ids.size(1)):
            if mask_positions[batch_idx, pos]:
                # Sample from generator
                sampled = torch.multinomial(
                    gen_probs[batch_idx, pos], num_samples=1
                )
                replaced_ids[batch_idx, pos] = sampled
                # Label: 0 if replaced with different token, 1 if same (or original)
                if sampled != token_ids[batch_idx, pos]:
                    disc_labels[batch_idx, pos] = 0.0

    # Step 4: Discriminator forward pass and loss
    disc_logits = discriminator(replaced_ids)
    disc_loss = F.binary_cross_entropy_with_logits(disc_logits, disc_labels)

    # Combined loss
    total_loss = gen_loss + disc_weight * disc_loss

    return total_loss, gen_loss, disc_loss

Let's train on a small corpus to see the dynamics:

In[21]:
Code
# Create a small training corpus
corpus_tokens = []
sentences = [
    ["the", "chef", "cooked", "the", "meal"],
    ["the", "waiter", "served", "the", "food"],
    ["the", "chef", "ate", "the", "food"],
    ["a", "waiter", "cooked", "a", "meal"],
]

for sent in sentences:
    corpus_tokens.append([word_to_idx[w] for w in sent])

corpus_tensor = torch.tensor(corpus_tokens)

# Initialize models
torch.manual_seed(42)
generator = RTDGenerator(vocab_size, d_model=32, n_heads=2, n_layers=1)
discriminator = RTDDiscriminator(vocab_size, d_model=48, n_heads=2, n_layers=2)

# Optimizers
gen_optimizer = torch.optim.AdamW(generator.parameters(), lr=1e-3)
disc_optimizer = torch.optim.AdamW(discriminator.parameters(), lr=1e-3)

# Training loop
gen_losses = []
disc_losses = []

for step in range(200):
    # Sample a batch
    batch_idx = torch.randint(0, len(corpus_tensor), (2,))
    batch = corpus_tensor[batch_idx]

    # Compute loss
    total_loss, gen_loss, disc_loss = compute_rtd_loss(
        generator, discriminator, batch, mask_token_id, mask_prob=0.3
    )

    # Update both models
    gen_optimizer.zero_grad()
    disc_optimizer.zero_grad()
    total_loss.backward()
    gen_optimizer.step()
    disc_optimizer.step()

    gen_losses.append(gen_loss.item())
    disc_losses.append(disc_loss.item())
Out[22]:
Console
Initial generator loss: 2.2161
Final generator loss: 0.4695
Initial discriminator loss: 0.6715
Final discriminator loss: 0.2582

Both losses decrease over training. The generator loss drops as it learns to predict masked tokens, while the discriminator loss decreases as it learns to distinguish original from replaced tokens. The initial discriminator loss around 0.69 corresponds to random guessing (binary cross-entropy for 50/50 predictions), confirming the model starts with no knowledge of which tokens are replaced.

Out[23]:
Visualization
Line plot showing generator MLM loss decreasing over training steps.
Generator loss decreases as it learns to predict masked tokens, producing more plausible replacements over time.
Line plot showing discriminator binary classification loss over training steps.
Discriminator loss tracks its ability to distinguish original from replaced tokens. The two losses are interdependent.

The training dynamics show both losses decreasing. The generator learns to predict masked tokens better, while the discriminator learns to detect replacements. These objectives are complementary: a better generator produces harder replacements, which in turn trains a better discriminator.

RTD Efficiency Advantages

The key advantage of RTD is sample efficiency. Let's quantify the difference:

In[24]:
Code
def compare_efficiency(seq_len, mask_prob=0.15):
    """Compare loss signal per token between MLM and RTD."""

    # MLM: loss only at masked positions
    mlm_loss_positions = int(seq_len * mask_prob)
    mlm_efficiency = mlm_loss_positions / seq_len

    # RTD: discriminator loss at all positions
    rtd_loss_positions = seq_len
    rtd_efficiency = rtd_loss_positions / seq_len

    efficiency_ratio = rtd_efficiency / mlm_efficiency

    return {
        "seq_len": seq_len,
        "mlm_positions": mlm_loss_positions,
        "rtd_positions": rtd_loss_positions,
        "mlm_efficiency": mlm_efficiency,
        "rtd_efficiency": rtd_efficiency,
        "ratio": efficiency_ratio,
    }
Out[25]:
Console
Sample Efficiency Comparison (512 token sequence, 15% masking):
  MLM loss positions: 76 (14.8% of tokens)
  RTD loss positions: 512 (100.0% of tokens)
  RTD efficiency gain: 6.7x more signal per forward pass

This 6-7x efficiency gain is substantial. In practice, ELECTRA achieves comparable results to BERT with roughly 1/4 of the compute. The savings come from learning from every token rather than just masked ones.

Out[26]:
Visualization
Bar chart comparing percentage of positions providing loss signal for MLM (15%) vs RTD (100%).
Compute efficiency comparison between MLM and RTD approaches. RTD provides training signal at every position, allowing it to match MLM performance with much less compute. The efficiency gap grows with sequence length since MLM always masks only 15%.

Weight Sharing and Embedding Tying

The original ELECTRA paper uses weight sharing between generator and discriminator embeddings. This has two benefits: it reduces total parameters and ensures both networks have compatible token representations.

The embedding matrix EE maps token IDs to dense vectors. Since both networks process the same vocabulary, they can share this large matrix. The generator uses the embeddings directly at its smaller hidden dimension, while the discriminator projects them up to its larger dimension:

In[27]:
Code
class ELECTRAWithSharedEmbeddings(nn.Module):
    """ELECTRA with shared embeddings between generator and discriminator."""

    def __init__(
        self,
        vocab_size,
        gen_d_model=64,
        disc_d_model=256,
        n_gen_layers=4,
        n_disc_layers=12,
        max_len=512,
    ):
        super().__init__()

        # Shared token embeddings at generator size
        self.shared_token_emb = nn.Embedding(vocab_size, gen_d_model)

        # Generator uses embeddings directly
        self.gen_pos_emb = nn.Embedding(max_len, gen_d_model)
        self.gen_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=gen_d_model,
                nhead=4,
                dim_feedforward=gen_d_model * 4,
                batch_first=True,
                activation="gelu",
            ),
            num_layers=n_gen_layers,
        )
        self.gen_output = nn.Linear(gen_d_model, vocab_size)

        # Discriminator projects embeddings to its larger dimension
        self.disc_embed_proj = nn.Linear(gen_d_model, disc_d_model)
        self.disc_pos_emb = nn.Embedding(max_len, disc_d_model)
        self.disc_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=disc_d_model,
                nhead=8,
                dim_feedforward=disc_d_model * 4,
                batch_first=True,
                activation="gelu",
            ),
            num_layers=n_disc_layers,
        )
        self.disc_classifier = nn.Linear(disc_d_model, 1)

    def generator_forward(self, input_ids):
        """Generator forward pass for MLM."""
        seq_len = input_ids.size(1)
        positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)

        x = self.shared_token_emb(input_ids) + self.gen_pos_emb(positions)
        x = self.gen_encoder(x)
        logits = self.gen_output(x)
        return logits

    def discriminator_forward(self, input_ids):
        """Discriminator forward pass for RTD."""
        seq_len = input_ids.size(1)
        positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)

        # Project shared embeddings to discriminator dimension
        x = self.shared_token_emb(input_ids)
        x = self.disc_embed_proj(x) + self.disc_pos_emb(positions)
        x = self.disc_encoder(x)
        logits = self.disc_classifier(x).squeeze(-1)
        return logits

The generator embedding dimension determines the shared embedding size. The discriminator projects these embeddings up to its larger hidden dimension. This asymmetry reflects the design principle: the generator needs only enough capacity to produce plausible replacements, while the discriminator needs more capacity to develop rich representations.

Generator Size and Training Dynamics

How big should the generator be relative to the discriminator? Too small, and it produces random, easily-detected replacements. Too large, and it wastes parameters that could go to the discriminator, which is what we actually use for downstream tasks.

The ELECTRA paper found that generator size between 1/4 and 1/2 of the discriminator works best. The sweet spot balances challenging replacements against parameter efficiency.

Out[28]:
Visualization
Line plot showing downstream task performance vs generator-to-discriminator size ratio.
Effect of generator size on downstream performance. Very small generators produce easily-detected replacements, limiting discriminator learning. Very large generators waste parameters. The optimal ratio is around 1/4 to 1/3 of discriminator size.

The training dynamics create an interesting interplay. As the generator improves, it produces harder replacements. This forces the discriminator to learn more nuanced representations. But if the generator becomes too good, detection becomes impossible, and learning stalls. The asymmetric sizing naturally prevents this collapse: the smaller generator cannot outpace the larger discriminator.

RTD vs MLM: When to Choose Each

Both objectives produce strong pretrained models, but they have different strengths.

Choose RTD (ELECTRA) when:

  • Compute budget is limited.
  • You need BERT-level performance with less training.
  • Pretraining from scratch rather than using existing checkpoints.
  • Sample efficiency matters more than final absolute performance.

Choose MLM (BERT) when:

  • You can afford extensive pretraining.
  • Using established pretrained checkpoints.
  • The ecosystem (fine-tuning code, adapters) assumes MLM architecture.
  • Simplicity of single-network training is preferred.

The following table summarizes key differences:

Comparison of MLM and RTD pretraining objectives. RTD provides more training signal per batch at the cost of increased architectural complexity.
AspectMLM (BERT)RTD (ELECTRA)
Loss positions15% (masked only)100% (all positions)
Network countSingle transformerGenerator + discriminator
Training efficiencyBaseline~4x more efficient
Architecture complexitySimpleModerate
Output headVocabulary logitsBinary classifier
Downstream modelFull encoderDiscriminator only

Limitations and Impact

Replaced token detection offers compelling efficiency gains but comes with trade-offs that shape its practical applications.

The two-network training setup adds complexity. Managing separate generator and discriminator networks with different sizes requires careful hyperparameter tuning. The generator learning rate, size ratio, and loss weighting all affect final performance. In contrast, MLM training has fewer moving parts. For practitioners without extensive compute resources to tune these hyperparameters, this complexity can be a barrier.

The discriminator's binary classification objective differs from downstream task objectives. While MLM directly trains vocabulary prediction, which transfers naturally to tasks involving token-level understanding, RTD trains original-vs-replaced classification. The representations transfer well in practice, but the training signal is less directly aligned with common downstream tasks like sequence labeling or question answering.

Weight sharing between generator and discriminator, while parameter-efficient, constrains architecture choices. The generator's embedding dimension becomes the shared base, potentially limiting discriminator capacity. Some practitioners prefer fully separate networks despite the parameter cost.

Despite these limitations, ELECTRA demonstrated that the dominant MLM paradigm was leaving efficiency on the table. The paper's key insight that detection is easier than generation and therefore allows learning from more positions has influenced subsequent work on efficient pretraining. The approach showed that with 1/4 of BERT's compute budget, comparable downstream performance was achievable.

The efficiency gains are most pronounced in the small-to-medium model regime where compute budgets are constrained. At the largest scales, where organizations can afford extensive pretraining, the absolute performance differences between objectives diminish. But for academic researchers, startups, and practitioners training models from scratch, RTD's efficiency remains attractive.

Key Parameters

When implementing RTD training, several parameters affect performance:

  • mask_prob is the fraction of tokens to mask before generating replacements. The default 0.15 (15%) balances training signal against context preservation. Higher rates provide more training examples but degrade generation quality.

  • disc_weight (λ\lambda) is the weighting factor for the discriminator loss in the combined objective. ELECTRA uses 50, heavily prioritizing discriminator learning since it's the model used for downstream tasks.

  • generator_size_ratio is the ratio of generator to discriminator model size. Optimal values lie between 0.25 and 0.33. Smaller generators produce easily-detected replacements, while larger generators waste parameters.

  • d_model (discriminator) is the hidden dimension of the discriminator transformer. Larger values increase capacity but require more compute. ELECTRA-Base uses 256, ELECTRA-Large uses 1024.

  • d_model (generator) is the hidden dimension of the generator, typically 1/4 to 1/3 of the discriminator's dimension. This asymmetry ensures the generator produces challenging but not impossible replacements.

  • n_layers (discriminator) is the number of transformer layers in the discriminator. More layers increase representational capacity. ELECTRA-Base uses 12 layers, ELECTRA-Large uses 24.

  • learning_rate can be set separately for generator and discriminator to improve training stability. Typical values range from 1e-4 to 5e-4, with some implementations using slightly lower rates for the generator.

  • temperature controls randomness when sampling from the generator's output distribution. Higher temperatures produce more diverse replacements, while lower temperatures favor the most likely tokens.

Summary

Replaced token detection reformulates pretraining as a detection problem rather than a generation problem. This chapter covered the key concepts:

  • Generator-discriminator architecture uses a small network to produce replacement tokens and a larger network to detect them, with only the discriminator used for downstream tasks.
  • Detection vs. generation is easier because binary classification requires less capacity than vocabulary prediction, enabling learning from every position.
  • Sample efficiency improves by 6-7x because the discriminator loss applies to all positions, not just the 15% that would be masked in MLM.
  • Generator sizing at 1/4 to 1/3 of discriminator size balances challenging replacements against parameter efficiency.
  • Weight sharing between generator and discriminator embeddings reduces parameters while maintaining compatible representations.
  • Training dynamics create complementary learning where better generators produce harder replacements that train better discriminators.

The next chapter explores denoising objectives, a family of pretraining tasks that corrupt inputs in various ways and train models to reconstruct them.

Quiz

Ready to test your understanding? Take this quick quiz to reinforce what you've learned about replaced token detection and ELECTRA's efficient pretraining approach.

Loading component...
Track your reading progress

Sign in to mark chapters as read and track your learning journey

Sign in →

Comments

Reference

BIBTEXAcademic
@misc{replacedtokendetectionelectrasefficientpretrainingobjective, author = {Michael Brenndoerfer}, title = {Replaced Token Detection: ELECTRA's Efficient Pretraining Objective}, year = {2025}, url = {https://mbrenndoerfer.com/writing/replaced-token-detection-electra-pretraining}, organization = {mbrenndoerfer.com}, note = {Accessed: 2025-12-19} }
APAAcademic
Michael Brenndoerfer (2025). Replaced Token Detection: ELECTRA's Efficient Pretraining Objective. Retrieved from https://mbrenndoerfer.com/writing/replaced-token-detection-electra-pretraining
MLAAcademic
Michael Brenndoerfer. "Replaced Token Detection: ELECTRA's Efficient Pretraining Objective." 2025. Web. 12/19/2025. <https://mbrenndoerfer.com/writing/replaced-token-detection-electra-pretraining>.
CHICAGOAcademic
Michael Brenndoerfer. "Replaced Token Detection: ELECTRA's Efficient Pretraining Objective." Accessed 12/19/2025. https://mbrenndoerfer.com/writing/replaced-token-detection-electra-pretraining.
HARVARDAcademic
Michael Brenndoerfer (2025) 'Replaced Token Detection: ELECTRA's Efficient Pretraining Objective'. Available at: https://mbrenndoerfer.com/writing/replaced-token-detection-electra-pretraining (Accessed: 12/19/2025).
SimpleBasic
Michael Brenndoerfer (2025). Replaced Token Detection: ELECTRA's Efficient Pretraining Objective. https://mbrenndoerfer.com/writing/replaced-token-detection-electra-pretraining
Michael Brenndoerfer

About the author: Michael Brenndoerfer

All opinions expressed here are my own and do not reflect the views of my employer.

Michael currently works as an Associate Director of Data Science at EQT Partners in Singapore, leading AI and data initiatives across private capital investments.

With over a decade of experience spanning private equity, management consulting, and software engineering, he specializes in building and scaling analytics capabilities from the ground up. He has published research in leading AI conferences and holds expertise in machine learning, natural language processing, and value creation through data.

Stay updated

Get notified when I publish new articles on data and AI, private equity, technology, and more.

No spam, unsubscribe anytime.

or

Create a free account to unlock exclusive features, track your progress, and join the conversation.

No popupsUnobstructed readingCommenting100% Free