KV Cache Explained: Efficient Attention for LLM Generation

Michael BrenndoerferJanuary 6, 202638 min read

Learn how KV cache eliminates redundant attention computations in transformers. Understand memory requirements, cache structure, and implementation details.

Reading Level

Choose your expertise level to adjust how many terms are explained. Beginners see more tooltips, experts see fewer to maintain reading flow. Hover over underlined terms for instant definitions.

KV Cache

When you ask a language model to complete the sentence "The capital of France is," it generates tokens one at a time: first "Paris," then perhaps a comma, then additional context. At each step, the model computes attention over all previous tokens to decide what comes next. Without optimization, this means the model recomputes the same attention calculations for "The," "capital," "of," "France," "is" every single time it generates a new token. For a 100-token response, that's nearly 5,000 redundant computations just for the prompt tokens.

The key-value cache, commonly called KV cache, eliminates this redundancy by storing the intermediate attention computations and reusing them across generation steps. This optimization reduces the complexity of autoregressive generation from quadratic to linear, making inference with large language models feasible.

The Redundancy Problem in Autoregressive Generation

As we explored in Part XVIII, decoder-only transformers generate text one token at a time. At each step, the model takes all tokens generated so far (including the prompt) and predicts the next token. The computational core of this process is the self-attention mechanism, which allows each token to gather information from all tokens that came before it. To understand why caching provides such dramatic benefits, we must first examine exactly how attention operates during generation and identify where the redundant work occurs.

Recall from Part X that self-attention computes queries, keys, and values from the input. The fundamental insight behind these three components is that attention works like an information retrieval system: queries represent what a token is looking for, keys represent what each token offers, and values contain the actual information to be retrieved. We can express this through three linear projections:

Q=XWQ,K=XWK,V=XWVQ = XW_Q, \quad K = XW_K, \quad V = XW_V

where:

  • XX: the input embedding matrix, containing the vector representations of all tokens in the sequence
  • WQ,WK,WVW_Q, W_K, W_V: the learned projection matrices that transform inputs into query, key, and value subspaces, each capturing a different aspect of the token's meaning
  • Q,K,VQ, K, V: the resulting matrices, where Q contains vectors representing what each token seeks, K contains vectors representing what each token can be matched against, and V contains the actual content to be aggregated

Each of these projection matrices is learned during training to extract the most useful representations for the attention mechanism. Once training is complete, these matrices remain fixed during inference. This means that given the same input token, the projections will always produce the same key and value vectors, a property that forms the foundation for caching.

The attention mechanism then computes a weighted sum of values based on the compatibility between queries and keys. This compatibility is measured through dot products, which capture how well a query's "question" matches each key's "answer." The full attention formula is:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

where:

  • Q,K,VQ, K, V: the query, key, and value matrices computed from the input
  • KTK^T: the transpose of the key matrix, arranged so that the dot product between queries and keys can be computed efficiently through matrix multiplication
  • dkd_k: the dimension of the key vectors, used for scaling to prevent the dot products from growing too large and causing vanishing gradients in the softmax
  • softmax\text{softmax}: the function that converts raw compatibility scores into probability weights summing to 1, ensuring all weights are positive and amplifying differences so that tokens with higher compatibility receive substantially more attention

The scaling factor dk\sqrt{d_k} is critical. Without this normalization, the dot products between high-dimensional vectors can become very large in magnitude. When these large values pass through the softmax function, the result becomes extremely peaked, meaning almost all attention weight concentrates on a single token. This causes gradients to vanish for all other positions, making training unstable. The square root scaling keeps the variance of the dot products roughly constant regardless of the dimension.

During generation, let's trace what happens at each step to understand precisely where redundancy arises. Suppose we're generating a response to a 10-token prompt. The model must produce one token at a time, with each new token depending on everything that came before it.

Step 1: The model processes all 10 prompt tokens simultaneously in what we call the prefill phase. It computes the QQ, KK, VV matrices, each of shape (10,dk)(10, d_k) or (10,dv)(10, d_v) depending on the specific projection. The attention mechanism then allows each prompt token to attend to all previous prompt tokens, establishing the initial context representation.

Step 2: After generating the first new token, we now have 11 tokens total. Without any optimization, the model would need to process the entire 11-token sequence from scratch. This means computing QQ, KK, VV for all 11 tokens, even though 10 of them are unchanged from the previous step.

Step 3: With 12 tokens total, we compute QQ, KK, VV for all 12 tokens, once again recomputing the projections for the original prompt tokens that haven't changed at all.

Notice the wasteful pattern emerging: at step tt, we recompute the keys and values for the first 10 prompt tokens, even though nothing about them has changed since step 1. The same tokens, passing through the same fixed projection matrices WKW_K and WVW_V, produce the same KK and VV vectors every single time. This redundancy stems from a fundamental property of the attention mechanism: the key and value for a given token depend only on that token's embedding and the projection weights, not on what comes after it in the sequence.

This redundancy has concrete costs that grow dramatically with sequence length. For a model with LL layers, HH attention heads, and head dimension dhd_h, generating TT tokens requires computing KK and VV projections repeatedly at every step. The total number of projection operations across all generation steps sums to:

t=1Tt=T(T+1)2T22\sum_{t=1}^{T} t = \frac{T(T+1)}{2} \approx \frac{T^2}{2}

where:

  • TT: the final sequence length, representing the total number of tokens processed including both the prompt and generated tokens
  • tt: the current step number, which also represents the sequence length at that particular step
  • t=1Tt\sum_{t=1}^{T} t: the sum of the arithmetic progression 1+2+3++T1 + 2 + 3 + \ldots + T, capturing the total number of token-projection operations across all steps

This quadratic scaling has severe implications for practical deployment. If generating 100 tokens requires 100×1012=5,050\frac{100 \times 101}{2} = 5,050 projection operations, then generating 1,000 tokens requires 1000×10012=500,500\frac{1000 \times 1001}{2} = 500,500 operations, a 100-fold increase in work for only a 10-fold increase in output length. This contrasts with the ideal case where each token's keys and values are computed only once, requiring TT projection operations. The KV cache achieves this by computing each projection once and storing it for future reuse.

Out[2]:
Visualization
Computational complexity comparison between generation with and without KV cache. Without caching, the number of projection operations grows quadratically with sequence length. With KV cache, the growth is linear, providing dramatic savings for longer sequences.
Computational complexity comparison between generation with and without KV cache. Without caching, the number of projection operations grows quadratically with sequence length. With KV cache, the growth is linear, providing dramatic savings for longer sequences.

The KV Cache Solution

The insight behind KV caching is straightforward once we recognize where the redundancy lies: since each token's key and value vectors depend only on that token's representation and the fixed projection weights, we can compute them once and store them for reuse in all subsequent generation steps. The query vectors, by contrast, must be recomputed because they are used differently at each step. A query asks "what information should I attend to?" and the answer depends on the token's position in the evolving sequence.

This asymmetry between queries and key-values is fundamental to understanding the cache. When token 5 generates a query, it needs to attend to tokens 1 through 4. When token 10 generates a query, it needs to attend to tokens 1 through 9. The keys and values for tokens 1 through 4 are the same in both cases, they haven't changed. But the query's role is different: it's always asking about everything that came before, and "everything that came before" grows with each step.

During generation with KV caching, the process changes fundamentally from the naive approach:

Prefill phase (Step 1): Process all prompt tokens at once, computing queries, keys, and values for the entire prompt. Store the computed KK and VV matrices in the cache, as these will be reused throughout generation. Use the queries to compute attention within the prompt, establishing the initial hidden states.

Decode phase (Steps 2+): For each new token, the process is efficient:

  1. Compute QQ, KK, VV only for the single new token
  2. Append the new KK and VV vectors to the end of the cache
  3. Compute attention using the new QQ against all cached KK values
  4. Weight all cached VV values by the resulting attention scores
  5. Produce the output for this position

The crucial difference is that we only compute projections for one token per step (the newly generated one), while the attention computation can freely access all previously computed keys and values through the cache. This transforms the projection work from quadratic to linear in the total sequence length.

Mathematical Formulation

Let's formalize this caching mechanism to understand exactly how it achieves the computational savings. The formalization reveals that caching is not an approximation but an algebraically equivalent reformulation of the attention computation.

At generation step tt, let xtx_t denote the embedding of the newly generated token. This embedding captures all the information the model has about this token at the input to the attention layer. We project this single token's embedding to obtain its query, key, and value vectors:

qt=xtWQ,kt=xtWK,vt=xtWVq_t = x_t W_Q, \quad k_t = x_t W_K, \quad v_t = x_t W_V

where:

  • xtx_t: the input embedding for the current token at step tt, a vector of dimension dmodeld_{\text{model}}
  • WQ,WK,WVW_Q, W_K, W_V: the fixed projection matrices shared across all positions and all generation steps
  • qtq_t: the query vector representing what the current token is seeking to attend to
  • ktk_t: the key vector representing how this token can be matched by future queries
  • vtv_t: the value vector containing the information this token contributes when attended to

Note the important distinction from the non-cached case: each of these is a single vector of dimension dkd_k or dvd_v, not a full matrix covering all positions. We compute projections for exactly one token, not the entire sequence.

The cache serves as the model's memory of past tokens, maintaining the history of key and value vectors computed in all previous steps. At the beginning of step tt, before processing the new token, the cache contains concatenated matrices:

Kcache=[k1;k2;;kt1]R(t1)×dkVcache=[v1;v2;;vt1]R(t1)×dv\begin{aligned} K_{\text{cache}} &= [k_1; k_2; \ldots; k_{t-1}] \in \mathbb{R}^{(t-1) \times d_k} \\ V_{\text{cache}} &= [v_1; v_2; \ldots; v_{t-1}] \in \mathbb{R}^{(t-1) \times d_v} \end{aligned}

where:

  • Kcache,VcacheK_{\text{cache}}, V_{\text{cache}}: matrices storing the complete history of keys and values for all previous t1t-1 tokens
  • ki,vik_i, v_i: the key and value vectors computed when token ii was first processed
  • [;][;]: the concatenation operation along the sequence dimension, stacking vectors as rows
  • dk,dvd_k, d_v: the dimensions of the key and value vectors, often equal in practice

This cache represents all the "memory" the attention mechanism has of the sequence so far. Each row corresponds to a token's contribution to the attention computation.

After computing the new token's projections, we update the cache by appending the new key and value vectors:

Kcache[Kcache;kt]Vcache[Vcache;vt]\begin{aligned} K_{\text{cache}} &\leftarrow [K_{\text{cache}}; k_t] \\ V_{\text{cache}} &\leftarrow [V_{\text{cache}}; v_t] \end{aligned}

where:

  • Kcache,VcacheK_{\text{cache}}, V_{\text{cache}}: the cached key and value matrices being extended with new entries
  • \leftarrow: the assignment operator indicating an in-place update to the stored cache state
  • kt,vtk_t, v_t: the newly computed key and value vectors being appended to preserve the sequence history

After this update, the cache contains keys and values for all tt tokens, ready for the attention computation.

With the cache updated, the model computes attention for the new token by having its query vector interact with the entire history:

at=softmax(qtKcacheTdk)R1×tot=atVcacheR1×dv\begin{aligned} a_t &= \text{softmax}\left(\frac{q_t K_{\text{cache}}^T}{\sqrt{d_k}}\right) \in \mathbb{R}^{1 \times t} \\ o_t &= a_t V_{\text{cache}} \in \mathbb{R}^{1 \times d_v} \end{aligned}

where:

  • ata_t: the attention weights for the current step, a vector of tt probabilities representing how much the new token should attend to each previous position
  • qtq_t: the query vector for the current token, seeking relevant information from the context
  • KcacheTK_{\text{cache}}^T: the transpose of the cached key matrix, shaped for efficient dot product computation with the query
  • VcacheV_{\text{cache}}: the cached value matrix containing the actual content to be retrieved and aggregated
  • dkd_k: the dimension of key vectors, providing the scaling factor for numerical stability
  • dvd_v: the dimension of value vectors, determining the output size
  • oto_t: the final attention output for step tt, a weighted combination of all cached values

The computation qtKcacheTq_t K_{\text{cache}}^T produces a vector of tt scores, one for each cached position. These scores measure how relevant each previous token is to the current query. After scaling and applying softmax, these become proper attention weights that sum to 1. Finally, multiplying by VcacheV_{\text{cache}} aggregates the cached values according to these weights, producing the output vector that captures all the relevant information from the sequence history.

This formulation achieves efficiency because the matrix multiplications involve a single query vector against the cached matrices, requiring O(t×d)O(t \times d) operations rather than O(t2×d)O(t^2 \times d) that would be needed to recompute attention from scratch.

Out[3]:
Visualization
Cache growth during autoregressive generation. Starting with an 8-token prompt processed during prefill, the cache grows by one token at each decode step. The prefill phase processes all prompt tokens simultaneously, while each subsequent step adds exactly one new token's keys and values to the cache.
Cache growth during autoregressive generation. Starting with an 8-token prompt processed during prefill, the cache grows by one token at each decode step. The prefill phase processes all prompt tokens simultaneously, while each subsequent step adds exactly one new token's keys and values to the cache.

Cache Structure

The KV cache must be maintained separately for each layer and each attention head in the transformer architecture. This requirement arises because different layers and heads learn different attention patterns and have different projection weights. Layer 1's keys and values are fundamentally different from layer 32's, as each layer captures different levels of abstraction. Similarly, within a layer, head 1 might learn to track syntactic relationships while head 4 tracks semantic similarity, each requiring its own separate cache.

Per-Layer, Per-Head Organization

For a transformer with LL layers and HH attention heads per layer, the complete cache consists of 2×L×H2 \times L \times H tensors: one KK cache and one VV cache for each attention head at each layer. This combinatorial structure means that even modest models require tracking many separate cache tensors.

In practice, deep learning frameworks typically consolidate this organization into two tensors per layer, using an additional dimension to index across heads:

Kcache(l)RB×H×T×dhVcache(l)RB×H×T×dh\begin{aligned} K_{\text{cache}}^{(l)} &\in \mathbb{R}^{B \times H \times T \times d_h} \\ V_{\text{cache}}^{(l)} &\in \mathbb{R}^{B \times H \times T \times d_h} \end{aligned}

where:

  • Kcache(l),Vcache(l)K_{\text{cache}}^{(l)}, V_{\text{cache}}^{(l)}: the key and value cache tensors for layer ll, containing caches for all heads
  • BB: the batch size, allowing multiple sequences to be processed in parallel
  • HH: the number of attention heads, each with its own cached key-value pairs
  • TT: the current sequence length, growing as generation proceeds
  • dhd_h: the head dimension, typically dmodel/Hd_{\text{model}} / H

Some implementations combine keys and values into a single tensor of shape (B,H,T,2,dh)(B, H, T, 2, d_h) or (B,2,H,T,dh)(B, 2, H, T, d_h) for improved memory locality. This layout keeps each position's key and value adjacent in memory, which can improve cache efficiency during the attention computation. However, the logical structure remains the same regardless of the physical memory layout.

How Cached Values Flow Through the Model

Understanding how the cache integrates with the transformer's layer-by-layer computation clarifies why the optimization is both correct and efficient. During the decode phase, when generating token by token, each transformer layer performs the following sequence of operations:

  1. Receives the hidden state for only the new token: ht(l)RB×1×dmodelh_t^{(l)} \in \mathbb{R}^{B \times 1 \times d_{\text{model}}}. This is the representation of the new token as computed by the previous layer.
  2. Projects this hidden state to obtain qtq_t, ktk_t, vtv_t for the new token only.
  3. Retrieves the cached Kcache(l)K_{\text{cache}}^{(l)} and Vcache(l)V_{\text{cache}}^{(l)} containing all previously computed keys and values for this layer.
  4. Appends the new ktk_t and vtv_t to the cache, extending the history by one position.
  5. Computes attention between the new query qtq_t and the full cache, allowing the new token to attend to all previous positions.
  6. Passes the attention output through the feed-forward network, which processes each position independently.
  7. Returns the hidden state for the new token: ht(l+1)h_t^{(l+1)}, ready for the next layer.

Notice that operations 1, 2, 6, and 7 all operate on just a single position. The feed-forward network, layer normalization, and residual connections process only the new token's hidden state, making them computationally trivial during decoding. The attention computation in step 5 is the only operation that touches the full sequence length, and even there, the work is linear in tt rather than quadratic because we compute attention for just one query position.

Cache Memory Requirements

Understanding KV cache memory consumption is essential for deployment planning, as the cache often becomes the primary memory bottleneck in production systems. Unlike model weights that have a fixed size, the cache grows dynamically during generation and can easily exceed the model weights in memory usage for long sequences or large batches.

Memory Formula

To derive the memory requirements, we must account for all the tensors that constitute the complete cache. For a single sequence of length TT, we need to store keys and values for every layer and every head. The fundamental memory formula is:

Memory=2×L×H×dh×T×bytes per element\text{Memory} = 2 \times L \times H \times d_h \times T \times \text{bytes per element}

where:

  • 22: accounts for storing both keys and values, as each requires a separate tensor
  • LL: the number of transformer layers, each maintaining its own independent cache
  • HH: the number of attention heads per layer
  • dhd_h: the dimension of each head, determining the size of individual key and value vectors
  • TT: the sequence length in tokens, the dynamic factor that grows during generation
  • bytes per element\text{bytes per element}: the memory size required for a single floating-point number (e.g., 2 bytes for FP16, 4 bytes for FP32)

Since the total model dimension is typically expressed as dmodel=H×dhd_{\text{model}} = H \times d_h, we can simplify this formula by substituting:

Memory=2×L×dmodel×T×bytes per element\text{Memory} = 2 \times L \times d_{\text{model}} \times T \times \text{bytes per element}

where:

  • dmodeld_{\text{model}}: the total model dimension, equal to H×dhH \times d_h
  • LL: the number of layers in the model
  • TT: the sequence length
  • bytes per element\text{bytes per element}: the numerical precision in bytes

This simplified formula reveals an important insight: the cache memory scales with the product of model depth and width, multiplied by sequence length. Doubling any of these factors doubles the memory requirement.

Example: LLaMA 2 7B

To make these formulas concrete, let's calculate the cache requirements for a widely deployed model. The LLaMA 2 7B model has the following specifications:

  • L=32L = 32 layers
  • dmodel=4096d_{\text{model}} = 4096
  • Typically stored in FP16 (2 bytes per element)

For a single sequence of length T=2048T = 2048, we can compute the cache size step by step:

Memory=2×32×4096×2048×2=1,073,741,824 bytes=1 GB\begin{aligned} \text{Memory} &= 2 \times 32 \times 4096 \times 2048 \times 2 \\ &= 1,073,741,824 \text{ bytes} \\ &= 1 \text{ GB} \end{aligned}

A single 2048-token sequence requires a full gigabyte just for the KV cache. This is separate from the memory needed for model weights, activations, and other inference overhead.

For a batch of 8 sequences at the 4096-token context length, the memory requirement grows proportionally:

Memory=8×2×32×4096×4096×2=17,179,869,184 bytes=16 GB\begin{aligned} \text{Memory} &= 8 \times 2 \times 32 \times 4096 \times 4096 \times 2 \\ &= 17,179,869,184 \text{ bytes} \\ &= 16 \text{ GB} \end{aligned}

For comparison, the model weights themselves require about 14 GB in FP16. This means that for a batch of 8 long-context requests, the KV cache consumes more memory than the entire model. At longer context lengths or larger batch sizes, the cache can substantially exceed the model size, making memory management the critical challenge for deployment.

Scaling with Modern Models

Larger models and longer contexts exacerbate the memory pressure, creating significant challenges for deployment at scale. The following table shows KV cache sizes for various model configurations, illustrating how the requirements grow across different architectural choices:

KV cache memory requirements for different model sizes at their respective context lengths in FP16 precision.
ModelLayersdmodeld_{\text{model}}ContextCache per Sequence
LLaMA 7B3240964K2 GB
LLaMA 13B4051204K3.1 GB
LLaMA 70B8081924K10.0 GB
GPT-4 (estimated)12012288128K590 GB

The cache memory scales linearly with both model size (via the L×dmodelL \times d_{\text{model}} product) and context length. This linear scaling in context length, combined with the quadratic scaling of the attention computation itself, makes long-context inference particularly challenging. A model like GPT-4 with a 128K context window requires hundreds of gigabytes per sequence, necessitating distributed systems and sophisticated memory management strategies.

Out[4]:
Visualization
KV cache memory scaling with context length for different model sizes. Memory grows linearly with context length, and larger models require proportionally more cache memory. The dashed horizontal lines indicate typical GPU memory capacities, showing how quickly the cache can exhaust available memory.
KV cache memory scaling with context length for different model sizes. Memory grows linearly with context length, and larger models require proportionally more cache memory. The dashed horizontal lines indicate typical GPU memory capacities, showing how quickly the cache can exhaust available memory.

Cache Management

Efficient cache management involves critical decisions about memory allocation strategies, handling variable sequence lengths across concurrent requests, and managing batched generation effectively. These operational considerations often determine whether a deployment can achieve its throughput and latency targets.

Static vs Dynamic Allocation

Static allocation pre-allocates cache tensors for the maximum sequence length at the start of generation. This approach reserves all memory upfront, avoiding the overhead of repeated memory allocation during the generation process. Static allocation is preferred when:

  • Maximum sequence length is known in advance
  • Memory fragmentation must be avoided
  • Consistent latency is required

Dynamic allocation grows the cache as needed, typically by allocating larger tensors and copying existing data when the current allocation is exhausted. This saves memory for shorter sequences but introduces allocation overhead and potential fragmentation. Modern frameworks often use a hybrid approach, allocating in chunks (e.g., 128 or 256 tokens at a time) to balance these tradeoffs.

Batched Generation

When generating for multiple sequences simultaneously, cache management becomes more complex because different sequences may have different lengths. The typical approaches are:

  • Padding: Allocate cache based on the longest sequence in the batch, padding shorter sequences. Simple but wasteful when sequence lengths vary significantly.
  • Separate caches: Maintain independent cache tensors per sequence. Avoids wasted memory but complicates the attention computation and prevents batched matrix operations.
  • Paged attention: Allocate cache in fixed-size blocks and track which blocks belong to which sequence. This approach, which we'll explore in detail in the upcoming chapter on Paged Attention, enables efficient memory utilization with variable-length sequences.

Cache Clearing and Context Window Management

When a sequence reaches the model's maximum context length, the system must decide how to proceed:

  • Truncation: Simply stop accepting new tokens or drop the oldest tokens
  • Sliding window: Maintain only the most recent WW tokens, discarding older ones (as used in Mistral's architecture, discussed in Part XIX)
  • Attention sinks: Keep initial tokens plus recent tokens, leveraging the attention sink phenomenon we covered in Part XV

Each strategy has tradeoffs between memory usage, coherence over long conversations, and computational complexity.

Implementation

Let's build a simple KV cache implementation to see how these concepts work in practice.

In[5]:
Code
import torch


class KVCache:
    """
    Simple KV cache for a single attention layer.
    """

    def __init__(
        self,
        batch_size: int,
        num_heads: int,
        head_dim: int,
        max_seq_len: int,
        dtype=torch.float32,
        device="cpu",
    ):
        self.batch_size = batch_size
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.max_seq_len = max_seq_len

        # Pre-allocate cache tensors
        cache_shape = (batch_size, num_heads, max_seq_len, head_dim)
        self.k_cache = torch.zeros(cache_shape, dtype=dtype, device=device)
        self.v_cache = torch.zeros(cache_shape, dtype=dtype, device=device)

        # Track current sequence length
        self.seq_len = 0

    def update(self, k_new: torch.Tensor, v_new: torch.Tensor) -> tuple:
        """
        Append new key-value pairs to cache and return full cache.

        Args:
            k_new: New keys of shape (batch, num_heads, new_len, head_dim)
            v_new: New values of shape (batch, num_heads, new_len, head_dim)

        Returns:
            Tuple of (all_keys, all_values) including new additions
        """
        new_len = k_new.shape[2]

        # Store new keys and values
        self.k_cache[:, :, self.seq_len : self.seq_len + new_len, :] = k_new
        self.v_cache[:, :, self.seq_len : self.seq_len + new_len, :] = v_new

        self.seq_len += new_len

        # Return only the valid portion of the cache
        return (
            self.k_cache[:, :, : self.seq_len, :],
            self.v_cache[:, :, : self.seq_len, :],
        )

    def get_seq_len(self) -> int:
        return self.seq_len
In[6]:
Code
from __future__ import annotations

import math
from typing import TYPE_CHECKING

import torch
import torch.nn as nn
import torch.nn.functional as F

if TYPE_CHECKING:
    pass


class CausalSelfAttentionWithCache(nn.Module):
    """
    Causal self-attention module that supports KV caching.
    """

    def __init__(self, d_model: int, num_heads: int):
        super().__init__()
        assert d_model % num_heads == 0

        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

        # Combined QKV projection for efficiency
        self.qkv_proj = nn.Linear(d_model, 3 * d_model, bias=False)
        self.out_proj = nn.Linear(d_model, d_model, bias=False)

        self.scale = 1.0 / math.sqrt(self.head_dim)

    def forward(
        self,
        x: torch.Tensor,
        kv_cache: "KVCache | None" = None,
        use_cache: bool = False,
    ) -> "tuple[torch.Tensor, KVCache | None]":
        """
        Forward pass with optional KV caching.

        Args:
            x: Input tensor of shape (batch, seq_len, d_model)
            kv_cache: Optional KVCache object for incremental decoding
            use_cache: Whether to use/update the cache

        Returns:
            Tuple of (output, updated_cache)
        """
        batch_size, seq_len, _ = x.shape

        # Compute Q, K, V projections
        qkv = self.qkv_proj(x)
        qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, batch, heads, seq, head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2]

        # Handle KV cache
        if use_cache and kv_cache is not None:
            k, v = kv_cache.update(k, v)
            cache_len = kv_cache.get_seq_len()
        else:
            cache_len = seq_len

        # Compute attention scores
        # q: (batch, heads, seq_len, head_dim)
        # k: (batch, heads, cache_len, head_dim)
        scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale

        # Apply causal mask
        # For cached attention: new queries can attend to all cached keys
        if seq_len == 1 and cache_len > 1:
            # Single token attending to full cache - no masking needed
            pass
        else:
            # Create causal mask for the relevant portion
            mask = torch.triu(
                torch.ones(
                    seq_len, cache_len, device=x.device, dtype=torch.bool
                ),
                diagonal=cache_len - seq_len + 1,
            )
            scores = scores.masked_fill(
                mask.unsqueeze(0).unsqueeze(0), float("-inf")
            )

        attn_weights = F.softmax(scores, dim=-1)

        # Apply attention to values
        out = torch.matmul(attn_weights, v)

        # Reshape and project output
        out = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, self.d_model)
        out = self.out_proj(out)

        return out, kv_cache

Now let's demonstrate how the cache accelerates generation:

In[7]:
Code
# Set up a small model configuration
d_model = 64
num_heads = 4
batch_size = 1
max_seq_len = 32

# Create attention module
attention = CausalSelfAttentionWithCache(d_model, num_heads)

# Simulate a prompt of 8 tokens
prompt_len = 8
prompt = torch.randn(batch_size, prompt_len, d_model)

# PREFILL PHASE: Process entire prompt at once
cache = KVCache(batch_size, num_heads, d_model // num_heads, max_seq_len)
output, cache = attention(prompt, cache, use_cache=True)
Out[8]:
Console
After prefill phase:
  Cache sequence length: 8
  K cache shape (used portion): torch.Size([1, 4, 8, 16])

The output confirms that the cache has been initialized with the 8 prompt tokens. The key cache shape shows we have stored the projection results for these tokens, ready to be attended to by future generation steps.

In[9]:
Code
# DECODE PHASE: Generate tokens one at a time
num_new_tokens = 5
generation_outputs = []

for step in range(num_new_tokens):
    # Simulate embedding of newly generated token
    new_token = torch.randn(batch_size, 1, d_model)

    # Process only the new token, reusing cached K,V
    output, cache = attention(new_token, cache, use_cache=True)
    generation_outputs.append(output)
Out[10]:
Console

After generating 5 tokens:
  Cache sequence length: 13
  Total tokens processed: 13
  K cache used: torch.Size([1, 4, 13, 16])

The cache has grown by 5 tokens, now containing the history for 13 tokens total.

Let's compare the computational savings: The crucial detail is that we only processed the 5 new tokens through the model's projection layers, yet the attention mechanism had access to the full 13-token history via the cache.

In[11]:
Code
def count_qkv_projections(
    prompt_len: int, gen_len: int, use_cache: bool
) -> dict:
    """
    Count Q, K, V projection operations with and without caching.
    """
    if use_cache:
        # Prefill: compute Q, K, V for all prompt tokens
        prefill_projections = prompt_len * 3  # Q, K, V for each token

        # Decode: compute Q, K, V only for new tokens
        decode_projections = gen_len * 3

        total = prefill_projections + decode_projections
    else:
        # Without cache: recompute all projections at each step
        total = 0
        for t in range(1, gen_len + 1):
            seq_len = prompt_len + t
            total += seq_len * 3  # Q, K, V for entire sequence

    return {"total_projections": total, "with_cache": use_cache}


# Compare for realistic generation scenario
prompt_len = 100
gen_len = 200

with_cache = count_qkv_projections(prompt_len, gen_len, use_cache=True)
without_cache = count_qkv_projections(prompt_len, gen_len, use_cache=False)

speedup = without_cache["total_projections"] / with_cache["total_projections"]
Out[12]:
Console
Generating 200 tokens from 100-token prompt:
  With KV cache: 900 projection operations
  Without cache: 120,300 projection operations
  Reduction factor: 133.7x fewer operations

The savings are substantial. For generation, the projection computation savings grow quadratically with sequence length.

Out[13]:
Visualization
Attention pattern during cached generation. The heatmap illustrates attention weights for a new token (position 12) attending to all cached positions (0-12). During decode, each new token computes attention scores against the entire cache, allowing it to gather information from the full context while only computing projections for itself.
Attention pattern during cached generation. The heatmap illustrates attention weights for a new token (position 12) attending to all cached positions (0-12). During decode, each new token computes attention scores against the entire cache, allowing it to gather information from the full context while only computing projections for itself.

Verifying Cache Correctness

A critical property of KV caching is that it must produce identical outputs to non-cached attention. Let's verify this:

In[14]:
Code
def verify_cache_equivalence():
    """
    Verify that cached and non-cached attention produce identical outputs.
    """
    torch.manual_seed(42)

    d_model = 64
    num_heads = 4
    batch_size = 2

    attention = CausalSelfAttentionWithCache(d_model, num_heads)

    # Create a sequence
    prompt = torch.randn(batch_size, 5, d_model)
    next_tokens = torch.randn(batch_size, 3, d_model)
    full_sequence = torch.cat([prompt, next_tokens], dim=1)

    # Method 1: Process entire sequence without cache
    output_no_cache, _ = attention(full_sequence, use_cache=False)

    # Method 2: Process with cache (prefill + decode)
    cache = KVCache(batch_size, num_heads, d_model // num_heads, max_seq_len=32)

    # Prefill with prompt
    output_prefill, cache = attention(prompt, cache, use_cache=True)

    # Decode remaining tokens one at a time
    cached_outputs = [output_prefill]
    for i in range(next_tokens.shape[1]):
        token = next_tokens[:, i : i + 1, :]
        output_step, cache = attention(token, cache, use_cache=True)
        cached_outputs.append(output_step)

    output_with_cache = torch.cat(cached_outputs, dim=1)

    return output_no_cache, output_with_cache


output_no_cache, output_with_cache = verify_cache_equivalence()

# Compare outputs
max_diff = (output_no_cache - output_with_cache).abs().max().item()
are_equal = torch.allclose(output_no_cache, output_with_cache, atol=1e-5)
Out[15]:
Console
Maximum difference between cached and non-cached: 1.19e-07
Outputs are equivalent: True

The numerical equivalence confirms that caching is purely an optimization: it changes how we compute the result, not what we compute.

Profiling Memory Usage

Let's measure actual memory consumption for different configurations:

In[16]:
Code
def calculate_cache_memory(
    num_layers: int,
    d_model: int,
    max_seq_len: int,
    batch_size: int,
    dtype_bytes: int = 2,  # FP16
) -> dict:
    """
    Calculate KV cache memory requirements.
    """
    # Memory per layer: 2 (K and V) × batch × seq_len × d_model × bytes
    memory_per_layer = 2 * batch_size * max_seq_len * d_model * dtype_bytes
    total_memory = num_layers * memory_per_layer

    return {
        "memory_per_layer_mb": memory_per_layer / (1024**2),
        "total_memory_mb": total_memory / (1024**2),
        "total_memory_gb": total_memory / (1024**3),
    }


# Model configurations
models = {
    "LLaMA-7B": {"layers": 32, "d_model": 4096},
    "LLaMA-13B": {"layers": 40, "d_model": 5120},
    "LLaMA-70B": {"layers": 80, "d_model": 8192},
}

# Calculate for different context lengths
context_lengths = [2048, 4096, 8192, 16384]
batch_size_1 = 1
memory_table_1 = []

for name, config in models.items():
    row = []
    for ctx in context_lengths:
        mem = calculate_cache_memory(
            config["layers"], config["d_model"], ctx, batch_size_1
        )
        row.append(mem["total_memory_gb"])
    memory_table_1.append((name, row))

# Calculate for different batch sizes (Context 4096)
target_ctx = 4096
batch_sizes = [1, 8, 32]
memory_table_2 = []

for name, config in models.items():
    row = []
    for bs in batch_sizes:
        mem = calculate_cache_memory(
            config["layers"], config["d_model"], target_ctx, bs
        )
        row.append(mem["total_memory_gb"])
    memory_table_2.append((name, row))
Out[17]:
Console
KV Cache Memory (GB) - Batch Size 1, FP16

Model        |   2048 |   4096 |   8192 |  16384
-------------------------------------------------------
LLaMA-7B     |   1.0G |   2.0G |   4.0G |   8.0G |
LLaMA-13B    |   1.6G |   3.1G |   6.2G |  12.5G |
LLaMA-70B    |   5.0G |  10.0G |  20.0G |  40.0G |

As context length increases, memory usage grows linearly. For the LLaMA-70B model, a single sequence at 8192 context requires 20 GB of cache memory, which is significant even for high-end hardware.

Out[18]:
Console


KV Cache Memory (GB) - Context 4096, FP16

Model        | Batch 1 | Batch 8 | Batch 32
--------------------------------------------------
LLaMA-7B     |    2.0G |   16.0G |   64.0G |
LLaMA-13B    |    3.1G |   25.0G |  100.0G |
LLaMA-70B    |   10.0G |   80.0G |  320.0G |

These numbers reveal why KV cache memory management is critical for production systems. A LLaMA-70B model serving 32 concurrent requests at 4K context requires over 160 GB just for the KV cache, exceeding what most single GPUs can provide.

Out[19]:
Visualization
KV cache memory scaling with batch size at fixed 4096 context length. Memory usage grows linearly with batch size, reaching over 80 GB for LLaMA 70B at batch size 32.
KV cache memory scaling with batch size at fixed 4096 context length. Memory usage grows linearly with batch size, reaching over 80 GB for LLaMA 70B at batch size 32.
Out[20]:
Visualization
Comparison of KV cache memory to model weights for different batch sizes. For large batches (16+), the transient cache memory exceeds the static model weights, particularly for larger models like LLaMA 70B.
Comparison of KV cache memory to model weights for different batch sizes. For large batches (16+), the transient cache memory exceeds the static model weights, particularly for larger models like LLaMA 70B.

Key Parameters

The key parameters for the KV cache implementation are:

  • d_model: The dimensionality of the model's hidden states.
  • num_heads: The number of attention heads.
  • head_dim: The dimension of each attention head (dmodel/num_headsd_{\text{model}} / \text{num\_heads}).
  • max_seq_len: The maximum sequence length the cache can store.
  • batch_size: The number of sequences processed simultaneously.

Limitations and Practical Considerations

KV caching introduces several challenges that drive ongoing research in inference optimization.

Memory consumption scales linearly with sequence length. While this is better than the quadratic scaling we'd face without caching, it still creates a hard constraint on context length and batch size. A server with 80 GB of GPU memory might fit the model weights comfortably but struggle to serve multiple long-context requests concurrently. This tension between throughput and context length is fundamental to LLM deployment.

Memory fragmentation becomes problematic with variable-length requests. When different sequences in a batch have different lengths, naive approaches either waste memory through padding or sacrifice batching efficiency. Production systems must carefully manage cache allocation to maximize GPU utilization. The upcoming chapter on Paged Attention addresses this with a memory management approach inspired by operating system virtual memory.

The cache must persist across the entire generation process. Unlike model weights that are read-only during inference, the KV cache is constantly updated. This means the cache cannot be easily offloaded to CPU during generation without incurring significant latency for memory transfers. Systems that need to handle many concurrent requests must carefully orchestrate which caches are active on GPU at any moment.

Grouped-Query Attention reduces cache size. As we discussed in Part XIX, architectures like LLaMA 2 70B use Grouped-Query Attention (GQA), where multiple query heads share a single key-value head. This reduces KV cache size proportionally. For example, if 8 query heads share 1 KV head, the cache is 8× smaller. This architectural choice was motivated specifically by the memory constraints we've described here.

Summary

The KV cache is a fundamental optimization that makes autoregressive generation practical. By storing and reusing the key and value projections from attention, we avoid recomputing the same values at every generation step. This reduces the computational overhead from quadratic to linear in the number of generated tokens.

The key concepts covered in this chapter are:

  • Redundancy problem: Without caching, each generation step recomputes K and V for all previous tokens, wasting computation
  • Cache structure: Separate K and V tensors per layer, growing with sequence length, storing projections for all processed tokens
  • Memory requirements: Cache size scales as 2×L×dmodel×T×bytes2 \times L \times d_{\text{model}} \times T \times \text{bytes}, easily reaching multiple gigabytes for long sequences
  • Prefill vs decode: The initial prompt processing (prefill) populates the cache, while subsequent generation steps (decode) each process only one token
  • Cache management: Decisions about static vs dynamic allocation, batching strategies, and context length limits significantly impact deployment efficiency

Understanding KV cache mechanics is essential for working with modern LLMs, as cache memory often becomes the primary constraint on throughput and context length. The next chapter examines KV cache memory in greater detail, followed by Paged Attention's approach to efficient memory management and techniques for compressing the cache to extend context lengths further.

Quiz

Ready to test your understanding? Take this quick quiz to reinforce what you've learned about KV cache optimization in transformer inference.

Loading component...

Reference

BIBTEXAcademic
@misc{kvcacheexplainedefficientattentionforllmgeneration, author = {Michael Brenndoerfer}, title = {KV Cache Explained: Efficient Attention for LLM Generation}, year = {2026}, url = {https://mbrenndoerfer.com/writing/kv-cache-transformer-attention-optimization}, organization = {mbrenndoerfer.com}, note = {Accessed: 2025-01-01} }
APAAcademic
Michael Brenndoerfer (2026). KV Cache Explained: Efficient Attention for LLM Generation. Retrieved from https://mbrenndoerfer.com/writing/kv-cache-transformer-attention-optimization
MLAAcademic
Michael Brenndoerfer. "KV Cache Explained: Efficient Attention for LLM Generation." 2026. Web. today. <https://mbrenndoerfer.com/writing/kv-cache-transformer-attention-optimization>.
CHICAGOAcademic
Michael Brenndoerfer. "KV Cache Explained: Efficient Attention for LLM Generation." Accessed today. https://mbrenndoerfer.com/writing/kv-cache-transformer-attention-optimization.
HARVARDAcademic
Michael Brenndoerfer (2026) 'KV Cache Explained: Efficient Attention for LLM Generation'. Available at: https://mbrenndoerfer.com/writing/kv-cache-transformer-attention-optimization (Accessed: today).
SimpleBasic
Michael Brenndoerfer (2026). KV Cache Explained: Efficient Attention for LLM Generation. https://mbrenndoerfer.com/writing/kv-cache-transformer-attention-optimization