Catastrophic Forgetting in Fine-Tuning: Causes & Mitigation

Michael BrenndoerferNovember 26, 202544 min read

Learn why neural networks forget prior capabilities during fine-tuning and discover mitigation strategies like EWC, L2-SP regularization, and replay methods.

Reading Level

Choose your expertise level to adjust how many terms are explained. Beginners see more tooltips, experts see fewer to maintain reading flow. Hover over underlined terms for instant definitions.

Catastrophic Forgetting

When you fine-tune a pre-trained language model on a new task, something troubling happens: the model gets better at your task but worse at everything else. Train BERT on sentiment analysis, and it may forget how to perform named entity recognition. Fine-tune GPT on medical text, and it might lose fluency in general English. This phenomenon, known as catastrophic forgetting, represents one of the central challenges in adapting powerful pre-trained models to specific domains.

As we discussed in the Transfer Learning chapter, the entire premise of modern NLP rests on leveraging knowledge from pre-training. But that knowledge is fragile. Without careful management, fine-tuning doesn't augment a model's capabilities; it overwrites them. Understanding why forgetting occurs and how to mitigate it is essential for you when customizing models without destroying what made them useful in the first place.

The Forgetting Phenomenon

Catastrophic forgetting occurs when a neural network trained on a new task loses performance on previously learned tasks. Unlike human forgetting, which is gradual and often recoverable, neural network forgetting can be sudden and complete. A model might drop from 90% accuracy on a prior task to near-random performance after just a few epochs of fine-tuning on new data.

To understand why this happens, consider the fundamental difference between how humans and neural networks store knowledge. Human memory is distributed across biological structures that maintain some degree of modularity. When you learn to play piano, the neural pathways encoding your knowledge of how to ride a bicycle remain largely undisturbed. Neural networks, in contrast, encode all their knowledge in a single set of shared weights. Every capability the model possesses depends on the same pool of parameters, and any change to those parameters affects all capabilities simultaneously.

Catastrophic Forgetting

The tendency of neural networks to abruptly lose previously learned information when trained on new data, caused by weight updates that overwrite the representations responsible for earlier capabilities.

The Stability-Plasticity Dilemma

The root cause of catastrophic forgetting lies in a fundamental tension in learning systems: the stability-plasticity dilemma. This dilemma, first articulated in the neuroscience literature, captures an essential trade-off that any adaptive system must navigate. A learning system must be:

  • Plastic enough to acquire new knowledge from incoming data
  • Stable enough to retain previously learned information

These two requirements exist in fundamental tension. High plasticity allows rapid learning but makes existing knowledge vulnerable to being overwritten. High stability preserves existing knowledge but prevents the acquisition of new information. Biological brains have evolved sophisticated mechanisms to balance these competing demands, including separate memory systems for different timescales and consolidation processes that protect important memories. Neural networks, by design, are highly plastic. Each gradient update modifies weights throughout the entire network, with no built-in mechanism to distinguish between weights that should change and weights that must remain stable.

Out[2]:
Visualization
Conceptual diagram showing the trade-off between stability and plasticity, with curves for learning rate and retention.
The stability-plasticity trade-off in neural networks. High plasticity enables rapid learning but causes forgetting, while high stability preserves knowledge but impedes adaptation. Effective fine-tuning strategies aim to find an optimal balance in the middle region.

Consider what happens during fine-tuning. The loss function measures performance only on the new task. It provides no signal about whether the model still performs well on other tasks, because the training process simply does not observe performance on those tasks. Gradients flow backward through the network, and every parameter update asks a single question: "How can I better predict sentiment?" The network has no mechanism to simultaneously ask: "Will this change hurt my ability to recognize named entities?" or "Am I still generating grammatically correct English?" The optimization process is entirely myopic, focused exclusively on the current batch of fine-tuning data.

This myopia is not a flaw in the training algorithm. It is a direct consequence of the objective function we specify. Standard supervised learning optimizes performance on the training distribution, and nothing in that objective encourages preservation of capabilities that are not represented in the training data. The resulting forgetting is thus not a bug but an inevitable consequence of how we have defined the learning problem.

Why Pre-trained Knowledge is Vulnerable

Pre-trained language models learn rich representations that capture syntax, semantics, and world knowledge. These representations are distributed across millions or billions of parameters, with individual neurons participating in many different capabilities. This distributed nature is both a strength and a vulnerability. It is a strength because distributed representations enable powerful generalization and efficient parameter use. It is a vulnerability because any modification to the shared parameters can propagate to affect multiple capabilities, even those that seem unrelated to the task being fine-tuned.

During fine-tuning, several distinct mechanisms contribute to forgetting, each operating through different pathways but all leading to the same outcome: degradation of pre-trained capabilities.

Parameter drift occurs when weights shift away from their pre-trained values. Each gradient update nudges parameters in the direction that improves performance on the fine-tuning task. Individually, these nudges may seem small and harmless. But the effects accumulate across millions of parameters and thousands of gradient updates. Even small changes to individual weights, accumulated across millions of parameters, can dramatically alter the model's behavior on tasks not represented in the fine-tuning data. The model drifts through weight space, moving further and further from the region that supported its original capabilities.

Representational shift happens when the intermediate representations learned during pre-training are reorganized to better serve the fine-tuning task. During pre-training, the model learns to encode text into hidden states that capture general linguistic properties: part-of-speech information, semantic relationships, discourse structure, and more. During fine-tuning, these representations may be reshaped to emphasize features relevant to the new task while de-emphasizing features that are no longer useful. Hidden state distributions may shift in ways that break the assumptions of capabilities built on top of them. A downstream classifier trained to interpret the original hidden state space may receive inputs it cannot properly interpret when that space transforms.

Output distribution changes affect the final layers most severely. If pre-training encouraged a diverse vocabulary distribution, with the model learning to predict many different words in many different contexts, and fine-tuning focuses on domain-specific terms, the model may lose fluency with general vocabulary. The output layer's weights adapt to assign high probability to the words that appear frequently in the fine-tuning data, while the weights corresponding to rarely-seen words decay toward values that produce low probabilities.

Measuring Forgetting

To manage forgetting, we first need to quantify it. Without measurement, forgetting remains invisible until deployment reveals unexpected failures. Several metrics capture different aspects of how models deteriorate after fine-tuning, and choosing the right metric depends on what aspects of performance matter most for your application.

Backward Transfer

The most direct measure is backward transfer: how much does performance on prior tasks change after learning new ones? This metric captures the intuition that we want to compare the model's performance on old tasks before and after it learns something new. If the model performed well on named entity recognition before fine-tuning on sentiment analysis, backward transfer tells us how much of that NER performance survived.

To formalize this intuition, we need notation for tracking performance across multiple tasks learned sequentially. Imagine a model that learns task 1, then task 2, then task 3, and so on up to task T. At each point in this sequence, we can evaluate the model on any of the tasks it has encountered. The backward transfer metric computes the average change in performance across all previously learned tasks.

BT=1T1i=1T1(RT,iRi,i)\text{BT} = \frac{1}{T-1} \sum_{i=1}^{T-1} (R_{T,i} - R_{i,i})

Let us unpack each component of this formula to understand exactly what it measures:

  • BT\text{BT}: the backward transfer score, which summarizes how much prior task performance has changed
  • TT: the total number of tasks learned sequentially, representing the full learning trajectory
  • 1T1\frac{1}{T-1}: the normalization term that computes the average change across all prior tasks, ensuring the metric is comparable regardless of how many tasks were learned
  • RT,iR_{T,i}: the performance on task ii after the model has finished training on all TT tasks, representing the model's current capability on that task
  • Ri,iR_{i,i}: the baseline performance on task ii immediately after it was originally learned, representing the best performance the model achieved when that task was its primary focus
  • i=1T1\sum_{i=1}^{T-1}: the summation over all previous tasks (excluding the current final task TT), accumulating the performance changes across the model's learning history

The intuition behind this formula is straightforward: for each prior task, we compute how much performance changed from immediately after learning that task to the present moment, then average these changes. Negative backward transfer indicates forgetting, as it means current performance is lower than initial performance. A model with backward transfer of -0.15, for instance, has lost an average of 15 percentage points on prior tasks. Positive backward transfer would indicate that learning new tasks somehow improved performance on old tasks, a phenomenon called positive transfer that can occur when tasks share structure.

Out[3]:
Visualization
Heatmap showing task performance matrix across sequential learning, illustrating backward transfer calculation.
Visualization of backward transfer measurement across sequential task learning. The heatmap shows performance on each task (columns) after training on each sequential task (rows). Backward transfer compares diagonal entries (performance right after learning) with the final row (current performance).
Backward Transfer: -0.199

Forgetting Measure

A related metric tracks the maximum performance achieved on each task versus current performance. This metric addresses a subtle limitation of backward transfer: what if performance on a task fluctuates during training? Perhaps the model's NER performance actually increased slightly after learning sentiment analysis, before crashing when a third task was learned. The forgetting measure captures the worst-case degradation from peak performance.

Fi=maxt{1,,T1}Rt,iRT,iF_i = \max_{t \in \{1, \ldots, T-1\}} R_{t,i} - R_{T,i}

Understanding each term in this formula reveals how it differs from backward transfer:

  • FiF_i: the forgetting measure for a specific task ii, quantifying how much performance on that particular task has degraded from its peak
  • maxtRt,i\max_{t} R_{t,i}: the highest performance recorded for task ii at any previous time step tt, representing the best the model ever did on this task throughout its training history
  • Rt,iR_{t,i}: the performance on task ii after training on task tt, allowing us to track performance at each point in the training sequence
  • RT,iR_{T,i}: the current performance on task ii after training on the final task TT, representing where the model stands now

This captures how much performance has degraded from the best observed level, regardless of when that peak occurred. A task might achieve peak performance not immediately after being learned, but at some later point when a related task was trained. The forgetting measure accounts for this by always comparing against the best-ever performance, providing a more comprehensive picture of capability loss.

Perplexity Degradation

For language models, perplexity on held-out general text provides a task-agnostic forgetting measure. Perplexity quantifies how surprised the model is by text it has not seen before. Lower perplexity indicates better language modeling capability, as the model assigns higher probability to the actual words that appear.

If a model's perplexity on Wikipedia increases from 15 to 45 after fine-tuning on legal documents, that quantifies how much general language modeling capability was lost. The model has become three times more surprised by general English text, indicating substantial degradation in its ability to model the distribution of ordinary language. This metric is particularly valuable because it does not require defining specific tasks. It measures the model's overall competence with language rather than its performance on any particular downstream application.

Benchmark Suites

In practice, we use standardized benchmark suites to measure forgetting comprehensively. These suites provide a battery of diverse tasks that together assess the breadth of a model's capabilities. For encoder models like BERT, the GLUE benchmark provides multiple tasks spanning sentiment analysis, textual entailment, and semantic similarity. Performance across this suite reveals whether fine-tuning has preserved the model's general language understanding or degraded it.

For generative models, evaluations might include:

  • Language modeling perplexity on diverse corpora
  • Zero-shot performance on reasoning benchmarks
  • Few-shot classification accuracy
  • Generation quality metrics (fluency, coherence)

These comprehensive evaluations provide a multidimensional view of forgetting, revealing which capabilities survived fine-tuning and which did not.

Visualizing Forgetting Dynamics

Let's examine how forgetting manifests during fine-tuning. We'll simulate a simplified scenario to build intuition.

In[4]:
Code
import numpy as np

np.random.seed(42)

# Simulate fine-tuning dynamics
epochs = 50

# Task A: original task (pre-training proxy)
# Task B: new fine-tuning task

# Before fine-tuning
task_a_baseline = 0.85
task_b_baseline = 0.45

# Simulate learning curves during fine-tuning
task_b_performance = []
task_a_performance = []

for epoch in range(epochs):
    # Task B improves with training
    task_b = task_b_baseline + (0.95 - task_b_baseline) * (
        1 - np.exp(-epoch / 10)
    )
    task_b += np.random.normal(0, 0.01)
    task_b_performance.append(np.clip(task_b, 0, 1))

    # Task A degrades - rapid initial forgetting, then stabilizes
    forgetting_rate = 0.3 * (1 - np.exp(-epoch / 8))
    task_a = task_a_baseline - forgetting_rate + np.random.normal(0, 0.01)
    task_a_performance.append(np.clip(task_a, 0, 1))
Out[5]:
Visualization
Line plot showing Task B performance rising while Task A performance falls during fine-tuning epochs.
Performance dynamics during fine-tuning showing the classic forgetting pattern: the new task (Task B) improves while the original capability (Task A) degrades. Most forgetting occurs in early epochs when gradient updates are largest.

The visualization reveals a characteristic pattern: most forgetting happens early in training when learning rates are highest and the model is adjusting most aggressively. Later epochs show diminishing returns on both learning and forgetting as the model approaches a new equilibrium.

Causes at the Parameter Level

To understand forgetting deeply, we need to examine what happens to individual parameters during fine-tuning. The aggregate measures of forgetting we discussed earlier emerge from microscopic changes happening simultaneously across millions of weights. By understanding these parameter-level dynamics, we can design more targeted interventions.

Weight Importance Varies

Not all parameters contribute equally to each task. Some weights are critical for general language understanding, carrying information about syntax and semantics that applies broadly. Others specialize in specific capabilities, encoding task-specific patterns that are only relevant in certain contexts. Many are relatively unimportant for any particular function, serving as "slack" capacity that the model uses flexibly.

The challenge is that fine-tuning doesn't distinguish between these categories; it updates all parameters to reduce the fine-tuning loss. The gradient provides information about how changing each parameter would affect the fine-tuning objective, but it provides no information about how those same changes would affect other objectives. A parameter critical for named entity recognition receives the same treatment as a parameter that is entirely unimportant for prior capabilities.

In[6]:
Code
# Simulate parameter importance for two tasks
n_params = 1000

# Generate importance scores for original capabilities
importance_original = np.random.exponential(0.3, n_params)
importance_original = importance_original / importance_original.max()

# Generate importance scores for fine-tuning task
importance_finetune = np.random.exponential(0.3, n_params)
importance_finetune = importance_finetune / importance_finetune.max()

# Parameters important for both (overlap)
shared_important = (importance_original > 0.5) & (importance_finetune > 0.5)
original_only = (importance_original > 0.5) & (importance_finetune <= 0.5)
finetune_only = (importance_original <= 0.5) & (importance_finetune > 0.5)
Out[7]:
Visualization
Scatter plot showing parameter importance scores for original and fine-tuning tasks with highlighted regions.
Parameter importance varies by task. When fine-tuning modifies parameters critical to original capabilities (lower-right quadrant) without protecting them, forgetting occurs. Successful mitigation strategies identify and protect these shared parameters.
Out[8]:
Console
Parameters critical for original task: 21
Parameters critical for fine-tuning: 28
Shared critical parameters: 0
Low importance parameters: 951

The key insight is that parameters in the "original task critical" category (blue points) need protection during fine-tuning. If these weights change significantly, original capabilities degrade, even if the changes help the fine-tuning objective. The scatter plot reveals the geometry of this problem: parameters occupy different regions depending on their importance to each task, and the region that matters most for forgetting is the one containing parameters that are important for original capabilities but not for the fine-tuning task.

Gradient Interference

When gradients for the fine-tuning task oppose what would maintain original capabilities, we have gradient interference. This phenomenon occurs when the direction that improves the fine-tuning objective is exactly opposite to the direction that would preserve prior performance.

Consider a weight that needs to increase to improve sentiment classification but whose original value supported named entity recognition. Standard gradient descent simply follows the fine-tuning gradient, potentially destroying NER performance. The optimizer has no way to know that this particular weight is important for a capability we want to preserve. It only sees the gradient pointing upward and dutifully moves the weight in that direction.

Gradient interference is particularly insidious because it is invisible during training. The training loss decreases, validation accuracy on the fine-tuning task improves, and everything appears to be working correctly. Only when we evaluate on prior tasks do we discover that the gradients we followed led us away from the region of weight space that supported those capabilities.

Out[9]:
Visualization
Vector diagram showing gradient directions for new task versus prior task, illustrating gradient interference.
Vector diagrams showing three gradient interaction scenarios. Aligned gradients (left) facilitate positive transfer, orthogonal gradients (center) allow independent learning, and opposing gradients (right) cause interference where the new task update (red) degrades prior capabilities (green).
Notebook output
Notebook output

Forgetting Mitigation Strategies

We have developed numerous techniques to reduce forgetting. These fall into several broad categories.

Regularization-Based Methods

These approaches add terms to the loss function that penalize changes to important parameters. The key insight behind regularization methods is that we can encode our desire to preserve prior capabilities directly into the objective function. By adding a penalty for parameter drift, we change the optimization landscape so that the model must balance improving on the new task against staying close to its original weights.

Elastic Weight Consolidation (EWC) identifies parameters important to prior tasks using the Fisher information matrix, then adds a penalty for changing those parameters. The Fisher information matrix provides a principled way to estimate parameter importance: parameters with high Fisher information are those where small changes cause large changes in the model's predictions on prior tasks. By penalizing changes to these high-importance parameters more heavily than changes to low-importance parameters, EWC implements a form of selective protection.

The EWC loss function combines the standard task loss with a weighted regularization term:

LEWC=Ltask+λ2iFi(θiθi)2\mathcal{L}_{\text{EWC}} = \mathcal{L}_{\text{task}} + \frac{\lambda}{2} \sum_i F_i (\theta_i - \theta_i^*)^2

Understanding each component of this formula reveals the mechanism by which EWC prevents forgetting:

  • LEWC\mathcal{L}_{\text{EWC}}: the composite loss function used during fine-tuning, which the optimizer minimizes
  • Ltask\mathcal{L}_{\text{task}}: the standard loss function (e.g., cross-entropy) for the new task, ensuring the model still learns the fine-tuning objective
  • λ\lambda: a hyperparameter controlling the regularization strength, allowing you to trade off between task performance and capability preservation
  • FiF_i: the Fisher information value for parameter ii, representing its importance to the previous task by approximating the curvature of the loss surface near the current parameter values
  • θi\theta_i: the current value of parameter ii during training, which changes as optimization proceeds
  • θi\theta_i^*: the value of parameter ii in the pre-trained (original) model, serving as the anchor point that regularization pulls toward
  • i\sum_i: the summation over all parameters in the neural network, ensuring every weight is subject to regularization

The key innovation in EWC is the use of Fisher information to weight the regularization. Parameters with high Fisher information, those whose changes would strongly affect predictions on prior tasks, receive stronger protection. Parameters with low Fisher information can change more freely to accommodate the new task. This selective protection allows the model to adapt where it can and stay stable where it must.

L2 regularization toward pre-trained weights is a simpler variant that penalizes all parameter changes equally. Rather than estimating parameter importance through Fisher information, L2-SP (L2 to Starting Point) assumes all parameters are equally important. This simplification dramatically reduces computational overhead while still providing meaningful protection against drift.

LL2-SP=Ltask+α2i(θiθi)2\mathcal{L}_{\text{L2-SP}} = \mathcal{L}_{\text{task}} + \frac{\alpha}{2} \sum_i (\theta_i - \theta_i^*)^2

Breaking down this formula shows how it implements uniform regularization:

  • LL2-SP\mathcal{L}_{\text{L2-SP}}: the regularized loss function (L2-SP stands for "L2 to Starting Point"), representing the objective the optimizer minimizes
  • Ltask\mathcal{L}_{\text{task}}: the loss for the current fine-tuning task, ensuring the model still learns to perform the new task well
  • α\alpha: the regularization strength hyperparameter, controlling how strongly the model is pulled back toward its original weights
  • θi\theta_i: the current value of parameter ii during training
  • θi\theta_i^*: the value of parameter ii in the pre-trained (original) model
  • i(θiθi)2\sum_i (\theta_i - \theta_i^*)^2: the sum of squared differences between current parameters and pre-trained parameters, measuring total drift from the original model

This formulation is equivalent to EWC if we assume every parameter has equal importance (i.e., Fi=1F_i = 1 for all ii). It prevents the model parameters from drifting too far from their starting values, treating all weights as equally critical. While this uniform treatment is less sophisticated than EWC's selective protection, it often works surprisingly well in practice and requires no additional computation to estimate parameter importance.

Out[10]:
Visualization
Side-by-side comparison of L2-SP uniform regularization versus EWC importance-weighted regularization.
Comparison of regularization approaches. L2-SP applies uniform regularization to all parameters (left), while EWC weights regularization by parameter importance estimated via Fisher information (right). EWC allows more freedom for unimportant parameters while strongly protecting critical ones.
Notebook output

Replay-Based Methods

Replay methods maintain performance on prior tasks by mixing in examples from previous training:

  • Experience replay stores a subset of pre-training examples and includes them in each fine-tuning batch. This ensures the model continuously sees data from the original distribution.
  • Generative replay uses the model itself (or a separate generator) to produce synthetic examples resembling pre-training data, avoiding the need to store actual training examples.

Architecture-Based Methods

Rather than regularizing training, architecture approaches modify the model structure:

  • Progressive networks add new capacity for each task while freezing previous parameters entirely. This eliminates forgetting but dramatically increases model size.
  • Adapter layers insert small trainable modules while keeping pre-trained weights frozen. We'll explore this approach in detail in the upcoming chapters on parameter-efficient fine-tuning (PEFT).

Learning Rate Strategies

Careful learning rate management significantly impacts forgetting:

  • Discriminative learning rates apply smaller learning rates to earlier layers (which encode more general features) and larger rates to later layers (which are more task-specific).
  • Learning rate warmup starts with very small updates, allowing the model to find a gentle path toward the fine-tuning objective without abrupt changes.
  • Early stopping based on validation performance on both the fine-tuning task and held-out prior tasks can prevent excessive forgetting.

Implementing Forgetting Measurement

Let's implement a practical framework for measuring forgetting during fine-tuning. We'll use a sentiment classification task while monitoring language modeling capability.

In[11]:
Code
import warnings

warnings.filterwarnings("ignore")
In[12]:
Code
# Create a simple dataset for demonstration
from torch.utils.data import Dataset


class SentimentDataset(Dataset):
    def __init__(self, tokenizer, num_samples=200):
        self.tokenizer = tokenizer
        # Simple positive/negative examples
        positive = [
            "This movie was absolutely wonderful and amazing.",
            "I loved every minute of this fantastic film.",
            "Brilliant performances and excellent direction.",
            "A masterpiece of modern cinema.",
            "Outstanding work from the entire cast.",
        ]
        negative = [
            "This was the worst movie I have ever seen.",
            "Terrible acting and boring plot throughout.",
            "I regret wasting my time on this film.",
            "Completely disappointing and poorly made.",
            "A total disaster from start to finish.",
        ]

        # Expand dataset
        self.texts = (positive * (num_samples // 10)) + (
            negative * (num_samples // 10)
        )
        self.labels = [1] * (num_samples // 2) + [0] * (num_samples // 2)

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        encoding = self.tokenizer(
            self.texts[idx],
            truncation=True,
            padding="max_length",
            max_length=64,
            return_tensors="pt",
        )
        return {
            "input_ids": encoding["input_ids"].squeeze(),
            "attention_mask": encoding["attention_mask"].squeeze(),
            "labels": torch.tensor(self.labels[idx]),
        }
In[13]:
Code
def measure_language_modeling_capability(model, tokenizer, test_texts):
    """Measure model's language modeling capability via pseudo-perplexity."""
    model.eval()
    total_loss = 0
    total_tokens = 0

    device = next(model.parameters()).device

    with torch.no_grad():
        for text in test_texts:
            encodings = tokenizer(
                text, return_tensors="pt", truncation=True, max_length=128
            )
            input_ids = encodings["input_ids"].to(device)

            # For classification model, use embedding similarity as proxy
            # In practice, you'd use a separate LM head or the original model
            outputs = model.base_model(input_ids)
            hidden_states = outputs.last_hidden_state

            # Measure representation quality via hidden state statistics
            total_loss += hidden_states.var().item()
            total_tokens += input_ids.size(1)

    return total_loss / len(test_texts)
In[14]:
Code
def train_epoch_with_monitoring(
    model, dataloader, optimizer, device, tokenizer, lm_test_texts
):
    """Train one epoch and return metrics."""
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for batch in dataloader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        optimizer.zero_grad()
        outputs = model(
            input_ids=input_ids, attention_mask=attention_mask, labels=labels
        )
        loss = outputs.loss
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        preds = outputs.logits.argmax(dim=-1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    train_acc = correct / total
    avg_loss = total_loss / len(dataloader)

    # Measure language modeling capability
    lm_score = measure_language_modeling_capability(
        model, tokenizer, lm_test_texts
    )

    return avg_loss, train_acc, lm_score

Now let's run fine-tuning while tracking both task performance and general capability:

In[15]:
Code
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# Use a small model for demonstration
model_name = "prajjwal1/bert-tiny"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(
    model_name, num_labels=2
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Create dataset and dataloader
dataset = SentimentDataset(tokenizer, num_samples=200)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

# Test texts for monitoring language modeling capability
lm_test_texts = [
    "The quick brown fox jumps over the lazy dog.",
    "Natural language processing enables computers to understand human language.",
    "Machine learning algorithms improve through experience and data.",
    "The weather forecast predicts rain tomorrow afternoon.",
    "Scientists discovered a new species in the Amazon rainforest.",
]

# Measure baseline
baseline_lm_score = measure_language_modeling_capability(
    model, tokenizer, lm_test_texts
)
In[16]:
Code
# Fine-tune and track forgetting
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)
num_epochs = 10

history = {
    "train_loss": [],
    "train_acc": [],
    "lm_score": [],
    "lm_score_change": [],
}

for epoch in range(num_epochs):
    train_loss, train_acc, lm_score = train_epoch_with_monitoring(
        model, dataloader, optimizer, device, tokenizer, lm_test_texts
    )

    history["train_loss"].append(train_loss)
    history["train_acc"].append(train_acc)
    history["lm_score"].append(lm_score)
    history["lm_score_change"].append(
        (lm_score - baseline_lm_score) / baseline_lm_score * 100
    )
Out[17]:
Visualization
Two-panel plot showing task accuracy increasing while LM score changes during fine-tuning.
Performance metrics during fine-tuning showing task accuracy (left) improving steadily while language modeling capability (right) changes as the model adapts. The divergence between these curves quantifies the stability-plasticity trade-off.
Notebook output
Out[18]:
Console
Final task accuracy: 100.00%
LM capability change: -6.3%
Max capability drift: 6.4%

The divergence between task accuracy and language modeling capability quantifies the severity of forgetting. While the model masters the sentiment task, the negative LM score change indicates a degradation in general linguistic competence, a warning sign for deployment.

Implementing L2-SP Regularization

Let's implement L2-SP (L2 regularization toward Starting Point), a simple but effective forgetting mitigation technique:

In[19]:
Code
class L2SPRegularizer:
    """L2 regularization toward pre-trained weights."""

    def __init__(self, model, alpha=0.01):
        self.alpha = alpha
        # Store pre-trained weights
        self.pretrained_weights = {}
        for name, param in model.named_parameters():
            if param.requires_grad:
                self.pretrained_weights[name] = param.data.clone()

    def penalty(self, model):
        """Calculate L2-SP penalty."""
        loss = 0.0
        for name, param in model.named_parameters():
            if name in self.pretrained_weights:
                diff = param - self.pretrained_weights[name]
                loss += (diff**2).sum()
        return self.alpha * loss
In[20]:
Code
def train_with_l2sp(
    model, dataloader, optimizer, regularizer, device, num_epochs=10
):
    """Train with L2-SP regularization."""
    history = {"loss": [], "acc": [], "reg_loss": []}

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        total_reg = 0
        correct = 0
        total = 0

        for batch in dataloader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            optimizer.zero_grad()
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels,
            )

            task_loss = outputs.loss
            reg_loss = regularizer.penalty(model)
            loss = task_loss + reg_loss

            loss.backward()
            optimizer.step()

            total_loss += task_loss.item()
            total_reg += reg_loss.item()
            preds = outputs.logits.argmax(dim=-1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

        history["loss"].append(total_loss / len(dataloader))
        history["acc"].append(correct / total)
        history["reg_loss"].append(total_reg / len(dataloader))

    return history
In[21]:
Code
# Compare standard fine-tuning vs L2-SP
# Reset model for fair comparison
model_standard = AutoModelForSequenceClassification.from_pretrained(
    model_name, num_labels=2
).to(device)
model_l2sp = AutoModelForSequenceClassification.from_pretrained(
    model_name, num_labels=2
).to(device)

# Standard fine-tuning
optimizer_std = torch.optim.AdamW(model_standard.parameters(), lr=2e-4)
history_standard = {"loss": [], "acc": []}

for epoch in range(10):
    model_standard.train()
    total_loss = 0
    correct = 0
    total = 0

    for batch in dataloader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        optimizer_std.zero_grad()
        outputs = model_standard(
            input_ids=input_ids, attention_mask=attention_mask, labels=labels
        )
        outputs.loss.backward()
        optimizer_std.step()

        total_loss += outputs.loss.item()
        preds = outputs.logits.argmax(dim=-1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    history_standard["loss"].append(total_loss / len(dataloader))
    history_standard["acc"].append(correct / total)

# L2-SP fine-tuning
optimizer_l2sp = torch.optim.AdamW(model_l2sp.parameters(), lr=2e-4)
regularizer = L2SPRegularizer(model_l2sp, alpha=0.01)
history_l2sp = train_with_l2sp(
    model_l2sp, dataloader, optimizer_l2sp, regularizer, device, num_epochs=10
)
In[22]:
Code
# Measure parameter drift for both models
def measure_parameter_drift(model, model_name_ref):
    """Calculate total parameter drift from pre-trained weights."""
    ref_model = AutoModelForSequenceClassification.from_pretrained(
        model_name_ref, num_labels=2
    )

    total_drift = 0
    total_params = 0

    for (name1, param1), (name2, param2) in zip(
        model.named_parameters(), ref_model.named_parameters()
    ):
        if "classifier" not in name1:  # Exclude task-specific head
            drift = ((param1.cpu() - param2) ** 2).sum().item()
            total_drift += drift
            total_params += param1.numel()

    return np.sqrt(total_drift / total_params)


drift_standard = measure_parameter_drift(model_standard, model_name)
drift_l2sp = measure_parameter_drift(model_l2sp, model_name)
Out[23]:
Visualization
Line chart comparing training accuracy of Standard vs L2-SP fine-tuning.
Task performance comparison between standard fine-tuning and L2-SP regularization. Both methods achieve similar training accuracy over 10 epochs, demonstrating that L2-SP regularization does not hinder the model's ability to learn the new task.
Out[24]:
Visualization
Bar chart showing lower parameter drift for L2-SP compared to Standard fine-tuning.
Parameter drift comparison. L2-SP significantly reduces the RMS parameter drift from pre-trained weights compared to standard fine-tuning, helping preserve original capabilities.
Out[25]:
Console
Standard fine-tuning drift: 0.0011
L2-SP fine-tuning drift: 0.0006
Drift reduction: 46.0%

The L2-SP regularizer successfully constrains parameter drift while maintaining task performance. In practice, you would tune the regularization strength α\alpha based on validation performance on both the target task and held-out tasks measuring prior capabilities.

Preserving Pre-trained Capabilities

Beyond mitigating forgetting during training, you need strategies to actively preserve the valuable capabilities that pre-trained models bring.

Layer-wise Learning Rates

Different layers encode different levels of abstraction. Earlier layers learn general features (syntax, basic semantics) while later layers specialize for the pre-training task. This suggests using different learning rates:

In[26]:
Code
def create_discriminative_learning_rates(model, base_lr=2e-5, decay_factor=0.9):
    """Create parameter groups with layer-wise learning rates."""
    param_groups = []

    # Get all encoder layers
    num_layers = len(model.base_model.encoder.layer)

    for i, layer in enumerate(model.base_model.encoder.layer):
        # Earlier layers get smaller learning rates
        layer_lr = base_lr * (decay_factor ** (num_layers - i - 1))
        param_groups.append({"params": layer.parameters(), "lr": layer_lr})

    # Classifier head gets full learning rate
    param_groups.append(
        {"params": model.classifier.parameters(), "lr": base_lr}
    )

    return param_groups


# Example usage
param_groups = create_discriminative_learning_rates(
    model_standard, base_lr=2e-5
)
Out[27]:
Console
Layer-wise learning rates:
  Encoder layer 0: 1.80e-05
  Encoder layer 1: 2.00e-05
  Classifier head: 2.00e-05
Out[28]:
Visualization
Bar chart showing increasing learning rates from early to late layers in a transformer model.
Discriminative learning rate schedule across model layers. Earlier layers (containing general linguistic features) receive smaller learning rates to limit updates, while later layers and the task-specific head receive larger rates to enable adaptation.

This discriminative schedule applies aggressive updates to the task-specific classifier head while protecting the foundational features in early encoder layers with much smaller learning rates.

Gradual Unfreezing

Another approach is to initially freeze early layers and progressively unfreeze them during training:

In[29]:
Code
def freeze_layers(model, num_layers_to_freeze):
    """Freeze the first n encoder layers."""
    for i, layer in enumerate(model.base_model.encoder.layer):
        if i < num_layers_to_freeze:
            for param in layer.parameters():
                param.requires_grad = False
        else:
            for param in layer.parameters():
                param.requires_grad = True


def count_trainable_params(model):
    """Count trainable parameters."""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
Out[30]:
Console
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at prajjwal1/bert-tiny and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Gradual unfreezing schedule:
  Freeze 2 layers: 3,989,634 / 4,386,178 trainable (91.0%)
  Freeze 1 layers: 4,187,906 / 4,386,178 trainable (95.5%)
  Freeze 0 layers: 4,386,178 / 4,386,178 trainable (100.0%)
Out[31]:
Visualization
Timeline diagram showing progressive unfreezing of model layers during fine-tuning epochs.
Gradual unfreezing schedule during fine-tuning. Training begins with most layers frozen (high stability), then progressively unfreezes layers to allow more adaptation. This controlled approach prevents early catastrophic updates while eventually enabling full model adaptation.

Starting with most parameters frozen limits the plasticity of the model, preventing large weight updates that could destroy pre-trained knowledge. As training stabilizes, more layers are unfrozen to allow fine-grained adaptation.

Evaluation-Driven Stopping

Perhaps the most reliable way to preserve capabilities is to continuously monitor them:

In[32]:
Code
class ForgettingAwareTrainer:
    """Trainer that stops when forgetting exceeds threshold."""

    def __init__(self, model, max_forgetting_pct=10.0):
        self.model = model
        self.max_forgetting_pct = max_forgetting_pct
        self.baseline_scores = {}

    def set_baseline(self, eval_fn, eval_name):
        """Record baseline performance on capability."""
        self.baseline_scores[eval_name] = eval_fn(self.model)

    def check_forgetting(self, eval_fn, eval_name):
        """Check if forgetting exceeds threshold."""
        if eval_name not in self.baseline_scores:
            return False, 0.0

        current = eval_fn(self.model)
        baseline = self.baseline_scores[eval_name]

        # Calculate percentage degradation
        if baseline != 0:
            degradation = (baseline - current) / abs(baseline) * 100
        else:
            degradation = 0.0

        return degradation > self.max_forgetting_pct, degradation
In[33]:
Code
# Demonstrate usage with the trainer
trainer = ForgettingAwareTrainer(model, max_forgetting_pct=10.0)


# Define a mock evaluation function (e.g., accuracy on held-out task)
def eval_capability(m):
    return 0.85


# Set the baseline before training
trainer.set_baseline(eval_capability, "reasoning")


# Check for forgetting (simulating a degradation)
def eval_degraded(m):
    return 0.75  # Performance dropped significantly


stop_signal, degradation = trainer.check_forgetting(eval_degraded, "reasoning")
Out[34]:
Console
Degradation detected: 11.8%
Stop training signal: True

The trainer detects that the performance drop (from 0.85 to 0.75) exceeds the allowable threshold, triggering a stop signal. This active monitoring acts as a safety brake, ensuring that the model retains essential capabilities throughout the fine-tuning process.

Limitations and Impact

Catastrophic forgetting remains an active research area with no perfect solutions. Each mitigation strategy involves trade-offs.

Regularization approaches like EWC and L2-SP slow forgetting but cannot eliminate it entirely. They also require storing additional information (Fisher matrices or original weights) and add hyperparameters that need tuning. For very long fine-tuning runs or significantly different task distributions, regularization alone may be insufficient.

Replay methods effectively maintain prior performance but require access to pre-training data, which may not be available for proprietary models. They also increase training time and memory requirements proportionally to how much replay data is included. Generative replay avoids the data storage problem but introduces its own errors through imperfect generation.

Architecture-based solutions like adapter layers offer the strongest forgetting prevention by freezing original weights entirely, but they constrain the model's ability to adapt representations. For tasks requiring significant representational change, frozen architectures may underperform full fine-tuning. We'll explore these approaches in detail in the upcoming PEFT chapters, where methods like LoRA provide a compelling middle ground.

The broader impact of understanding catastrophic forgetting extends beyond single-model fine-tuning. Continual learning systems that must adapt to streaming data face forgetting challenges at every update. Multi-task models must balance performance across many objectives simultaneously. Even alignment techniques like RLHF must carefully manage forgetting to prevent models from losing capabilities while learning to follow instructions.

Summary

Catastrophic forgetting occurs when fine-tuning overwrites the neural representations responsible for pre-trained capabilities. This phenomenon stems from the fundamental stability-plasticity dilemma: neural networks must be plastic enough to learn new tasks but stable enough to retain old knowledge.

Measuring forgetting requires evaluating models on held-out tasks representing prior capabilities, using metrics like backward transfer, perplexity degradation, or benchmark suite performance. Without measurement, forgetting goes undetected until deployment reveals unexpected failures.

Mitigation strategies span multiple approaches. Regularization methods like L2-SP and EWC penalize parameter drift from pre-trained values. Replay methods mix prior task data into fine-tuning batches. Architecture methods freeze original weights while adding trainable components. Learning rate strategies apply differential updates across layers.

Preserving pre-trained capabilities requires combining these techniques with careful monitoring. Layer-wise learning rates protect general features in early layers. Gradual unfreezing allows controlled adaptation. Most importantly, continuous evaluation on capability benchmarks enables early detection and intervention when forgetting exceeds acceptable thresholds.

As we'll see in the Fine-tuning Learning Rates chapter, the choice of learning rate schedule profoundly impacts both task learning and forgetting dynamics. The PEFT methods covered later offer particularly elegant solutions by adding small trainable modules while keeping pre-trained weights completely frozen, effectively eliminating forgetting at the cost of reduced adaptation flexibility.

Quiz

Ready to test your understanding? Take this quick quiz to reinforce what you've learned about catastrophic forgetting and strategies to mitigate it during fine-tuning.

Loading component...

Reference

BIBTEXAcademic
@misc{catastrophicforgettinginfinetuningcausesmitigation, author = {Michael Brenndoerfer}, title = {Catastrophic Forgetting in Fine-Tuning: Causes & Mitigation}, year = {2025}, url = {https://mbrenndoerfer.com/writing/catastrophic-forgetting-fine-tuning-mitigation}, organization = {mbrenndoerfer.com}, note = {Accessed: 2025-01-01} }
APAAcademic
Michael Brenndoerfer (2025). Catastrophic Forgetting in Fine-Tuning: Causes & Mitigation. Retrieved from https://mbrenndoerfer.com/writing/catastrophic-forgetting-fine-tuning-mitigation
MLAAcademic
Michael Brenndoerfer. "Catastrophic Forgetting in Fine-Tuning: Causes & Mitigation." 2026. Web. today. <https://mbrenndoerfer.com/writing/catastrophic-forgetting-fine-tuning-mitigation>.
CHICAGOAcademic
Michael Brenndoerfer. "Catastrophic Forgetting in Fine-Tuning: Causes & Mitigation." Accessed today. https://mbrenndoerfer.com/writing/catastrophic-forgetting-fine-tuning-mitigation.
HARVARDAcademic
Michael Brenndoerfer (2025) 'Catastrophic Forgetting in Fine-Tuning: Causes & Mitigation'. Available at: https://mbrenndoerfer.com/writing/catastrophic-forgetting-fine-tuning-mitigation (Accessed: today).
SimpleBasic
Michael Brenndoerfer (2025). Catastrophic Forgetting in Fine-Tuning: Causes & Mitigation. https://mbrenndoerfer.com/writing/catastrophic-forgetting-fine-tuning-mitigation