Search

Search articles

Weight Tying: Sharing Embeddings Between Input and Output Layers

Michael BrenndoerferUpdated June 17, 202531 min read

Learn how weight tying reduces transformer parameters by sharing the input embedding and output projection matrices. Covers the theoretical justification, implementation details, encoder-decoder tying, and when to use this technique.

Track your reading progress

Sign in to mark chapters as read and track your learning journey

Sign in →
Reading Level

Choose your expertise level to adjust how many terms are explained. Beginners see more tooltips, experts see fewer to maintain reading flow. Hover over underlined terms for instant definitions.

Weight Tying

Language models contain two large embedding matrices that seem to serve different purposes: one converts input tokens into vectors, and another converts output vectors back into token probabilities. But these matrices are surprisingly similar in structure, both mapping between the same vocabulary and the same hidden dimension. Weight tying exploits this similarity by making them literally the same matrix, cutting parameter count and often improving model quality.

In this chapter, you'll learn how weight tying works, why it makes theoretical sense, and when to use it in your own models. We'll implement tied embeddings from scratch and examine the practical considerations that determine whether tying helps or hurts.

The Two Embedding Matrices

Before diving into weight tying, let's be clear about what we're tying together. A language model has two distinct embedding operations that bookend the transformer layers.

The input embedding matrix ERV×d\mathbf{E} \in \mathbb{R}^{V \times d} converts discrete token indices into continuous vectors. Each row contains the learned embedding for one vocabulary token:

hin=E[t]\mathbf{h}_{\text{in}} = \mathbf{E}[t]

where:

  • E\mathbf{E}: the input embedding matrix of shape V×dV \times d
  • tt: the input token index (an integer from 0 to V1V - 1)
  • hin\mathbf{h}_{\text{in}}: the resulting embedding vector of dimension dd
  • VV: the vocabulary size
  • dd: the embedding/hidden dimension

This is a simple lookup operation. Token 42 retrieves row 42 from the matrix.

The output projection matrix WoutRV×d\mathbf{W}_{\text{out}} \in \mathbb{R}^{V \times d} (also called the "LM head") does the reverse. It takes the transformer's final hidden state and produces logits for each vocabulary token:

z=Wouthout\mathbf{z} = \mathbf{W}_{\text{out}} \mathbf{h}_{\text{out}}

where:

  • z\mathbf{z}: the output logits vector of length VV, one score per vocabulary token
  • Wout\mathbf{W}_{\text{out}}: the output projection matrix of shape V×dV \times d
  • hout\mathbf{h}_{\text{out}}: the final hidden state from the transformer, a column vector of dimension dd

Each row of Wout\mathbf{W}_{\text{out}} contains the "output embedding" for one vocabulary token. The matrix multiplication computes the dot product between the hidden state and each token's output embedding, yielding a score for every token in the vocabulary.

After applying softmax to these logits, we get a probability distribution over the vocabulary for the next token prediction.

Notice the dimensional symmetry. Both the input embedding and output projection have shape V×dV \times d, with each row corresponding to one vocabulary token. This structural identity suggests they might be doing fundamentally similar things, which is exactly the insight behind weight tying.

In[2]:
Code
import torch.nn as nn

# Typical model dimensions
vocab_size = 50000
hidden_dim = 768

# Without weight tying: two separate matrices
input_embedding = nn.Embedding(vocab_size, hidden_dim)
output_projection = nn.Linear(hidden_dim, vocab_size, bias=False)

# Count parameters
input_params = vocab_size * hidden_dim
output_params = hidden_dim * vocab_size
total_untied = input_params + output_params
Out[3]:
Console
Parameter count without weight tying:
  Input embedding:      38,400,000 parameters
  Output projection:    38,400,000 parameters
  Total:                76,800,000 parameters

Both matrices have identical size: 50,000 × 768

With a vocabulary of 50,000 tokens and hidden dimension of 768, each matrix contains over 38 million parameters. Together, they account for nearly 77 million parameters, often a substantial fraction of smaller models.

The Weight Tying Idea

We've established that language models maintain two large matrices: one for reading tokens (input embedding) and one for generating them (output projection). Both have identical dimensions, V×dV \times d, where each row represents one vocabulary token. This structural symmetry hints at a deeper connection.

Think about what these matrices actually encode. The input embedding learns "what does this token mean when I read it?" The output projection learns "what hidden state pattern should produce this token?" But these are really two perspectives on the same question: what is the semantic identity of this token within the model's learned representation space?

The Core Insight: One Matrix, Two Roles

Weight tying formalizes this intuition by collapsing both matrices into one:

Wout=E\mathbf{W}_{\text{out}} = \mathbf{E}

where:

  • Wout\mathbf{W}_{\text{out}}: the output projection matrix, now identical to E\mathbf{E}
  • E\mathbf{E}: the input embedding matrix of shape V×dV \times d

This single equation halves the embedding-related parameters. But what does it mean computationally? Let's trace through the math to understand how the shared matrix serves both roles.

From Intuition to Formula

When a token enters the model, we look up its embedding. This is unchanged:

hin=E[t]\mathbf{h}_{\text{in}} = \mathbf{E}[t]

The token at index tt retrieves row tt from the matrix. Simple table lookup.

When the transformer produces its final hidden state hout\mathbf{h}_{\text{out}}, we need to convert this dd-dimensional vector into scores for all VV vocabulary tokens. With weight tying, we compute:

z=Ehout\mathbf{z} = \mathbf{E} \mathbf{h}_{\text{out}}

This matrix multiplication produces a vector of VV logits. But what's actually happening inside this operation? Let's expand it to see the underlying mechanism.

The Dot Product as Similarity

Each logit ziz_i is computed by taking the dot product between the hidden state and the embedding of token ii:

zi=eihout=j=1dei,jhjz_i = \mathbf{e}_i \cdot \mathbf{h}_{\text{out}} = \sum_{j=1}^{d} e_{i,j} \cdot h_j

where:

  • ziz_i: the logit (unnormalized score) for token ii, a single real number
  • ei\mathbf{e}_i: the embedding vector for token ii (row ii of E\mathbf{E}), with dd components: ei,1,ei,2,,ei,de_{i,1}, e_{i,2}, \ldots, e_{i,d}
  • hout\mathbf{h}_{\text{out}}: the final hidden state vector, with dd components: h1,h2,,hdh_1, h_2, \ldots, h_d
  • dd: the embedding dimension

The dot product measures alignment. Two vectors pointing in the same direction yield a large positive value. Orthogonal vectors yield zero. Opposite directions yield large negative values.

This creates a compelling interpretation: the model predicts the next token by finding which embedding best matches its internal representation. When the transformer processes "The cat sat on the ___", it generates a hidden state that should be similar to the embedding of "mat" (or "floor" or "couch"). The dot product with each vocabulary token quantifies this similarity, and softmax converts these similarities into probabilities:

P(tokenicontext)=exp(zi)k=1Vexp(zk)P(\text{token}_i | \text{context}) = \frac{\exp(z_i)}{\sum_{k=1}^{V} \exp(z_k)}

Why This Works

The elegance of weight tying lies in what it forces the model to learn. Without tying, the input embedding for "cat" could be completely unrelated to the output pattern for "cat". The model might use one representation for reading and an entirely different one for generating.

With tying, these must be the same. The embedding that represents "cat" when reading is exactly the target pattern the model must produce when generating "cat". This constraint creates coherence: the model develops a unified semantic space where reading and writing use consistent representations.

Implementation Note

In PyTorch, the embedding matrix has shape (V,d)(V, d) and hidden states typically have shape (B,S,d)(B, S, d) where BB is batch size and SS is sequence length. To compute logits, we use hidden_states @ embedding.weight.T, which computes dot products between each hidden state and each embedding row, yielding output shape (B,S,V)(B, S, V).

Implementation

Translating this math into code is straightforward. We create a single embedding matrix and use it for both encoding (lookup) and decoding (matrix multiplication):

In[4]:
Code
class TiedEmbeddings(nn.Module):
    """Language model head with tied input and output embeddings."""

    def __init__(self, vocab_size, hidden_dim):
        super().__init__()
        # Single shared embedding matrix
        self.embedding = nn.Embedding(vocab_size, hidden_dim)

    def encode(self, token_ids):
        """Convert token IDs to embeddings."""
        return self.embedding(token_ids)

    def decode(self, hidden_states):
        """Project hidden states to vocabulary logits."""
        # Use embedding weight matrix transposed
        return torch.matmul(hidden_states, self.embedding.weight.T)


# Create tied embeddings
tied_model = TiedEmbeddings(vocab_size, hidden_dim)
tied_params = vocab_size * hidden_dim
Out[5]:
Console
Parameter count with weight tying:
  Shared embedding:     38,400,000 parameters

Parameter reduction: 38,400,000 parameters (50% savings)

The parameter savings are substantial. We've eliminated one entire V×dV \times d matrix, cutting embedding-related parameters in half. For models with large vocabularies or smaller hidden dimensions, this reduction represents a significant fraction of total model size.

Why Weight Tying Makes Sense

Weight tying isn't just a memory optimization. There's a theoretical justification for why input and output embeddings should be related.

Consider what each embedding represents:

  • Input embedding: Encodes a token's meaning so the model can process it
  • Output embedding: Defines what hidden state pattern should produce this token

Both embeddings capture the same fundamental question: "What does this token mean in the context of this model?" A token's input representation should be similar to the output pattern that generates it.

Think about the word "cat." When you read it (input), you activate a certain semantic representation. When you want to produce it (output), you need to generate a hidden state that matches that same representation. It would be strange if the model's conception of "cat" for reading was completely different from its conception for writing.

Distributional Hypothesis Connection

Weight tying aligns with the distributional hypothesis: words that appear in similar contexts have similar meanings. The input embedding captures what contexts a word appears in; the output embedding captures what words appear in a given context. These are two sides of the same distributional coin.

Empirical Evidence

Research has consistently shown that weight tying helps rather than hurts:

  • Press & Wolf (2017) demonstrated that tying input and output embeddings improves perplexity in language models, despite having fewer parameters
  • Inan et al. (2017) showed similar results and analyzed the theoretical connections
  • Most modern language models, including GPT-2, BERT, and their descendants, use weight tying by default

The improvement isn't just about regularization from having fewer parameters. Tied embeddings actually learn better representations because gradients from the output loss flow directly into the input embeddings, and vice versa.

Implementation Details

Let's implement a complete language model head with weight tying, handling the practical details that matter in real systems.

In[6]:
Code
class LanguageModelHead(nn.Module):
    """
    Complete language model embedding layer with weight tying.

    Handles token embedding, positional encoding, and output projection
    using a single shared vocabulary embedding matrix.
    """

    def __init__(
        self, vocab_size, hidden_dim, max_seq_len, dropout=0.1, tie_weights=True
    ):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.tie_weights = tie_weights

        # Token embeddings (always needed)
        self.token_embedding = nn.Embedding(vocab_size, hidden_dim)

        # Position embeddings (not tied)
        self.position_embedding = nn.Embedding(max_seq_len, hidden_dim)

        # Output projection: tied or separate
        if tie_weights:
            self.output_projection = None  # Will use token_embedding.weight
        else:
            self.output_projection = nn.Linear(
                hidden_dim, vocab_size, bias=False
            )

        self.dropout = nn.Dropout(dropout)

        # Initialize embeddings
        self._init_weights()

    def _init_weights(self):
        """Initialize embedding weights."""
        nn.init.normal_(self.token_embedding.weight, mean=0.0, std=0.02)
        nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
        if self.output_projection is not None:
            nn.init.normal_(self.output_projection.weight, mean=0.0, std=0.02)

    def embed(self, token_ids):
        """Convert token IDs to embeddings with position information."""
        batch_size, seq_len = token_ids.shape

        # Token embeddings
        tok_emb = self.token_embedding(token_ids)

        # Position embeddings
        positions = torch.arange(seq_len, device=token_ids.device)
        pos_emb = self.position_embedding(positions)

        # Combine and apply dropout
        return self.dropout(tok_emb + pos_emb)

    def project(self, hidden_states):
        """Project hidden states to vocabulary logits."""
        if self.tie_weights:
            # Use shared embedding matrix
            return torch.matmul(hidden_states, self.token_embedding.weight.T)
        else:
            return self.output_projection(hidden_states)


# Compare tied vs untied
model_tied = LanguageModelHead(
    vocab_size, hidden_dim, max_seq_len=512, tie_weights=True
)
model_untied = LanguageModelHead(
    vocab_size, hidden_dim, max_seq_len=512, tie_weights=False
)


def count_parameters(model):
    return sum(p.numel() for p in model.parameters())
Out[7]:
Console
Parameter comparison:
  Tied weights:     38,793,216 parameters
  Untied weights:   77,193,216 parameters
  Difference:       38,400,000 parameters

The tied model saves 49.7% of embedding-related parameters

The key implementation detail is in the project method. With tied weights, we directly use self.token_embedding.weight.T instead of a separate projection layer. This ensures that gradients flow through the same parameters during both forward and backward passes.

Verifying the Tying

Let's verify that our implementation actually shares parameters correctly:

In[8]:
Code
import torch

# Test the tied model
test_tokens = torch.randint(0, vocab_size, (2, 10))  # Batch of 2, length 10

# Get embeddings
embeddings = model_tied.embed(test_tokens)

# Simulate transformer processing (just use embeddings as output for demo)
hidden_output = embeddings

# Get logits
logits = model_tied.project(hidden_output)

# Verify gradient flow
loss = logits.sum()
loss.backward()
Out[9]:
Console
Input shape: torch.Size([2, 10])
Embedding shape: torch.Size([2, 10, 768])
Output logits shape: torch.Size([2, 10, 50000])

Token embedding gradient shape: torch.Size([50000, 768])
Gradient flows through shared weights: True

The gradient from the output loss flows directly into the token embedding matrix, confirming that the tying works correctly.

A Worked Example

The formulas above describe weight tying abstractly, but seeing the mechanism in action makes it concrete. Let's build a tiny language model and trace through exactly how tied embeddings convert a hidden state into token probabilities.

Setting Up a Toy Vocabulary

We'll work with a vocabulary of just seven words, using 4-dimensional embeddings. These small numbers let us inspect every value and understand exactly what's happening at each step.

In[10]:
Code
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

# Create a tiny vocabulary for illustration
tiny_vocab = ["the", "cat", "sat", "on", "mat", "dog", "ran"]
tiny_size = len(tiny_vocab)
tiny_dim = 4

# Initialize with small, interpretable embeddings
np.random.seed(42)
torch.manual_seed(42)


# Create tied embedding model
class TinyLM(nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding = nn.Embedding(tiny_size, tiny_dim)
        # Initialize with small random values
        nn.init.uniform_(self.embedding.weight, -0.5, 0.5)

    def forward(self, hidden_state):
        """Get probabilities for next token given hidden state."""
        logits = torch.matmul(hidden_state, self.embedding.weight.T)
        return F.softmax(logits, dim=-1)


model = TinyLM()
Out[11]:
Console
Vocabulary: ['the', 'cat', 'sat', 'on', 'mat', 'dog', 'ran']

Embedding matrix (7 tokens × 4 dims):
---------------------------------------------
  the   : [+0.289, -0.219, +0.289, +0.089]
  cat   : [+0.254, -0.305, -0.495, -0.193]
  sat   : [-0.384, +0.410, +0.144, +0.207]
  on    : [+0.158, -0.009, +0.391, -0.355]
  mat   : [+0.031, -0.341, +0.154, -0.172]
  dog   : [+0.153, -0.104, +0.415, -0.296]
  ran   : [-0.298, -0.298, +0.450, +0.167]

Each word has a 4-dimensional embedding vector. These are randomly initialized, as they would be before training. In a trained model, semantically similar words would have similar embeddings.

The Prediction Mechanism

Now comes the key insight. Suppose the transformer has processed some context and produced a final hidden state. With weight tying, we predict the next token by computing dot products between this hidden state and every embedding in our vocabulary.

Let's simulate this. We'll create a hidden state that's similar to the "cat" embedding, as if the model were about to predict "cat" as the next word:

In[12]:
Code
# Suppose the transformer outputs a hidden state similar to "cat"'s embedding
cat_embedding = model.embedding.weight[1].detach()  # Index 1 = "cat"

# Create a hidden state that's close to "cat" but with some noise
hidden_state = cat_embedding + torch.randn(tiny_dim) * 0.1
hidden_state = hidden_state.unsqueeze(0)  # Add batch dimension

# Get next token probabilities
probs = model(hidden_state).squeeze()

The hidden state is intentionally constructed to be similar to "cat"'s embedding, with a small amount of noise added. In a real model, the transformer layers would produce this hidden state based on the input context.

Tracing the Dot Products

Now let's examine what happens inside the forward pass. For each vocabulary token, we compute the dot product between the hidden state and that token's embedding:

Out[13]:
Console
Hidden state (similar to 'cat' embedding):
  [ 0.26889548 -0.32564193 -0.5336563  -0.09405649]

Dot products with each token embedding:
----------------------------------------
  the   : -0.0135  →  P = 0.1434
  cat   : +0.4498  →  P = 0.2279
  sat   : -0.3331  →  P = 0.1042
  on    : -0.1301  →  P = 0.1276
  mat   : +0.0535  →  P = 0.1533
  dog   : -0.1183  →  P = 0.1291
  ran   : -0.2387  →  P = 0.1145

Most likely next token: 'cat'

The dot product reveals the alignment between the hidden state and each embedding. "Cat" receives the highest score because its embedding is most similar to the hidden state we constructed. After softmax normalization, this translates to the highest probability.

This is the heart of weight tying: the same embedding that represents "cat" for input is also the target pattern the model aims to produce when generating "cat". The transformer's job is to transform the input context into a hidden state that aligns with the correct next token's embedding.

Visualizing Embedding Similarity

To better understand how the dot product works as a similarity measure, let's visualize the pairwise similarities between all embeddings in our tiny vocabulary along with the hidden state:

Out[14]:
Visualization
Heatmap showing dot product similarities between seven token embeddings and a hidden state vector.
Pairwise dot product similarities between token embeddings and the hidden state. The hidden state (bottom row) shows highest similarity with 'cat', which is expected since we constructed it to be close to the cat embedding. The diagonal tends to show high values because each embedding is perfectly aligned with itself.

The heatmap reveals the structure of our embedding space. The bottom row (and rightmost column) shows how similar the hidden state is to each vocabulary embedding. Since we constructed the hidden state to resemble "cat", that cell shows the highest similarity in the hidden row. The diagonal entries are all positive because each embedding has positive self-similarity. Off-diagonal entries show how similar different tokens are to each other in the learned representation space.

Scaling Considerations

Weight tying becomes more impactful as vocabulary size grows relative to model depth. Let's analyze how the savings scale:

In[15]:
Code
def analyze_weight_tying_impact(
    vocab_size, hidden_dim, num_layers, d_ff_multiplier=4
):
    """Calculate what fraction of parameters weight tying saves."""

    # Embedding parameters
    embedding_params = vocab_size * hidden_dim
    tied_savings = embedding_params  # One matrix instead of two

    # Per-layer transformer parameters (approximate)
    attention_params = 4 * hidden_dim * hidden_dim  # Q, K, V, O projections
    ffn_params = 2 * hidden_dim * (hidden_dim * d_ff_multiplier)  # Up and down
    layer_norm_params = 4 * hidden_dim  # Two layer norms
    layer_params = attention_params + ffn_params + layer_norm_params

    # Total model parameters
    total_layers = num_layers * layer_params
    total_untied = (
        2 * embedding_params + total_layers
    )  # Input + output embeddings
    total_tied = embedding_params + total_layers  # Shared embedding

    return {
        "embedding_params": embedding_params,
        "layer_params": total_layers,
        "total_untied": total_untied,
        "total_tied": total_tied,
        "savings": total_untied - total_tied,
        "savings_pct": 100 * (total_untied - total_tied) / total_untied,
    }


# Analyze different model configurations
configs = [
    ("Small (GPT-2)", 50257, 768, 12),
    ("Medium", 50257, 1024, 24),
    ("Large", 50257, 1280, 36),
    ("XL", 50257, 1600, 48),
    ("Large vocab", 128000, 1024, 24),  # Like newer models
]
Out[16]:
Console
Weight Tying Impact Analysis
=====================================================================================
Config                    Vocab   Hidden  Layers         Savings % of Model
-------------------------------------------------------------------------------------
Small (GPT-2)            50,257      768      12      38,597,376      23.8%
Medium                   50,257     1024      24      51,463,168      12.7%
Large                    50,257     1280      36      64,328,960       7.7%
XL                       50,257     1600      48      80,411,200       4.9%
Large vocab             128,000     1024      24     131,072,000      23.2%

For smaller models, weight tying can save 10-20% of total parameters. As models grow deeper, the relative savings decrease because transformer layers dominate, but the absolute parameter savings remain substantial.

To understand why weight tying matters more for some models than others, let's visualize how parameters are distributed:

Out[17]:
Visualization
Stacked bar chart showing parameter distribution across embedding and transformer layers for five model configurations.
Parameter distribution in different model configurations. Smaller models allocate a larger fraction of their parameters to embeddings, making weight tying more impactful. Larger models are dominated by transformer layer parameters.

The visualization makes the trade-off clear. In GPT-2 Small, embeddings consume over 15% of parameters, so weight tying provides substantial savings. In the XL configuration, embeddings are less than 5% of the model. The "Large vocab" case is interesting: despite having more layers, the 128K vocabulary pushes embedding costs back up.

Out[18]:
Visualization
Bar chart showing absolute parameter savings in millions for five model configurations.
Absolute parameter savings from weight tying across different model scales. Larger vocabularies yield more savings in absolute terms.
Bar chart showing percentage parameter savings for five model configurations.
Relative parameter savings as a percentage of total model parameters. Weight tying provides larger relative benefits for smaller models.

The "Large vocab" configuration shows an interesting pattern: models with bigger vocabularies (like those using 128K+ token vocabularies for multilingual support) benefit more from weight tying because the embedding matrix is a larger fraction of total parameters.

Encoder-Decoder Weight Tying

So far we've discussed tying input and output embeddings within a single model. Encoder-decoder architectures offer additional tying opportunities.

In a sequence-to-sequence model like T5 or BART, you have:

  • Encoder input embeddings
  • Decoder input embeddings
  • Decoder output embeddings

Three separate matrices, all with the same shape (V×d)(V \times d), mapping between the same vocabulary and hidden dimension. Research has shown that tying all three together works well:

Eenc=Edec=Wout\mathbf{E}_{\text{enc}} = \mathbf{E}_{\text{dec}} = \mathbf{W}_{\text{out}}

where:

  • Eenc\mathbf{E}_{\text{enc}}: the encoder input embedding matrix of shape V×dV \times d
  • Edec\mathbf{E}_{\text{dec}}: the decoder input embedding matrix of shape V×dV \times d
  • Wout\mathbf{W}_{\text{out}}: the decoder output projection matrix of shape V×dV \times d
  • VV: the shared vocabulary size
  • dd: the hidden dimension

This three-way tying means a single learned embedding serves all three roles: encoding source tokens, encoding target tokens during teacher forcing, and defining the output distribution over the vocabulary.

In[19]:
Code
class EncoderDecoderEmbeddings(nn.Module):
    """
    Shared embeddings for encoder-decoder architecture.

    All three embedding matrices are tied together.
    """

    def __init__(self, vocab_size, hidden_dim, tie_all=True):
        super().__init__()
        self.tie_all = tie_all

        # Shared embedding (used by all three roles)
        self.shared_embedding = nn.Embedding(vocab_size, hidden_dim)

        if not tie_all:
            # Separate embeddings if not tying
            self.encoder_embedding = nn.Embedding(vocab_size, hidden_dim)
            self.decoder_embedding = nn.Embedding(vocab_size, hidden_dim)
            self.output_projection = nn.Linear(
                hidden_dim, vocab_size, bias=False
            )

    def encode_input(self, token_ids):
        """Embed encoder input tokens."""
        if self.tie_all:
            return self.shared_embedding(token_ids)
        return self.encoder_embedding(token_ids)

    def decode_input(self, token_ids):
        """Embed decoder input tokens."""
        if self.tie_all:
            return self.shared_embedding(token_ids)
        return self.decoder_embedding(token_ids)

    def project_output(self, hidden_states):
        """Project decoder hidden states to vocabulary."""
        if self.tie_all:
            return torch.matmul(hidden_states, self.shared_embedding.weight.T)
        return self.output_projection(hidden_states)


# Compare parameter counts
enc_dec_tied = EncoderDecoderEmbeddings(vocab_size, hidden_dim, tie_all=True)
enc_dec_untied = EncoderDecoderEmbeddings(vocab_size, hidden_dim, tie_all=False)
Out[20]:
Console
Encoder-Decoder Embedding Parameters:
  Fully tied:     38,400,000 (1 matrix)
  Untied:        153,600,000 (3 matrices)
  Savings:       115,200,000 (75%)

For encoder-decoder models, full weight tying eliminates two-thirds of embedding parameters. T5, one of the most successful encoder-decoder transformers, uses this three-way tying by default.

Effects on Training Dynamics

Weight tying doesn't just reduce parameters. It changes how the model learns.

Gradient Flow

With tied weights, the embedding matrix receives gradients from two sources:

  1. Input gradients: Backpropagated through the transformer from the loss
  2. Output gradients: Direct gradients from the output projection

This double gradient flow can be viewed as implicit multi-task learning. The embedding must simultaneously satisfy two objectives: representing tokens well for input processing, and providing good targets for output prediction.

In[21]:
Code
class GradientTracker(nn.Module):
    """Track gradient magnitudes from different sources."""

    def __init__(self, vocab_size, hidden_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, hidden_dim)
        self.fc = nn.Linear(hidden_dim, hidden_dim)  # Simulates transformer

    def forward(self, token_ids, return_hidden=False):
        # Input path
        x = self.embedding(token_ids)
        hidden = self.fc(x)

        # Output path (tied)
        logits = torch.matmul(hidden, self.embedding.weight.T)

        if return_hidden:
            return logits, hidden
        return logits


# Track gradients
model = GradientTracker(1000, 64)
tokens = torch.randint(0, 1000, (4, 16))
labels = torch.randint(0, 1000, (4, 16))

logits = model(tokens)
loss = F.cross_entropy(logits.view(-1, 1000), labels.view(-1))
loss.backward()

grad_magnitude = model.embedding.weight.grad.abs().mean()
Out[22]:
Console
Mean gradient magnitude on tied embedding: 0.001345

This gradient combines contributions from both:
  - Forward pass through the input embedding
  - Backward pass through the output projection

The gradient magnitude shows how much the embedding weights would change in a single training step (before learning rate scaling). With tied weights, this gradient is typically larger than it would be with separate embeddings because it aggregates signals from both the input and output paths. This can speed up learning for rare tokens that might otherwise receive sparse gradient updates.

Embedding Scale

A subtle issue arises with weight tying: the optimal scale for input embeddings may differ from the optimal scale for output projections.

Input embeddings are often scaled by d\sqrt{d} before being added to positional encodings (following the original transformer paper). The scaling factor is:

hscaled=dE[t]\mathbf{h}_{\text{scaled}} = \sqrt{d} \cdot \mathbf{E}[t]

where:

  • hscaled\mathbf{h}_{\text{scaled}}: the scaled input embedding
  • dd: the embedding dimension
  • E[t]\mathbf{E}[t]: the raw embedding lookup for token tt
  • d\sqrt{d}: the scaling factor, which counterbalances the variance reduction that occurs when embeddings are initialized with small values

Without this scaling, embeddings initialized with small weights would be overwhelmed by positional encodings. But if we apply this scaling directly to the weight matrix, it would also affect the output logits when using tied weights, potentially making them too large.

Modern implementations handle this by applying scaling at input time rather than modifying the embedding matrix:

In[23]:
Code
class ScaledTiedEmbeddings(nn.Module):
    """
    Tied embeddings with proper scaling.

    Input embeddings are scaled up, but output projection uses raw weights.
    """

    def __init__(self, vocab_size, hidden_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.embedding = nn.Embedding(vocab_size, hidden_dim)
        self.scale = np.sqrt(hidden_dim)

    def embed(self, token_ids):
        """Scaled embedding for input."""
        return self.embedding(token_ids) * self.scale

    def project(self, hidden_states):
        """Unscaled projection for output."""
        return torch.matmul(hidden_states, self.embedding.weight.T)


scaled_model = ScaledTiedEmbeddings(vocab_size, hidden_dim)
test_input = torch.randint(0, vocab_size, (1, 10))

# Forward pass
input_embeds = scaled_model.embed(test_input)
# ... transformer processing would happen here ...
output_logits = scaled_model.project(
    input_embeds
)  # Using embeds as proxy for hidden state
Out[24]:
Console
Input embedding scale factor: 27.71
Input embedding norm (per token): 757.44

The scale is applied at input time, not in the embedding matrix itself,
so the output projection sees the unscaled embeddings.

The scale factor of approximately 27.7 (for hidden dimension 768) significantly amplifies the input embeddings. By applying this scaling at runtime rather than baking it into the weights, we decouple the input and output requirements. The transformer processes scaled embeddings while the output projection uses raw embeddings, allowing each path to operate at its optimal scale.

When to Tie Weights

Weight tying isn't always the right choice. Here are the key considerations:

Tie weights when:

  • Your vocabulary is large relative to model depth
  • You want to reduce memory footprint without major architectural changes
  • You're training from scratch and can let the model adapt
  • Input and output domains are the same (e.g., language modeling)

Consider untied weights when:

  • Input and output vocabularies differ (e.g., translation between different tokenizers)
  • You're fine-tuning a pre-trained model that was trained without tying
  • Model capacity is more important than parameter efficiency
  • You observe that tied embeddings underperform in your specific task
In[25]:
Code
def recommend_weight_tying(
    vocab_size, hidden_dim, num_layers, same_io_vocab=True
):
    """Simple heuristic for weight tying recommendation."""

    embedding_params = vocab_size * hidden_dim
    # Rough estimate of transformer layer params
    layer_params = 12 * hidden_dim * hidden_dim  # Approximate
    total_layer_params = num_layers * layer_params

    embedding_ratio = embedding_params / (
        total_layer_params + 2 * embedding_params
    )

    recommendation = {
        "tie_weights": same_io_vocab and embedding_ratio > 0.05,
        "embedding_ratio": embedding_ratio,
        "reasoning": [],
    }

    if not same_io_vocab:
        recommendation["reasoning"].append(
            "Different input/output vocabularies - cannot tie"
        )
    elif embedding_ratio > 0.15:
        recommendation["reasoning"].append(
            "Embeddings are >15% of model - tying highly recommended"
        )
    elif embedding_ratio > 0.05:
        recommendation["reasoning"].append(
            "Embeddings are 5-15% of model - tying recommended"
        )
    else:
        recommendation["reasoning"].append(
            "Embeddings are <5% of model - tying optional"
        )

    return recommendation
Out[26]:
Console
Weight Tying Recommendations:
=================================================================

GPT-2 Small:
  Embedding ratio: 23.8%
  Recommendation: ✓ Tie
  Reason: Embeddings are >15% of model - tying highly recommended

GPT-3 175B:
  Embedding ratio: 0.4%
  Recommendation: ✗ Don't tie
  Reason: Embeddings are <5% of model - tying optional

Multilingual MT:
  Embedding ratio: 23.2%
  Recommendation: ✗ Don't tie
  Reason: Different input/output vocabularies - cannot tie

The heuristic reveals an important pattern: smaller models like GPT-2 Small have embeddings that constitute a significant fraction of total parameters (over 15%), making weight tying highly impactful. For massive models like GPT-3 175B, embeddings are less than 1% of parameters, so tying provides minimal savings. However, even large models typically use weight tying because it rarely hurts performance and provides a small memory benefit. The multilingual translation case shows when tying is impossible: different input/output vocabularies require separate embedding spaces.

Limitations and Impact

Weight tying represents one of those elegant techniques where reducing complexity actually improves results. The constraint that input and output embeddings share the same learned representation forces the model to develop more coherent internal semantics.

The primary limitation is inflexibility. When input and output tasks require genuinely different token representations, tied weights create tension. Machine translation between languages with different scripts is the canonical example: the optimal encoding for reading Japanese may differ from the optimal target representation for generating English, even if both pass through the same vocabulary. In such cases, untied weights give the model freedom to specialize.

There's also a capacity argument. Very large models may benefit from the additional expressiveness of separate embeddings. When parameter count is less constrained, the regularization effect of tying matters less, and the model might learn better with independent representations. However, empirical evidence here is mixed, as even the largest models typically use tied weights.

Weight tying's impact extends beyond parameter efficiency. By forcing the model to reconcile input and output representations, it creates a more unified semantic space. This can improve generalization, especially for smaller models where every parameter must work harder. The technique has become standard practice in language modeling, appearing in GPT-2, BERT, T5, and most subsequent architectures. It's one of those design decisions that has become so universal that its presence is often assumed rather than stated.

Key Parameters

When implementing weight tying in your models, these are the key configuration choices:

  • tie_weights (bool): Whether to share the embedding matrix between input and output. Set to True for most language models; set to False when input and output vocabularies differ or when fine-tuning models trained without tying.

  • vocab_size (int): Size of the vocabulary. Larger vocabularies increase the parameter savings from tying. With 50K+ tokens, embedding matrices can dominate smaller models.

  • hidden_dim / d_model (int): The embedding and hidden dimension. This determines both the embedding matrix size (V×dV \times d) and the scale factor (d\sqrt{d}) used for input scaling.

  • embedding_scale (float): Scaling factor applied to input embeddings, typically d\sqrt{d}. Apply this at forward time rather than modifying the weight matrix to preserve unscaled embeddings for output projection.

  • bias (bool): Whether to include a bias term in the output projection. Most implementations set this to False when using tied weights, as the embedding matrix provides sufficient expressivity.

Summary

Weight tying exploits the structural symmetry between input embedding and output projection matrices in language models. Both matrices have shape V×dV \times d, where VV is vocabulary size and dd is the hidden dimension. Instead of maintaining separate matrices E\mathbf{E} for input and Wout\mathbf{W}_{\text{out}} for output, weight tying sets Wout=E\mathbf{W}_{\text{out}} = \mathbf{E}, so a single shared matrix serves both roles.

The key insights from this chapter:

  • Parameter reduction: Weight tying eliminates one full embedding matrix, saving V×dV \times d parameters. For a 50K vocabulary with 768-dimensional embeddings, that's 38 million fewer parameters.

  • Theoretical justification: Input and output embeddings both answer "what does this token mean?" Tying them forces a consistent semantic representation where the hidden state that produces a token is similar to that token's embedding.

  • Implementation: The output projection becomes a matrix multiplication with the transposed embedding matrix. Gradients flow to the shared weights from both input and output paths.

  • Encoder-decoder extension: In sequence-to-sequence models, all three embedding matrices (encoder input, decoder input, decoder output) can share weights, eliminating two-thirds of embedding parameters.

  • Training effects: Tied weights receive gradients from multiple sources, creating implicit multi-task learning. Care must be taken with embedding scaling to handle different requirements for input and output.

  • When to use: Weight tying is recommended when vocabularies are large relative to model depth and input/output domains match. Avoid it when vocabularies differ or when maximum capacity is more important than efficiency.

Modern language models nearly universally adopt weight tying, making it a fundamental architectural decision rather than an optimization. Understanding why it works helps you reason about the semantic structure these models learn.

Quiz

Ready to test your understanding? Take this quick quiz to reinforce what you've learned about weight tying in language models.

Loading component...
Track your reading progress

Sign in to mark chapters as read and track your learning journey

Sign in →

Comments

Reference

BIBTEXAcademic
@misc{weighttyingsharingembeddingsbetweeninputandoutputlayers, author = {Michael Brenndoerfer}, title = {Weight Tying: Sharing Embeddings Between Input and Output Layers}, year = {2025}, url = {https://mbrenndoerfer.com/writing/weight-tying-shared-embeddings-transformers}, organization = {mbrenndoerfer.com}, note = {Accessed: 2025-12-19} }
APAAcademic
Michael Brenndoerfer (2025). Weight Tying: Sharing Embeddings Between Input and Output Layers. Retrieved from https://mbrenndoerfer.com/writing/weight-tying-shared-embeddings-transformers
MLAAcademic
Michael Brenndoerfer. "Weight Tying: Sharing Embeddings Between Input and Output Layers." 2025. Web. 12/19/2025. <https://mbrenndoerfer.com/writing/weight-tying-shared-embeddings-transformers>.
CHICAGOAcademic
Michael Brenndoerfer. "Weight Tying: Sharing Embeddings Between Input and Output Layers." Accessed 12/19/2025. https://mbrenndoerfer.com/writing/weight-tying-shared-embeddings-transformers.
HARVARDAcademic
Michael Brenndoerfer (2025) 'Weight Tying: Sharing Embeddings Between Input and Output Layers'. Available at: https://mbrenndoerfer.com/writing/weight-tying-shared-embeddings-transformers (Accessed: 12/19/2025).
SimpleBasic
Michael Brenndoerfer (2025). Weight Tying: Sharing Embeddings Between Input and Output Layers. https://mbrenndoerfer.com/writing/weight-tying-shared-embeddings-transformers
Michael Brenndoerfer

About the author: Michael Brenndoerfer

All opinions expressed here are my own and do not reflect the views of my employer.

Michael currently works as an Associate Director of Data Science at EQT Partners in Singapore, leading AI and data initiatives across private capital investments.

With over a decade of experience spanning private equity, management consulting, and software engineering, he specializes in building and scaling analytics capabilities from the ground up. He has published research in leading AI conferences and holds expertise in machine learning, natural language processing, and value creation through data.

Stay updated

Get notified when I publish new articles on data and AI, private equity, technology, and more.

No spam, unsubscribe anytime.

or

Create a free account to unlock exclusive features, track your progress, and join the conversation.

No popupsUnobstructed readingCommenting100% Free