Search

Search articles

Multi-Query Attention: Memory-Efficient LLM Inference

Michael BrenndoerferUpdated August 4, 202539 min read

Learn how Multi-Query Attention reduces KV cache memory by sharing keys and values across attention heads, enabling efficient long-context inference.

Track your reading progress

Sign in to mark chapters as read and track your learning journey

Sign in →
Reading Level

Choose your expertise level to adjust how many terms are explained. Beginners see more tooltips, experts see fewer to maintain reading flow. Hover over underlined terms for instant definitions.

Multi-Query Attention

In standard multi-head attention, each head maintains its own query, key, and value projections. This design creates a bottleneck during inference: the KV cache grows linearly with the number of heads, consuming large amounts of GPU memory for long sequences. Multi-Query Attention (MQA) solves this problem by sharing a single key-value pair across all query heads, reducing memory requirements while maintaining most of the model's representational power.

This chapter explores the mechanics of MQA, quantifies its memory benefits, examines the quality trade-offs involved, and compares it with its more moderate cousin, Grouped Query Attention (GQA).

The KV Cache Memory Problem

Before diving into MQA, we need to understand why KV cache memory is such a critical concern for large language model inference. During autoregressive generation, models cache the keys and values computed for previous tokens to avoid redundant computation. This cache grows with every token generated.

For a model with LL layers, hh attention heads, head dimension dkd_k, and sequence length TT, the KV cache size is:

KV Cache Size=2×L×h×T×dk×bytes per element\text{KV Cache Size} = 2 \times L \times h \times T \times d_k \times \text{bytes per element}

where:

  • 22: accounts for storing both keys and values (each requiring the same memory)
  • LL: the number of transformer layers, each maintaining its own KV cache
  • hh: the number of attention heads per layer, each with independent K and V tensors
  • TT: the current sequence length (this grows with each generated token)
  • dkd_k: the dimension per head, determining the size of each key/value vector
  • Bytes per element: depends on numerical precision (4 for FP32, 2 for FP16/BF16)

The formula multiplies these factors because we need to store one key vector and one value vector for every combination of layer, head, and sequence position. The memory grows linearly with sequence length, which becomes problematic for long-context generation.

Let's calculate the KV cache requirements for a concrete example.

In[2]:
Code
# KV cache size calculation for different model scales
def calculate_kv_cache_size(
    num_layers: int,
    num_heads: int,
    head_dim: int,
    seq_length: int,
    batch_size: int = 1,
    bytes_per_element: int = 2,  # FP16/BF16
) -> dict:
    """Calculate KV cache memory requirements in bytes and GB."""
    # Standard MHA: each head has its own K and V
    kv_cache_bytes = (
        2  # K and V
        * num_layers
        * num_heads
        * seq_length
        * head_dim
        * batch_size
        * bytes_per_element
    )

    return {
        "bytes": kv_cache_bytes,
        "megabytes": kv_cache_bytes / (1024**2),
        "gigabytes": kv_cache_bytes / (1024**3),
    }


# Example: LLaMA-2 7B scale model
llama_7b_config = {
    "num_layers": 32,
    "num_heads": 32,
    "head_dim": 128,
    "seq_length": 4096,
}

kv_cache = calculate_kv_cache_size(**llama_7b_config)
Out[3]:
Console
LLaMA-2 7B KV Cache (4K context, batch=1):
  Size: 2.00 GB
  Per-token overhead: 512.00 KB
  At 8,192 tokens: 4.00 GB
  At 16,384 tokens: 8.00 GB
  At 32,768 tokens: 16.00 GB
  At 131,072 tokens: 64.00 GB

For a 7B parameter model, the KV cache alone consumes over 1 GB for a 4K context window. At 128K tokens (common in modern models), this balloons to over 32 GB, which exceeds the memory of most consumer GPUs. The problem compounds further with larger batch sizes for throughput optimization.

The situation becomes even more acute for larger models. Let's compare across scales.

In[4]:
Code
# Model configurations showing query heads (for KV cache comparison assuming MHA)
# Note: LLaMA-2 70B actually uses GQA with 8 KV heads, not full MHA
models = {
    "LLaMA-2 7B": {"layers": 32, "heads": 32, "head_dim": 128},
    "LLaMA-2 13B": {"layers": 40, "heads": 40, "head_dim": 128},
    "LLaMA-2 70B": {"layers": 80, "heads": 64, "head_dim": 128},
    "GPT-3 175B": {"layers": 96, "heads": 96, "head_dim": 128},
}

seq_lengths = [4096, 32768, 131072]
Out[5]:
Visualization
Bar chart showing KV cache size in GB for different model scales at varying context lengths.
KV cache memory requirements grow rapidly with model size and context length. A 70B model serving 128K context requires over 160 GB of memory just for the cache, exceeding the capacity of even enterprise GPUs like the A100-80GB.

This memory pressure is the driving motivation behind MQA. When the cache alone exceeds available GPU memory, you can't run inference at all, regardless of how fast your hardware is.

Multi-Query Attention: Extreme Sharing

Multi-Query Attention, introduced by Noam Shazeer in 2019, addresses the KV cache problem through aggressive parameter sharing. The core insight is simple: instead of having hh separate key and value projections (one per head), use just one shared key and one shared value projection across all heads.

Multi-Query Attention (MQA)

Multi-Query Attention maintains multiple query heads but shares a single key head and a single value head across all of them. Each query head computes attention using the same keys and values, reducing the KV cache size by a factor of hh (the number of heads).

The MQA Formulation

To truly understand Multi-Query Attention, we need to first build intuition about what multi-head attention actually does and why each component exists. Think of attention as a retrieval system: queries ask questions, keys advertise what information is available, and values contain the actual content to retrieve. Multi-head attention runs multiple such retrieval systems in parallel, each learning to find different kinds of relationships.

Step 1: The Standard Multi-Head Attention Baseline

In standard multi-head attention, each head ii transforms the input through its own learned projection matrices. Given an input sequence X\mathbf{X}, we compute head-specific queries, keys, and values:

Qi=XWiQ,Ki=XWiK,Vi=XWiV\mathbf{Q}_i = \mathbf{X} \mathbf{W}^Q_i, \quad \mathbf{K}_i = \mathbf{X} \mathbf{W}^K_i, \quad \mathbf{V}_i = \mathbf{X} \mathbf{W}^V_i

where:

  • XRn×dmodel\mathbf{X} \in \mathbb{R}^{n \times d_{\text{model}}}: the input sequence with nn tokens, each represented as a dmodeld_{\text{model}}-dimensional vector
  • WiQRdmodel×dk\mathbf{W}^Q_i \in \mathbb{R}^{d_{\text{model}} \times d_k}: the query projection matrix for head ii
  • WiKRdmodel×dk\mathbf{W}^K_i \in \mathbb{R}^{d_{\text{model}} \times d_k}: the key projection matrix for head ii
  • WiVRdmodel×dv\mathbf{W}^V_i \in \mathbb{R}^{d_{\text{model}} \times d_v}: the value projection matrix for head ii
  • dkd_k: the dimension of queries and keys per head (typically dmodel/hd_{\text{model}} / h)
  • dvd_v: the dimension of values per head (often equal to dkd_k)

The subscript ii on each projection matrix is the key detail here. It tells us that every head learns its own transformation. Head 1 might learn to project tokens in a way that emphasizes syntactic features, while head 2 might emphasize semantic similarity. This specialization is what makes multi-head attention so powerful: the model can simultaneously attend to different types of relationships.

But this flexibility comes at a cost. During inference, we must cache the keys and values for each head separately. With 32 heads, we store 32 different Ki\mathbf{K}_i matrices and 32 different Vi\mathbf{V}_i matrices. This is where the memory bottleneck arises.

Step 2: The MQA Insight: Sharing Keys and Values

MQA asks a provocative question: do we really need different keys and values for each head? The hypothesis is that while different heads benefit from asking different questions (different queries), they might be able to share a common "index" (keys) and "content store" (values).

Mathematically, MQA removes the subscript from the key and value projections:

Qi=XWiQ,K=XWK,V=XWV\mathbf{Q}_i = \mathbf{X} \mathbf{W}^Q_i, \quad \mathbf{K} = \mathbf{X} \mathbf{W}^K, \quad \mathbf{V} = \mathbf{X} \mathbf{W}^V

Notice the key change:

  • WiQ\mathbf{W}^Q_i: still has subscript ii, meaning each head retains its own query projection
  • WK\mathbf{W}^K and WV\mathbf{W}^V: no subscript, meaning a single projection shared across all heads

The resulting matrices have dimensions:

  • QiRn×dk\mathbf{Q}_i \in \mathbb{R}^{n \times d_k}: unique query matrix for head ii
  • KRn×dk\mathbf{K} \in \mathbb{R}^{n \times d_k}: one shared key matrix for all heads
  • VRn×dv\mathbf{V} \in \mathbb{R}^{n \times d_v}: one shared value matrix for all heads

This is a simple but effective change. Different heads can still "ask different questions" because each projects the input into its own query space. But they all look up answers from the same key-value store. It's like having 32 researchers who each have their own research questions but share access to a single library catalog and book collection.

Step 3: Computing Attention with Shared KV

Each head now computes attention using its unique queries against the shared keys and values:

headi=softmax(QiKTdk)V\text{head}_i = \text{softmax}\left(\frac{\mathbf{Q}_i \mathbf{K}^T}{\sqrt{d_k}}\right) \mathbf{V}

Let's trace through what happens in this formula:

  1. Compute attention scores: QiKT\mathbf{Q}_i \mathbf{K}^T produces an n×nn \times n matrix where entry (p,q)(p, q) measures how much token at position pp should attend to position qq. Because Qi\mathbf{Q}_i is unique to head ii, each head produces different attention patterns.

  2. Scale the scores: Dividing by dk\sqrt{d_k} prevents the dot products from growing too large as dimension increases. Without scaling, the softmax would produce very peaked distributions, making gradients vanish during training.

  3. Normalize to weights: The softmax converts each row of scores into a probability distribution that sums to 1. These are the attention weights.

  4. Aggregate values: Multiplying by V\mathbf{V} computes a weighted sum of value vectors. All heads aggregate from the same V\mathbf{V}, but with different weights.

The key insight is that expressiveness comes from two sources: the diversity of attention patterns (which heads preserve through unique queries) and the diversity of retrieved content (which MQA sacrifices through shared values). Empirically, the diversity of attention patterns matters more than the diversity of values for most tasks.

To illustrate this, let's visualize how different query heads can produce distinct attention patterns even when sharing the same keys and values.

Out[6]:
Visualization
Head 1: Local focus pattern, attending primarily to recent tokens.
Head 1: Local focus pattern, attending primarily to recent tokens.
Head 2: Start token pattern, strongly attending to the first position.
Head 2: Start token pattern, strongly attending to the first position.
Head 3: Uniform pattern, equal attention across all previous tokens.
Head 3: Uniform pattern, equal attention across all previous tokens.
Head 4: Alternating pattern, preferring even-indexed positions.
Head 4: Alternating pattern, preferring even-indexed positions.

Each head learns a different way to "ask questions" of the shared key-value store. Head 1 might focus on recent context for local coherence, Head 2 might anchor to the beginning of the sequence, Head 3 might aggregate broadly, and Head 4 might pick up periodic patterns. The shared keys and values don't limit this diversity; they simply mean all heads retrieve from the same "library."

Step 4: Combining Head Outputs

Finally, we concatenate all head outputs and project back to the model dimension:

MQA(X)=Concat(head1,,headh)WO\text{MQA}(\mathbf{X}) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) \mathbf{W}^O

where:

  • headiRn×dv\text{head}_i \in \mathbb{R}^{n \times d_v}: the output from attention head ii, containing each head's weighted aggregation
  • Concat()Rn×(hdv)\text{Concat}(\cdot) \in \mathbb{R}^{n \times (h \cdot d_v)}: concatenation along the feature dimension
  • WOR(hdv)×dmodel\mathbf{W}^O \in \mathbb{R}^{(h \cdot d_v) \times d_{\text{model}}}: the output projection that mixes information across heads
  • hh: the total number of attention heads

This final projection is unchanged from standard MHA. Each head contributes its perspective, and the output projection learns to combine them into a coherent representation. The key benefit of MQA is that we achieve this with much less memory overhead during inference.

Visualizing the Difference

The architectural difference between MHA and MQA becomes clear when we visualize the projection structure.

Out[7]:
Visualization
Multi-Head Attention (MHA): Each head maintains its own Q, K, V projections, requiring h copies in the KV cache.
Multi-Head Attention (MHA): Each head maintains its own Q, K, V projections, requiring h copies in the KV cache.
Multi-Query Attention (MQA): All heads share a single K and V projection, reducing cache size by a factor of h.
Multi-Query Attention (MQA): All heads share a single K and V projection, reducing cache size by a factor of h.

Implementing Multi-Query Attention

Now that we understand the mathematical formulation, let's translate it into working code. The implementation will reveal exactly how the parameter sharing manifests in practice and why it leads to such significant memory savings.

We'll build MQA from scratch, examining each component to understand how it differs from standard multi-head attention. By the end, you'll see that the core change is just a matter of adjusting projection dimensions.

The Core MQA Module

The main difference between MHA and MQA lies in the output dimensions of the projection layers. In MHA, the key and value projections output num_heads * head_dim dimensions (one set of K/V vectors per head). In MQA, they output just head_dim (a single set of K/V vectors shared across all heads).

Let's implement this step by step, with comments highlighting the critical differences.

In[8]:
Code
import torch
import torch.nn as nn
import torch.nn.functional as F


class MultiQueryAttention(nn.Module):
    """
    Multi-Query Attention: all heads share a single K and V projection.

    This reduces KV cache size by a factor of num_heads while maintaining
    expressive power through multiple query heads.
    """

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        head_dim: int | None = None,
        dropout: float = 0.0,
    ):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = head_dim or (hidden_size // num_heads)
        self.dropout = dropout

        # Each head has its own query projection
        # Total query params: hidden_size -> num_heads * head_dim
        self.q_proj = nn.Linear(
            hidden_size, num_heads * self.head_dim, bias=False
        )

        # Single shared key projection: hidden_size -> head_dim
        self.k_proj = nn.Linear(hidden_size, self.head_dim, bias=False)

        # Single shared value projection: hidden_size -> head_dim
        self.v_proj = nn.Linear(hidden_size, self.head_dim, bias=False)

        # Output projection
        self.o_proj = nn.Linear(
            num_heads * self.head_dim, hidden_size, bias=False
        )

        self.scale = self.head_dim**-0.5

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor | None = None,
        past_key_value: tuple[torch.Tensor, torch.Tensor] | None = None,
        use_cache: bool = False,
    ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
        batch_size, seq_len, _ = hidden_states.shape

        # Project queries: (batch, seq, num_heads * head_dim)
        q = self.q_proj(hidden_states)

        # Project keys and values: (batch, seq, head_dim)
        k = self.k_proj(hidden_states)
        v = self.v_proj(hidden_states)

        # Reshape queries for multi-head: (batch, num_heads, seq, head_dim)
        q = q.view(batch_size, seq_len, self.num_heads, self.head_dim)
        q = q.transpose(1, 2)

        # Keys and values stay as single head: (batch, 1, seq, head_dim)
        k = k.unsqueeze(1)
        v = v.unsqueeze(1)

        # Handle KV cache
        if past_key_value is not None:
            past_k, past_v = past_key_value
            k = torch.cat([past_k, k], dim=2)
            v = torch.cat([past_v, v], dim=2)

        present_key_value = (k, v) if use_cache else None

        # Compute attention scores: (batch, num_heads, seq_q, seq_kv)
        # Broadcasting: Q is (batch, heads, seq, dim), K is (batch, 1, seq, dim)
        # The single K is broadcast across all heads
        attn_weights = torch.matmul(q, k.transpose(-2, -1)) * self.scale

        # Apply causal mask if needed
        if attention_mask is not None:
            attn_weights = attn_weights + attention_mask

        # Softmax and dropout
        attn_weights = F.softmax(attn_weights, dim=-1)
        attn_weights = F.dropout(
            attn_weights, p=self.dropout, training=self.training
        )

        # Apply attention to values
        # attn_weights: (batch, heads, seq_q, seq_kv)
        # v: (batch, 1, seq_kv, dim) -> broadcasts to all heads
        attn_output = torch.matmul(attn_weights, v)

        # Reshape and project output
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch_size, seq_len, -1)
        output = self.o_proj(attn_output)

        return output, present_key_value

Several implementation details deserve attention:

Projection dimensions: The query projection outputs num_heads * head_dim dimensions because each head needs its own query vectors. But the key and value projections output only head_dim, producing a single set of K/V vectors. This is the source of our memory savings.

Broadcasting: When we compute torch.matmul(q, k.transpose(-2, -1)), PyTorch's broadcasting rules handle the dimension mismatch. The query tensor has shape (batch, num_heads, seq, head_dim) while the key tensor has shape (batch, 1, seq, head_dim). Broadcasting automatically replicates the single key across all heads, computing the correct attention scores without physically copying the data.

KV cache handling: The cache stores only the shared K and V tensors, not per-head copies. When we concatenate past keys with new keys, we're working with tensors of shape (batch, 1, seq, head_dim) rather than (batch, num_heads, seq, head_dim). This is where the memory reduction materializes during autoregressive generation.

Comparing Parameter Counts

To make the savings concrete, let's count parameters in MHA versus MQA for a realistic configuration.

In[9]:
Code
def count_parameters(module: nn.Module) -> int:
    """Count trainable parameters in a module."""
    return sum(p.numel() for p in module.parameters() if p.requires_grad)


class StandardMultiHeadAttention(nn.Module):
    """Standard MHA for comparison."""

    def __init__(
        self, hidden_size: int, num_heads: int, head_dim: int | None = None
    ):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = head_dim or (hidden_size // num_heads)
        total_dim = num_heads * self.head_dim

        # All three projections are full-sized
        self.q_proj = nn.Linear(hidden_size, total_dim, bias=False)
        self.k_proj = nn.Linear(hidden_size, total_dim, bias=False)
        self.v_proj = nn.Linear(hidden_size, total_dim, bias=False)
        self.o_proj = nn.Linear(total_dim, hidden_size, bias=False)


# Create instances with typical LLaMA-like config
hidden_size = 4096
num_heads = 32
head_dim = 128

mha = StandardMultiHeadAttention(hidden_size, num_heads, head_dim)
mqa = MultiQueryAttention(hidden_size, num_heads, head_dim)
Out[10]:
Console
Configuration: hidden_size=4096, num_heads=32, head_dim=128

Parameter counts:
  Standard MHA: 67,108,864 parameters
  Multi-Query:  34,603,008 parameters
  Reduction:    48.4%

Projection breakdown (MHA):
  Q: 4096 × 4096 = 16,777,216
  K: 4096 × 4096 = 16,777,216
  V: 4096 × 4096 = 16,777,216

Projection breakdown (MQA):
  Q: 4096 × 4096 = 16,777,216
  K: 4096 × 128 = 524,288
  V: 4096 × 128 = 524,288

The numbers reveal the trade-off clearly. MQA reduces the K and V projection parameters by exactly num_heads (32x in this case). However, the total attention layer parameter reduction is more modest, around 50%, because the query and output projections remain unchanged at full size.

This distinction matters: the parameter savings affect model size and training memory, but the real win comes during inference. The KV cache stores the outputs of the K and V projections, not the projection matrices themselves. Since each forward pass produces only one K vector and one V vector per token (instead of 32), the cache reduction matches the full 32x factor. This is why MQA changes the economics of inference even though its impact on total parameter count is more modest.

MQA Memory Benefits

MQA's main benefit appears during inference. The KV cache reduction is proportional to the number of heads, which can be substantial for modern large models.

Calculating KV Cache Savings

In[11]:
Code
def calculate_kv_cache_mqa(
    num_layers: int,
    num_heads: int,  # Only matters for MHA
    head_dim: int,
    seq_length: int,
    batch_size: int = 1,
    bytes_per_element: int = 2,
    use_mqa: bool = False,
) -> dict:
    """Calculate KV cache for MHA vs MQA."""

    if use_mqa:
        # MQA: single K and V regardless of num_heads
        num_kv_heads = 1
    else:
        # MHA: K and V for each head
        num_kv_heads = num_heads

    kv_cache_bytes = (
        2  # K and V
        * num_layers
        * num_kv_heads
        * seq_length
        * head_dim
        * batch_size
        * bytes_per_element
    )

    return {
        "bytes": kv_cache_bytes,
        "megabytes": kv_cache_bytes / (1024**2),
        "gigabytes": kv_cache_bytes / (1024**3),
        "num_kv_heads": num_kv_heads,
    }


# Compare for various configurations
configs = [
    {"name": "7B-scale", "layers": 32, "heads": 32, "head_dim": 128},
    {"name": "13B-scale", "layers": 40, "heads": 40, "head_dim": 128},
    {"name": "70B-scale", "layers": 80, "heads": 64, "head_dim": 128},
]
Out[12]:
Visualization
Grouped bar chart comparing KV cache size between MHA and MQA across different model scales.
Multi-Query Attention reduces KV cache size by a factor equal to the number of heads. For a 70B model with 64 heads, this means a 64x reduction in cache memory, making long context generation feasible on consumer hardware.
Out[13]:
Console
KV Cache Comparison (32K context):

Model        MHA Cache    MQA Cache    Reduction    Heads   
--------------------------------------------------------
7B-scale     16.00 GB      0.500 GB     32x           32
13B-scale    25.00 GB      0.625 GB     40x           40
70B-scale    80.00 GB      1.250 GB     64x           64

The table shows that the cache reduction factor exactly matches the number of attention heads. For the 7B-scale model with 32 heads, MQA reduces the cache from 8 GB to 0.25 GB. The 70B-scale model with 64 heads drops from 40 GB to just 0.625 GB. These reductions determine whether inference is feasible on given hardware.

The cache reduction factor equals the number of attention heads. A 7B model with 32 heads sees its cache shrink from 8 GB to 0.25 GB, while a 70B model with 64 heads drops from 40 GB to under 1 GB. This is the difference between requiring a cluster of GPUs and running inference on a single consumer card.

The scaling behavior becomes even clearer when we plot cache growth as a function of context length.

Out[14]:
Visualization
Line plot showing KV cache size in GB versus context length for MHA and MQA, with horizontal lines marking GPU memory limits.
KV cache memory grows linearly with context length. For a 7B model, MHA (red) crosses the 24GB consumer GPU limit at around 90K tokens, while MQA (green) stays under 1GB even at 128K tokens. The 32x gap between the lines represents the memory savings from sharing keys and values.

The linear relationship between context length and cache size explains why long-context models are so memory-hungry. At 128K tokens, the MHA cache alone requires 32 GB, which exceeds the memory of most consumer GPUs before we even load the model weights. MQA brings this down to 1 GB, making long-context inference practical on a single GPU.

Throughput Implications

The memory savings translate directly into throughput improvements. With less memory consumed by the KV cache, you can:

  1. Serve longer contexts: Fit 128K tokens instead of 8K on the same hardware
  2. Increase batch sizes: Process more requests in parallel
  3. Reduce hardware costs: Serve the same load with fewer GPUs
In[15]:
Code
def estimate_max_batch_size(
    gpu_memory_gb: float,
    model_memory_gb: float,
    num_layers: int,
    num_heads: int,
    head_dim: int,
    seq_length: int,
    use_mqa: bool,
) -> int:
    """Estimate maximum batch size given GPU memory constraints."""
    available_for_cache = gpu_memory_gb - model_memory_gb

    cache_per_sequence = calculate_kv_cache_mqa(
        num_layers=num_layers,
        num_heads=num_heads,
        head_dim=head_dim,
        seq_length=seq_length,
        batch_size=1,
        use_mqa=use_mqa,
    )["gigabytes"]

    max_batch = int(available_for_cache / cache_per_sequence)
    return max(1, max_batch)


# Example: 7B model on A100-80GB
model_config = {"num_layers": 32, "num_heads": 32, "head_dim": 128}
model_memory = 14  # FP16 7B model ~14GB
gpu_memory = 80
Out[16]:
Console
Maximum batch sizes on A100-80GB (7B model, 66GB for cache):

Context Length   MHA Batch    MQA Batch    Throughput Gain 
--------------------------------------------------------
4,096 tokens   33           1056         32.0x
8,192 tokens   16           528          33.0x
16,384 tokens   8            264          33.0x
32,768 tokens   4            132          33.0x
65,536 tokens   2            66           33.0x

At 64K context, MHA can only serve 1 request at a time, while MQA can batch 32+ requests. This translates to significant throughput improvements in production settings.

The following heatmap visualizes the feasible operating space for both architectures. Each cell shows whether a given batch size and context length combination fits within GPU memory.

Out[17]:
Visualization
MHA feasibility: limited combinations fit in memory, especially at longer contexts.
MHA feasibility: limited combinations fit in memory, especially at longer contexts.
MQA feasibility: most batch-context combinations fit, enabling high throughput.
MQA feasibility: most batch-context combinations fit, enabling high throughput.

The difference is clear. With MHA, the A100-80GB can only handle batch size 1 beyond 32K context, and batch size 4 is already impossible at 16K. With MQA, you can run batch size 64 at 32K context, or batch size 16 even at 128K. This expanded operating space translates directly to higher throughput and better hardware utilization.

Quality Trade-offs

The extreme parameter sharing in MQA raises an important question: what do we lose by sharing keys and values across all heads? The answer is nuanced and depends on the task.

What Each Head Learns

In standard MHA, each head can specialize its key and value representations. Research has shown that different heads often capture distinct linguistic phenomena:

  • Syntactic heads: Attend to grammatical structure (subject-verb, modifier-noun)
  • Semantic heads: Capture meaning relationships (synonymy, entailment)
  • Positional heads: Focus on local context or specific positions
  • Rare pattern heads: Specialize in infrequent but important patterns

When we share K and V, all heads must use the same "retrieval index" (keys) and the same "information content" (values). They can still differ in what they query for (different Q projections), but they're constrained to retrieve from the same representation.

Empirical Quality Results

The original MQA paper and subsequent studies found modest quality degradation. Let's examine representative benchmark results across translation, summarization, and language modeling tasks.

Out[19]:
Visualization
Bar chart showing performance comparison between MHA and MQA across different NLP benchmarks.
MQA typically shows modest quality degradation of 1-3% compared to MHA across various benchmarks. The trade-off is often acceptable given the memory and throughput improvements.

The quality degradation typically falls in the 1-3% range. For many applications, especially inference-heavy deployments where throughput matters more than the last percentage point of accuracy, this is an excellent trade-off.

When MQA Hurts More

Some tasks are more sensitive to the reduced expressiveness:

  • Long-range reasoning: Tasks requiring diverse attention patterns over long contexts
  • Multi-task learning: Models expected to handle many different tasks
  • Fine-grained linguistic tasks: Parsing, detailed entity typing, relation extraction

For these applications, Grouped Query Attention (GQA) offers a middle ground, sharing K/V among groups of heads rather than all heads.

MQA vs GQA

Grouped Query Attention interpolates between MHA and MQA by grouping heads and sharing K/V within each group.

In[20]:
Code
def attention_comparison(num_heads: int) -> dict:
    """Compare MHA, MQA, and various GQA configurations."""
    configs = {
        "MHA": {"kv_heads": num_heads, "sharing_factor": 1},
        "GQA-8": {"kv_heads": num_heads // 4, "sharing_factor": 4},
        "GQA-4": {"kv_heads": num_heads // 8, "sharing_factor": 8},
        "GQA-2": {"kv_heads": 2, "sharing_factor": num_heads // 2},
        "MQA": {"kv_heads": 1, "sharing_factor": num_heads},
    }
    return configs


num_heads = 32
configs = attention_comparison(num_heads)
Out[21]:
Visualization
Diagram showing the spectrum of attention variants from MHA to MQA with GQA options in between.
The attention mechanism spectrum from full MHA to extreme MQA. GQA offers intermediate points that balance memory efficiency with model expressiveness. Most modern LLMs use GQA with 2-8 KV heads.
Out[22]:
Console
Attention Variant Comparison (32 query heads):

Variant      KV Heads     Sharing Factor   Cache Reduction 
--------------------------------------------------------
MHA          32           1                1x              
GQA-8        8            4                4x              
GQA-4        4            8                8x              
GQA-2        2            16               16x             
MQA          1            32               32x             

The spectrum shows a smooth trade-off between expressiveness and efficiency. GQA-8 with 8 KV heads provides a 4x cache reduction while maintaining most of the representational power of full MHA. GQA-4 doubles the savings to 8x. MQA represents the extreme endpoint with maximum savings but minimum KV diversity.

Why GQA Often Wins

In practice, GQA with 4-8 KV heads often provides the best trade-off:

  1. Near-MQA memory efficiency: 4-8x cache reduction covers most memory constraints
  2. Near-MHA quality: Multiple KV representations preserve expressiveness for complex tasks
  3. Flexible tuning: Choose the balance point appropriate for your deployment

LLaMA-2 70B uses GQA with 8 KV heads (sharing among 8 query heads each), achieving excellent quality while maintaining reasonable inference costs. Mistral and many other modern models follow similar patterns.

Practical Implementation Considerations

Converting MHA to MQA

If you have a pre-trained MHA model and want MQA's efficiency gains, you can convert it through careful weight averaging:

In[23]:
Code
def convert_mha_to_mqa(
    mha_module: nn.Module, num_heads: int, head_dim: int
) -> MultiQueryAttention:
    """
    Convert a pre-trained MHA module to MQA by averaging K and V projections.

    This is a form of knowledge distillation that preserves average behavior
    while reducing parameters.
    """
    hidden_size = mha_module.q_proj.in_features

    # Create MQA module
    mqa = MultiQueryAttention(hidden_size, num_heads, head_dim)

    # Copy Q projection directly
    mqa.q_proj.weight.data = mha_module.q_proj.weight.data.clone()

    # Average K projections across heads
    k_weights = mha_module.k_proj.weight.data
    k_weights = k_weights.view(num_heads, head_dim, hidden_size)
    mqa.k_proj.weight.data = k_weights.mean(dim=0)

    # Average V projections across heads
    v_weights = mha_module.v_proj.weight.data
    v_weights = v_weights.view(num_heads, head_dim, hidden_size)
    mqa.v_proj.weight.data = v_weights.mean(dim=0)

    # Copy output projection
    mqa.o_proj.weight.data = mha_module.o_proj.weight.data.clone()

    return mqa


# Demonstrate the conversion
mha_layer = StandardMultiHeadAttention(
    hidden_size=4096, num_heads=32, head_dim=128
)
# Add output projection for completeness
mha_layer.o_proj = nn.Linear(32 * 128, 4096, bias=False)
Out[24]:
Console
Conversion complete!
Original MHA K projection: torch.Size([4096, 4096])
Converted MQA K projection: torch.Size([128, 4096])
Parameter reduction in K: 32x

The conversion uses averaging, which works reasonably well but may require fine-tuning to recover any lost quality. Research suggests that a small amount of fine-tuning (1-5% of original training) can recover most of the original model's performance.

Uptraining for Better Quality

Rather than training from scratch with MQA, many practitioners start with MHA pre-training and then "uptrain" to MQA or GQA:

  1. Pre-train with MHA: Full expressiveness during the learning phase
  2. Convert to MQA/GQA: Average or group the KV projections
  3. Fine-tune briefly: Recover quality with the new architecture
  4. Deploy efficiently: Enjoy the memory benefits during inference

This approach often yields better results than training MQA from scratch, as the model benefits from the full representational capacity during the critical pre-training phase.

Summary

Multi-Query Attention addresses a key bottleneck in large language model deployment: KV cache memory consumption. By sharing a single key and value projection across all attention heads, MQA achieves substantial memory reductions:

  • Memory reduction: KV cache shrinks by a factor of hh (the number of heads), often 32-64x for modern models
  • Throughput gains: Smaller caches enable larger batch sizes and longer contexts on the same hardware
  • Modest quality loss: Typically 1-3% degradation on standard benchmarks
  • Parameter savings: K and V projection parameters reduce by a factor of hh

The trade-off between efficiency and expressiveness has led to the popular Grouped Query Attention (GQA), which shares K/V among groups of heads rather than all heads. Most modern LLMs, including LLaMA-2, Mistral, and others, use GQA with 4-8 KV heads as a practical middle ground.

For inference-heavy deployments where throughput matters more than the last percentage point of accuracy, MQA and GQA are essential techniques. They transform the economics of serving large language models, making long-context generation feasible on hardware that would otherwise be memory-bound.

Key Parameters

When implementing or configuring MQA and GQA, these parameters have the greatest impact on the memory-quality trade-off:

  • num_heads (int): The number of query heads. This determines model expressiveness and directly affects the cache reduction factor when using MQA. Typical values range from 32 (7B models) to 128 (largest models).
  • num_kv_heads (int): The number of key-value heads. Set to 1 for MQA, equal to num_heads for standard MHA, or an intermediate value for GQA. Common GQA configurations use 4-8 KV heads regardless of query head count.
  • head_dim (int): The dimension of each attention head. Standard values are 64 or 128. Larger head dimensions increase per-token cache size but may improve attention quality. Most modern LLMs use 128.
  • hidden_size (int): The model's hidden dimension, typically equal to num_heads * head_dim. This determines the size of projection matrices and overall model capacity.

When choosing between MHA, GQA, and MQA, consider your deployment constraints. If memory is severely limited or batch sizes must be large, MQA provides maximum savings. If quality is the priority and memory is less constrained, GQA with 4-8 KV heads offers a balanced approach that most production systems favor.

Quiz

Ready to test your understanding? Take this quick quiz to reinforce what you've learned about Multi-Query Attention and its memory efficiency trade-offs.

Loading component...
Track your reading progress

Sign in to mark chapters as read and track your learning journey

Sign in →

Comments

Reference

BIBTEXAcademic
@misc{multiqueryattentionmemoryefficientllminference, author = {Michael Brenndoerfer}, title = {Multi-Query Attention: Memory-Efficient LLM Inference}, year = {2025}, url = {https://mbrenndoerfer.com/writing/multi-query-attention-memory-efficient-inference}, organization = {mbrenndoerfer.com}, note = {Accessed: 2025-12-19} }
APAAcademic
Michael Brenndoerfer (2025). Multi-Query Attention: Memory-Efficient LLM Inference. Retrieved from https://mbrenndoerfer.com/writing/multi-query-attention-memory-efficient-inference
MLAAcademic
Michael Brenndoerfer. "Multi-Query Attention: Memory-Efficient LLM Inference." 2025. Web. 12/19/2025. <https://mbrenndoerfer.com/writing/multi-query-attention-memory-efficient-inference>.
CHICAGOAcademic
Michael Brenndoerfer. "Multi-Query Attention: Memory-Efficient LLM Inference." Accessed 12/19/2025. https://mbrenndoerfer.com/writing/multi-query-attention-memory-efficient-inference.
HARVARDAcademic
Michael Brenndoerfer (2025) 'Multi-Query Attention: Memory-Efficient LLM Inference'. Available at: https://mbrenndoerfer.com/writing/multi-query-attention-memory-efficient-inference (Accessed: 12/19/2025).
SimpleBasic
Michael Brenndoerfer (2025). Multi-Query Attention: Memory-Efficient LLM Inference. https://mbrenndoerfer.com/writing/multi-query-attention-memory-efficient-inference
Michael Brenndoerfer

About the author: Michael Brenndoerfer

All opinions expressed here are my own and do not reflect the views of my employer.

Michael currently works as an Associate Director of Data Science at EQT Partners in Singapore, leading AI and data initiatives across private capital investments.

With over a decade of experience spanning private equity, management consulting, and software engineering, he specializes in building and scaling analytics capabilities from the ground up. He has published research in leading AI conferences and holds expertise in machine learning, natural language processing, and value creation through data.

Stay updated

Get notified when I publish new articles on data and AI, private equity, technology, and more.

No spam, unsubscribe anytime.

or

Create a free account to unlock exclusive features, track your progress, and join the conversation.

No popupsUnobstructed readingCommenting100% Free