Search

Search articles

Grouped Query Attention: Memory-Efficient LLM Inference

Michael BrenndoerferUpdated August 6, 202539 min read

Master GQA, the attention mechanism behind LLaMA 2 and Mistral. Learn KV head sharing, memory savings, implementation, and quality tradeoffs.

Track your reading progress

Sign in to mark chapters as read and track your learning journey

Sign in →
Reading Level

Choose your expertise level to adjust how many terms are explained. Beginners see more tooltips, experts see fewer to maintain reading flow. Hover over underlined terms for instant definitions.

Grouped Query Attention

Multi-Query Attention (MQA) reduces the key-value cache to a single head, slashing memory by a factor equal to the number of attention heads. This aggressive sharing works well for many applications, but it pushes model capacity to its limit. For tasks requiring nuanced multi-step reasoning or tracking multiple independent information streams, that single key-value representation can become a bottleneck. Grouped Query Attention (GQA) offers a middle path: share key-value heads among groups of query heads, trading off between MQA's extreme memory efficiency and Multi-Head Attention's full representational capacity.

This chapter explores GQA in depth. You'll understand why intermediate sharing often outperforms the extremes, learn the mathematical formulation that makes grouped sharing work, implement GQA from scratch, and analyze the quality-efficiency tradeoffs that have made it the default choice for models like LLaMA 2 and Mistral.

The Spectrum of Key-Value Sharing

Before diving into GQA, let's place it within the broader landscape of attention mechanisms. The key insight is that key-value sharing exists on a spectrum, with Multi-Head Attention (MHA) and Multi-Query Attention (MQA) as the two extremes.

Grouped Query Attention (GQA)

Grouped Query Attention divides query heads into groups, with each group sharing a single set of key-value heads. This provides a tunable trade-off between MHA's full capacity and MQA's memory efficiency, controlled by the number of key-value groups.

In standard MHA, each of the hh query heads has its own dedicated key and value projections. This maximizes representational capacity but requires storing hh key-value pairs in the cache during inference. In MQA, all query heads share a single key-value pair, reducing cache size by a factor of hh but potentially limiting what information the model can attend to in parallel.

GQA sits between these extremes. With hh query heads and gg key-value groups (where 1<g<h1 < g < h), each group of h/gh/g query heads shares one key-value pair. The memory reduction factor is h/gh/g, which can be tuned to balance efficiency and quality.

In[2]:
Code
def attention_variants_comparison(
    num_query_heads: int,
    num_layers: int,
    head_dim: int,
    seq_len: int,
    batch_size: int,
):
    """Compare KV cache sizes for different attention variants."""
    bytes_per_element = 2  # fp16

    results = []

    # Test various numbers of KV heads
    kv_head_options = [num_query_heads, 8, 4, 2, 1]  # MHA, GQA variants, MQA

    for num_kv_heads in kv_head_options:
        # KV cache: 2 (K and V) × layers × kv_heads × head_dim × seq × batch × bytes
        cache_size = (
            2
            * num_layers
            * num_kv_heads
            * head_dim
            * seq_len
            * batch_size
            * bytes_per_element
        )

        # Determine variant name
        if num_kv_heads == num_query_heads:
            name = "MHA"
        elif num_kv_heads == 1:
            name = "MQA"
        else:
            name = f"GQA-{num_kv_heads}"

        results.append(
            {
                "name": name,
                "kv_heads": num_kv_heads,
                "cache_gb": cache_size / 1e9,
                "reduction": num_query_heads / num_kv_heads,
            }
        )

    return results


# LLaMA 2 70B-like configuration
results = attention_variants_comparison(
    num_query_heads=64, num_layers=80, head_dim=128, seq_len=4096, batch_size=32
)
Out[3]:
Console
KV Cache Comparison (LLaMA 2 70B config, batch=32, seq=4096)
============================================================

Variant      KV Heads     Cache Size      Reduction
------------------------------------------------------------
MHA          64           343.6 GB         1x
GQA-8        8            42.9 GB         8x
GQA-4        4            21.5 GB         16x
GQA-2        2            10.7 GB         32x
MQA          1            5.4 GB         64x

GQA-8 provides significant memory savings while retaining
more representational capacity than MQA

The numbers tell a clear story. GQA with 8 key-value heads reduces cache size by 8x compared to MHA, a substantial improvement that enables larger batches or longer sequences. At the same time, it retains 8 separate key-value representations instead of collapsing everything to one, preserving the model's ability to attend to multiple distinct aspects of the input.

Why Groups Work Better Than Extremes

The effectiveness of GQA stems from an empirical observation about attention head behavior: many attention heads learn similar patterns. Research on trained transformers has shown that heads often cluster into functional groups, with heads in the same cluster attending to similar types of relationships. GQA exploits this redundancy by formalizing the grouping structure.

Consider a model with 32 query heads. In MHA, each head has its own key-value projection, but analysis often reveals that heads 1-4 learn similar positional patterns, heads 5-8 focus on syntactic relationships, and so on. These natural clusters suggest that forcing heads within a cluster to share key-value projections might not significantly impact model quality while providing substantial memory savings.

Out[4]:
Visualization
Diagram showing 8 query heads divided into 2 groups, each group sharing one K,V pair.
Grouped Query Attention architecture with 8 query heads and 2 KV groups. Query heads are partitioned into groups of 4, with each group sharing a single key-value pair. This reduces the KV cache by 4x while maintaining 2 independent key-value representations.

The diagram illustrates GQA with 8 query heads divided into 2 groups. Query heads Q1-Q4 share key-value pair (K1, V1), while Q5-Q8 share (K2, V2). Each group can attend to different aspects of the input through their shared key-value representation, while the queries within each group specialize their attention patterns using different learned projections.

Mathematical Formulation

Let's formalize GQA mathematically. The core idea is straightforward: instead of each query head having its own key-value projections (as in MHA) or all query heads sharing one key-value projection (as in MQA), GQA divides query heads into groups where each group shares a single key-value projection.

Given an input sequence XRT×dX \in \mathbb{R}^{T \times d}, GQA is parameterized by:

  • hh: number of query heads (determines query diversity)
  • gg: number of key-value groups, where gg divides hh evenly (controls memory vs. capacity trade-off)
  • dk=d/hd_k = d / h: dimension per head (same as standard multi-head attention)

The ratio h/gh/g determines how many query heads share each key-value head. For example, with h=32h = 32 query heads and g=8g = 8 KV groups, each group of 4 query heads shares one key-value representation.

Query Projections

Query projections work exactly as in standard multi-head attention. Each query head ii has its own learned projection matrix that transforms the input into a query representation:

Qi=XWiQfor i{1,,h}Q_i = X W^Q_i \quad \text{for } i \in \{1, \ldots, h\}

where:

  • QiRT×dkQ_i \in \mathbb{R}^{T \times d_k}: the query matrix for head ii, containing a dkd_k-dimensional query vector for each of the TT positions
  • XRT×dX \in \mathbb{R}^{T \times d}: the input sequence with TT tokens, each represented as a dd-dimensional vector
  • WiQRd×dkW^Q_i \in \mathbb{R}^{d \times d_k}: the learned query projection matrix for head ii, transforming from model dimension to head dimension

Having hh separate query projections preserves the model's ability to ask hh different "questions" of the input. Each query head can still specialize in detecting different patterns, maintaining expressive power.

Key-Value Projections (The Key Difference)

Here is where GQA diverges from MHA. Instead of hh key-value projections, we have only gg projections that are shared among groups of query heads:

Kj=XWjK,Vj=XWjVfor j{1,,g}K_j = X W^K_j, \quad V_j = X W^V_j \quad \text{for } j \in \{1, \ldots, g\}

where:

  • KjRT×dkK_j \in \mathbb{R}^{T \times d_k}: the key matrix for group jj, containing key vectors for all TT positions
  • VjRT×dkV_j \in \mathbb{R}^{T \times d_k}: the value matrix for group jj, containing value vectors for all TT positions
  • WjKRd×dkW^K_j \in \mathbb{R}^{d \times d_k}: the learned key projection matrix for group jj
  • WjVRd×dkW^V_j \in \mathbb{R}^{d \times d_k}: the learned value projection matrix for group jj

With only gg key-value projections instead of hh, the KV cache during inference stores gg key-value pairs rather than hh, reducing memory by a factor of h/gh/g.

Grouping Function

To connect query heads to their shared key-value representations, we define a grouping function that maps each query head index to its corresponding key-value group:

group(i)=(i1)gh+1\text{group}(i) = \left\lfloor \frac{(i - 1) \cdot g}{h} \right\rfloor + 1

where:

  • ii: the query head index, ranging from 1 to hh
  • gg: the total number of key-value groups
  • hh: the total number of query heads
  • \lfloor \cdot \rfloor: the floor function (rounds down to nearest integer)

This function partitions the hh query heads into gg groups of size h/gh/g each. For example, with h=8h = 8 query heads and g=2g = 2 groups:

  • Query heads 1, 2, 3, 4 map to group 1 (using K1K_1, V1V_1)
  • Query heads 5, 6, 7, 8 map to group 2 (using K2K_2, V2V_2)

Attention Computation

Each query head computes attention using its own query but the shared key-value from its group:

headi=softmax(QiKgroup(i)Tdk)Vgroup(i)\text{head}_i = \text{softmax}\left(\frac{Q_i K_{\text{group}(i)}^T}{\sqrt{d_k}}\right) V_{\text{group}(i)}

where:

  • headiRT×dk\text{head}_i \in \mathbb{R}^{T \times d_k}: the output of attention head ii
  • QiRT×dkQ_i \in \mathbb{R}^{T \times d_k}: the query matrix for head ii (unique to this head)
  • Kgroup(i)RT×dkK_{\text{group}(i)} \in \mathbb{R}^{T \times d_k}: the key matrix shared by all heads in group group(i)\text{group}(i)
  • Vgroup(i)RT×dkV_{\text{group}(i)} \in \mathbb{R}^{T \times d_k}: the value matrix shared by all heads in group group(i)\text{group}(i)
  • QiKgroup(i)TRT×TQ_i K_{\text{group}(i)}^T \in \mathbb{R}^{T \times T}: the attention score matrix
  • dk\sqrt{d_k}: scaling factor to prevent dot products from growing too large

The key insight is that different query heads in the same group can still learn different attention patterns. Even though heads 1-4 all use K1K_1 and V1V_1, their different query projections Q1,Q2,Q3,Q4Q_1, Q_2, Q_3, Q_4 produce different attention weights over the same key-value pairs.

Output Combination

Finally, all head outputs are concatenated and projected back to the model dimension:

GQA(X)=Concat(head1,,headh)WO\text{GQA}(X) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) W^O

where:

  • Concat(head1,,headh)RT×(hdk)\text{Concat}(\text{head}_1, \ldots, \text{head}_h) \in \mathbb{R}^{T \times (h \cdot d_k)}: concatenation of all hh head outputs along the feature dimension
  • WOR(hdk)×dW^O \in \mathbb{R}^{(h \cdot d_k) \times d}: the output projection matrix
  • GQA(X)RT×d\text{GQA}(X) \in \mathbb{R}^{T \times d}: the final output, matching the input dimension

Since hdk=dh \cdot d_k = d, the concatenated heads have dimension dd, and WOW^O projects back to dimension dd, preserving the input shape for residual connections.

In[5]:
Code
import torch
import torch.nn as nn
import torch.nn.functional as F


class GroupedQueryAttention(nn.Module):
    """Grouped Query Attention implementation."""

    def __init__(
        self,
        d_model: int,
        num_heads: int,
        num_kv_heads: int,
        dropout: float = 0.0,
    ):
        super().__init__()
        assert d_model % num_heads == 0, (
            "d_model must be divisible by num_heads"
        )
        assert num_heads % num_kv_heads == 0, (
            "num_heads must be divisible by num_kv_heads"
        )

        self.d_model = d_model
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads
        self.num_queries_per_kv = num_heads // num_kv_heads
        self.head_dim = d_model // num_heads

        # Query projection: full size for all query heads
        self.w_q = nn.Linear(d_model, num_heads * self.head_dim, bias=False)

        # Key and value projections: reduced size for KV heads only
        self.w_k = nn.Linear(d_model, num_kv_heads * self.head_dim, bias=False)
        self.w_v = nn.Linear(d_model, num_kv_heads * self.head_dim, bias=False)

        # Output projection
        self.w_o = nn.Linear(d_model, d_model, bias=False)

        self.dropout = nn.Dropout(dropout)
        self.scale = self.head_dim**-0.5

    def forward(
        self,
        x: torch.Tensor,
        mask: torch.Tensor = None,
        kv_cache: tuple = None,
    ) -> tuple:
        batch_size, seq_len, _ = x.shape

        # Compute queries: (batch, seq, num_heads, head_dim)
        q = self.w_q(x).view(batch_size, seq_len, self.num_heads, self.head_dim)

        # Compute keys and values: (batch, seq, num_kv_heads, head_dim)
        k = self.w_k(x).view(
            batch_size, seq_len, self.num_kv_heads, self.head_dim
        )
        v = self.w_v(x).view(
            batch_size, seq_len, self.num_kv_heads, self.head_dim
        )

        # Handle KV cache for incremental decoding
        if kv_cache is not None:
            k_cache, v_cache = kv_cache
            k = torch.cat([k_cache, k], dim=1)
            v = torch.cat([v_cache, v], dim=1)

        new_kv_cache = (k, v)
        kv_seq_len = k.size(1)

        # Reshape for attention computation
        # q: (batch, num_heads, seq, head_dim)
        q = q.transpose(1, 2)
        # k, v: (batch, num_kv_heads, kv_seq, head_dim)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        # Expand k, v to match query heads by repeating
        # (batch, num_kv_heads, kv_seq, head_dim) -> (batch, num_heads, kv_seq, head_dim)
        k = k.repeat_interleave(self.num_queries_per_kv, dim=1)
        v = v.repeat_interleave(self.num_queries_per_kv, dim=1)

        # Compute attention scores: (batch, num_heads, seq, kv_seq)
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale

        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, float("-inf"))

        attn_weights = F.softmax(attn_scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Apply attention to values: (batch, num_heads, seq, head_dim)
        attn_output = torch.matmul(attn_weights, v)

        # Reshape and project output
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch_size, seq_len, self.d_model)

        return self.w_o(attn_output), new_kv_cache
Out[6]:
Console
Grouped Query Attention Test
=======================================================

Configuration:
  d_model: 512
  Query heads: 8
  KV heads: 2
  Queries per KV group: 4

Shapes:
  Input: (2, 16, 512)
  Output: (2, 16, 512)
  KV cache K: (2, 16, 2, 64)
  KV cache V: (2, 16, 2, 64)

Parameter counts:
  GQA: 655,360
  MHA (equivalent): 1,048,576
  MQA (equivalent): 589,824
Out[7]:
Visualization
Stacked bar chart showing parameter counts for MHA, GQA, and MQA, broken down by projection type.
Parameter count comparison between attention variants. GQA reduces parameters compared to MHA by using fewer K,V projections, while retaining more capacity than MQA. The query and output projections remain the same size across all variants.

The key implementation detail is the repeat_interleave operation that expands the key-value tensors to match the number of query heads. This creates the grouping structure: consecutive query heads share the same key-value entries. The KV cache stores only the reduced number of key-value heads, providing the memory savings.

Out[8]:
Visualization
Query head Q1 attention pattern.
Query head Q1 attention pattern.
Query head Q2 attention pattern.
Query head Q2 attention pattern.
Query head Q3 attention pattern.
Query head Q3 attention pattern.
Query head Q4 attention pattern.
Query head Q4 attention pattern.

The attention heatmaps above reveal a crucial property of GQA: even though all four query heads share the same key and value tensors (they're all in Group 1), each head produces distinct attention patterns. Head Q1 might focus heavily on recent tokens, Q2 on the first token, Q3 on positions 2-3, and Q4 shows yet another pattern. This diversity emerges because each query head has its own learned projection. The queries are different even when the keys are identical.

The attention heatmaps reveal a crucial property of GQA: even though all four query heads share the same key and value tensors (they're all in Group 1), each head produces distinct attention patterns. Head Q1 might focus heavily on recent tokens, Q2 on the first token, Q3 on positions 2-3, and Q4 shows yet another pattern. This diversity emerges because each query head has its own learned projection. The queries are different even when the keys are identical.

Memory Savings Analysis

GQA's memory benefits depend on the ratio of query heads to key-value heads. During autoregressive generation, the model caches key and value tensors from previous positions to avoid redundant computation. The size of this KV cache determines how many sequences can be processed in parallel and how long those sequences can be.

KV Cache Memory Formula

For a single transformer layer using GQA, the KV cache stores keys and values for all gg groups across all cached positions:

CacheGQA=2×g×dk×T×B×bytes\text{Cache}_{\text{GQA}} = 2 \times g \times d_k \times T \times B \times \text{bytes}

where:

  • 22: accounts for storing both keys and values (two tensors per group)
  • gg: number of key-value groups in GQA
  • dkd_k: dimension per head (typically 64 or 128 in modern models)
  • TT: current sequence length (grows during generation)
  • BB: batch size (number of sequences processed in parallel)
  • bytes\text{bytes}: bytes per element (2 for float16/bfloat16, 4 for float32)

For a full model with LL layers, multiply by LL. The total cache size is 2×L×g×dk×T×B×bytes2 \times L \times g \times d_k \times T \times B \times \text{bytes}.

Memory Reduction Factor

Comparing GQA to MHA, which stores hh key-value heads instead of gg, the memory reduction is simply the ratio of heads to groups:

Reduction factor=hg\text{Reduction factor} = \frac{h}{g}

where:

  • hh: number of query heads (and number of KV heads in MHA)
  • gg: number of KV groups in GQA

For GQA-8 with 64 query heads, the reduction factor is 64/8=8×64/8 = 8\times. This means the KV cache is 8 times smaller than it would be with full MHA, enabling either 8 times longer sequences, 8 times larger batches, or some combination of both.

In[9]:
Code
def memory_analysis(
    num_query_heads: int,
    head_dim: int,
    num_layers: int,
    seq_lengths: list,
    batch_size: int = 1,
):
    """Analyze KV cache memory for different GQA configurations."""
    bytes_per_element = 2  # fp16

    kv_head_options = [num_query_heads, 8, 4, 2, 1]

    results = {"seq_lengths": seq_lengths, "configs": []}

    for num_kv_heads in kv_head_options:
        config_data = {
            "name": "MHA"
            if num_kv_heads == num_query_heads
            else "MQA"
            if num_kv_heads == 1
            else f"GQA-{num_kv_heads}",
            "kv_heads": num_kv_heads,
            "memory_gb": [],
        }

        for seq_len in seq_lengths:
            cache_bytes = (
                2
                * num_layers
                * num_kv_heads
                * head_dim
                * seq_len
                * batch_size
                * bytes_per_element
            )
            config_data["memory_gb"].append(cache_bytes / 1e9)

        results["configs"].append(config_data)

    return results


# Analyze for a 70B-scale model
memory_data = memory_analysis(
    num_query_heads=64,
    head_dim=128,
    num_layers=80,
    seq_lengths=[1024, 2048, 4096, 8192, 16384],
    batch_size=1,
)
Out[10]:
Visualization
Line plot showing KV cache memory growth with sequence length for MHA, GQA variants, and MQA.
KV cache memory scaling with sequence length for different attention variants. GQA-8 (used by LLaMA 2 70B) provides 8x memory reduction compared to MHA while retaining more capacity than MQA. At 16K tokens, this translates to 4.4 GB vs 35.2 GB for MHA.
Out[11]:
Console
KV Cache Memory at 16K Sequence Length (70B model, batch=1)
=======================================================

Variant      KV Heads     Memory (GB)     vs MHA
-------------------------------------------------------
MHA          64           42.95             1.0x smaller
GQA-8        8            5.37             8.0x smaller
GQA-4        4            2.68             16.0x smaller
GQA-2        2            1.34             32.0x smaller
MQA          1            0.67             64.0x smaller

The analysis reveals why GQA-8 has become the standard choice for large models. At 16K tokens, MHA would require 35 GB just for the KV cache, exceeding the memory of most GPUs. GQA-8 reduces this to 4.4 GB, fitting comfortably on modern hardware while retaining 8 independent key-value representations.

GQA vs MQA: Quality Comparison

Does GQA's additional key-value capacity translate to meaningful quality improvements over MQA? Research from Meta's LLaMA 2 paper and subsequent work provides empirical guidance.

The original GQA paper (Ainslie et al., 2023) conducted extensive experiments comparing MHA, GQA with various group sizes, and MQA. Key findings include:

  • GQA with 8 KV heads matches MHA quality within 0.5% on most benchmarks while providing 8x memory reduction
  • MQA shows 1-3% quality degradation compared to MHA on complex reasoning tasks
  • The quality gap between GQA-8 and MQA is most pronounced on multi-hop reasoning and long-context retrieval
In[12]:
Code
# Simulated benchmark results based on published research
benchmark_results = {
    "model_configs": [
        {"name": "MHA (baseline)", "kv_heads": 32, "relative_quality": 100.0},
        {"name": "GQA-8", "kv_heads": 8, "relative_quality": 99.5},
        {"name": "GQA-4", "kv_heads": 4, "relative_quality": 99.0},
        {"name": "GQA-2", "kv_heads": 2, "relative_quality": 98.0},
        {"name": "MQA", "kv_heads": 1, "relative_quality": 96.5},
    ],
    "task_breakdown": {
        "General LM": [100.0, 99.8, 99.5, 99.0, 98.0],
        "Reasoning": [100.0, 99.2, 98.5, 97.0, 95.0],
        "Long Context": [100.0, 99.0, 97.5, 95.0, 92.0],
    },
}
Out[13]:
Visualization
Grouped bar chart comparing quality scores for MHA, GQA variants, and MQA across three task types.
Relative quality comparison across attention variants and task types. GQA-8 closely matches MHA performance across all tasks, while MQA shows more significant degradation on reasoning and long-context tasks that require tracking multiple information streams.

The task breakdown reveals an important pattern. For general language modeling, all variants perform similarly because the task primarily requires local pattern matching. Reasoning tasks show more separation because they require the model to hold multiple pieces of evidence in mind simultaneously. Long-context tasks magnify this effect, as the model must retrieve and combine information from distant positions.

GQA-8 provides the best balance: within 1% of MHA on all tasks while providing substantial memory savings. This explains its adoption in LLaMA 2, Mistral, and other recent models.

Implementation for Inference

For production deployment, efficient KV cache management matters. Let's implement a complete GQA module optimized for autoregressive generation with proper caching.

In[14]:
Code
class GQAForGeneration(nn.Module):
    """GQA implementation optimized for autoregressive generation."""

    def __init__(
        self,
        d_model: int,
        num_heads: int,
        num_kv_heads: int,
        max_seq_len: int = 4096,
        dropout: float = 0.0,
    ):
        super().__init__()
        assert d_model % num_heads == 0
        assert num_heads % num_kv_heads == 0

        self.d_model = d_model
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads
        self.num_queries_per_kv = num_heads // num_kv_heads
        self.head_dim = d_model // num_heads
        self.max_seq_len = max_seq_len

        # Projections
        self.w_q = nn.Linear(d_model, num_heads * self.head_dim, bias=False)
        self.w_k = nn.Linear(d_model, num_kv_heads * self.head_dim, bias=False)
        self.w_v = nn.Linear(d_model, num_kv_heads * self.head_dim, bias=False)
        self.w_o = nn.Linear(d_model, d_model, bias=False)

        self.dropout = nn.Dropout(dropout)
        self.scale = self.head_dim**-0.5

        # Pre-compute causal mask
        self.register_buffer(
            "causal_mask", torch.tril(torch.ones(max_seq_len, max_seq_len))
        )

    def forward(
        self,
        x: torch.Tensor,
        start_pos: int = 0,
        kv_cache: tuple = None,
    ) -> tuple:
        batch_size, seq_len, _ = x.shape

        # Compute queries
        q = self.w_q(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
        q = q.transpose(1, 2)  # (batch, num_heads, seq, head_dim)

        # Compute new keys and values
        k_new = self.w_k(x).view(
            batch_size, seq_len, self.num_kv_heads, self.head_dim
        )
        v_new = self.w_v(x).view(
            batch_size, seq_len, self.num_kv_heads, self.head_dim
        )

        # Update cache
        if kv_cache is not None:
            k_cache, v_cache = kv_cache
            k = torch.cat([k_cache, k_new], dim=1)
            v = torch.cat([v_cache, v_new], dim=1)
        else:
            k = k_new
            v = v_new

        new_kv_cache = (k, v)
        kv_seq_len = k.size(1)

        # Reshape for attention
        k = k.transpose(1, 2)  # (batch, num_kv_heads, kv_seq, head_dim)
        v = v.transpose(1, 2)

        # Expand KV to match query heads
        k = k.repeat_interleave(self.num_queries_per_kv, dim=1)
        v = v.repeat_interleave(self.num_queries_per_kv, dim=1)

        # Attention
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale

        # Apply causal mask
        mask = self.causal_mask[start_pos : start_pos + seq_len, :kv_seq_len]
        mask = mask.unsqueeze(0).unsqueeze(0)  # (1, 1, seq, kv_seq)
        attn_scores = attn_scores.masked_fill(mask == 0, float("-inf"))

        attn_weights = F.softmax(attn_scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Output
        attn_output = torch.matmul(attn_weights, v)
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch_size, seq_len, self.d_model)

        return self.w_o(attn_output), new_kv_cache


def simulate_generation(model: GQAForGeneration, prompt_len: int, gen_len: int):
    """Simulate token-by-token generation and track cache growth."""
    batch_size = 1
    d_model = model.d_model

    # Process prompt
    prompt = torch.randn(batch_size, prompt_len, d_model)
    _, kv_cache = model(prompt, start_pos=0)

    cache_sizes = []
    current_pos = prompt_len

    # Generate tokens one by one
    for _ in range(gen_len):
        # Simulate next token embedding
        next_token = torch.randn(batch_size, 1, d_model)

        # Process with cached KV
        _, kv_cache = model(
            next_token, start_pos=current_pos, kv_cache=kv_cache
        )
        current_pos += 1

        # Track cache size
        k_cache, v_cache = kv_cache
        cache_bytes = (k_cache.numel() + v_cache.numel()) * 4  # float32
        cache_sizes.append(
            {
                "position": current_pos,
                "cache_kb": cache_bytes / 1024,
                "k_shape": tuple(k_cache.shape),
            }
        )

    return cache_sizes
Out[15]:
Console
GQA Incremental Generation Demo
=======================================================

Model config: d_model=512, heads=8, kv_heads=2
Prompt: 20 tokens, Generated: 30 tokens

Cache growth:
  After prompt: 21.0 KB, shape: (1, 21, 2, 64)
  After 15 tokens: 35.0 KB, shape: (1, 35, 2, 64)
  After 30 tokens: 50.0 KB, shape: (1, 50, 2, 64)

With MHA, cache would be 4x larger: 200.0 KB

The cache shape shows (batch, seq, num_kv_heads, head_dim) with only 2 KV heads instead of 8 query heads. During attention, these are expanded via repeat_interleave to match the query head count, but the cache itself remains compact.

Out[16]:
Visualization
Line plot showing KV cache size in KB growing linearly with token position, with annotations showing MHA equivalent.
KV cache memory growth during token-by-token generation. The cache grows linearly with sequence length, but GQA-2 requires only 25% of the memory that MHA would need for the same sequence. This difference matters most for long-context generation.

The visualization clearly shows the linear growth of the KV cache during generation. The shaded region represents the memory savings from using GQA-2 instead of MHA. For longer sequences and larger models, these savings become substantial, often determining whether a model fits in GPU memory or requires model parallelism.

Choosing the Right Number of KV Heads

The optimal number of KV heads depends on your specific constraints and requirements. Let's develop a framework for making this decision.

In[17]:
Code
def analyze_gqa_tradeoffs(
    num_query_heads: int,
    d_model: int,
    num_layers: int,
    target_memory_gb: float,
    seq_length: int,
    batch_size: int,
):
    """Find GQA configurations that fit within a memory budget."""
    bytes_per_element = 2  # fp16
    head_dim = d_model // num_query_heads

    viable_configs = []

    for num_kv_heads in [1, 2, 4, 8, 16, num_query_heads]:
        if num_query_heads % num_kv_heads != 0:
            continue

        # KV cache memory
        cache_bytes = (
            2
            * num_layers
            * num_kv_heads
            * head_dim
            * seq_length
            * batch_size
            * bytes_per_element
        )
        cache_gb = cache_bytes / 1e9

        # KV projection parameters
        kv_params = 2 * d_model * num_kv_heads * head_dim

        # Estimated quality (simplified model based on research)
        quality_factor = min(
            1.0, 0.92 + 0.08 * (num_kv_heads / num_query_heads) ** 0.5
        )

        config = {
            "kv_heads": num_kv_heads,
            "cache_gb": cache_gb,
            "kv_params_m": kv_params / 1e6,
            "quality": quality_factor * 100,
            "fits_budget": cache_gb <= target_memory_gb,
            "memory_reduction": num_query_heads / num_kv_heads,
        }
        viable_configs.append(config)

    return viable_configs


# Example: 70B model on 80GB GPU with 8K context
configs = analyze_gqa_tradeoffs(
    num_query_heads=64,
    d_model=8192,
    num_layers=80,
    target_memory_gb=20.0,  # Reserve 60GB for model weights
    seq_length=8192,
    batch_size=8,
)
Out[18]:
Console
GQA Configuration Analysis (70B model, 8K context, batch=8)
Memory budget for KV cache: 20 GB
======================================================================

KV Heads   Cache        Quality      Reduction    Fits?
----------------------------------------------------------------------
1          2.7 GB      93.0%       64x         ✓
2          5.4 GB      93.4%       32x         ✓
4          10.7 GB      94.0%       16x         ✓
8          21.5 GB      94.8%       8x         ✗
16         42.9 GB      96.0%       4x         ✗
64         171.8 GB      100.0%       1x         ✗

Recommended: GQA-4 (highest quality that fits budget)
Out[19]:
Visualization
Scatter plot with quality on y-axis and memory on x-axis, showing viable GQA configurations.
Decision framework for choosing GQA configuration. The plot shows the quality-memory trade-off curve, with the green region indicating configurations that fit within the 20 GB memory budget. GQA-8 is often the sweet spot, providing good quality retention with significant memory savings.

The framework reveals several insights:

  • MHA is only viable for small batches or short contexts
  • GQA-8 fits most practical deployment scenarios while retaining near-MHA quality
  • GQA-4 or GQA-2 may be necessary for very long contexts or large batches
  • MQA should be reserved for extreme memory constraints

Models Using GQA

GQA has become the de facto standard for modern large language models. Here's how major models configure their attention:

In[20]:
Code
production_models = [
    {
        "name": "LLaMA 2 7B",
        "query_heads": 32,
        "kv_heads": 32,
        "attention": "MHA",
        "head_dim": 128,
    },
    {
        "name": "LLaMA 2 70B",
        "query_heads": 64,
        "kv_heads": 8,
        "attention": "GQA-8",
        "head_dim": 128,
    },
    {
        "name": "Mistral 7B",
        "query_heads": 32,
        "kv_heads": 8,
        "attention": "GQA-8",
        "head_dim": 128,
    },
    {
        "name": "Mixtral 8x7B",
        "query_heads": 32,
        "kv_heads": 8,
        "attention": "GQA-8",
        "head_dim": 128,
    },
    {
        "name": "Falcon 40B",
        "query_heads": 64,
        "kv_heads": 1,
        "attention": "MQA",
        "head_dim": 64,
    },
    {
        "name": "Falcon 180B",
        "query_heads": 64,
        "kv_heads": 8,
        "attention": "GQA-8",
        "head_dim": 128,
    },
    {
        "name": "Qwen 72B",
        "query_heads": 64,
        "kv_heads": 8,
        "attention": "GQA-8",
        "head_dim": 128,
    },
]
Out[21]:
Console
GQA in Production Models
======================================================================

Model              Attention    Q Heads    KV Heads   Ratio
----------------------------------------------------------------------
LLaMA 2 7B         MHA          32         32         1:1
LLaMA 2 70B        GQA-8        64         8          8:1
Mistral 7B         GQA-8        32         8          4:1
Mixtral 8x7B       GQA-8        32         8          4:1
Falcon 40B         MQA          64         1          64:1
Falcon 180B        GQA-8        64         8          8:1
Qwen 72B           GQA-8        64         8          8:1

Key observations:
- Smaller models (7B) often use MHA since memory isn't critical
- Larger models (70B+) consistently use GQA-8
- GQA-8 has emerged as the industry standard ratio

A clear pattern emerges: smaller models often retain MHA because memory pressure is manageable, while models at 70B+ scale consistently adopt GQA with 8 key-value heads. This convergence suggests that 8 KV heads represent an empirically validated sweet spot for balancing quality and efficiency at scale.

Converting MHA to GQA

What if you have an existing MHA model and want the benefits of GQA? The LLaMA 2 paper describes a conversion process called "uptraining" that adapts pretrained MHA models to use GQA.

The conversion process works as follows:

  1. Mean pooling: Average the key and value projection weights across the heads that will share a KV representation
  2. Uptraining: Continue training the converted model on a small fraction (typically 5%) of the original training data
  3. Validation: Verify quality recovery on held-out benchmarks
In[22]:
Code
def convert_mha_to_gqa(
    mha_k_weight: torch.Tensor,  # (d_model, num_heads * head_dim)
    mha_v_weight: torch.Tensor,  # (d_model, num_heads * head_dim)
    num_heads: int,
    num_kv_groups: int,
) -> tuple:
    """Convert MHA K,V weights to GQA by mean pooling within groups."""
    d_model = mha_k_weight.shape[0]
    head_dim = mha_k_weight.shape[1] // num_heads
    heads_per_group = num_heads // num_kv_groups

    # Reshape to (d_model, num_heads, head_dim)
    k_reshaped = mha_k_weight.view(d_model, num_heads, head_dim)
    v_reshaped = mha_v_weight.view(d_model, num_heads, head_dim)

    # Reshape to (d_model, num_kv_groups, heads_per_group, head_dim)
    k_grouped = k_reshaped.view(
        d_model, num_kv_groups, heads_per_group, head_dim
    )
    v_grouped = v_reshaped.view(
        d_model, num_kv_groups, heads_per_group, head_dim
    )

    # Mean pool within groups: (d_model, num_kv_groups, head_dim)
    gqa_k_weight = k_grouped.mean(dim=2)
    gqa_v_weight = v_grouped.mean(dim=2)

    # Reshape to (d_model, num_kv_groups * head_dim)
    gqa_k_weight = gqa_k_weight.view(d_model, num_kv_groups * head_dim)
    gqa_v_weight = gqa_v_weight.view(d_model, num_kv_groups * head_dim)

    return gqa_k_weight, gqa_v_weight


# Example conversion
torch.manual_seed(42)
d_model, num_heads, head_dim = 512, 8, 64
num_kv_groups = 2

# Simulate MHA weights
mha_k = torch.randn(d_model, num_heads * head_dim)
mha_v = torch.randn(d_model, num_heads * head_dim)

# Convert to GQA
gqa_k, gqa_v = convert_mha_to_gqa(mha_k, mha_v, num_heads, num_kv_groups)
Out[23]:
Console
MHA to GQA Weight Conversion
==================================================

Original MHA weights:
  K shape: torch.Size([512, 512])
  V shape: torch.Size([512, 512])

Converted GQA weights:
  K shape: torch.Size([512, 128])
  V shape: torch.Size([512, 128])

Parameter reduction: 4x for K,V projections
Out[24]:
Visualization
Original MHA key weights with 8 heads.
Original MHA key weights with 8 heads.
Weights grouped for mean pooling (2 groups of 4).
Weights grouped for mean pooling (2 groups of 4).
GQA key weights after mean pooling (2 heads).
GQA key weights after mean pooling (2 heads).

Mean pooling provides a reasonable initialization, but the converted model typically shows 2-5% quality degradation before uptraining. The uptraining phase allows the model to adapt its query projections to work effectively with the shared key-value representations. Research suggests that 5% of the original training compute is sufficient to recover most of the lost quality.

Limitations and Practical Considerations

While GQA provides an excellent balance between efficiency and quality, several practical considerations deserve attention.

The main limitation is inherent to the grouping structure. When multiple query heads share key-value representations, they cannot independently retrieve different information from the context. Consider a task requiring the model to simultaneously track a subject's location and their emotional state. With MHA, separate heads might specialize in these two aspects. With GQA, heads within the same group must coordinate through their queries alone, which may reduce independence. Empirically, this limitation rarely causes problems with 8 KV groups, but becomes more pronounced with fewer groups.

Implementation complexity is slightly higher than MHA. The repeat_interleave operation required to expand KV tensors to match query heads adds overhead and can be tricky to optimize for specific hardware. Modern deep learning frameworks provide efficient implementations, but custom kernels like Flash Attention require explicit GQA support. Most major frameworks now include this, but verify compatibility before deployment.

Debugging attention patterns becomes more challenging with GQA. When visualizing attention weights, remember that heads within a group share the same underlying key-value representation. Patterns that look different may stem entirely from query differences, not from different "views" of the input. This can complicate interpretability analyses.

Finally, the optimal number of KV groups depends on both the model architecture and the target deployment scenario. A model trained with GQA-8 cannot easily be converted to GQA-4 or MQA without quality loss. If you anticipate needing flexibility, consider training multiple variants or designing for the most constrained deployment scenario.

Key Parameters

When implementing or configuring GQA, the following parameters have the most significant impact on behavior:

  • num_heads (h): The number of query heads. Determines query diversity and must be divisible by num_kv_heads. Common values: 32 (7B models), 64 (70B models). More heads enable more diverse attention patterns but require proportionally more query projection parameters.

  • num_kv_heads (g): The number of key-value groups. Controls the memory-quality trade-off. Must divide num_heads evenly. Common values: 8 for large models (GQA-8), 1 for MQA, equal to num_heads for MHA. Lower values reduce KV cache memory but may impact quality on complex reasoning tasks.

  • head_dim (d_k): Dimension per attention head, computed as d_model // num_heads. Typical values: 64 or 128. Affects the expressiveness of each attention head and determines the KV cache size per head.

  • d_model: Model dimension, must be divisible by num_heads. Standard values: 4096 (7B), 8192 (70B). Larger dimensions increase model capacity but also increase memory requirements.

  • Query-to-KV ratio (num_heads // num_kv_heads): Determines how many query heads share each KV head. A ratio of 8:1 (GQA-8) has emerged as the industry standard, providing 8x memory reduction with minimal quality impact.

Summary

Grouped Query Attention represents a practical evolution of transformer attention for efficient inference. By sharing key-value heads among groups of query heads, GQA achieves significant memory savings while retaining most of the representational capacity of Multi-Head Attention.

Key takeaways from this chapter:

  • GQA groups query heads to share key-value projections, providing a tunable trade-off between MHA's capacity and MQA's efficiency
  • Memory savings scale with the ratio of query heads to KV heads: 8 query heads sharing 1 KV head reduces cache by 8x
  • Quality impact is minimal for typical ratios: GQA-8 matches MHA within 0.5% on most benchmarks while reducing KV cache by 8x
  • 8 KV heads has emerged as the industry standard for large models, used by LLaMA 2 70B, Mistral, Falcon 180B, and others
  • Implementation requires expanding KV tensors to match query heads during attention, typically via repeat_interleave
  • MHA-to-GQA conversion is possible through mean pooling followed by uptraining on ~5% of training data

The convergence of major model families on GQA-8 suggests this configuration hits an important practical sweet spot. For most large-scale deployments, GQA provides the memory efficiency needed for practical inference without sacrificing the quality that makes the models useful in the first place.

Looking ahead, attention efficiency research continues to evolve. Techniques like sliding window attention, sparse attention, and linear attention offer complementary approaches to managing long sequences. GQA can be combined with these methods, and future chapters will explore how these techniques work together in modern architectures.

Quiz

Ready to test your understanding? Take this quick quiz to reinforce what you've learned about Grouped Query Attention.

Loading component...
Track your reading progress

Sign in to mark chapters as read and track your learning journey

Sign in →

Comments

Reference

BIBTEXAcademic
@misc{groupedqueryattentionmemoryefficientllminference, author = {Michael Brenndoerfer}, title = {Grouped Query Attention: Memory-Efficient LLM Inference}, year = {2025}, url = {https://mbrenndoerfer.com/writing/grouped-query-attention-gqa-efficient-llm-inference}, organization = {mbrenndoerfer.com}, note = {Accessed: 2025-12-19} }
APAAcademic
Michael Brenndoerfer (2025). Grouped Query Attention: Memory-Efficient LLM Inference. Retrieved from https://mbrenndoerfer.com/writing/grouped-query-attention-gqa-efficient-llm-inference
MLAAcademic
Michael Brenndoerfer. "Grouped Query Attention: Memory-Efficient LLM Inference." 2025. Web. 12/19/2025. <https://mbrenndoerfer.com/writing/grouped-query-attention-gqa-efficient-llm-inference>.
CHICAGOAcademic
Michael Brenndoerfer. "Grouped Query Attention: Memory-Efficient LLM Inference." Accessed 12/19/2025. https://mbrenndoerfer.com/writing/grouped-query-attention-gqa-efficient-llm-inference.
HARVARDAcademic
Michael Brenndoerfer (2025) 'Grouped Query Attention: Memory-Efficient LLM Inference'. Available at: https://mbrenndoerfer.com/writing/grouped-query-attention-gqa-efficient-llm-inference (Accessed: 12/19/2025).
SimpleBasic
Michael Brenndoerfer (2025). Grouped Query Attention: Memory-Efficient LLM Inference. https://mbrenndoerfer.com/writing/grouped-query-attention-gqa-efficient-llm-inference
Michael Brenndoerfer

About the author: Michael Brenndoerfer

All opinions expressed here are my own and do not reflect the views of my employer.

Michael currently works as an Associate Director of Data Science at EQT Partners in Singapore, leading AI and data initiatives across private capital investments.

With over a decade of experience spanning private equity, management consulting, and software engineering, he specializes in building and scaling analytics capabilities from the ground up. He has published research in leading AI conferences and holds expertise in machine learning, natural language processing, and value creation through data.

Stay updated

Get notified when I publish new articles on data and AI, private equity, technology, and more.

No spam, unsubscribe anytime.

or

Create a free account to unlock exclusive features, track your progress, and join the conversation.

No popupsUnobstructed readingCommenting100% Free