Search

Search articles

BERT Pre-training: MLM, NSP & Training Strategies Explained

Michael BrenndoerferUpdated July 19, 202544 min read

Complete guide to BERT pre-training covering masked language modeling, next sentence prediction, data preparation, hyperparameters, and training dynamics with code implementations.

Track your reading progress

Sign in to mark chapters as read and track your learning journey

Sign in →
Reading Level

Choose your expertise level to adjust how many terms are explained. Beginners see more tooltips, experts see fewer to maintain reading flow. Hover over underlined terms for instant definitions.

BERT Pre-training

Pre-training is where BERT learns language. Starting from random weights, the model processes billions of words and emerges with representations that capture syntax, semantics, and even world knowledge. This transformation happens through two complementary objectives: predicting masked tokens and recognizing whether sentences belong together.

Understanding BERT's pre-training reveals why it works so well. The training data, masking strategies, auxiliary objectives, and hyperparameters all interact to produce a model that transfers effectively to downstream tasks. This chapter walks through each component: how the data is prepared, how the objectives shape learning, what hyperparameters matter, and how long training takes to converge.

Pre-training Data

BERT's knowledge comes from text. The original model trained on two massive corpora: the BooksCorpus and English Wikipedia. Together, these sources provide around 3.3 billion words of diverse, high-quality text.

BooksCorpus

A collection of approximately 11,000 unpublished books spanning multiple genres, totaling around 800 million words. Originally compiled for book-to-movie alignment research, it provides long-form narrative text with coherent multi-sentence structure.

The BooksCorpus contributes long-form narrative text with coherent story structure. Each book maintains consistent style, character references, and thematic development across thousands of sentences. This teaches the model long-range dependencies that isolated sentences cannot provide.

Wikipedia adds encyclopedic knowledge. Its articles cover history, science, geography, biography, and countless specialized domains. The text is well-edited, factually dense, and structurally organized with clear section boundaries. From Wikipedia, BERT learns factual associations, entity relationships, and formal writing conventions.

The combination matters. Fiction teaches narrative flow, dialogue patterns, and emotional language. Non-fiction teaches factual structure, technical vocabulary, and logical argumentation. Together, they produce representations that generalize across domains.

Document Structure

Pre-training data must preserve document boundaries because one of BERT's objectives requires sampling sentence pairs from the same document. Each document is processed as a sequence of sentences, maintaining the original order.

The data pipeline typically proceeds as follows:

  1. Document collection: Gather raw text organized by document (books, articles, pages)
  2. Sentence segmentation: Split documents into individual sentences using punctuation and linguistic rules
  3. Tokenization: Convert sentences to WordPiece token sequences
  4. Pair construction: Create training examples by combining sentence pairs

For Wikipedia, each article forms a document. For BooksCorpus, each book forms a document. Sentence boundaries are detected using simple heuristics (periods followed by spaces and capital letters) combined with special handling for abbreviations and quotations.

In[3]:
Code
import re


def segment_sentences(text):
    """
    Simple sentence segmentation.
    Production systems use more sophisticated methods.
    """
    # Split on sentence-ending punctuation followed by space and capital
    pattern = r"(?<=[.!?])\s+(?=[A-Z])"
    sentences = re.split(pattern, text)
    # Filter empty strings and strip whitespace
    return [s.strip() for s in sentences if s.strip()]


# Example document
sample_doc = """
The transformer architecture changed NLP. It introduced self-attention as 
the primary mechanism. Previous models relied heavily on recurrence. 
Attention allows parallel processing of sequences. This enables much 
faster training on modern hardware.
"""

sentences = segment_sentences(sample_doc)
Out[4]:
Console
Document segmented into 5 sentences:
  1. The transformer architecture changed NLP.
  2. It introduced self-attention as 
the primary mechanism.
  3. Previous models relied heavily on recurrence.
  4. Attention allows parallel processing of sequences.
  5. This enables much 
faster training on modern hardware.

Real preprocessing pipelines handle edge cases like abbreviations ("Dr.", "U.S."), URLs, quotations, and numbered lists. BERT's original preprocessing used a combination of sentence boundary detection libraries and manual filtering.

Creating Training Examples

Each training example for BERT consists of two segments (which we call sentence A and sentence B) joined with special tokens. The format is:

[CLS] sentence A tokens [SEP] sentence B tokens [SEP]

where [CLS] is a special classification token and [SEP] marks segment boundaries.

The two sentences are drawn from the same document. Half the time, sentence B immediately follows sentence A in the original text (a "positive" pair). Half the time, sentence B is a random sentence from a different location in the corpus (a "negative" pair). This sampling strategy creates the data for the Next Sentence Prediction task.

In[5]:
Code
import random


def create_training_pairs(documents, positive_ratio=0.5):
    """
    Create sentence pairs for BERT pre-training.

    Args:
        documents: List of documents, each a list of sentences
        positive_ratio: Fraction of pairs from consecutive sentences

    Returns:
        List of (sentence_a, sentence_b, is_next) tuples
    """
    pairs = []
    all_sentences = [sent for doc in documents for sent in doc]

    for doc in documents:
        for i in range(len(doc) - 1):
            sentence_a = doc[i]

            if random.random() < positive_ratio:
                # Positive: actual next sentence
                sentence_b = doc[i + 1]
                is_next = True
            else:
                # Negative: random sentence from corpus
                sentence_b = random.choice(all_sentences)
                is_next = False

            pairs.append((sentence_a, sentence_b, is_next))

    return pairs
In[6]:
Code
# Create sample documents
sample_documents = [
    [
        "The cat sat on the mat.",
        "It was a sunny afternoon.",
        "Birds sang in the nearby trees.",
    ],
    [
        "Machine learning transforms data into predictions.",
        "Neural networks learn hierarchical representations.",
        "Deep learning requires substantial compute.",
    ],
]

random.seed(42)
training_pairs = create_training_pairs(sample_documents)
Out[7]:
Console
Sample training pairs:

1. [NotNext]
   A: The cat sat on the mat.
   B: The cat sat on the mat.

2. [NotNext]
   A: It was a sunny afternoon.
   B: It was a sunny afternoon.

3. [IsNext]
   A: Machine learning transforms data into predictions.
   B: Neural networks learn hierarchical representations.

4. [NotNext]
   A: Neural networks learn hierarchical representations.
   B: Deep learning requires substantial compute.

The example shows how the sampling works. Consecutive sentence pairs teach the model about discourse coherence. Random pairs teach it to distinguish coherent text from jumbled fragments. This binary classification signal, combined with masked token prediction, gives BERT a richer understanding than masking alone would provide.

Sequence Length and Packing

BERT processes fixed-length sequences, typically 512 tokens. Most sentence pairs are shorter than this limit, so multiple pairs can be packed into a single sequence to maximize GPU utilization.

The packing strategy concatenates sentence pairs until the sequence reaches the maximum length:

[CLS] pair1_A [SEP] pair1_B [SEP] [CLS] pair2_A [SEP] pair2_B [SEP] ...

However, the original BERT paper uses a simpler approach: each training example contains exactly one sentence pair, padded to the maximum length. This wastes some compute on padding tokens but simplifies the training pipeline.

A compromise used in the original BERT paper is to sample 90% of training examples with short sequences (128 tokens) and 10% with the full length (512 tokens). Short sequences speed up initial training when the model is learning basic patterns, while long sequences in the final phase provide the context needed for complex relationships.

The MLM Objective in Practice

We covered masked language modeling conceptually in the previous chapter. Here we see exactly how BERT applies it during pre-training.

For each training sequence, BERT randomly selects 15% of tokens for prediction. Of these selected tokens:

  • 80% are replaced with the [MASK] token
  • 10% are replaced with a random vocabulary token
  • 10% are left unchanged

The model must predict the original token at all selected positions, regardless of how they were corrupted.

Out[8]:
Visualization
Pie chart showing 80% MASK token, 10% random token, 10% unchanged.
BERT's 80-10-10 masking strategy. The [MASK] tokens provide the primary learning signal, while random and unchanged tokens prevent the model from relying solely on the mask token during fine-tuning.
In[9]:
Code
from transformers import BertTokenizer

# Load BERT tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")


def apply_bert_masking(input_ids, tokenizer, mask_prob=0.15):
    """
    Apply BERT's MLM masking strategy.

    Args:
        input_ids: Token IDs to mask
        tokenizer: Tokenizer with mask_token_id and vocab
        mask_prob: Probability of selecting each token

    Returns:
        masked_ids: Token IDs after masking
        labels: Original IDs at masked positions, -100 elsewhere
    """
    labels = input_ids.clone()
    masked_ids = input_ids.clone()

    # Probability matrix for masking
    probability_matrix = torch.full(input_ids.shape, mask_prob)

    # Don't mask special tokens
    special_tokens_mask = torch.tensor(
        [
            1
            if token_id
            in [
                tokenizer.cls_token_id,
                tokenizer.sep_token_id,
                tokenizer.pad_token_id,
            ]
            else 0
            for token_id in input_ids.tolist()
        ],
        dtype=torch.bool,
    )
    probability_matrix.masked_fill_(special_tokens_mask, 0.0)

    # Sample masked positions
    masked_indices = torch.bernoulli(probability_matrix).bool()

    # Labels are -100 for non-masked (ignored in loss)
    labels[~masked_indices] = -100

    # 80% -> [MASK]
    indices_replaced = (
        torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool()
        & masked_indices
    )
    masked_ids[indices_replaced] = tokenizer.mask_token_id

    # 10% -> random token
    indices_random = (
        torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool()
        & masked_indices
        & ~indices_replaced
    )
    random_words = torch.randint(
        len(tokenizer), input_ids.shape, dtype=torch.long
    )
    masked_ids[indices_random] = random_words[indices_random]

    # Remaining 10% stay unchanged but still have labels

    return masked_ids, labels
In[10]:
Code
# Demonstrate masking on a sample sentence
text = "The quick brown fox jumps over the lazy dog."
encoding = tokenizer(text, return_tensors="pt", padding=False)
input_ids = encoding["input_ids"].squeeze()

torch.manual_seed(42)
masked_ids, labels = apply_bert_masking(input_ids, tokenizer)
Out[11]:
Console
Original text: The quick brown fox jumps over the lazy dog.

Token-level view:
Position   Original        Masked          Label     
--------------------------------------------------
0          [CLS]           [CLS]           ignore    
1          the             the             ignore    
2          quick           quick           ignore    
3          brown           brown           ignore    
4          fox             fox             ignore    
5          jumps           jumps           ignore    
6          over            over            ignore    
7          the             the             ignore    
8          lazy            lazy            ignore    
9          dog             [MASK]          dog       
10         .               .               ignore    
11         [SEP]           [SEP]           ignore    

The output shows which tokens were selected for masking and how each was treated. Positions with "ignore" labels don't contribute to the MLM loss. The model learns from the positions that do have labels, whether they show [MASK], a random token, or the original token.

Whole Word Masking

The original BERT masked individual WordPiece tokens. This created a problem: when a word like "playing" tokenizes to ["play", "##ing"], masking only "##ing" gives the model an easy hint.

BERT later adopted whole word masking (WWM), where all tokens from the same word are masked together. This provides a stronger learning signal because the model cannot peek at sibling tokens.

In[12]:
Code
def apply_whole_word_masking(input_ids, tokenizer, mask_prob=0.15):
    """
    Apply whole word masking for BERT.

    When any subword of a word is selected, all subwords are masked.
    """
    labels = input_ids.clone()
    masked_ids = input_ids.clone()

    # Get tokens and identify word boundaries
    tokens = tokenizer.convert_ids_to_tokens(input_ids.tolist())

    # Group tokens into words
    word_indices = []  # List of (start_idx, end_idx) for each word
    current_word_start = None

    for i, token in enumerate(tokens):
        if token in [
            tokenizer.cls_token,
            tokenizer.sep_token,
            tokenizer.pad_token,
        ]:
            continue
        if not token.startswith("##"):
            # Start of new word
            if current_word_start is not None:
                word_indices.append((current_word_start, i))
            current_word_start = i

    # Don't forget last word
    if current_word_start is not None:
        word_indices.append((current_word_start, len(tokens)))

    # Select words to mask (15% of words, not tokens)
    num_to_mask = max(1, int(len(word_indices) * mask_prob))
    words_to_mask = random.sample(
        word_indices, min(num_to_mask, len(word_indices))
    )

    # Apply masking to all tokens in selected words
    for start, end in words_to_mask:
        for i in range(start, end):
            labels[i] = input_ids[i]

            rand = random.random()
            if rand < 0.8:
                masked_ids[i] = tokenizer.mask_token_id
            elif rand < 0.9:
                masked_ids[i] = random.randint(0, len(tokenizer) - 1)
            # else: keep original (10%)

    # Set non-masked positions to -100
    for i in range(len(labels)):
        is_masked = any(start <= i < end for start, end in words_to_mask)
        if not is_masked:
            labels[i] = -100

    return masked_ids, labels
In[13]:
Code
# Demonstrate whole word masking
text_wwm = "The unbelievable transformation surprised everyone."
encoding_wwm = tokenizer(text_wwm, return_tensors="pt", padding=False)
input_ids_wwm = encoding_wwm["input_ids"].squeeze()

random.seed(42)
masked_ids_wwm, labels_wwm = apply_whole_word_masking(input_ids_wwm, tokenizer)
Out[14]:
Console
Original text: The unbelievable transformation surprised everyone.

Tokenization:
[CLS] the unbelievable transformation surprised everyone . [SEP]

After whole word masking:
[CLS] the unbelievable transformation surprised everyone [MASK] [MASK]

Notice how multi-token words get fully masked. The model cannot see any part of the word "unbelievable" or "transformation" if they are selected for masking. This produces more robust representations than token-level masking.

Out[15]:
Visualization
Diagram showing the same sentence with token-level masking leaving hints versus whole-word masking removing complete words.
Token-level vs whole-word masking comparison. Token-level masking can leave partial word hints (like '##ing'), while whole-word masking removes all subword units together, forcing the model to predict from true context.

Next Sentence Prediction

Masked language modeling teaches BERT about word-level and sentence-level patterns. But understanding full documents requires reasoning about how sentences relate to each other. Does this sentence follow logically from the previous one? Are these two sentences from the same document?

Next Sentence Prediction (NSP) provides this document-level signal. Given two sentences A and B, the model must predict whether B immediately follows A in the original corpus.

Next Sentence Prediction (NSP)

A binary classification task where the model predicts whether sentence B is the actual next sentence after sentence A in the source document (IsNext) or a random sentence from elsewhere in the corpus (NotNext).

How NSP Works

The [CLS] token's representation at the final layer is used for the NSP prediction. A small classification head (typically a single linear layer) maps this representation to two logits, one for IsNext and one for NotNext.

Out[16]:
Visualization
Diagram showing sentence pair input flowing through BERT to produce CLS embedding which feeds into binary classifier.
Next Sentence Prediction uses the [CLS] token representation. After processing both sentences through the transformer, the [CLS] embedding is passed to a binary classifier that predicts whether the sentences are consecutive.

The training procedure constructs balanced batches with 50% positive (actual next sentence) and 50% negative (random sentence) pairs. The NSP loss is cross-entropy between the predicted probabilities and the binary label.

In[17]:
Code
class NSPHead(nn.Module):
    """Classification head for Next Sentence Prediction."""

    def __init__(self, hidden_size):
        super().__init__()
        self.classifier = nn.Linear(hidden_size, 2)

    def forward(self, cls_embedding):
        # cls_embedding: (batch_size, hidden_size)
        logits = self.classifier(cls_embedding)
        return logits  # (batch_size, 2)


# Example usage
hidden_size = 768
nsp_head = NSPHead(hidden_size)

# Simulate batch of CLS embeddings
batch_size = 4
cls_embeddings = torch.randn(batch_size, hidden_size)
nsp_logits = nsp_head(cls_embeddings)
Out[18]:
Console
Input CLS embeddings shape: torch.Size([4, 768])
NSP logits shape: torch.Size([4, 2])

Predicted probabilities:
  Example 1: IsNext=0.714, NotNext=0.286
  Example 2: IsNext=0.733, NotNext=0.267
  Example 3: IsNext=0.732, NotNext=0.268
  Example 4: IsNext=0.701, NotNext=0.299

The NSP Controversy

NSP was designed to help BERT understand document structure, particularly for tasks like question answering where reasoning across sentences matters. However, subsequent research questioned its value.

RoBERTa (2019) found that removing NSP entirely and training on longer contiguous sequences actually improved performance. The researchers hypothesized that NSP's negative examples (random sentences from different documents) were too easy to distinguish, providing weak signal. Topic differences between sentences gave obvious cues without requiring deeper understanding.

ALBERT took a different approach, replacing NSP with Sentence Order Prediction (SOP). Instead of distinguishing consecutive sentences from random ones, SOP asks whether two consecutive sentences are in the correct order or swapped. This harder task forced the model to learn finer-grained discourse relationships.

Out[19]:
Visualization
Diagram showing IsNext versus NotNext with different topics.
NSP distinguishes consecutive sentences from random pairs. The topic difference makes this easy.
Diagram showing correct order versus swapped order with same topic.
SOP distinguishes correct order from swapped order. Same topic forces focus on coherence.

Despite the controversy, understanding NSP matters for working with pre-trained BERT models. Many available checkpoints were trained with NSP, and the [CLS] token was optimized for this task. When fine-tuning for classification, you're building on representations shaped by both MLM and NSP.

Implementing NSP

Let's implement a complete NSP training step to see how the pieces fit together:

In[20]:
Code
def prepare_nsp_batch(sentence_pairs, tokenizer, max_length=128):
    """
    Prepare a batch of sentence pairs for NSP training.

    Args:
        sentence_pairs: List of (sent_a, sent_b, is_next) tuples
        tokenizer: BERT tokenizer
        max_length: Maximum sequence length

    Returns:
        Dictionary with input_ids, attention_mask, token_type_ids, labels
    """
    sentences_a = [pair[0] for pair in sentence_pairs]
    sentences_b = [pair[1] for pair in sentence_pairs]
    labels = torch.tensor([1 if pair[2] else 0 for pair in sentence_pairs])

    # Tokenize with sentence pairs
    encoding = tokenizer(
        sentences_a,
        sentences_b,
        padding=True,
        truncation=True,
        max_length=max_length,
        return_tensors="pt",
    )

    encoding["labels"] = labels
    return encoding
In[21]:
Code
# Create sample pairs
nsp_pairs = [
    ("The weather is nice today.", "Let's go for a walk.", True),
    (
        "Python is a programming language.",
        "It supports multiple paradigms.",
        True,
    ),
    ("The cat is sleeping.", "Stock prices rose sharply.", False),
    ("She studied hard for the exam.", "The movie was entertaining.", False),
]

batch = prepare_nsp_batch(nsp_pairs, tokenizer)
Out[22]:
Console
Batch contents:
  input_ids shape: torch.Size([4, 17])
  attention_mask shape: torch.Size([4, 17])
  token_type_ids shape: torch.Size([4, 17])
  labels: [1, 1, 0, 0]

Decoded first example:
  [CLS] the weather is nice today. [SEP] let's go for a walk. [SEP]
  Label: IsNext

The token_type_ids distinguish sentence A (0) from sentence B (1). This helps the model understand which tokens belong to which segment, enabling it to reason about sentence relationships.

Combined Pre-training Objective

So far we've examined BERT's two objectives separately: MLM teaches the model to understand words in context, while NSP teaches it to recognize document-level coherence. But how do these objectives work together during training? The answer lies in how we combine their signals into a single optimization target.

Why Combine Two Objectives?

Consider what each objective teaches in isolation. MLM forces the model to build rich token representations by predicting masked words from surrounding context. A model that excels at MLM understands syntax, semantics, and even factual knowledge encoded in word patterns. However, MLM operates at the sentence level. It doesn't explicitly teach the model whether two sentences belong together or flow logically from one to the other.

NSP addresses this gap. By asking "does sentence B follow sentence A?", the model must develop representations that capture discourse relationships, topic coherence, and logical flow. The [CLS] token becomes a summary of the entire sentence pair, encoding whether they form a coherent unit.

Training on both objectives simultaneously allows the model to develop representations that are rich at multiple levels: tokens capture local meaning, while the [CLS] token captures global coherence. The question becomes: how do we combine these two learning signals?

The Combined Loss Formula

BERT uses the simplest possible combination: add the losses together. Each batch of training data produces two loss values, and the total loss is their sum:

Ltotal=LMLM+LNSP\mathcal{L}_{\text{total}} = \mathcal{L}_{\text{MLM}} + \mathcal{L}_{\text{NSP}}

where:

  • Ltotal\mathcal{L}_{\text{total}}: the total pre-training loss that BERT minimizes through gradient descent
  • LMLM\mathcal{L}_{\text{MLM}}: the masked language modeling loss, measuring how well the model predicts masked tokens
  • LNSP\mathcal{L}_{\text{NSP}}: the next sentence prediction loss, measuring how well the model classifies sentence pairs

This equal weighting (1.0 for each) means both objectives contribute equally to the gradient updates. The model cannot cheat by ignoring one task to optimize the other. It must find parameters that satisfy both constraints simultaneously.

Understanding the MLM Loss Component

The MLM loss measures how surprised the model is by the correct answers at masked positions. If the model confidently predicts the right token, the loss is low. If it assigns low probability to the correct token, the loss is high.

Mathematically, we sum the negative log-probabilities across all masked positions:

LMLM=iMlogPθ(xix~)\mathcal{L}_{\text{MLM}} = -\sum_{i \in \mathcal{M}} \log P_\theta(x_i | \tilde{x})

Let's unpack each component:

  • M\mathcal{M}: the set of masked position indices. For a 512-token sequence with 15% masking, this contains approximately 77 positions.
  • xix_i: the original token at position ii before we corrupted it. This is the "ground truth" the model must recover.
  • x~\tilde{x}: the corrupted input sequence where masked positions show [MASK], random tokens, or unchanged tokens (following the 80-10-10 rule).
  • Pθ(xix~)P_\theta(x_i | \tilde{x}): the probability our model (with parameters θ\theta) assigns to the correct token xix_i, given the entire corrupted sequence as context.

The negative sign and logarithm work together to create our loss function. When Pθ(xix~)P_\theta(x_i | \tilde{x}) is close to 1 (confident and correct), log(1)=0\log(1) = 0, contributing zero loss. When Pθ(xix~)P_\theta(x_i | \tilde{x}) is small (the model missed the answer), log(small number)\log(\text{small number}) becomes a large negative value, and the negative sign makes it a large positive loss. This asymmetry creates strong gradients that push the model to assign high probability to correct tokens.

Understanding the NSP Loss Component

The NSP loss has a simpler structure because it's a binary classification problem. The model looks at the [CLS] token's representation and must decide: IsNext or NotNext?

LNSP=logPθ(yNSPh[CLS])\mathcal{L}_{\text{NSP}} = -\log P_\theta(y_{\text{NSP}} | h_{\text{[CLS]}})

Here's what each symbol means:

  • yNSPy_{\text{NSP}}: the true binary label. We encode IsNext as 1 and NotNext as 0. This comes from how we constructed the training data (50% consecutive pairs, 50% random pairs).
  • h[CLS]h_{\text{[CLS]}}: the hidden representation of the [CLS] token after passing through all transformer layers. This 768-dimensional vector (for BERT-base) summarizes the entire sentence pair.
  • Pθ(yNSPh[CLS])P_\theta(y_{\text{NSP}} | h_{\text{[CLS]}}): the probability the model assigns to the correct label, computed by passing h[CLS]h_{\text{[CLS]}} through the NSP classification head.

The structure mirrors the MLM loss: negative log-probability of the correct answer. When the model confidently predicts the right class, loss is near zero. When it's wrong or uncertain, loss increases.

How the Losses Interact

Both losses share the same transformer backbone. When we compute the gradient of Ltotal\mathcal{L}_{\text{total}}, it flows back through both the MLM head and the NSP head, then merges in the transformer layers. This means every transformer layer receives gradient signal from both objectives.

The interplay teaches complementary skills:

  1. Token embeddings learn to represent individual words in ways that support both masked token prediction (MLM) and sentence-pair understanding (NSP).

  2. Attention patterns develop to capture both local dependencies (which tokens help predict a masked word) and global coherence (what makes two sentences related).

  3. The [CLS] token becomes especially important for NSP, learning to aggregate sentence-level meaning, while other positions focus more on token-level prediction.

This multi-task learning is why BERT's representations transfer so well to diverse downstream tasks. The model has learned to encode information at multiple levels of abstraction.

Out[23]:
Visualization
Stacked area chart showing MLM and NSP loss contributions over training steps.
Relative contribution of MLM and NSP losses during training. Early in training, MLM dominates because predicting from a 30K vocabulary is much harder than binary classification. As training progresses, both losses decrease but MLM remains the primary learning signal.

Implementing the Combined Loss

Let's translate these formulas into code. PyTorch's cross_entropy function computes exactly the negative log-probability we described, making the implementation straightforward. The key insight is handling the MLM labels: we use -100 for positions that weren't masked, and PyTorch's ignore_index parameter automatically excludes these from the loss calculation.

In[24]:
Code
def compute_pretraining_loss(
    mlm_logits,  # (batch, seq_len, vocab_size)
    mlm_labels,  # (batch, seq_len) with -100 for non-masked
    nsp_logits,  # (batch, 2)
    nsp_labels,  # (batch,) with 0 or 1
    vocab_size,
):
    """
    Compute combined BERT pre-training loss.
    """
    # MLM loss (ignores -100 labels automatically)
    mlm_loss = F.cross_entropy(
        mlm_logits.view(-1, vocab_size), mlm_labels.view(-1), ignore_index=-100
    )

    # NSP loss
    nsp_loss = F.cross_entropy(nsp_logits, nsp_labels)

    # Combined loss
    total_loss = mlm_loss + nsp_loss

    return total_loss, mlm_loss, nsp_loss

The mlm_logits.view(-1, vocab_size) reshapes the predictions from a 3D tensor (batch, sequence, vocabulary) into a 2D tensor (batch × sequence, vocabulary). This flattening allows us to treat each position independently, which matches our formula where we sum over individual masked positions.

Now let's see what the losses look like at initialization, before any training has occurred. We'll create random logits (simulating an untrained model) and random labels (simulating our training data):

In[25]:
Code
# Simulate pre-training forward pass with random (untrained) model
batch_size, seq_len, vocab_size = 4, 128, 30522

# Random logits simulate an untrained model's predictions
mlm_logits = torch.randn(batch_size, seq_len, vocab_size)
mlm_labels = torch.full((batch_size, seq_len), -100)

# Mask about 15% of positions (as BERT does)
for i in range(batch_size):
    mask_positions = torch.randperm(seq_len)[: int(seq_len * 0.15)]
    mlm_labels[i, mask_positions] = torch.randint(
        0, vocab_size, (len(mask_positions),)
    )

nsp_logits = torch.randn(batch_size, 2)
nsp_labels = torch.randint(0, 2, (batch_size,))

total, mlm, nsp = compute_pretraining_loss(
    mlm_logits, mlm_labels, nsp_logits, nsp_labels, vocab_size
)
Out[26]:
Console
Pre-training losses (random initialization):
  MLM Loss:   10.8903
  NSP Loss:   1.1417
  Total Loss: 12.0319

Expected random MLM loss: ln(30522) = 10.3262
Expected random NSP loss: ln(2) = 0.6931

The results confirm our understanding of the loss functions. An untrained model assigns roughly uniform probability across all tokens, so the expected MLM loss is ln(V)=ln(30522)10.33\ln(V) = \ln(30522) \approx 10.33, representing maximum uncertainty over the vocabulary. Similarly, random binary classification yields expected loss ln(2)0.69\ln(2) \approx 0.69. The actual values fluctuate slightly due to random sampling, but they hover around these theoretical baselines.

As training progresses, both losses decrease. The MLM loss drops as the model learns which tokens fit in which contexts. The NSP loss drops as the model learns to distinguish coherent sentence pairs from random ones. The total loss, being their sum, tracks the model's overall progress toward understanding language.

Pre-training Hyperparameters

BERT's pre-training configuration was carefully tuned. The hyperparameters balance training stability, convergence speed, and final model quality.

Optimizer Settings

BERT uses AdamW, a variant of Adam that decouples weight decay from the gradient update:

  • Learning rate: Peak of 1e-4 for both BERT-base and BERT-large
  • Warmup steps: 10,000 steps of linear warmup from 0 to peak learning rate
  • Learning rate schedule: Linear decay from peak to 0 over the remaining training
  • Weight decay: 0.01 applied to all parameters except biases and LayerNorm
  • Adam β1\beta_1: 0.9
  • Adam β2\beta_2: 0.999
  • Adam ϵ\epsilon: 1e-6
Out[27]:
Visualization
Line plot showing learning rate rising linearly during warmup then decaying linearly to zero.
BERT learning rate schedule with 10,000-step linear warmup followed by linear decay. The warmup phase prevents unstable updates from large gradients early in training.

The warmup phase is critical for training stability. At initialization, gradients can be very large. Starting with a high learning rate would cause the model to diverge. Warmup gradually increases the learning rate, allowing the model to find a reasonable parameter region before taking larger steps.

Batch Size and Sequence Length

BERT uses large batch sizes to improve training efficiency:

  • Batch size: 256 sequences
  • Sequence length: 128 tokens for 90% of training, 512 tokens for final 10%
  • Effective tokens per batch: 256 × 128 = 32,768 tokens (short phase), 256 × 512 = 131,072 tokens (long phase)

The two-phase approach saves compute. Most learning happens with short sequences where attention is cheap. The final phase with long sequences teaches the model to handle extended context.

Out[28]:
Visualization
Stacked area chart showing sequence length phases across training steps.
BERT's two-phase training strategy. The first 90% of training uses shorter 128-token sequences for efficiency, while the final 10% uses full 512-token sequences to learn long-range dependencies.
In[29]:
Code
# Compute training statistics
batch_size = 256
short_seq_len = 128
long_seq_len = 512
total_steps = 1_000_000
short_phase_ratio = 0.9

short_phase_steps = int(total_steps * short_phase_ratio)
long_phase_steps = total_steps - short_phase_steps

short_phase_tokens = short_phase_steps * batch_size * short_seq_len
long_phase_tokens = long_phase_steps * batch_size * long_seq_len
total_tokens = short_phase_tokens + long_phase_tokens
Out[30]:
Console
BERT Pre-training Statistics:

Short sequence phase (seq_len=128):
  Steps: 900,000
  Tokens: 29,491,200,000 (29.5B)

Long sequence phase (seq_len=512):
  Steps: 100,000
  Tokens: 13,107,200,000 (13.1B)

Total tokens processed: 42,598,400,000 (42.6B)
Out[31]:
Visualization
Bar chart comparing tokens processed in short vs long sequence phases.
Tokens processed during BERT pre-training. Despite using short sequences for 90% of training, the long sequence phase processes 4x more tokens per step, accounting for about 31% of total token throughput.

BERT-base processes approximately 40 billion tokens during pre-training. This is roughly 12 passes through the 3.3 billion word corpus, allowing the model to see each token multiple times with different masking patterns.

Model Configurations

BERT comes in two sizes:

BERT model configurations. Times are for 16 TPU v3 chips.
ParameterBERT-baseBERT-large
Layers1224
Hidden size7681024
Attention heads1216
Feed-forward size30724096
Parameters110M340M
Pre-training time~4 days~12 days

The larger model achieves better performance but requires 3× more compute. Most practitioners use BERT-base because it offers a good balance of quality and efficiency.

In[32]:
Code
def count_bert_parameters(
    layers, hidden_size, attention_heads, ff_size, vocab_size=30522
):
    """
    Count parameters in a BERT model.
    """
    # Embedding layers
    token_embeddings = vocab_size * hidden_size
    position_embeddings = 512 * hidden_size
    segment_embeddings = 2 * hidden_size
    embedding_layernorm = 2 * hidden_size  # gamma, beta

    embeddings = (
        token_embeddings
        + position_embeddings
        + segment_embeddings
        + embedding_layernorm
    )

    # Each transformer layer
    # Multi-head attention: Q, K, V projections + output projection
    attention = (
        4 * hidden_size * hidden_size + 4 * hidden_size
    )  # weights + biases

    # Feed-forward: two linear layers
    ff = hidden_size * ff_size + ff_size  # first layer
    ff += ff_size * hidden_size + hidden_size  # second layer

    # Two layer norms per layer
    layer_norms = 4 * hidden_size

    layer_params = attention + ff + layer_norms
    transformer_params = layers * layer_params

    # Output heads
    mlm_head = hidden_size * hidden_size + hidden_size  # dense
    mlm_head += hidden_size  # layer norm
    mlm_head += hidden_size * vocab_size + vocab_size  # output projection

    nsp_head = hidden_size * 2 + 2  # linear classifier

    total = embeddings + transformer_params + mlm_head + nsp_head

    return {
        "embeddings": embeddings,
        "transformer": transformer_params,
        "mlm_head": mlm_head,
        "nsp_head": nsp_head,
        "total": total,
    }
Out[33]:
Console
Parameter breakdown (millions):

Component            BERT-base       BERT-large     
--------------------------------------------------
embeddings                 23.8M           31.8M
transformer                85.1M          302.3M
mlm_head                   24.1M           32.3M
nsp_head                    0.0M            0.0M
total                     133.0M          366.4M

The embeddings account for a significant fraction of parameters due to the large vocabulary. The transformer layers dominate in the larger model. The prediction heads are relatively small.

Out[34]:
Visualization
Horizontal bar chart showing parameter counts for BERT-base components.
BERT-base parameter distribution. Embeddings (23M) and transformer layers (85M) dominate.
Horizontal bar chart showing parameter counts for BERT-large components.
BERT-large parameter distribution. Transformer layers (302M) become the clear majority.

Pre-training Duration and Convergence

How do you know when pre-training is done? BERT trained for 1 million steps, but the loss curve provides more insight than a fixed number.

Loss Curves

Both MLM and NSP losses decrease rapidly in early training, then gradually level off. The MLM loss typically drops from around 10 (random baseline) to 1.5-2.0. The NSP loss drops from 0.69 (random) to below 0.1.

Out[35]:
Visualization
Line plot showing MLM loss starting around 10 and dropping to 2, NSP loss dropping quickly to near 0.
Simulated BERT pre-training loss curves. MLM loss dominates early training and continues improving throughout. NSP loss converges quickly as the task is easier. The total loss is the sum of both.

The curves show that NSP converges much faster than MLM. After about 100K steps, the NSP loss is near its final value. MLM continues improving throughout training, though with diminishing returns after 500K steps.

Validation Metrics

During pre-training, the primary metrics are MLM accuracy (fraction of masked tokens correctly predicted) and NSP accuracy. These provide intuition about how well the model is learning.

Typical final values for BERT-base:

  • MLM accuracy: 60-65%
  • NSP accuracy: 97-98%

The high NSP accuracy reflects that the task is relatively easy. The MLM accuracy of 60% might seem low, but remember the model is predicting from a 30,000-token vocabulary with only surrounding context. Many masked positions have multiple plausible answers.

Out[36]:
Visualization
Line plot showing MLM accuracy rising to 60% and NSP accuracy rising quickly to 97%.
Simulated validation accuracy during BERT pre-training. NSP accuracy rises quickly to near-perfect levels (the task is relatively easy), while MLM accuracy improves more gradually, plateauing around 60-65% due to the inherent difficulty of vocabulary prediction.

When to Stop

Pre-training should stop when:

  1. Loss plateaus: The MLM loss has stopped decreasing for many steps
  2. Downstream performance peaks: Validation on downstream tasks shows no improvement
  3. Compute budget exhausted: The allocated resources are fully used

In practice, most BERT variants train for a fixed number of steps (1M for original BERT, 100K for RoBERTa's larger batches) determined empirically. Longer training generally helps until the model begins overfitting to the pre-training corpus.

Putting It All Together

Let's implement a simplified but complete pre-training loop that demonstrates all the components:

In[37]:
Code
from torch.utils.data import Dataset


class BertPretrainingDataset(Dataset):
    """Dataset for BERT pre-training."""

    def __init__(self, documents, tokenizer, max_length=128, mask_prob=0.15):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.mask_prob = mask_prob

        # Create sentence pairs
        self.pairs = []
        all_sentences = [s for doc in documents for s in doc]

        for doc in documents:
            for i in range(len(doc) - 1):
                # 50% positive, 50% negative
                if random.random() < 0.5:
                    self.pairs.append((doc[i], doc[i + 1], True))
                else:
                    random_sent = random.choice(all_sentences)
                    self.pairs.append((doc[i], random_sent, False))

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        sent_a, sent_b, is_next = self.pairs[idx]

        # Tokenize
        encoding = self.tokenizer(
            sent_a,
            sent_b,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        )

        input_ids = encoding["input_ids"].squeeze()
        attention_mask = encoding["attention_mask"].squeeze()
        token_type_ids = encoding["token_type_ids"].squeeze()

        # Apply MLM masking
        masked_ids, mlm_labels = apply_bert_masking(
            input_ids, self.tokenizer, self.mask_prob
        )

        return {
            "input_ids": masked_ids,
            "attention_mask": attention_mask,
            "token_type_ids": token_type_ids,
            "mlm_labels": mlm_labels,
            "nsp_labels": torch.tensor(1 if is_next else 0),
        }
In[38]:
Code
from torch.utils.data import DataLoader

# Create a small dataset for demonstration
demo_documents = [
    [
        "The sun was setting behind the mountains.",
        "Golden light painted the clouds.",
        "Birds flew home for the evening.",
        "The temperature began to drop.",
    ],
    [
        "Machine learning algorithms learn from data.",
        "They identify patterns automatically.",
        "This enables predictions on new inputs.",
        "Deep learning uses neural networks.",
    ],
    [
        "The chef prepared a delicious meal.",
        "Fresh ingredients make the difference.",
        "She seasoned everything perfectly.",
        "The guests enjoyed every bite.",
    ],
]

random.seed(42)
dataset = BertPretrainingDataset(demo_documents, tokenizer, max_length=64)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
Out[39]:
Console
Dataset size: 9 sentence pairs
Batch count: 3

Sample batch shapes:
  input_ids: torch.Size([64])
  attention_mask: torch.Size([64])
  token_type_ids: torch.Size([64])
  mlm_labels: torch.Size([64])
  nsp_labels: torch.Size([])

This simplified implementation captures the essential pre-training workflow. A production system would add:

  • Whole word masking instead of token-level masking
  • Dynamic masking (fresh masks each epoch)
  • Distributed training across multiple GPUs/TPUs
  • Gradient accumulation for larger effective batch sizes
  • Mixed precision training for memory efficiency
  • Checkpointing and resumption

Limitations and Practical Considerations

BERT's pre-training approach works well, but it comes with constraints that affect how the model can be used.

The compute requirements are substantial. Pre-training BERT-base on 4 TPU v3 chips takes approximately 4 days. BERT-large requires roughly 12 days. For organizations without access to TPU pods or large GPU clusters, pre-training from scratch is impractical. This is why most practitioners start from publicly available checkpoints and fine-tune rather than pre-training. The environmental cost is also non-trivial: training BERT-large once produces CO₂ emissions comparable to a trans-American flight.

The fixed vocabulary constrains domain adaptation. BERT's WordPiece vocabulary was trained on BooksCorpus and Wikipedia. For specialized domains like medicine or law, many terms tokenize into character-level fragments, losing semantic coherence. Continued pre-training on domain-specific text helps, but the vocabulary remains suboptimal. Some projects address this by training entirely new models with domain-specific vocabularies.

NSP's value remains debated. RoBERTa demonstrated that removing NSP entirely and using longer contiguous sequences improved downstream performance. This suggests the original NSP task may have been too easy, providing weak supervision. However, some tasks like question answering that require cross-sentence reasoning may still benefit from sentence-level objectives. The ALBERT paper's Sentence Order Prediction offers a middle ground.

Static masking limits training efficiency. Each training example shows the model the same masked positions across all epochs. Dynamic masking, where positions are re-sampled each epoch, provides more varied training signal. RoBERTa adopted dynamic masking as a default, showing modest improvements.

Key Parameters

The most important parameters when configuring BERT pre-training:

  • mask_prob (default: 0.15): Fraction of tokens selected for masking per sequence. The 15% rate balances learning signal against context preservation. Higher rates (up to 40%) have been explored with mixed results.

  • learning_rate (default: 1e-4): Peak learning rate after warmup. Lower rates (5e-5) may be needed for larger models or smaller batch sizes to prevent instability.

  • warmup_steps (default: 10,000): Number of steps to linearly increase learning rate from 0 to peak. Prevents large gradient updates early in training when loss landscape is rough.

  • batch_size (default: 256): Number of sequences per training batch. Larger batches (up to 8192 with gradient accumulation) improve training efficiency but require more memory.

  • max_seq_length (default: 512): Maximum sequence length the model can process. BERT uses 128 for 90% of training, then 512 for the final 10% to save compute.

  • num_train_steps (default: 1,000,000): Total training steps. More steps generally improve performance until overfitting begins.

  • weight_decay (default: 0.01): L2 regularization coefficient applied to all parameters except biases and LayerNorm. Helps prevent overfitting.

  • positive_ratio (default: 0.5): Fraction of sentence pairs that are actual consecutive sentences versus random pairs for NSP. Balanced 50/50 is standard.

Summary

BERT's pre-training combines two complementary objectives to learn rich language representations:

  • Training data from BooksCorpus and Wikipedia provides diverse, high-quality text totaling 3.3 billion words
  • Masked Language Modeling forces bidirectional context understanding by predicting randomly masked tokens (15% mask rate, 80-10-10 corruption strategy)
  • Next Sentence Prediction adds document-level signal by classifying whether sentence pairs are consecutive (though its value is debated)
  • Hyperparameters include AdamW optimizer with linear warmup and decay, 256 batch size, and two-phase sequence length (128 then 512)
  • Training duration spans 1 million steps over approximately 4 days for BERT-base, processing roughly 40 billion tokens
  • The combined loss equally weights MLM and NSP, allowing both objectives to shape the shared representations

The next chapter explores fine-tuning: how to adapt these pre-trained representations to specific downstream tasks.

Quiz

Ready to test your understanding? Take this quick quiz to reinforce what you've learned about BERT pre-training.

Loading component...
Track your reading progress

Sign in to mark chapters as read and track your learning journey

Sign in →

Comments

Reference

BIBTEXAcademic
@misc{bertpretrainingmlmnsptrainingstrategiesexplained, author = {Michael Brenndoerfer}, title = {BERT Pre-training: MLM, NSP & Training Strategies Explained}, year = {2025}, url = {https://mbrenndoerfer.com/writing/bert-pretraining-mlm-nsp-training-guide}, organization = {mbrenndoerfer.com}, note = {Accessed: 2025-12-19} }
APAAcademic
Michael Brenndoerfer (2025). BERT Pre-training: MLM, NSP & Training Strategies Explained. Retrieved from https://mbrenndoerfer.com/writing/bert-pretraining-mlm-nsp-training-guide
MLAAcademic
Michael Brenndoerfer. "BERT Pre-training: MLM, NSP & Training Strategies Explained." 2025. Web. 12/19/2025. <https://mbrenndoerfer.com/writing/bert-pretraining-mlm-nsp-training-guide>.
CHICAGOAcademic
Michael Brenndoerfer. "BERT Pre-training: MLM, NSP & Training Strategies Explained." Accessed 12/19/2025. https://mbrenndoerfer.com/writing/bert-pretraining-mlm-nsp-training-guide.
HARVARDAcademic
Michael Brenndoerfer (2025) 'BERT Pre-training: MLM, NSP & Training Strategies Explained'. Available at: https://mbrenndoerfer.com/writing/bert-pretraining-mlm-nsp-training-guide (Accessed: 12/19/2025).
SimpleBasic
Michael Brenndoerfer (2025). BERT Pre-training: MLM, NSP & Training Strategies Explained. https://mbrenndoerfer.com/writing/bert-pretraining-mlm-nsp-training-guide
Michael Brenndoerfer

About the author: Michael Brenndoerfer

All opinions expressed here are my own and do not reflect the views of my employer.

Michael currently works as an Associate Director of Data Science at EQT Partners in Singapore, leading AI and data initiatives across private capital investments.

With over a decade of experience spanning private equity, management consulting, and software engineering, he specializes in building and scaling analytics capabilities from the ground up. He has published research in leading AI conferences and holds expertise in machine learning, natural language processing, and value creation through data.

Stay updated

Get notified when I publish new articles on data and AI, private equity, technology, and more.

No spam, unsubscribe anytime.

or

Create a free account to unlock exclusive features, track your progress, and join the conversation.

No popupsUnobstructed readingCommenting100% Free