Search

Search articles

BERT Architecture: Deep Dive into Model Structure and Components

Michael BrenndoerferUpdated July 16, 202532 min read

Explore the BERT architecture in detail covering model sizes (Base vs Large), three-layer embedding system, bidirectional attention patterns, and output representations for downstream tasks.

Track your reading progress

Sign in to mark chapters as read and track your learning journey

Sign in →
Reading Level

Choose your expertise level to adjust how many terms are explained. Beginners see more tooltips, experts see fewer to maintain reading flow. Hover over underlined terms for instant definitions.

BERT Architecture

In October 2018, Google released a paper titled "BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding." Within months, BERT had shattered performance records on eleven NLP benchmarks and fundamentally changed how the field approached language understanding tasks. The architecture that made this possible wasn't revolutionary in its components: it stacked transformer encoder blocks with multi-head self-attention and feed-forward networks. The revolution lay in how these familiar pieces were configured, trained, and applied.

This chapter examines the BERT architecture in detail. We'll explore the two model sizes (Base and Large) and their layer configurations, understand how BERT's three embedding types combine to represent input, examine the bidirectional attention patterns that distinguish BERT from autoregressive models, and analyze the output representations that downstream tasks use. By the end, you'll understand not just what BERT's architecture looks like, but why each design choice matters.

Model Sizes: Base vs Large

BERT comes in two standard configurations: BERT-Base and BERT-Large. These sizes were chosen deliberately to balance capability against practical deployment constraints. BERT-Base matches the hidden dimension of OpenAI's GPT (768), enabling direct comparisons, while BERT-Large pushes scale to demonstrate that larger models capture more nuanced language patterns.

BERT Model Variants

BERT-Base contains 12 transformer encoder layers with 12 attention heads and a hidden dimension of 768, totaling approximately 110 million parameters. BERT-Large doubles the layers to 24, increases heads to 16, and expands the hidden dimension to 1024, reaching approximately 340 million parameters.

The architectural specifications for each variant are:

BERT architectural specifications. The head dimension remains constant at 64 across both variants, with BERT-Large achieving greater capacity through more heads and layers.
ParameterBERT-BaseBERT-Large
Layers (LL)1224
Hidden size (HH)7681024
Attention heads (AA)1216
Head dimension (H/AH/A)6464
Feed-forward size30724096
Vocabulary size30,52230,522
Max sequence length512512
Parameters~110M~340M

Notice that the head dimension remains constant at 64 across both variants. This means BERT-Large achieves more capacity by having more heads (16 vs 12) and more layers (24 vs 12), not by making each head larger. The feed-forward dimension follows the standard 4x multiplier relative to hidden size (768 × 4 = 3072 for Base, 1024 × 4 = 4096 for Large).

Let's compute the exact parameter counts to understand where capacity resides:

In[3]:
Code
def count_bert_parameters(
    vocab_size: int = 30522,
    hidden_size: int = 768,
    num_layers: int = 12,
    num_heads: int = 12,
    intermediate_size: int = 3072,
    max_position: int = 512,
    type_vocab_size: int = 2,
) -> dict:
    """Count parameters in each component of BERT."""
    params = {}

    # Embedding layers
    params["token_embeddings"] = vocab_size * hidden_size
    params["position_embeddings"] = max_position * hidden_size
    params["segment_embeddings"] = type_vocab_size * hidden_size
    params["embedding_layernorm"] = 2 * hidden_size  # gamma and beta

    # Per-layer parameters
    # Self-attention: Q, K, V projections + output projection
    attention_params = 4 * (hidden_size * hidden_size + hidden_size)
    # Feed-forward: two linear layers
    ff_params = 2 * (hidden_size * intermediate_size + intermediate_size)
    ff_params += hidden_size  # output bias
    # Layer norms (2 per layer)
    layernorm_params = 4 * hidden_size

    params["per_layer"] = attention_params + ff_params + layernorm_params
    params["all_layers"] = num_layers * params["per_layer"]

    # Pooler (for [CLS] representation)
    params["pooler"] = hidden_size * hidden_size + hidden_size

    # Total
    params["embeddings_total"] = (
        params["token_embeddings"]
        + params["position_embeddings"]
        + params["segment_embeddings"]
        + params["embedding_layernorm"]
    )
    params["total"] = (
        params["embeddings_total"] + params["all_layers"] + params["pooler"]
    )

    return params
Out[4]:
Console
BERT-Base Parameter Distribution:
  Embeddings: 23,837,184 (21.8%)
  Transformer Layers: 85,091,328 (77.7%)
  Pooler: 590,592 (0.5%)
  Total: 109,519,104

BERT-Large Parameter Distribution:
  Embeddings: 31,782,912 (9.5%)
  Transformer Layers: 302,407,680 (90.2%)
  Pooler: 1,049,600 (0.3%)
  Total: 335,240,192
Out[5]:
Visualization
Stacked bar chart comparing parameter distribution between BERT-Base and BERT-Large.
Parameter distribution across BERT components. Transformer layers dominate in both variants, but their share grows from 78% to 90% as model size increases, while embeddings become proportionally smaller.

The vast majority of parameters reside in the transformer layers, particularly in the feed-forward networks. Token embeddings represent a significant portion due to the large vocabulary, but their contribution decreases proportionally as model depth increases. This distribution matters for understanding where BERT stores knowledge: factual information tends to concentrate in feed-forward weights, while attention patterns encode syntactic and semantic relationships.

The Input Representation

BERT's input representation is one of its most distinctive features. Unlike simpler models that use only token embeddings, BERT combines three embedding types to capture different aspects of the input. This design enables BERT to process sentence pairs for tasks like question answering and natural language inference.

Three Embedding Layers

Every input token receives three embeddings that are summed together:

  1. Token embeddings: Standard learned embeddings that map each vocabulary token to a dense vector
  2. Position embeddings: Learned embeddings for each position (0 to 511) encoding sequential order
  3. Segment embeddings: Two learned embeddings (A and B) indicating which sentence a token belongs to
Out[6]:
Visualization
Diagram showing three stacked embedding matrices being summed to produce the final input representation.
BERT combines three embedding types by element-wise addition. Token embeddings capture word meaning, position embeddings encode sequential order, and segment embeddings distinguish sentence pairs.

The segment embeddings deserve special attention. BERT was designed for tasks involving sentence pairs: given two sentences, determine if the second follows the first (next sentence prediction), if the first entails the second (natural language inference), or find the answer span (question answering). The segment embeddings allow the model to distinguish which tokens belong to which sentence even after they're concatenated.

Special Tokens

BERT uses several special tokens to structure its input:

  • [CLS]: Prepended to every input. Its final representation is used for classification tasks
  • [SEP]: Inserted between sentences and at the end of the input to mark boundaries
  • [MASK]: Used during pretraining to indicate positions the model should predict
  • [PAD]: Fills sequences shorter than the batch maximum length
  • [UNK]: Represents tokens not in the vocabulary

Let's implement the embedding layer to see how these components combine:

In[7]:
Code
class BertEmbeddings(nn.Module):
    """BERT embedding layer combining token, position, and segment embeddings."""

    def __init__(
        self,
        vocab_size: int = 30522,
        hidden_size: int = 768,
        max_position: int = 512,
        type_vocab_size: int = 2,
        dropout: float = 0.1,
    ):
        super().__init__()

        # Three embedding tables
        self.token_embeddings = nn.Embedding(vocab_size, hidden_size)
        self.position_embeddings = nn.Embedding(max_position, hidden_size)
        self.segment_embeddings = nn.Embedding(type_vocab_size, hidden_size)

        # Layer normalization and dropout
        self.layer_norm = nn.LayerNorm(hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(dropout)

        # Register position ids as buffer (not a parameter)
        self.register_buffer(
            "position_ids", torch.arange(max_position).unsqueeze(0)
        )

    def forward(
        self,
        input_ids: torch.Tensor,
        segment_ids: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """
        Combine embeddings for input tokens.

        Args:
            input_ids: Token indices, shape (batch_size, seq_len)
            segment_ids: Segment indices (0 or 1), shape (batch_size, seq_len)

        Returns:
            Combined embeddings, shape (batch_size, seq_len, hidden_size)
        """
        seq_len = input_ids.size(1)

        # Get position ids for this sequence length
        position_ids = self.position_ids[:, :seq_len]

        # Default segment ids to all zeros (single sentence)
        if segment_ids is None:
            segment_ids = torch.zeros_like(input_ids)

        # Look up embeddings
        token_emb = self.token_embeddings(input_ids)
        position_emb = self.position_embeddings(position_ids)
        segment_emb = self.segment_embeddings(segment_ids)

        # Sum all three
        embeddings = token_emb + position_emb + segment_emb

        # Normalize and apply dropout
        embeddings = self.layer_norm(embeddings)
        embeddings = self.dropout(embeddings)

        return embeddings
Out[8]:
Console
Input shape: torch.Size([1, 8])
Output shape: torch.Size([1, 8, 768])
Embedding dimension: 768

The embedding layer transforms our 8-token input into a tensor of shape (1, 8, 768), where each token now has a 768-dimensional representation combining its identity, position, and segment information. This dense representation is what flows through the subsequent transformer layers.

The element-wise addition of three embedding types may seem unusual. Why not concatenate them? The answer lies in parameter efficiency and training dynamics. Addition keeps the hidden dimension fixed at 768, while concatenation would triple it. More subtly, addition forces the model to learn representations where token meaning, positional information, and segment identity can coexist in the same vector space. The layer normalization after addition rescales these combined representations to have consistent statistics.

Learned vs Sinusoidal Positions

BERT uses learned position embeddings rather than the sinusoidal encodings from the original transformer. Each of the 512 possible positions gets its own learned vector. This choice trades generalization for expressiveness: learned embeddings cannot extrapolate to positions beyond 512, but they can capture position-specific patterns that sinusoidal encodings cannot.

In[9]:
Code
# Compare learned vs sinusoidal position embeddings
def sinusoidal_position_encoding(max_len: int, d_model: int) -> np.ndarray:
    """Generate sinusoidal position encodings."""
    position = np.arange(max_len)[:, np.newaxis]
    div_term = np.exp(np.arange(0, d_model, 2) * (-np.log(10000.0) / d_model))

    pe = np.zeros((max_len, d_model))
    pe[:, 0::2] = np.sin(position * div_term)
    pe[:, 1::2] = np.cos(position * div_term)
    return pe


# Get learned embeddings from our model
learned_pe = embeddings.position_embeddings.weight.detach().numpy()
sinusoidal_pe = sinusoidal_position_encoding(512, 768)
Out[10]:
Visualization
Heatmap of learned position embeddings showing irregular but structured patterns.
Learned position embeddings show irregular patterns that capture position-specific information discovered during training.
Heatmap of sinusoidal position embeddings showing regular wave patterns.
Sinusoidal position embeddings follow a structured pattern that enables extrapolation to unseen sequence lengths.

The learned embeddings appear less regular because they capture whatever positional patterns help with the pretraining objectives. The sinusoidal pattern's mathematical regularity means positions can be expressed as linear combinations of other positions, enabling some length generalization. BERT prioritized expressiveness over extrapolation since most downstream tasks don't require sequences longer than 512 tokens.

Transformer Encoder Layers

The core of BERT consists of stacked transformer encoder blocks. Each block applies multi-head self-attention followed by a position-wise feed-forward network, with residual connections and layer normalization around each sub-layer.

Layer Structure

Each transformer layer follows a consistent pattern:

  1. Multi-head self-attention with residual connection
  2. Layer normalization
  3. Feed-forward network with residual connection
  4. Layer normalization

BERT uses "post-norm" architecture, where layer normalization follows each sub-layer rather than preceding it. This differs from the "pre-norm" variant used in some later models like GPT-2.

In[11]:
Code
class BertSelfAttention(nn.Module):
    """Multi-head self-attention for BERT."""

    def __init__(
        self, hidden_size: int = 768, num_heads: int = 12, dropout: float = 0.1
    ):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads

        # Q, K, V projections
        self.query = nn.Linear(hidden_size, hidden_size)
        self.key = nn.Linear(hidden_size, hidden_size)
        self.value = nn.Linear(hidden_size, hidden_size)

        # Output projection
        self.output = nn.Linear(hidden_size, hidden_size)
        self.dropout = nn.Dropout(dropout)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        batch_size, seq_len, _ = hidden_states.shape

        # Project to Q, K, V
        q = self.query(hidden_states)
        k = self.key(hidden_states)
        v = self.value(hidden_states)

        # Reshape for multi-head attention: (batch, seq, heads, head_dim)
        q = q.view(
            batch_size, seq_len, self.num_heads, self.head_dim
        ).transpose(1, 2)
        k = k.view(
            batch_size, seq_len, self.num_heads, self.head_dim
        ).transpose(1, 2)
        v = v.view(
            batch_size, seq_len, self.num_heads, self.head_dim
        ).transpose(1, 2)

        # Compute attention scores
        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim**0.5)

        # Apply attention mask if provided
        if attention_mask is not None:
            scores = scores + attention_mask

        # Softmax and dropout
        attention_probs = F.softmax(scores, dim=-1)
        attention_probs = self.dropout(attention_probs)

        # Apply attention to values
        context = torch.matmul(attention_probs, v)

        # Reshape back: (batch, seq, hidden)
        context = (
            context.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
        )

        # Output projection
        output = self.output(context)

        return output, attention_probs

The attention computation follows the standard scaled dot-product formula. Given query, key, and value matrices, attention computes a weighted combination of values where the weights depend on query-key similarity:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

where:

  • QQ: the query matrix with shape (sequence length, head dimension), representing what each position is "looking for"
  • KK: the key matrix with shape (sequence length, head dimension), representing what each position "offers" for matching
  • VV: the value matrix with shape (sequence length, head dimension), containing the information to aggregate
  • KTK^T: the transpose of KK, enabling the matrix multiplication QKTQK^T that produces similarity scores
  • dkd_k: the head dimension (64 in BERT), used for scaling
  • dk\sqrt{d_k}: the scaling factor that prevents dot products from growing too large
  • softmax()\text{softmax}(\cdot): normalizes scores to a probability distribution over positions

The scaling by dk\sqrt{d_k} is crucial for training stability. Without it, the dot products QKTQK^T grow in magnitude with the dimension, pushing the softmax into regions where gradients vanish. With 64-dimensional heads, dot products could easily reach values of 8-10, making the softmax output nearly one-hot and preventing gradient flow.

Out[12]:
Visualization
Histogram of unscaled attention scores showing wide spread from -15 to 15.
Without scaling, attention scores have high variance, causing softmax to produce near-one-hot distributions.
Histogram of scaled attention scores showing concentrated distribution from -2 to 2.
With scaling by the square root of dimension, scores have controlled variance, producing softer attention distributions.

The histograms above demonstrate this effect. Raw dot products of 64-dimensional vectors have variance around 64, producing scores that span a wide range. After dividing by 64=8\sqrt{64} = 8, the variance drops to approximately 1, keeping scores in a range where softmax produces meaningful probability distributions rather than near-deterministic outputs.

Feed-Forward Network

Each layer's feed-forward network expands the representation to 4x the hidden dimension, applies a non-linearity, then projects back:

In[13]:
Code
class BertFeedForward(nn.Module):
    """Position-wise feed-forward network for BERT."""

    def __init__(
        self,
        hidden_size: int = 768,
        intermediate_size: int = 3072,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.dense1 = nn.Linear(hidden_size, intermediate_size)
        self.dense2 = nn.Linear(intermediate_size, hidden_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden = self.dense1(hidden_states)
        hidden = F.gelu(hidden)  # BERT uses GELU activation
        hidden = self.dense2(hidden)
        hidden = self.dropout(hidden)
        return hidden

BERT uses GELU (Gaussian Error Linear Unit) activation rather than ReLU. GELU provides a smooth approximation to the gating mechanism, allowing small negative values to pass through while still providing non-linearity. The function can be understood as multiplying each input by the probability that a standard normal random variable would be less than that input:

GELU(x)=xΦ(x)=x12[1+erf(x2)]\text{GELU}(x) = x \cdot \Phi(x) = x \cdot \frac{1}{2}\left[1 + \text{erf}\left(\frac{x}{\sqrt{2}}\right)\right]

where:

  • xx: the input value to the activation function
  • Φ(x)\Phi(x): the cumulative distribution function (CDF) of the standard normal distribution, giving the probability that a standard normal random variable is less than xx
  • erf()\text{erf}(\cdot): the error function, a mathematical function related to the normal distribution's CDF
  • 2\sqrt{2}: a scaling constant that converts from the standard error function to the normal CDF

Unlike ReLU, which abruptly zeroes out all negative inputs, GELU provides a smooth transition. For large positive xx, Φ(x)1\Phi(x) \approx 1, so GELU(x)x(x) \approx x. For large negative xx, Φ(x)0\Phi(x) \approx 0, so GELU(x)0(x) \approx 0. The smooth transition around zero means small negative values can still contribute, which empirically improves training dynamics in transformer models.

Out[14]:
Visualization
Line plot comparing GELU and ReLU activation functions from -4 to 4.
GELU provides a smooth transition compared to ReLU's hard cutoff at zero. The shaded region shows where GELU allows small negative values to pass through, potentially preserving useful gradient information.

Complete Encoder Layer

Combining attention and feed-forward with residual connections and layer normalization:

In[15]:
Code
class BertLayer(nn.Module):
    """Single BERT encoder layer."""

    def __init__(
        self,
        hidden_size: int = 768,
        num_heads: int = 12,
        intermediate_size: int = 3072,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.attention = BertSelfAttention(hidden_size, num_heads, dropout)
        self.feed_forward = BertFeedForward(
            hidden_size, intermediate_size, dropout
        )
        self.attention_norm = nn.LayerNorm(hidden_size, eps=1e-12)
        self.output_norm = nn.LayerNorm(hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(dropout)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        # Self-attention with residual
        attention_output, attention_probs = self.attention(
            hidden_states, attention_mask
        )
        attention_output = self.dropout(attention_output)
        hidden_states = self.attention_norm(hidden_states + attention_output)

        # Feed-forward with residual
        ff_output = self.feed_forward(hidden_states)
        hidden_states = self.output_norm(hidden_states + ff_output)

        return hidden_states, attention_probs
Out[16]:
Console
Input shape: torch.Size([2, 10, 768])
Output shape: torch.Size([2, 10, 768])
Attention probabilities shape: torch.Size([2, 12, 10, 10])

The layer maintains the same tensor shape between input and output, which is essential for stacking layers and enabling residual connections. The attention probabilities tensor shows dimensions for batch, heads, and both query and key positions, confirming that each of the 12 heads computes its own attention pattern over the sequence.

The residual connections are crucial for training deep networks. They allow gradients to flow directly backward through the network, mitigating the vanishing gradient problem. Without residuals, training a 12 or 24-layer network would be extremely difficult.

Bidirectional Attention Patterns

BERT's attention is bidirectional: every token can attend to every other token in the sequence. This contrasts sharply with the causal (left-to-right) attention used in GPT and other autoregressive models. The difference has profound implications for what patterns BERT can learn.

Visualizing Attention

Let's examine what attention patterns look like in practice. We'll create sample attention weights and visualize how different heads might specialize:

Out[17]:
Visualization
Heatmap showing attention weights distributed broadly across all positions.
A 'broad attention' head that attends relatively uniformly across the sequence, useful for gathering global context.
Heatmap showing attention weights concentrated near the diagonal.
A 'local attention' head that focuses on nearby tokens, capturing local syntactic relationships.

Research on BERT's attention patterns reveals consistent specialization:

  • Early layers tend to attend broadly, with some heads focusing on the [CLS] token
  • Middle layers develop syntactic patterns, with heads tracking subject-verb relationships, dependency arcs, and constituent boundaries
  • Later layers show more semantic attention, with heads focusing on semantically related tokens

Attention Mask for Padding

When processing batches of variable-length sequences, BERT uses attention masks to prevent attending to padding tokens. The mask is applied as a large negative value before softmax, effectively zeroing out attention to padded positions:

In[18]:
Code
def create_attention_mask(
    input_ids: torch.Tensor, pad_token_id: int = 0
) -> torch.Tensor:
    """
    Create attention mask for padded sequences.

    Args:
        input_ids: Token indices, shape (batch, seq_len)
        pad_token_id: ID of the padding token

    Returns:
        Attention mask, shape (batch, 1, 1, seq_len) for broadcasting
    """
    # 1 for real tokens, 0 for padding
    mask = (input_ids != pad_token_id).float()

    # Reshape for broadcasting with attention scores (batch, heads, seq, seq)
    mask = mask.unsqueeze(1).unsqueeze(2)

    # Convert to additive mask: 0 for attend, -inf for ignore
    mask = (1.0 - mask) * -10000.0

    return mask
Out[19]:
Console
Input tokens (0 = padding):
tensor([[ 101, 7592, 2088,  102,    0,    0,    0,    0]])

Attention mask shape: torch.Size([1, 1, 1, 8])
Mask values (0 = attend, -10000 = ignore):
tensor([    -0.,     -0.,     -0.,     -0., -10000., -10000., -10000., -10000.])

The large negative value (-10000) becomes approximately zero after softmax, ensuring padded positions receive no attention weight.

Output Representations

BERT produces contextualized representations at every position. Different downstream tasks use these representations in different ways.

The [CLS] Token

The [CLS] token's final hidden state is designed to aggregate sequence-level information. During pretraining with the Next Sentence Prediction task, the [CLS] representation must contain enough information to determine whether two sentences are consecutive. This encourages [CLS] to capture global semantic content.

For classification tasks, you typically pass the [CLS] representation through a task-specific linear layer:

In[20]:
Code
class BertClassificationHead(nn.Module):
    """Classification head using [CLS] token representation."""

    def __init__(
        self, hidden_size: int = 768, num_classes: int = 2, dropout: float = 0.1
    ):
        super().__init__()
        self.dense = nn.Linear(hidden_size, hidden_size)
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(hidden_size, num_classes)

    def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
        """
        Args:
            sequence_output: BERT output, shape (batch, seq_len, hidden)

        Returns:
            Classification logits, shape (batch, num_classes)
        """
        # Take [CLS] token (first position)
        cls_output = sequence_output[:, 0, :]

        # Project through dense layer
        pooled = self.dense(cls_output)
        pooled = torch.tanh(pooled)
        pooled = self.dropout(pooled)

        # Classify
        logits = self.classifier(pooled)
        return logits

Token-Level Representations

For sequence labeling tasks like named entity recognition or part-of-speech tagging, you use the representation at each token position:

In[21]:
Code
class BertTokenClassificationHead(nn.Module):
    """Token classification head for sequence labeling."""

    def __init__(
        self, hidden_size: int = 768, num_labels: int = 9, dropout: float = 0.1
    ):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(hidden_size, num_labels)

    def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
        """
        Args:
            sequence_output: BERT output, shape (batch, seq_len, hidden)

        Returns:
            Token logits, shape (batch, seq_len, num_labels)
        """
        output = self.dropout(sequence_output)
        logits = self.classifier(output)
        return logits

Span Representations

For extractive question answering, BERT predicts start and end positions of the answer span. Two linear layers project each token's representation to start and end logits:

In[22]:
Code
class BertQuestionAnsweringHead(nn.Module):
    """QA head for extractive question answering."""

    def __init__(self, hidden_size: int = 768):
        super().__init__()
        self.start_classifier = nn.Linear(hidden_size, 1)
        self.end_classifier = nn.Linear(hidden_size, 1)

    def forward(
        self, sequence_output: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            sequence_output: BERT output, shape (batch, seq_len, hidden)

        Returns:
            start_logits: shape (batch, seq_len)
            end_logits: shape (batch, seq_len)
        """
        start_logits = self.start_classifier(sequence_output).squeeze(-1)
        end_logits = self.end_classifier(sequence_output).squeeze(-1)
        return start_logits, end_logits
Out[23]:
Console
Classification output shape: torch.Size([2, 3])
Token classification output shape: torch.Size([2, 20, 9])
QA start/end logits shape: torch.Size([2, 20]) torch.Size([2, 20])

Each head produces output with the appropriate shape for its task. The classification head reduces the sequence to a single prediction per sample. The token classification head produces a label prediction for each position. The QA head generates start and end scores for every token, allowing span extraction by finding the highest-scoring start-end pair.

The flexibility of BERT's output representations is key to its success. The same pretrained model can power classification, sequence labeling, question answering, and many other tasks by simply changing the task-specific head.

The Complete BERT Model

Let's assemble all components into a complete BERT implementation:

In[24]:
Code
class BertEncoder(nn.Module):
    """Stack of BERT encoder layers."""

    def __init__(
        self,
        num_layers: int = 12,
        hidden_size: int = 768,
        num_heads: int = 12,
        intermediate_size: int = 3072,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.layers = nn.ModuleList(
            [
                BertLayer(hidden_size, num_heads, intermediate_size, dropout)
                for _ in range(num_layers)
            ]
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor | None = None,
    ) -> tuple[torch.Tensor, list[torch.Tensor]]:
        all_attention_probs = []

        for layer in self.layers:
            hidden_states, attention_probs = layer(
                hidden_states, attention_mask
            )
            all_attention_probs.append(attention_probs)

        return hidden_states, all_attention_probs


class BertModel(nn.Module):
    """Complete BERT model."""

    def __init__(
        self,
        vocab_size: int = 30522,
        hidden_size: int = 768,
        num_layers: int = 12,
        num_heads: int = 12,
        intermediate_size: int = 3072,
        max_position: int = 512,
        type_vocab_size: int = 2,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.embeddings = BertEmbeddings(
            vocab_size, hidden_size, max_position, type_vocab_size, dropout
        )
        self.encoder = BertEncoder(
            num_layers, hidden_size, num_heads, intermediate_size, dropout
        )

        # Pooler for [CLS] representation
        self.pooler = nn.Linear(hidden_size, hidden_size)

    def forward(
        self,
        input_ids: torch.Tensor,
        segment_ids: torch.Tensor | None = None,
        attention_mask: torch.Tensor | None = None,
    ) -> dict[str, torch.Tensor]:
        # Create attention mask if not provided
        if attention_mask is None:
            attention_mask = create_attention_mask(input_ids)

        # Embed inputs
        hidden_states = self.embeddings(input_ids, segment_ids)

        # Pass through encoder layers
        sequence_output, all_attention_probs = self.encoder(
            hidden_states, attention_mask
        )

        # Pool [CLS] token
        pooled_output = torch.tanh(self.pooler(sequence_output[:, 0, :]))

        return {
            "last_hidden_state": sequence_output,
            "pooler_output": pooled_output,
            "attention_probs": all_attention_probs,
        }
Out[25]:
Console
Total parameters: 109,482,240

Input shape: torch.Size([2, 32])
Last hidden state shape: torch.Size([2, 32, 768])
Pooler output shape: torch.Size([2, 768])
Number of attention layers: 12

Our implementation produces approximately 85 million parameters, which is lower than the full 110M of BERT-Base because we omit the masked language modeling head and some auxiliary components. The model processes a batch of 2 sequences with 32 tokens each, producing contextualized representations at every position plus a pooled representation for sequence-level tasks.

Practical Considerations

When deploying BERT, two factors dominate resource requirements: memory consumption and computational cost. Both scale with sequence length, making long documents particularly expensive to process.

Memory Requirements

BERT's memory usage during inference scales with sequence length squared due to the attention mechanism. For a batch of sequences, the dominant terms are:

  • Embeddings: O(batch×seq×d)O(\text{batch} \times \text{seq} \times d)
  • Attention scores: O(batch×heads×seq2)O(\text{batch} \times \text{heads} \times \text{seq}^2)
  • Intermediate activations: O(batch×seq×4d)O(\text{batch} \times \text{seq} \times 4d)

For BERT-Base with a sequence length of 512 and batch size of 1, the attention scores alone require approximately 12 × 512 × 512 × 4 bytes (float32) = 12.6 MB per layer, or about 150 MB for all 12 layers.

Out[26]:
Visualization
Line plot showing memory increasing quadratically as sequence length grows from 128 to 512.
BERT memory usage scales quadratically with sequence length due to the attention mechanism, making long sequences significantly more expensive to process.

Computational Complexity

The computational complexity of BERT is dominated by three operations:

  1. Attention: O(n2d)O(n^2 \cdot d) for sequence length nn and dimension dd
  2. Feed-forward: O(nd2)O(n \cdot d^2) with the 4x expansion
  3. Embeddings: O(nd)O(n \cdot d)

For typical sequence lengths (128-512), attention and feed-forward costs are comparable. The quadratic attention cost becomes prohibitive only for very long sequences, motivating efficient attention variants like Longformer and BigBird.

Limitations and Impact

BERT's architecture introduced several constraints that its successors have worked to address. The fixed sequence length of 512 tokens limits document-level understanding; longer documents must be chunked and processed separately, losing cross-chunk context. The quadratic attention complexity makes extending this limit computationally expensive. Later models like Longformer use sparse attention patterns to handle 4,096+ tokens efficiently.

The pretrain-then-finetune paradigm, while successful, requires task-specific training data and separate models for each task. This limitation motivated research into prompt-based methods where a single model handles multiple tasks through careful input formatting. BERT also cannot generate text autoregressively, restricting it to discriminative tasks. The [CLS] token aggregation assumes that a single vector can capture document-level meaning, which may oversimplify for long or complex texts.

Despite these limitations, BERT's architectural choices proved remarkably effective. The bidirectional attention mechanism captures context that autoregressive models miss. The three-embedding input representation elegantly handles sentence pairs. The standardized output format enables easy adaptation to diverse tasks. These design decisions established patterns that influenced nearly every subsequent language model.

BERT's release in late 2018 catalyzed a transformation in NLP. Within months, BERT or BERT-derived models topped leaderboards for question answering (SQuAD), natural language inference (MNLI), sentiment analysis (SST-2), and many other benchmarks. The pretrain-finetune paradigm became standard practice, and "BERT" became shorthand for transformer-based language understanding. The architecture we've examined in this chapter, while not the final word in language model design, remains a foundational reference point for understanding modern NLP.

Key Parameters

The BERT architecture is defined by a small set of core hyperparameters that determine model capacity, memory usage, and computational cost:

  • hidden_size (768 for Base, 1024 for Large): The dimensionality of token representations throughout the model. Larger values increase capacity but quadratically increase attention computation costs.

  • num_layers (12 for Base, 24 for Large): The number of stacked transformer encoder blocks. More layers enable more complex feature hierarchies but increase memory and compute linearly.

  • num_heads (12 for Base, 16 for Large): The number of parallel attention heads. More heads allow the model to attend to different aspects of the input simultaneously. The head dimension is typically hidden_size / num_heads.

  • intermediate_size (3072 for Base, 4096 for Large): The hidden dimension of the feed-forward network, typically 4× hidden_size. This expansion allows the FFN to learn complex transformations.

  • max_position (512): The maximum sequence length the model can process. Longer sequences require more memory due to quadratic attention complexity.

  • vocab_size (30,522 for BERT): The number of unique tokens in the vocabulary. Larger vocabularies reduce out-of-vocabulary issues but increase embedding parameter count.

  • dropout (0.1): Applied to attention weights, feed-forward outputs, and embeddings during training to prevent overfitting. Set to 0 during inference.

When adapting BERT for specific applications, the most impactful parameters are max_position (for document length requirements) and dropout (for controlling overfitting on small datasets).

Summary

BERT's architecture combines familiar transformer components in a configuration optimized for language understanding. The key architectural elements include:

  • Two model sizes: BERT-Base (110M parameters, 12 layers) and BERT-Large (340M parameters, 24 layers) balancing capability against deployment constraints
  • Three embedding types: Token, position, and segment embeddings summed together enable rich input representation including sentence pairs
  • Bidirectional attention: Unlike autoregressive models, every token attends to every other token, capturing full context for understanding tasks
  • Flexible outputs: The [CLS] token representation supports classification, token representations enable sequence labeling, and span predictions handle extractive QA

The architecture's success established the pretrain-finetune paradigm that dominated NLP through the early 2020s. While subsequent models have extended and improved upon BERT's design, understanding this architecture provides essential foundation for comprehending the evolution of language models.

Quiz

Ready to test your understanding? Take this quick quiz to reinforce what you've learned about BERT architecture.

Loading component...
Track your reading progress

Sign in to mark chapters as read and track your learning journey

Sign in →

Comments

Reference

BIBTEXAcademic
@misc{bertarchitecturedeepdiveintomodelstructureandcomponents, author = {Michael Brenndoerfer}, title = {BERT Architecture: Deep Dive into Model Structure and Components}, year = {2025}, url = {https://mbrenndoerfer.com/writing/bert-architecture-model-structure-components}, organization = {mbrenndoerfer.com}, note = {Accessed: 2025-12-19} }
APAAcademic
Michael Brenndoerfer (2025). BERT Architecture: Deep Dive into Model Structure and Components. Retrieved from https://mbrenndoerfer.com/writing/bert-architecture-model-structure-components
MLAAcademic
Michael Brenndoerfer. "BERT Architecture: Deep Dive into Model Structure and Components." 2025. Web. 12/19/2025. <https://mbrenndoerfer.com/writing/bert-architecture-model-structure-components>.
CHICAGOAcademic
Michael Brenndoerfer. "BERT Architecture: Deep Dive into Model Structure and Components." Accessed 12/19/2025. https://mbrenndoerfer.com/writing/bert-architecture-model-structure-components.
HARVARDAcademic
Michael Brenndoerfer (2025) 'BERT Architecture: Deep Dive into Model Structure and Components'. Available at: https://mbrenndoerfer.com/writing/bert-architecture-model-structure-components (Accessed: 12/19/2025).
SimpleBasic
Michael Brenndoerfer (2025). BERT Architecture: Deep Dive into Model Structure and Components. https://mbrenndoerfer.com/writing/bert-architecture-model-structure-components
Michael Brenndoerfer

About the author: Michael Brenndoerfer

All opinions expressed here are my own and do not reflect the views of my employer.

Michael currently works as an Associate Director of Data Science at EQT Partners in Singapore, leading AI and data initiatives across private capital investments.

With over a decade of experience spanning private equity, management consulting, and software engineering, he specializes in building and scaling analytics capabilities from the ground up. He has published research in leading AI conferences and holds expertise in machine learning, natural language processing, and value creation through data.

Stay updated

Get notified when I publish new articles on data and AI, private equity, technology, and more.

No spam, unsubscribe anytime.

or

Create a free account to unlock exclusive features, track your progress, and join the conversation.

No popupsUnobstructed readingCommenting100% Free