Instruction Tuning Training: Data Mixing & Loss Masking

Michael BrenndoerferDecember 19, 202524 min read

Master instruction tuning training with data mixing strategies, loss masking, and hyperparameter selection for effective language model fine-tuning.

Reading Level

Choose your expertise level to adjust how many terms are explained. Beginners see more tooltips, experts see fewer to maintain reading flow. Hover over underlined terms for instant definitions.

Instruction Tuning Training

With instruction data prepared and properly formatted, we turn to the training process itself. Instruction tuning makes pre-trained models follow requests and provide helpful responses. This process requires more than standard fine-tuning. You must decide how to mix instruction types, which signals to learn from, and how to keep pre-trained capabilities.

This chapter covers the mechanics of instruction tuning and the decisions required during training. We will look at balancing task types with data mixing, using loss masking to focus the model on responses, selecting hyperparameters, and using multi-task learning to improve generalization.

Data Mixing Strategies

Instruction tuning datasets contain examples from many task types, such as question answering, summarization, and coding. Each category requires different skills, such as factual recall or logical reasoning. Task distribution determines what the model learns and how it generalizes. You must control this distribution.

The Sampling Problem

Training on all available instruction data causes task imbalance. Instruction datasets are rarely balanced because some tasks are easier to collect than others. Question-answering pairs can be scraped from forums and documentation, while high-quality mathematical reasoning examples require expert annotation. If your dataset contains 120,000 question-answering examples but only 8,000 math examples, the model sees QA tasks 15 times more frequently during training. This imbalance causes the model to perform well on common tasks but struggle with rare ones, even if both are equally important for the application.

Consider a hypothetical instruction dataset with the following distribution, which reflects the kind of imbalance commonly found in real instruction tuning corpora:

In[2]:
Code
# Simulated task distribution in a typical instruction dataset
tasks = [
    "QA",
    "Summarization",
    "Code",
    "Math",
    "Creative",
    "Translation",
    "Classification",
]
raw_counts = [120000, 80000, 15000, 8000, 25000, 40000, 60000]

# Calculate proportions
total = sum(raw_counts)
proportions = [c / total for c in raw_counts]
Out[3]:
Console
Task Distribution in Raw Dataset:
----------------------------------------
QA               120,000 examples (34.5%)
Summarization     80,000 examples (23.0%)
Code              15,000 examples ( 4.3%)
Math               8,000 examples ( 2.3%)
Creative          25,000 examples ( 7.2%)
Translation       40,000 examples (11.5%)
Classification    60,000 examples (17.2%)
Total            348,000 examples

In proportional sampling, where examples are drawn based on their natural frequency, the model sees the QA category roughly 15 times more often than math. This imbalance biases the training signal. For example, after 1 million examples, the model has updated its weights 345,000 times for QA but only 23,000 times for math. The model's parameters are therefore shaped far more by what helps it answer questions than by what helps it solve mathematical problems, even if both are equally important.

Sampling Strategies

Three main approaches address task imbalance, each with different trade-offs that make them suitable for different situations and goals:

Proportional sampling uses the natural data distribution without modification. Examples appear according to their frequency in the dataset, meaning that if 34% of your data is question-answering, then 34% of your training batches will contain question-answering examples on average. This works well if the natural distribution matches your goals, like when a general assistant should handle common tasks more fluently because you perform them more often. The main drawback is underfitting on minority tasks. The model may not see enough math examples to learn reasoning, even if that skill is important.

Equal sampling takes the opposite approach, assigning each task equal probability regardless of dataset size. If you have seven task categories, each receives roughly 14.3% of the training examples. This strategy ensures the model sees rare tasks frequently enough to learn from every category. However, it introduces its own problems: it may waste model capacity by over-training on already well-represented tasks, and more seriously, it can cause severe overfitting when small task datasets must be repeated many times to achieve equal representation. If your math dataset has only 8,000 examples but needs to match 120,000 QA examples, each math example would be seen 15 times, risking memorization rather than generalization.

Temperature-based sampling uses a temperature parameter to balance the task distribution. This approach gives you continuous control between the extremes of proportional and equal sampling. Given raw task proportions pip_i, the sampling probability becomes:

qi=pi1/Tjpj1/Tq_i = \frac{p_i^{1/T}}{\sum_{j} p_j^{1/T}}

where:

  • qiq_i: the adjusted sampling probability for task ii, representing how likely you are to draw an example from this task during training
  • pip_i: the raw proportion of task ii in the original dataset, calculated as the number of examples in task ii divided by the total number of examples
  • pi1/Tp_i^{1/T}: the exponentiated proportion for task ii, where the 1/T1/T power reduces the differences between large and small values to make the distribution more uniform
  • TT: the temperature parameter that controls distribution smoothing, typically set to values of T1T \ge 1
  • jpj1/T\sum_{j} p_j^{1/T}: the sum of exponentiated proportions across all tasks, which ensures the adjusted probabilities sum to one

The temperature parameter changes the distribution by exponentiating probabilities. Understanding how it works helps you choose the right setting:

  1. When T=1T = 1: The exponent is 11, so pi1/1=pip_i^{1/1} = p_i, which means qi=piq_i = p_i. This preserves the original proportional sampling exactly, with no adjustment to the natural distribution.
  2. When TT \to \infty: The exponent approaches 00. Since any positive number raised to the power of 00 equals 11, all weights approach equality. The result is perfectly uniform sampling across all tasks, regardless of their original sizes.
  3. When T23T \approx 2-3: The distribution is flattened but not made uniform. Rare tasks are upweighted to ensure they appear frequently enough during training for the model to learn from them, while common tasks are still sampled more frequently than rare ones, reflecting their greater natural occurrence. This intermediate regime often provides the best balance in practice.
In[4]:
Code
def temperature_sampling(proportions, temperature):
    """Apply temperature to sampling distribution."""
    adjusted = [p ** (1 / temperature) for p in proportions]
    total = sum(adjusted)
    return [a / total for a in adjusted]


# Compare different temperatures
temperatures = [1.0, 2.0, 5.0, float("inf")]
sampling_distributions = {}

for temp in temperatures:
    if temp == float("inf"):
        # Equal sampling
        sampling_distributions[temp] = [1 / len(tasks)] * len(tasks)
    else:
        sampling_distributions[temp] = temperature_sampling(proportions, temp)
Out[5]:
Visualization
Task sampling probabilities under different temperature regimes ($T=1, 2, 5, \infty$). Higher temperatures progressively flatten the distribution, upweighting minority tasks like Math and Code, until $T=\infty$ produces perfectly equal sampling across all categories.
Task sampling probabilities under different temperature regimes ($T=1, 2, 5, \infty$). Higher temperatures progressively flatten the distribution, upweighting minority tasks like Math and Code, until $T=\infty$ produces perfectly equal sampling across all categories.

Research from instruction tuning papers like FLAN suggests that moderate temperature values (T23T \approx 2-3) often work well in practice. These values prevent minority tasks from being neglected during training, ensuring the model develops at least basic competency across all task types, while still reflecting the natural importance of common tasks by sampling them somewhat more frequently. The exact optimal temperature depends on your specific dataset and use case, but the T=2T = 2 to T=3T = 3 range provides a reasonable starting point for experimentation.

Out[6]:
Visualization
Sampling probabilities for minority tasks (Math, Code, Creative) across different temperature settings. Higher temperatures ($T \geq 2$) significantly increase the probability of sampling these rare tasks compared to the proportional baseline ($T=1$).
Sampling probabilities for minority tasks (Math, Code, Creative) across different temperature settings. Higher temperatures ($T \geq 2$) significantly increase the probability of sampling these rare tasks compared to the proportional baseline ($T=1$).
Out[7]:
Visualization
Boost factors showing the relative upweighting of tasks compared to proportional sampling. Small tasks receive the largest boost factors at higher temperatures, ensuring they are seen frequently enough for learning, while large tasks are downweighted relative to their natural frequency.
Boost factors showing the relative upweighting of tasks compared to proportional sampling. Small tasks receive the largest boost factors at higher temperatures, ensuring they are seen frequently enough for learning, while large tasks are downweighted relative to their natural frequency.

Multi-Epoch Considerations

When training for multiple epochs, you must carefully manage how sampling and repetition interact. With proportional sampling across multiple epochs, the training process remains dominated by common tasks, potentially leading to severe underfitting on rare ones: after three epochs, the model may have seen each QA example three times while still struggling with math because those examples remain sparse in each epoch. Conversely, with equal sampling, rare task examples must be repeated far more frequently than common ones within each epoch to balance the distribution, and across multiple epochs this repetition compounds, dramatically increasing the risk of memorizing specific examples rather than learning general patterns.

To prevent this, you can cap how many times any single example appears during training:

In[8]:
Code
import numpy as np


def create_capped_dataset(task_examples, target_size, max_repeats=3):
    """
    Create balanced dataset with capped repetition.

    Args:
        task_examples: Dict mapping task name to list of examples
        target_size: Desired total dataset size
        max_repeats: Maximum times any example can appear
    """
    balanced_data = []
    examples_per_task = target_size // len(task_examples)

    for task, examples in task_examples.items():
        # How many times to repeat the full dataset
        n_examples = len(examples)
        repeats_needed = examples_per_task / n_examples

        if repeats_needed <= 1:
            # Subsample large datasets
            indices = np.random.choice(
                n_examples, examples_per_task, replace=False
            )
            balanced_data.extend([examples[i] for i in indices])
        elif repeats_needed <= max_repeats:
            # Repeat small datasets
            full_repeats = int(repeats_needed)
            remainder = examples_per_task - (full_repeats * n_examples)

            for _ in range(full_repeats):
                balanced_data.extend(examples)
            balanced_data.extend(
                np.random.choice(examples, remainder, replace=False).tolist()
            )
        else:
            # Cap at max_repeats
            capped_total = n_examples * max_repeats
            for _ in range(max_repeats):
                balanced_data.extend(examples)
            print(
                f"Warning: {task} capped at {capped_total} examples (wanted {examples_per_task})"
            )

    return balanced_data
In[9]:
Code
# Demonstrate with example task sizes
task_examples = {
    "QA": list(range(10000)),
    "Math": list(range(500)),
    "Code": list(range(2000)),
}

balanced = create_capped_dataset(
    task_examples, target_size=15000, max_repeats=3
)
Out[10]:
Console
Created balanced dataset with 11500 examples

The final dataset has 11,500 examples. The 'Math' task was capped at 3 repetitions to prevent overfitting. This balances task representation while avoiding memorization.

Out[11]:
Visualization
Comparison of required versus capped repetitions for three sample tasks. The 'Math' task would require 10 repetitions to achieve equal sampling, but is capped at 3 to prevent overfitting, while 'QA' requires no repetition.
Comparison of required versus capped repetitions for three sample tasks. The 'Math' task would require 10 repetitions to achieve equal sampling, but is capped at 3 to prevent overfitting, while 'QA' requires no repetition.
Out[12]:
Visualization
Final composition of the dataset after applying repetition capping. The resulting distribution balances task representation without allowing any single small task (like Math) to dominate through excessive repetition.
Final composition of the dataset after applying repetition capping. The resulting distribution balances task representation without allowing any single small task (like Math) to dominate through excessive repetition.

Key Parameters

The key parameters for data mixing and sampling are:

  • temperature: Controls the sharpness of the sampling distribution. A higher temperature (e.g., T=2T=2) upweights minority tasks compared to their natural frequency, ensuring they receive more training attention. Lower temperatures preserve more of the original distribution, while very high temperatures approach uniform sampling.
  • max_repeats: The maximum number of times a single example can be repeated in a balanced dataset. This parameter prevents overfitting on small task categories by capping repetition even when equal sampling would require seeing those examples many more times.

Loss Masking

Standard language modeling calculates loss for all tokens. In instruction tuning, this wastes capacity on reproducing instructions instead of learning to respond.

Why Mask the Prompt

Consider an instruction tuning example:

User: Explain photosynthesis in simple terms. Assistant: Photosynthesis is the process plants use to convert sunlight into food...

Without loss masking, the model receives gradient updates for every token, including the instruction "Explain photosynthesis in simple terms." But reproducing user instructions is not the goal. The model should learn to generate helpful responses, not memorize prompts. Training on prompt tokens dilutes the learning signal with information the model already understands from pre-training.

Loss masking addresses this by zeroing out the loss for prompt tokens. Only the assistant's response contributes to parameter updates:

In[13]:
Code
import torch


def compute_masked_loss(logits, labels, mask):
    """
    Compute cross-entropy loss only on unmasked positions.

    Args:
        logits: Model outputs of shape (batch, seq_len, vocab_size)
        labels: Target token IDs of shape (batch, seq_len)
        mask: Binary mask where 1 = compute loss, 0 = ignore
    """
    # Flatten for cross-entropy computation
    batch_size, seq_len, vocab_size = logits.shape

    # Shift for next-token prediction
    shift_logits = logits[:, :-1, :].contiguous()
    shift_labels = labels[:, 1:].contiguous()
    shift_mask = mask[:, 1:].contiguous()

    # Compute per-token loss
    loss_fn = torch.nn.CrossEntropyLoss(reduction="none")
    per_token_loss = loss_fn(
        shift_logits.view(-1, vocab_size), shift_labels.view(-1)
    ).view(batch_size, -1)

    # Apply mask and average over valid tokens
    masked_loss = per_token_loss * shift_mask
    total_loss = masked_loss.sum() / shift_mask.sum()

    return total_loss

The mask indicates which tokens should contribute to the loss. Typically, prompt tokens (system message and user input) receive a mask value of 0, while response tokens receive a mask value of 1.

Creating Loss Masks

Loss masks must align with tokenized sequences. This requires tracking where the prompt ends and the response begins:

In[14]:
Code
def create_loss_mask(input_ids, response_start_idx):
    """
    Create a loss mask that zeros out prompt tokens.

    Args:
        input_ids: Tokenized sequence
        response_start_idx: Index where assistant response begins

    Returns:
        Binary mask tensor
    """
    seq_len = len(input_ids)
    mask = torch.zeros(seq_len)
    mask[response_start_idx:] = 1.0
    return mask


# Example usage
prompt = "User: What is machine learning?\nAssistant:"
response = " Machine learning is a subset of AI..."
full_text = prompt + response

# Simulate tokenization (actual implementation uses a tokenizer)
prompt_tokens = 12  # Number of tokens in prompt
response_tokens = 8  # Number of tokens in response

mask = create_loss_mask(
    input_ids=list(range(prompt_tokens + response_tokens)),
    response_start_idx=prompt_tokens,
)
Out[15]:
Console
Loss mask pattern:
Prompt tokens (masked):   [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
Response tokens (active): [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]

Visualizing Masked Loss

The following visualization shows how loss masking focuses training on the response portion of each example:

Out[16]:
Visualization
Visualization of loss masking in instruction tuning. Gray tokens (prompt) contribute zero loss, while colored tokens (response) contribute to gradient updates. This focuses learning on generating helpful responses rather than reproducing instructions.
Visualization of loss masking in instruction tuning. Gray tokens (prompt) contribute zero loss, while colored tokens (response) contribute to gradient updates. This focuses learning on generating helpful responses rather than reproducing instructions.

Impact of Loss Masking

Loss masking significantly affects what the model learns during instruction tuning. Without masking, gradients flow from both prompt and response tokens, diluting the signal that teaches response generation. With masking, all gradient information comes from response tokens, making training more efficient and focused.

Out[17]:
Visualization
Comparison of gradient contribution with and without loss masking. Without masking, prompt tokens consume a significant portion of the learning signal. With masking, 100% of gradients come from response tokens.
Comparison of gradient contribution with and without loss masking. Without masking, prompt tokens consume a significant portion of the learning signal. With masking, 100% of gradients come from response tokens.

The effect is particularly pronounced in scenarios with long prompts relative to responses. In long-context applications or multi-turn conversations, prompts can comprise 60-70% of tokens. Without masking, most of the learning signal would be spent on reproducing context rather than learning to respond appropriately.

Training Hyperparameters

Instruction tuning requires careful hyperparameter selection. Unlike pre-training, which processes trillions of tokens over many epochs, instruction tuning typically involves much smaller datasets and fewer training steps. This concentrated training amplifies the importance of each hyperparameter choice.

Learning Rate

The learning rate is typically lower than pre-training, usually in the range of 1×1051 \times 10^{-5} to 5×1055 \times 10^{-5}. Higher learning rates risk catastrophic forgetting, where the model loses pre-trained capabilities while learning to follow instructions. Lower learning rates preserve more of the base model's knowledge but may require more training steps.

In[18]:
Code
# Typical learning rate schedules for instruction tuning
hyperparameters = {
    "learning_rate": 2e-5,
    "warmup_ratio": 0.03,  # 3% of training steps
    "lr_scheduler": "cosine",
    "weight_decay": 0.01,
}

A warmup period helps stabilize early training. During warmup, the learning rate gradually increases from near-zero to the target value. This prevents large gradient updates before the model adapts to the new data distribution.

Batch Size and Gradient Accumulation

Larger batch sizes provide more stable gradient estimates but require more memory. Gradient accumulation allows effective large batches on limited hardware:

In[19]:
Code
# Effective batch size = batch_size × gradient_accumulation_steps × num_gpus
training_config = {
    "per_device_batch_size": 4,
    "gradient_accumulation_steps": 8,
    "num_gpus": 4,
    # Effective batch size: 4 × 8 × 4 = 128
}

Effective batch sizes between 64 and 256 are common for instruction tuning. Smaller batches introduce more noise into gradient estimates, which can help generalization but may slow convergence. Larger batches converge faster but may generalize less well.

Number of Epochs

Instruction tuning typically uses 1-3 epochs over the training data. More epochs risk overfitting to the instruction format rather than learning general instruction-following capabilities. Signs of overfitting include:

  • Training loss continues decreasing while validation loss increases
  • Model outputs become formulaic or repetitive
  • Performance degrades on held-out tasks
Out[20]:
Visualization
Typical loss curves during instruction tuning showing the relationship between training and validation loss. Overfitting begins when validation loss starts increasing while training loss continues to decrease.
Typical loss curves during instruction tuning showing the relationship between training and validation loss. Overfitting begins when validation loss starts increasing while training loss continues to decrease.

Early stopping based on validation loss helps prevent overfitting. Save checkpoints regularly and evaluate on held-out instruction examples to identify the optimal stopping point.

Multi-Task Learning Benefits

Instruction tuning naturally supports multi-task learning. By training on diverse instruction types simultaneously, the model learns transferable skills. A model trained on summarization, translation, and question-answering often performs better on each individual task than models trained separately, because the shared representations capture general language understanding.

The key benefits include:

  • Positive transfer: Skills learned on one task improve performance on related tasks
  • Regularization: Task diversity prevents overfitting to any single task pattern
  • Emergent capabilities: Models sometimes gain abilities not explicitly trained, arising from the combination of skills

However, multi-task learning also introduces potential negative transfer when tasks conflict. For example, tasks requiring very different output styles (terse classification labels versus verbose explanations) may interfere with each other. Temperature-based sampling and careful task curation help mitigate these conflicts.

Summary

Instruction tuning training requires balancing multiple competing concerns: ensuring adequate coverage of minority tasks without overfitting, focusing learning on response generation through loss masking, and selecting hyperparameters that preserve pre-trained capabilities while teaching instruction following.

The key decisions covered in this chapter are:

  1. Data mixing strategy: Use temperature-based sampling with T23T \approx 2-3 to balance task representation
  2. Repetition capping: Limit how often small datasets are repeated to prevent memorization
  3. Loss masking: Zero out loss on prompt tokens to focus learning on response generation
  4. Hyperparameter selection: Use lower learning rates (15×1051-5 \times 10^{-5}), moderate batch sizes (64-256), and few epochs (1-3) to avoid catastrophic forgetting

These techniques work together to produce models that follow instructions reliably while maintaining their broader language capabilities. The next chapter examines how to evaluate whether instruction tuning has succeeded.

Quiz

Ready to test your understanding? Take this quick quiz to reinforce what you've learned about instruction tuning training.

Loading component...

Reference

BIBTEXAcademic
@misc{instructiontuningtrainingdatamixinglossmasking, author = {Michael Brenndoerfer}, title = {Instruction Tuning Training: Data Mixing & Loss Masking}, year = {2025}, url = {https://mbrenndoerfer.com/writing/instruction-tuning-training-data-mixing-loss-masking}, organization = {mbrenndoerfer.com}, note = {Accessed: 2025-01-01} }
APAAcademic
Michael Brenndoerfer (2025). Instruction Tuning Training: Data Mixing & Loss Masking. Retrieved from https://mbrenndoerfer.com/writing/instruction-tuning-training-data-mixing-loss-masking
MLAAcademic
Michael Brenndoerfer. "Instruction Tuning Training: Data Mixing & Loss Masking." 2026. Web. today. <https://mbrenndoerfer.com/writing/instruction-tuning-training-data-mixing-loss-masking>.
CHICAGOAcademic
Michael Brenndoerfer. "Instruction Tuning Training: Data Mixing & Loss Masking." Accessed today. https://mbrenndoerfer.com/writing/instruction-tuning-training-data-mixing-loss-masking.
HARVARDAcademic
Michael Brenndoerfer (2025) 'Instruction Tuning Training: Data Mixing & Loss Masking'. Available at: https://mbrenndoerfer.com/writing/instruction-tuning-training-data-mixing-loss-masking (Accessed: today).
SimpleBasic
Michael Brenndoerfer (2025). Instruction Tuning Training: Data Mixing & Loss Masking. https://mbrenndoerfer.com/writing/instruction-tuning-training-data-mixing-loss-masking