Search

Search articles

Mistral Architecture: Sliding Window Attention & Efficient LLM Design

Michael BrenndoerferUpdated August 6, 202549 min read

Deep dive into Mistral 7B's architectural innovations including sliding window attention, grouped query attention, and rolling buffer KV cache. Learn how these techniques achieve LLaMA 2 13B performance with half the parameters.

Track your reading progress

Sign in to mark chapters as read and track your learning journey

Sign in →
Reading Level

Choose your expertise level to adjust how many terms are explained. Beginners see more tooltips, experts see fewer to maintain reading flow. Hover over underlined terms for instant definitions.

Mistral Architecture

When Mistral AI released their 7B parameter model in September 2023, it achieved something remarkable: matching or exceeding the performance of LLaMA 2 13B on most benchmarks while using nearly half the parameters. This wasn't magic. It was the result of carefully chosen architectural innovations that improve efficiency without sacrificing capability. The most significant of these is sliding window attention, a technique that limits the computational burden of attention while preserving the model's ability to reason over long contexts.

This chapter dissects the Mistral architecture, examining how it builds on the LLaMA foundation while introducing key modifications that boost efficiency. We'll implement sliding window attention from scratch, visualize how it differs from full attention, and understand why Mistral's design choices have made it one of the most influential open-weight models in the field.

Architectural Foundation

Mistral 7B inherits the core structure of LLaMA while introducing targeted improvements. Understanding these changes requires first recognizing what stays the same: Mistral uses a decoder-only transformer with RMSNorm for layer normalization, SwiGLU for the feed-forward network, and rotary positional embeddings (RoPE) for encoding position information. These components, proven effective in LLaMA, form the foundation upon which Mistral builds.

The key architectural parameters for Mistral 7B are:

  • Hidden dimension: 4096
  • Number of layers: 32
  • Number of attention heads: 32
  • Number of KV heads: 8 (grouped query attention)
  • Feed-forward hidden dimension: 14336
  • Vocabulary size: 32000
  • Context length: 8192 tokens (sliding window of 4096)
  • Total parameters: 7.3 billion

Three innovations distinguish Mistral from its predecessors: sliding window attention (SWA), rolling buffer KV cache, and pre-fill chunking. Together, these enable efficient processing of long sequences while maintaining strong performance.

In[3]:
Code
# Mistral 7B configuration
mistral_config = {
    "hidden_size": 4096,
    "num_layers": 32,
    "num_attention_heads": 32,
    "num_kv_heads": 8,
    "intermediate_size": 14336,
    "vocab_size": 32000,
    "max_position_embeddings": 8192,
    "sliding_window": 4096,
    "rope_theta": 10000.0,
}
Out[4]:
Console
Mistral 7B Configuration:
----------------------------------------
Hidden Size                         4,096
Num Layers                             32
Num Attention Heads                    32
Num Kv Heads                            8
Intermediate Size                  14,336
Vocab Size                         32,000
Max Position Embeddings             8,192
Sliding Window                      4,096
Rope Theta                        10000.0

The configuration reveals Mistral's efficiency-focused design. The 8 KV heads (compared to 32 query heads) indicate grouped query attention, while the 4096-token sliding window enables the model to handle 8K context lengths without quadratic memory growth.

Sliding Window Attention

To understand sliding window attention, we need to first appreciate the problem it solves. Standard self-attention allows every token to attend to every other token. This seems natural: when predicting the next word, shouldn't the model consider the entire context? The answer is yes, in principle. But the computational reality tells a different story.

The Quadratic Problem

Consider what happens when you compute attention. For each token in your sequence, you calculate a score against every other token. With nn tokens, that's nn scores per token, and nn tokens total, giving n×n=n2n \times n = n^2 scores. Store these in memory, and you need space proportional to n2n^2. Compute them, and you need operations proportional to n2n^2.

This quadratic scaling creates a harsh reality:

Attention score counts grow quadratically with sequence length. Each 4x increase in length produces a 16x increase in scores.
Sequence LengthAttention ScoresGrowth Factor
1,024 tokens~1 million1x
4,096 tokens~17 million16x
16,384 tokens~268 million256x

Doubling the sequence length quadruples both memory and compute. By the time you reach 16K tokens, a common context length for modern models, you're dealing with hundreds of millions of attention scores per layer. Multiply by 32 layers and the numbers add up quickly.

The Sliding Window Insight

Mistral's solution starts with a simple observation: in practice, tokens rarely need to attend to the entire preceding context with equal importance. Recent tokens matter more than distant ones for most language patterns. What if we formalized this intuition and restricted attention to only the most recent tokens?

This is exactly what sliding window attention does. Instead of allowing token ii to attend to all tokens from 0 to ii, we restrict it to a fixed window of size ww. Token ii can only attend to tokens in positions max(0,iw+1)\max(0, i - w + 1) through ii, a span of at most ww tokens.

Think of it like looking through a sliding window that moves along with your position in the sequence. Early in the sequence, when you're at position 3, you might see positions 0, 1, 2, and 3. Later, at position 1000, you see positions 997, 998, 999, and 1000. The window always contains the same number of tokens (or fewer, at the start), regardless of how long the sequence grows.

From Intuition to Formula

Let's formalize this sliding window mathematically. We need a mask that tells the attention mechanism which positions are allowed (within the window) and which are forbidden (outside the window or in the future).

The attention mask Maski,j\text{Mask}_{i,j} determines whether query position ii can attend to key position jj:

Maski,j={0if max(0,iw+1)jiotherwise\text{Mask}_{i,j} = \begin{cases} 0 & \text{if } \max(0, i-w+1) \leq j \leq i \\ -\infty & \text{otherwise} \end{cases}

Let's unpack each component:

  • ii: the query position, meaning the token that's "asking" for information
  • jj: the key position, meaning the token that might provide information
  • ww: the window size (4096 for Mistral 7B)
  • max(0,iw+1)\max(0, i-w+1): the leftmost position in the window, clamped to 0 for early tokens
  • The condition max(0,iw+1)ji\max(0, i-w+1) \leq j \leq i: defines the valid window, positions jj can attend to

The mask values themselves serve a specific purpose in the attention computation:

  • 00: When added to attention scores, leaves them unchanged, allowing attention
  • -\infty: When added to attention scores, makes them negative infinity. After softmax, e=0e^{-\infty} = 0, completely blocking attention

Why this particular formulation? The answer lies in how attention scores become attention weights. Raw scores pass through softmax, which exponentiates each score. By adding -\infty to forbidden positions, we ensure their exponentiated values are exactly zero, not just small. This creates a hard boundary that the model cannot violate.

Complexity Transformation

The power of sliding window attention becomes clear when we analyze its complexity. With full attention, each of nn tokens computes scores against all nn tokens, giving O(n2)O(n^2) operations. With sliding window attention, each of nn tokens computes scores against only ww tokens, giving O(nw)O(n \cdot w) operations.

When ww is fixed (4096 for Mistral), this becomes O(n)O(n): linear in sequence length. The difference is significant:

  • Full attention at 32K tokens: 32,76821.0732{,}768^2 \approx 1.07 billion scores
  • Sliding window at 32K tokens: 32,768×4,09613432{,}768 \times 4{,}096 \approx 134 million scores

That's an 8x reduction at 32K tokens, and the gap widens as sequences grow longer.

Out[5]:
Visualization
Log-scale line plot showing full attention curve rising steeply while sliding window attention grows linearly.
Computational complexity comparison showing quadratic O(n^2) scaling for full attention versus linear O(n*w) scaling for sliding window attention. At 32K tokens, sliding window requires 8x fewer operations.
Out[6]:
Visualization
Lower triangular heatmap showing full causal attention pattern where darker cells indicate allowed attention.
Standard causal attention mask. Every token can attend to all previous tokens, creating a lower triangular pattern. Memory and compute scale quadratically with sequence length.
Banded diagonal heatmap showing sliding window attention pattern with fixed window width.
Sliding window attention mask with window size 4. Each token attends only to the 4 most recent tokens, creating a banded pattern. Memory and compute scale linearly with sequence length.

The visual difference is clear. Standard causal attention fills the entire lower triangle, meaning late tokens must process attention scores for all preceding tokens. Sliding window attention creates a narrow band along the diagonal, greatly reducing the number of computations.

Information Flow Across Layers

A natural concern arises: if each token can only see ww tokens back, how does the model capture long-range dependencies? If the sliding window is 4096 tokens, can a token at position 8000 ever learn anything from a token at position 100?

The answer is yes, and understanding why requires thinking about how information propagates through stacked transformer layers.

The Layered Propagation Mechanism

Consider a concrete example. At layer 1, token 8000 can directly attend to tokens 3905 through 8000 (a window of 4096 positions). It cannot see token 100 directly. But here's the key insight: token 3905 at layer 1 could see tokens 0 through 3905. When token 8000 attends to token 3905 at layer 2, it's attending to a representation that already contains information from the beginning of the sequence.

Each layer extends the reach:

  1. Layer 1: Token ii sees positions [iw+1,i][i-w+1, i], a direct window of ww tokens
  2. Layer 2: Token ii attends to tokens that themselves saw ww tokens back, extending indirect reach to 2w2w
  3. Layer \ell: Information can travel up to ×w\ell \times w positions through layered propagation

This gives us the receptive field formula:

R=×wR_\ell = \ell \times w

where:

  • RR_\ell: the receptive field at layer \ell, measuring how far back information can theoretically travel
  • \ell: the layer number (1 to LL)
  • ww: the sliding window size

Applying the Formula to Mistral

For Mistral 7B with L=32L = 32 layers and w=4096w = 4096:

R32=32×4096=131,072 tokensR_{32} = 32 \times 4096 = 131{,}072 \text{ tokens}

This theoretical receptive field of 131K tokens vastly exceeds Mistral's 8K context length. Even at layer 2, the receptive field is 2×4096=8,1922 \times 4096 = 8{,}192 tokens, covering the entire context window.

This analysis reveals an important design choice: Mistral's window size and layer count are calibrated so that the receptive field exceeds the context length early in the network. By layer 2, any token can (indirectly) access information from any other token in the context.

The Trade-off: Direct vs Indirect Access

While the receptive field ensures information can flow across the entire context, there's a qualitative difference between direct and indirect access. When token 8000 directly attends to token 7999, it sees that token's representation with full fidelity. When it indirectly accesses token 100 through layered propagation, the information has passed through multiple aggregation steps, potentially becoming more diffuse.

For tasks requiring precise retrieval of specific early tokens, this indirect propagation may be less effective than full attention. But for most language modeling tasks, where context builds gradually and recent tokens matter most, the sliding window provides an excellent trade-off between efficiency and capability.

Out[7]:
Visualization
Line plot showing receptive field expanding linearly from 4K tokens at layer 1 to 128K tokens at layer 32.
Receptive field growth across transformer layers. With each layer, the effective context a token can access grows linearly. By layer 32, the receptive field covers over 130K tokens, far exceeding the typical context window.

The visualization shows that by layer 2, the receptive field already exceeds the 8K context length. This means that even though individual attention operations are local, the overall architecture can still capture dependencies spanning the entire input.

Memory and Compute Savings

Understanding the theory of sliding window attention is valuable, but the practical impact is what matters for deployment. Let's derive exactly how much memory and compute we save, starting from first principles.

Deriving the Memory Formulas

The attention mechanism computes a score for each query-key pair. With nn query positions and nn key positions, we have n×nn \times n scores to store. Each attention head maintains its own score matrix, and each score requires storage space determined by the numerical precision.

For full attention, the memory required is:

Mfull=n×n×h×b=n2hbM_{\text{full}} = n \times n \times h \times b = n^2 \cdot h \cdot b

For sliding window attention, each query position only attends to min(n,w)\min(n, w) key positions. When the sequence is shorter than the window (nwn \leq w), we still compute nn scores per query. When the sequence exceeds the window (n>wn > w), we compute exactly ww scores per query:

Msliding=n×min(n,w)×h×bM_{\text{sliding}} = n \times \min(n, w) \times h \times b

Let's define each variable precisely:

  • MM: total memory in bytes for all attention score matrices
  • nn: sequence length (number of tokens)
  • ww: sliding window size (4096 for Mistral)
  • hh: number of attention heads (32 for Mistral)
  • bb: bytes per element (2 for fp16, 4 for fp32)

The Crossover Point

The formulas reveal a key insight: the two methods behave identically when nwn \leq w. Both compute the same n×nn \times n attention matrix because the window hasn't started to "slide" yet. The divergence occurs when n>wn > w:

  • Full attention: Memory grows as n2n^2, accelerating as sequences lengthen
  • Sliding window: Memory grows as nwn \cdot w, which is linear in nn since ww is fixed

For Mistral with w=4096w = 4096, the crossover happens at 4096 tokens. Below this, both methods are equivalent. Above this, sliding window provides increasing savings:

  • At n=8192n = 8192: Sliding window uses 8192×409681922=12\frac{8192 \times 4096}{8192^2} = \frac{1}{2} the memory (50% savings)
  • At n=32768n = 32768: Sliding window uses 32768×4096327682=18\frac{32768 \times 4096}{32768^2} = \frac{1}{8} the memory (87.5% savings)

The savings percentage follows a simple formula: 1wn1 - \frac{w}{n} when n>wn > w. As sequences grow longer, the savings approach 100%.

Quantifying the Savings

Let's compute concrete memory numbers for Mistral's configuration:

In[8]:
Code
def attention_memory_bytes(seq_len, window_size, num_heads, dtype_bytes=2):
    """Calculate memory for attention scores matrix."""
    if window_size is None:
        # Full attention: n x n matrix per head
        return seq_len * seq_len * num_heads * dtype_bytes
    else:
        # Sliding window: n x w matrix per head (approximately)
        effective_size = min(seq_len, window_size)
        return seq_len * effective_size * num_heads * dtype_bytes


def memory_comparison(seq_lengths, window_size=4096, num_heads=32):
    """Compare memory for full vs sliding window attention."""
    results = []
    for n in seq_lengths:
        full_mem = attention_memory_bytes(n, None, num_heads)
        swa_mem = attention_memory_bytes(n, window_size, num_heads)
        savings = (1 - swa_mem / full_mem) * 100 if full_mem > 0 else 0
        results.append(
            {
                "seq_len": n,
                "full_attention_mb": full_mem / (1024**2),
                "sliding_window_mb": swa_mem / (1024**2),
                "savings_pct": savings,
            }
        )
    return results


seq_lengths = [1024, 2048, 4096, 8192, 16384, 32768]
comparison = memory_comparison(seq_lengths)
Out[9]:
Console
Attention Memory Comparison (32 heads, fp16):

  Seq Length      Full Attn   Sliding (4K)    Savings
------------------------------------------------------
       1,024         64.0 MB         64.0 MB       0.0%
       2,048        256.0 MB        256.0 MB       0.0%
       4,096       1024.0 MB       1024.0 MB       0.0%
       8,192       4096.0 MB       2048.0 MB      50.0%
      16,384      16384.0 MB       4096.0 MB      75.0%
      32,768      65536.0 MB       8192.0 MB      87.5%

At the 8K context length that Mistral supports, sliding window attention uses exactly half the memory of full attention. At 32K tokens, the savings exceed 87%. These memory reductions directly translate to the ability to process longer sequences or use larger batch sizes.

Out[10]:
Visualization
Log-scale plot comparing memory growth, with full attention curving upward quadratically and sliding window attention flattening into linear growth.
Memory consumption comparison between full attention and sliding window attention. Full attention grows quadratically (red), while sliding window attention grows linearly after reaching the window size (blue). The gap widens significantly at longer sequences.

The crossover point at the window size (4096) matters. Below this threshold, both methods use similar memory because sequences are shorter than the window. Beyond this point, sliding window attention's linear scaling provides increasing advantages.

Out[11]:
Visualization
Line plot showing memory savings percentage increasing from 0% at 4K tokens to nearly 90% at 64K tokens.
Memory savings from sliding window attention as a function of sequence length. Savings approach 100% as sequences grow longer, following the formula 1 - w/n where w is the window size (4096) and n is the sequence length.

Rolling Buffer KV Cache

During autoregressive generation, transformers cache the key and value projections of previous tokens to avoid redundant computation. In standard transformers, this KV cache grows linearly with sequence length, eventually consuming significant memory for long generations.

Mistral implements a rolling buffer KV cache that exploits sliding window attention. Since tokens beyond the window boundary will never be attended to again, their cached keys and values can be safely discarded. The cache maintains only the most recent ww positions, using modular indexing to overwrite old entries.

In[12]:
Code
class RollingKVCache:
    """
    Rolling buffer cache for sliding window attention.

    Only stores the most recent `window_size` key-value pairs,
    using modular indexing to overwrite old entries.
    """

    def __init__(self, window_size, num_heads, head_dim, dtype=np.float16):
        self.window_size = window_size
        self.num_heads = num_heads
        self.head_dim = head_dim

        # Pre-allocate fixed-size buffers
        self.key_cache = np.zeros(
            (window_size, num_heads, head_dim), dtype=dtype
        )
        self.value_cache = np.zeros(
            (window_size, num_heads, head_dim), dtype=dtype
        )
        self.current_length = 0

    def update(self, position, key, value):
        """Add new key-value pair at the given position."""
        # Use modular indexing to wrap around
        cache_idx = position % self.window_size
        self.key_cache[cache_idx] = key
        self.value_cache[cache_idx] = value
        self.current_length = min(self.current_length + 1, self.window_size)

    def get_valid_entries(self, current_position):
        """Retrieve cached entries visible from current position."""
        if current_position < self.window_size:
            # Haven't filled the buffer yet
            return self.key_cache[: current_position + 1], self.value_cache[
                : current_position + 1
            ]
        else:
            # Buffer is full, need to handle wrap-around
            start_pos = current_position - self.window_size + 1
            indices = [
                (start_pos + i) % self.window_size
                for i in range(self.window_size)
            ]
            return self.key_cache[indices], self.value_cache[indices]

    def memory_usage_bytes(self):
        """Calculate actual memory usage."""
        return self.key_cache.nbytes + self.value_cache.nbytes
Out[13]:
Console
Rolling KV Cache Demonstration (window=4):

As we add tokens, old entries are overwritten:

Position 0: Stored at index 0, visible range [0, 0]
Position 1: Stored at index 1, visible range [0, 1]
Position 2: Stored at index 2, visible range [0, 2]
Position 3: Stored at index 3, visible range [0, 3]
Position 4: Stored at index 0, visible range [1, 4]
Position 5: Stored at index 1, visible range [2, 5]
Position 6: Stored at index 2, visible range [3, 6]
Position 7: Stored at index 3, visible range [4, 7]

The rolling buffer ensures that memory usage remains constant regardless of how many tokens are generated. For a 32-layer Mistral model with 8 KV heads per layer, the total cache memory is fixed at:

Cache Size=2×L×w×hkv×dh×b\text{Cache Size} = 2 \times L \times w \times h_{\text{kv}} \times d_h \times b

where:

  • Cache Size\text{Cache Size}: total memory in bytes for the KV cache
  • 22: factor accounting for storing both keys and values (each cached position stores one K vector and one V vector)
  • LL: number of transformer layers (32 for Mistral 7B), since each layer maintains its own KV cache
  • ww: sliding window size (4096 for Mistral), the maximum number of positions stored
  • hkvh_{\text{kv}}: number of key-value heads per layer (8 for Mistral 7B with GQA)
  • dhd_h: dimension of each head, computed as dh=dmodel/hqd_h = d_{\text{model}} / h_q where dmodel=4096d_{\text{model}} = 4096 and hq=32h_q = 32, giving dh=128d_h = 128
  • bb: bytes per element (2 for fp16, 4 for fp32)

For Mistral 7B with fp16 precision, the fixed cache size is:

2×32×4096×8×128×2=536,870,912 bytes512 MB2 \times 32 \times 4096 \times 8 \times 128 \times 2 = 536,870,912 \text{ bytes} \approx 512 \text{ MB}

This fixed 512 MB footprint contrasts sharply with standard transformers, where the cache grows linearly with sequence length and could reach several gigabytes for long generations.

In[14]:
Code
def kv_cache_size_bytes(
    num_layers, window_size, num_kv_heads, head_dim, dtype_bytes=2
):
    """Calculate KV cache memory for rolling buffer."""
    # 2 for keys and values
    return 2 * num_layers * window_size * num_kv_heads * head_dim * dtype_bytes


def compare_cache_sizes(seq_lengths, config):
    """Compare rolling vs standard KV cache."""
    num_layers = config["num_layers"]
    num_kv_heads = config["num_kv_heads"]
    head_dim = config["hidden_size"] // config["num_attention_heads"]
    window_size = config["sliding_window"]

    results = []
    for seq_len in seq_lengths:
        # Standard cache: stores all positions
        standard_size = kv_cache_size_bytes(
            num_layers, seq_len, num_kv_heads, head_dim
        )
        # Rolling cache: stores only window_size positions
        rolling_size = kv_cache_size_bytes(
            num_layers, min(seq_len, window_size), num_kv_heads, head_dim
        )
        results.append(
            {
                "seq_len": seq_len,
                "standard_mb": standard_size / (1024**2),
                "rolling_mb": rolling_size / (1024**2),
            }
        )
    return results


cache_comparison = compare_cache_sizes(seq_lengths, mistral_config)
Out[15]:
Console
KV Cache Memory Comparison (Mistral 7B config):

 Sequence Length   Standard Cache    Rolling Cache
----------------------------------------------------
         4,096.0          512.0 MB          512.0 MB
4,404.743718592965          550.6 MB          512.0 MB
4,713.48743718593          589.2 MB          512.0 MB
5,022.231155778894          627.8 MB          512.0 MB
5,330.974874371859          666.4 MB          512.0 MB
5,639.718592964824          705.0 MB          512.0 MB
5,948.462311557789          743.6 MB          512.0 MB
6,257.206030150754          782.2 MB          512.0 MB
6,565.949748743718          820.7 MB          512.0 MB
6,874.693467336683          859.3 MB          512.0 MB
7,183.437185929648          897.9 MB          512.0 MB
7,492.180904522613          936.5 MB          512.0 MB
7,800.924623115578          975.1 MB          512.0 MB
8,109.668341708542         1013.7 MB          512.0 MB
8,418.412060301507         1052.3 MB          512.0 MB
8,727.155778894472         1090.9 MB          512.0 MB
9,035.899497487437         1129.5 MB          512.0 MB
9,344.643216080402         1168.1 MB          512.0 MB
9,653.386934673366         1206.7 MB          512.0 MB
9,962.130653266331         1245.3 MB          512.0 MB
10,270.874371859296         1283.9 MB          512.0 MB
10,579.61809045226         1322.5 MB          512.0 MB
10,888.361809045226         1361.0 MB          512.0 MB
11,197.10552763819         1399.6 MB          512.0 MB
11,505.849246231155         1438.2 MB          512.0 MB
11,814.59296482412         1476.8 MB          512.0 MB
12,123.336683417085         1515.4 MB          512.0 MB
12,432.08040201005         1554.0 MB          512.0 MB
12,740.824120603014         1592.6 MB          512.0 MB
13,049.56783919598         1631.2 MB          512.0 MB
13,358.311557788944         1669.8 MB          512.0 MB
13,667.055276381909         1708.4 MB          512.0 MB
13,975.798994974873         1747.0 MB          512.0 MB
14,284.542713567838         1785.6 MB          512.0 MB
14,593.286432160803         1824.2 MB          512.0 MB
14,902.030150753768         1862.8 MB          512.0 MB
15,210.773869346733         1901.3 MB          512.0 MB
15,519.517587939697         1939.9 MB          512.0 MB
15,828.261306532662         1978.5 MB          512.0 MB
16,137.005025125627         2017.1 MB          512.0 MB
16,445.748743718592         2055.7 MB          512.0 MB
16,754.49246231156         2094.3 MB          512.0 MB
17,063.23618090452         2132.9 MB          512.0 MB
17,371.979899497484         2171.5 MB          512.0 MB
17,680.72361809045         2210.1 MB          512.0 MB
17,989.467336683418         2248.7 MB          512.0 MB
18,298.21105527638         2287.3 MB          512.0 MB
18,606.954773869344         2325.9 MB          512.0 MB
18,915.69849246231         2364.5 MB          512.0 MB
19,224.442211055277         2403.1 MB          512.0 MB
19,533.18592964824         2441.6 MB          512.0 MB
19,841.929648241203         2480.2 MB          512.0 MB
20,150.67336683417         2518.8 MB          512.0 MB
20,459.417085427136         2557.4 MB          512.0 MB
20,768.1608040201         2596.0 MB          512.0 MB
21,076.904522613062         2634.6 MB          512.0 MB
21,385.64824120603         2673.2 MB          512.0 MB
21,694.391959798995         2711.8 MB          512.0 MB
22,003.13567839196         2750.4 MB          512.0 MB
22,311.87939698492         2789.0 MB          512.0 MB
22,620.623115577888         2827.6 MB          512.0 MB
22,929.366834170854         2866.2 MB          512.0 MB
23,238.110552763817         2904.8 MB          512.0 MB
23,546.85427135678         2943.4 MB          512.0 MB
23,855.597989949747         2981.9 MB          512.0 MB
24,164.341708542714         3020.5 MB          512.0 MB
24,473.085427135677         3059.1 MB          512.0 MB
24,781.82914572864         3097.7 MB          512.0 MB
25,090.572864321606         3136.3 MB          512.0 MB
25,399.316582914573         3174.9 MB          512.0 MB
25,708.060301507536         3213.5 MB          512.0 MB
26,016.8040201005         3252.1 MB          512.0 MB
26,325.547738693465         3290.7 MB          512.0 MB
26,634.291457286432         3329.3 MB          512.0 MB
26,943.035175879395         3367.9 MB          512.0 MB
27,251.778894472358         3406.5 MB          512.0 MB
27,560.522613065325         3445.1 MB          512.0 MB
27,869.26633165829         3483.7 MB          512.0 MB
28,178.010050251254         3522.3 MB          512.0 MB
28,486.753768844217         3560.8 MB          512.0 MB
28,795.497487437184         3599.4 MB          512.0 MB
29,104.24120603015         3638.0 MB          512.0 MB
29,412.984924623113         3676.6 MB          512.0 MB
29,721.728643216076         3715.2 MB          512.0 MB
30,030.472361809043         3753.8 MB          512.0 MB
30,339.21608040201         3792.4 MB          512.0 MB
30,647.959798994973         3831.0 MB          512.0 MB
30,956.703517587935         3869.6 MB          512.0 MB
31,265.447236180902         3908.2 MB          512.0 MB
31,574.19095477387         3946.8 MB          512.0 MB
31,882.93467336683         3985.4 MB          512.0 MB
32,191.678391959795         4024.0 MB          512.0 MB
32,500.42211055276         4062.6 MB          512.0 MB
32,809.16582914573         4101.1 MB          512.0 MB
33,117.90954773869         4139.7 MB          512.0 MB
33,426.653266331654         4178.3 MB          512.0 MB
33,735.39698492462         4216.9 MB          512.0 MB
34,044.14070351759         4255.5 MB          512.0 MB
34,352.884422110554         4294.1 MB          512.0 MB
34,661.62814070351         4332.7 MB          512.0 MB
34,970.37185929648         4371.3 MB          512.0 MB
35,279.115577889446         4409.9 MB          512.0 MB
35,587.859296482406         4448.5 MB          512.0 MB
35,896.60301507537         4487.1 MB          512.0 MB
36,205.34673366834         4525.7 MB          512.0 MB
36,514.090452261305         4564.3 MB          512.0 MB
36,822.83417085427         4602.9 MB          512.0 MB
37,131.57788944723         4641.4 MB          512.0 MB
37,440.3216080402         4680.0 MB          512.0 MB
37,749.065326633165         4718.6 MB          512.0 MB
38,057.809045226124         4757.2 MB          512.0 MB
38,366.55276381909         4795.8 MB          512.0 MB
38,675.29648241206         4834.4 MB          512.0 MB
38,984.040201005024         4873.0 MB          512.0 MB
39,292.78391959799         4911.6 MB          512.0 MB
39,601.52763819095         4950.2 MB          512.0 MB
39,910.27135678392         4988.8 MB          512.0 MB
40,219.01507537688         5027.4 MB          512.0 MB
40,527.75879396984         5066.0 MB          512.0 MB
40,836.50251256281         5104.6 MB          512.0 MB
41,145.246231155776         5143.2 MB          512.0 MB
41,453.98994974874         5181.7 MB          512.0 MB
41,762.73366834171         5220.3 MB          512.0 MB
42,071.47738693467         5258.9 MB          512.0 MB
42,380.221105527635         5297.5 MB          512.0 MB
42,688.9648241206         5336.1 MB          512.0 MB
42,997.70854271356         5374.7 MB          512.0 MB
43,306.45226130653         5413.3 MB          512.0 MB
43,615.195979899494         5451.9 MB          512.0 MB
43,923.93969849246         5490.5 MB          512.0 MB
44,232.68341708543         5529.1 MB          512.0 MB
44,541.42713567839         5567.7 MB          512.0 MB
44,850.17085427135         5606.3 MB          512.0 MB
45,158.91457286432         5644.9 MB          512.0 MB
45,467.65829145728         5683.5 MB          512.0 MB
45,776.402010050246         5722.1 MB          512.0 MB
46,085.14572864321         5760.6 MB          512.0 MB
46,393.88944723618         5799.2 MB          512.0 MB
46,702.633165829146         5837.8 MB          512.0 MB
47,011.376884422105         5876.4 MB          512.0 MB
47,320.12060301507         5915.0 MB          512.0 MB
47,628.86432160804         5953.6 MB          512.0 MB
47,937.608040201         5992.2 MB          512.0 MB
48,246.351758793964         6030.8 MB          512.0 MB
48,555.09547738693         6069.4 MB          512.0 MB
48,863.8391959799         6108.0 MB          512.0 MB
49,172.582914572864         6146.6 MB          512.0 MB
49,481.32663316582         6185.2 MB          512.0 MB
49,790.07035175879         6223.8 MB          512.0 MB
50,098.81407035176         6262.4 MB          512.0 MB
50,407.557788944716         6300.9 MB          512.0 MB
50,716.30150753768         6339.5 MB          512.0 MB
51,025.04522613065         6378.1 MB          512.0 MB
51,333.788944723616         6416.7 MB          512.0 MB
51,642.53266331658         6455.3 MB          512.0 MB
51,951.27638190954         6493.9 MB          512.0 MB
52,260.02010050251         6532.5 MB          512.0 MB
52,568.763819095475         6571.1 MB          512.0 MB
52,877.507537688434         6609.7 MB          512.0 MB
53,186.2512562814         6648.3 MB          512.0 MB
53,494.99497487437         6686.9 MB          512.0 MB
53,803.738693467334         6725.5 MB          512.0 MB
54,112.4824120603         6764.1 MB          512.0 MB
54,421.22613065326         6802.7 MB          512.0 MB
54,729.96984924623         6841.2 MB          512.0 MB
55,038.71356783919         6879.8 MB          512.0 MB
55,347.45728643215         6918.4 MB          512.0 MB
55,656.20100502512         6957.0 MB          512.0 MB
55,964.944723618086         6995.6 MB          512.0 MB
56,273.68844221105         7034.2 MB          512.0 MB
56,582.43216080402         7072.8 MB          512.0 MB
56,891.17587939698         7111.4 MB          512.0 MB
57,199.919597989945         7150.0 MB          512.0 MB
57,508.66331658291         7188.6 MB          512.0 MB
57,817.40703517587         7227.2 MB          512.0 MB
58,126.15075376884         7265.8 MB          512.0 MB
58,434.894472361804         7304.4 MB          512.0 MB
58,743.63819095477         7343.0 MB          512.0 MB
59,052.38190954774         7381.5 MB          512.0 MB
59,361.1256281407         7420.1 MB          512.0 MB
59,669.86934673366         7458.7 MB          512.0 MB
59,978.61306532663         7497.3 MB          512.0 MB
60,287.35678391959         7535.9 MB          512.0 MB
60,596.100502512556         7574.5 MB          512.0 MB
60,904.84422110552         7613.1 MB          512.0 MB
61,213.58793969849         7651.7 MB          512.0 MB
61,522.331658291456         7690.3 MB          512.0 MB
61,831.075376884415         7728.9 MB          512.0 MB
62,139.81909547738         7767.5 MB          512.0 MB
62,448.56281407035         7806.1 MB          512.0 MB
62,757.30653266331         7844.7 MB          512.0 MB
63,066.050251256274         7883.3 MB          512.0 MB
63,374.79396984924         7921.8 MB          512.0 MB
63,683.53768844221         7960.4 MB          512.0 MB
63,992.281407035174         7999.0 MB          512.0 MB
64,301.02512562813         8037.6 MB          512.0 MB
64,609.7688442211         8076.2 MB          512.0 MB
64,918.51256281407         8114.8 MB          512.0 MB
65,227.256281407026         8153.4 MB          512.0 MB
        65,536.0         8192.0 MB          512.0 MB

Rolling cache is fixed at 512.0 MB regardless of sequence length

The fixed memory footprint is valuable for deployment scenarios where memory budgets are tight and generation lengths are unpredictable.

Out[16]:
Visualization
Line plot showing standard KV cache growing linearly while rolling buffer stays flat at 512 MB.
KV cache memory comparison between standard (growing) and rolling buffer (fixed) approaches. The rolling buffer caps memory at 512 MB regardless of sequence length, while standard cache grows linearly.

Grouped Query Attention

Mistral employs grouped query attention (GQA), which reduces the number of key-value heads relative to query heads. Instead of the standard multi-head attention where each query head has its own key-value pair, GQA groups multiple query heads to share the same key-value heads.

Grouped Query Attention

GQA is a memory-bandwidth optimization where hqh_q query heads share hkvh_{kv} key-value heads, with hqh_q being a multiple of hkvh_{kv}. In this notation:

  • hqh_q: number of query heads (32 for Mistral)
  • hkvh_{kv}: number of key-value heads (8 for Mistral)
  • hq/hkvh_q / h_{kv}: the grouping ratio (4 for Mistral, meaning 4 query heads share each KV head)

This reduces KV cache size by a factor of hq/hkvh_q / h_{kv} while maintaining most of the representational capacity of full multi-head attention.

For Mistral 7B:

  • Query heads: 32
  • KV heads: 8
  • Ratio: 4 query heads share each KV head

This means the KV cache is 4x smaller than it would be with standard multi-head attention, directly improving inference throughput by reducing memory bandwidth requirements.

Out[17]:
Visualization
Diagram showing 32 query heads grouped into 8 groups of 4, with each group sharing one KV head pair.
Grouped Query Attention in Mistral. Each group of 4 query heads (shades of blue) shares a single key-value head pair (orange/green). This reduces KV cache size by 4x while maintaining 32 independent query heads for representation.

The combination of GQA and the rolling buffer creates substantial memory savings. With 8 KV heads instead of 32, and a fixed window size instead of growing with sequence length, Mistral's inference memory footprint remains manageable even for long-context applications.

Pre-fill and Chunking

When processing a prompt before generation begins, standard transformers compute attention for all tokens at once. For long prompts, this can create memory spikes that exceed available capacity. Mistral introduces pre-fill chunking to address this.

The idea is simple: instead of processing the entire prompt in a single forward pass, split it into chunks of size ww (the window size). Process each chunk sequentially, using the rolling KV cache to maintain context between chunks. This bounds memory usage during pre-fill to the same level as during generation.

In[18]:
Code
def chunked_prefill_memory(
    prompt_length, window_size, hidden_size, num_heads, dtype_bytes=2
):
    """
    Calculate peak memory during chunked pre-fill.

    Instead of processing all tokens at once, we process
    chunks of size window_size sequentially.
    """
    chunk_size = min(prompt_length, window_size)

    # Attention matrix: chunk_size x chunk_size per head
    attention_mem = chunk_size * chunk_size * num_heads * dtype_bytes

    # Activations for current chunk: chunk_size x hidden_size
    activation_mem = chunk_size * hidden_size * dtype_bytes

    return attention_mem + activation_mem


def standard_prefill_memory(
    prompt_length, hidden_size, num_heads, dtype_bytes=2
):
    """Calculate memory for standard (non-chunked) pre-fill."""
    # Full attention matrix: n x n per head
    attention_mem = prompt_length * prompt_length * num_heads * dtype_bytes

    # Activations: n x hidden_size
    activation_mem = prompt_length * hidden_size * dtype_bytes

    return attention_mem + activation_mem


prompt_lengths = [1024, 2048, 4096, 8192, 16384]
window_size = 4096
hidden_size = 4096
num_heads = 32
Out[19]:
Console
Pre-fill Memory Comparison:

 Prompt Length       Standard        Chunked    Reduction
----------------------------------------------------------
         1,024         72.0 MB         72.0 MB         0.0%
         2,048        272.0 MB        272.0 MB         0.0%
         4,096       1056.0 MB       1056.0 MB         0.0%
         8,192       4160.0 MB       1056.0 MB        74.6%
        16,384      16512.0 MB       1056.0 MB        93.6%

For an 8K token prompt, chunked pre-fill uses 75% less memory than standard pre-fill. This reduction enables processing of longer prompts on the same hardware, or processing multiple requests concurrently.

Out[20]:
Visualization
Line plot comparing standard pre-fill memory growing quadratically with chunked pre-fill staying nearly flat.
Pre-fill memory comparison showing quadratic growth for standard pre-fill versus constant memory for chunked pre-fill. Beyond the window size, chunked pre-fill provides substantial memory savings.

Implementation: Sliding Window Attention

Having understood the mathematics behind sliding window attention, let's implement it from scratch. Building the mechanism ourselves will solidify the concepts and reveal implementation details that matter in practice.

Step 1: Building the Attention Mask

The foundation of sliding window attention is the mask that enforces which positions can attend to which. We need to encode two constraints:

  1. Causal constraint: Position ii cannot attend to any position j>ij > i (no looking into the future)
  2. Window constraint: Position ii cannot attend to any position j<iw+1j < i - w + 1 (outside the sliding window)

Both constraints translate to blocking certain positions with -\infty values:

In[21]:
Code
def create_sliding_window_mask(seq_len, window_size):
    """
    Create a causal mask with sliding window constraint.

    Returns a mask where:
    - 0 indicates positions that CAN be attended to
    - -inf indicates positions that CANNOT be attended to

    Args:
        seq_len: Length of the sequence
        window_size: Size of the sliding window

    Returns:
        Mask tensor of shape (seq_len, seq_len)
    """
    # Start with zeros (allowing all positions)
    mask = np.zeros((seq_len, seq_len))

    for i in range(seq_len):
        for j in range(seq_len):
            if j > i:
                # Future positions: block (causal constraint)
                mask[i, j] = float("-inf")
            elif i - j >= window_size:
                # Past positions beyond window: block
                mask[i, j] = float("-inf")

    return mask

The logic is straightforward: for each position pair (i,j)(i, j), we check if jj is in the future (j>ij > i) or outside the window (ijwi - j \geq w). If either condition is true, we block with -\infty.

Let's visualize the mask to verify it matches our expectations:

Out[22]:
Console
Sliding Window Mask (8 tokens, window=4):
0 = can attend, -inf = blocked

Token 0:   .  X  X  X  X  X  X  X
Token 1:   .  .  X  X  X  X  X  X
Token 2:   .  .  .  X  X  X  X  X
Token 3:   .  .  .  .  X  X  X  X
Token 4:   X  .  .  .  .  X  X  X
Token 5:   X  X  .  .  .  .  X  X
Token 6:   X  X  X  .  .  .  .  X
Token 7:   X  X  X  X  .  .  .  .

Reading this output row by row tells the story of the sliding window:

  • Token 0: Can only attend to itself (the sequence just started)
  • Token 1: Can attend to positions 0-1 (window not yet full)
  • Token 2: Can attend to positions 0-2 (window not yet full)
  • Token 3: Can attend to positions 0-3 (window now full with 4 positions)
  • Token 4: The window slides! Can attend to positions 1-4, blocking position 0
  • Tokens 5-7: The window continues sliding, always covering exactly 4 positions

This banded diagonal structure is what transforms quadratic complexity into linear.

Step 2: The Complete Attention Forward Pass

With the mask in place, we can implement the full sliding window attention computation. The algorithm follows the standard attention formula, but with our sliding window mask applied before softmax:

In[23]:
Code
def sliding_window_attention(
    queries,  # (batch, seq_len, num_heads, head_dim)
    keys,  # (batch, seq_len, num_heads, head_dim)
    values,  # (batch, seq_len, num_heads, head_dim)
    window_size,
    scale=None,
):
    """
    Compute sliding window attention.

    Each query can only attend to keys within the sliding window.
    """
    batch_size, seq_len, num_heads, head_dim = queries.shape

    if scale is None:
        scale = 1.0 / np.sqrt(head_dim)

    # Reshape for batched matrix multiplication
    # (batch, num_heads, seq_len, head_dim)
    q = queries.transpose(0, 2, 1, 3)
    k = keys.transpose(0, 2, 1, 3)
    v = values.transpose(0, 2, 1, 3)

    # Compute attention scores: (batch, num_heads, seq_len, seq_len)
    scores = np.matmul(q, k.transpose(0, 1, 3, 2)) * scale

    # Apply sliding window mask
    mask = create_sliding_window_mask(seq_len, window_size)
    scores = scores + mask  # Broadcasting adds mask to all batches and heads

    # Softmax over keys dimension
    # Use stable softmax: subtract max before exp
    scores_max = np.max(scores, axis=-1, keepdims=True)
    scores_exp = np.exp(scores - scores_max)
    attention_weights = scores_exp / np.sum(scores_exp, axis=-1, keepdims=True)

    # Apply attention to values
    output = np.matmul(attention_weights, v)

    # Reshape back to (batch, seq_len, num_heads, head_dim)
    output = output.transpose(0, 2, 1, 3)

    return output, attention_weights

Let's walk through this implementation step by step:

  1. Reshaping: We transpose from (batch, seq, heads, dim) to (batch, heads, seq, dim) to enable batched matrix multiplication across all heads simultaneously

  2. Score computation: The matrix product q @ k.T computes dot products between all query-key pairs, giving raw attention scores. We scale by 1/dk1/\sqrt{d_k} to prevent the softmax from becoming too peaked

  3. Mask application: Adding the mask (00 or -\infty) to scores enforces our sliding window constraint. Positions with -\infty will have zero attention weight after softmax

  4. Stable softmax: We subtract the maximum score before exponentiating to prevent numerical overflow. This is mathematically equivalent to standard softmax but numerically stable

  5. Value aggregation: The final matmul blends values according to attention weights, producing the output representation

Step 3: Testing and Visualization

Let's verify our implementation works correctly and visualize the resulting attention patterns:

In[24]:
Code
# Create sample data with fixed seed for reproducibility
np.random.seed(42)
batch_size = 1
seq_len = 16
num_heads = 4
head_dim = 32
window_size = 4

queries = np.random.randn(batch_size, seq_len, num_heads, head_dim).astype(
    np.float32
)
keys = np.random.randn(batch_size, seq_len, num_heads, head_dim).astype(
    np.float32
)
values = np.random.randn(batch_size, seq_len, num_heads, head_dim).astype(
    np.float32
)

# Run sliding window attention
output, attention_weights = sliding_window_attention(
    queries, keys, values, window_size
)
Out[25]:
Console
Input shapes:
  Queries: (1, 16, 4, 32)
  Keys:    (1, 16, 4, 32)
  Values:  (1, 16, 4, 32)

Output shape: (1, 16, 4, 32)
Attention weights shape: (1, 4, 16, 16)

Window size: 4
Attention pattern is banded with width 4

The output maintains the same shape as the input queries, as expected. The attention weights tensor has shape (batch, heads, seq_len, seq_len), but only the diagonal band of width 4 contains non-zero values due to our sliding window mask.

Out[26]:
Visualization
Heatmap showing sliding window attention weights for head 0 with banded diagonal structure.
Attention pattern for head 0. Each row shows which keys a query attends to, with the sliding window creating a banded diagonal structure.
Heatmap showing sliding window attention weights for head 1 with different weight distribution.
Attention pattern for head 1. Different heads learn different attention patterns within the same sliding window constraint.

The attention heatmaps reveal several important properties of sliding window attention:

The banded structure: Each row shows attention concentrated in a diagonal band of width 4 (our window size). The strict boundaries are enforced by our mask, the positions outside the band have exactly zero attention weight.

Within-window variation: Inside the allowed window, different heads develop different patterns. Head 0 might emphasize recent positions while Head 1 focuses on slightly older ones. This diversity is valuable, different heads can specialize in different types of relationships.

Preserved expressiveness: Despite the hard constraint on which positions can be attended, the model retains full flexibility in how attention is distributed within the window. The softmax ensures weights sum to 1, but the specific allocation depends on the learned query-key interactions.

This implementation demonstrates that sliding window attention is a straightforward modification to standard attention. The only change is adding a carefully constructed mask before softmax. Everything else, the score computation, softmax normalization, and value aggregation, remains identical.

Mistral vs LLaMA: Architectural Comparison

Let's directly compare the architectural choices between Mistral 7B and LLaMA 2 7B to understand what changed:

In[27]:
Code
llama2_7b_config = {
    "hidden_size": 4096,
    "num_layers": 32,
    "num_attention_heads": 32,
    "num_kv_heads": 32,  # Full MHA in base LLaMA 2 7B
    "intermediate_size": 11008,
    "vocab_size": 32000,
    "max_position_embeddings": 4096,
    "sliding_window": None,  # Full attention
}


def compare_configs(config1, config2, name1, name2):
    """Compare two model configurations."""
    all_keys = set(config1.keys()) | set(config2.keys())

    comparison = []
    for key in sorted(all_keys):
        val1 = config1.get(key, "N/A")
        val2 = config2.get(key, "N/A")
        if val1 != val2:
            comparison.append((key, val1, val2, "←"))
        else:
            comparison.append((key, val1, val2, ""))

    return comparison
Out[28]:
Console
Architecture Comparison: LLaMA 2 7B vs Mistral 7B

Parameter                          LLaMA 2 7B     Mistral 7B  Diff
--------------------------------------------------------------------
Hidden Size                             4,096          4,096      
Intermediate Size                      11,008         14,336     ←
Max Position Embeddings                 4,096          8,192     ←
Num Attention Heads                        32             32      
Num Kv Heads                               32              8     ←
Num Layers                                 32             32      
Rope Theta                                N/A        10000.0     ←
Sliding Window                           Full          4,096     ←
Vocab Size                             32,000         32,000      

The key differences are:

  1. KV heads: Mistral uses 8 KV heads (GQA) vs LLaMA's 32 (MHA), reducing KV cache by 4x
  2. FFN size: Mistral uses a larger intermediate dimension (14336 vs 11008), adding capacity
  3. Context length: Mistral supports 8K tokens vs LLaMA's 4K
  4. Sliding window: Mistral uses a 4096-token window, LLaMA uses full attention

These changes create a more efficient model that can handle longer contexts with less memory.

Performance Analysis

Mistral 7B achieves strong benchmark performance despite having fewer parameters than comparable models. The efficiency gains from its architectural choices translate to practical improvements in deployment.

In[29]:
Code
# Benchmark comparison (approximate scores from public benchmarks)
benchmarks = {
    "Model": ["LLaMA 2 7B", "LLaMA 2 13B", "Mistral 7B"],
    "Parameters": [7.0, 13.0, 7.3],
    "MMLU (5-shot)": [46.8, 54.8, 60.1],
    "HellaSwag": [77.2, 80.7, 81.3],
    "ARC-Challenge": [53.0, 59.4, 61.1],
    "TruthfulQA": [38.8, 37.4, 42.2],
    "WinoGrande": [74.0, 76.6, 78.4],
}

# Calculate efficiency (performance per billion parameters)
efficiency_metric = "MMLU (5-shot)"
Out[30]:
Console
Benchmark Comparison:

      Model  Parameters  MMLU (5-shot)  HellaSwag  ARC-Challenge  TruthfulQA  WinoGrande
 LLaMA 2 7B         7.0           46.8       77.2           53.0        38.8        74.0
LLaMA 2 13B        13.0           54.8       80.7           59.4        37.4        76.6
 Mistral 7B         7.3           60.1       81.3           61.1        42.2        78.4

============================================================
Efficiency Analysis (MMLU score per billion parameters):

LLaMA 2 7B      6.69 points/B params
LLaMA 2 13B     4.22 points/B params
Mistral 7B      8.23 points/B params

Mistral 7B achieves the highest efficiency at over 8 MMLU points per billion parameters, compared to about 6.7 for LLaMA 2 7B and 4.2 for LLaMA 2 13B. Mistral's architectural innovations translate directly to better parameter efficiency, not just reduced memory usage.

Out[31]:
Visualization
Scatter plot with MMLU score on y-axis and parameter count on x-axis, showing Mistral 7B above the trend line.
Parameter efficiency visualization. Mistral 7B (blue) achieves the highest MMLU score relative to its parameter count. The dashed lines show efficiency ratios, where steeper slopes indicate better performance per parameter.
Out[32]:
Visualization
Grouped bar chart comparing Mistral 7B to LLaMA 2 models across 5 benchmarks, showing Mistral matching the larger model.
Benchmark performance comparison across models. Mistral 7B (blue) matches or exceeds LLaMA 2 13B (orange) on most benchmarks despite having 44% fewer parameters. The efficiency gains from sliding window attention and GQA don't come at the cost of capability.

The results show that Mistral's architectural innovations enable parameter-efficient scaling. By using GQA and sliding window attention, the model frees up computational budget for a larger FFN layer and more efficient training, achieving better performance with fewer parameters.

Inference Efficiency

The practical impact of Mistral's design becomes most apparent during inference. Let's quantify the throughput improvements:

In[33]:
Code
def estimate_inference_memory(batch_size, seq_len, config, dtype_bytes=2):
    """
    Estimate peak memory during inference.

    Includes: model weights, KV cache, activations
    """
    hidden = config["hidden_size"]
    layers = config["num_layers"]
    heads = config["num_attention_heads"]
    kv_heads = config["num_kv_heads"]
    intermediate = config["intermediate_size"]
    vocab = config["vocab_size"]
    window = config.get("sliding_window", seq_len)
    head_dim = hidden // heads

    # Model weights (approximate)
    embed_params = vocab * hidden
    attn_params = layers * (hidden * hidden * 4)  # Q, K, V, O projections
    ffn_params = layers * (hidden * intermediate * 3)  # up, gate, down
    total_params = embed_params + attn_params + ffn_params
    weights_bytes = total_params * dtype_bytes

    # KV cache
    cache_len = min(seq_len, window) if window else seq_len
    kv_cache_bytes = (
        2 * layers * cache_len * kv_heads * head_dim * dtype_bytes * batch_size
    )

    # Activations (current layer only)
    activation_bytes = (
        batch_size * seq_len * hidden * dtype_bytes * 4
    )  # Rough estimate

    return {
        "weights_gb": weights_bytes / (1024**3),
        "kv_cache_gb": kv_cache_bytes / (1024**3),
        "activations_gb": activation_bytes / (1024**3),
        "total_gb": (weights_bytes + kv_cache_bytes + activation_bytes)
        / (1024**3),
    }


# Compare at different sequence lengths
seq_lengths = [1024, 4096, 8192, 16384]
batch_size = 1
Out[34]:
Console
Inference Memory Comparison (batch_size=1, fp16):

                               LLaMA 2 7B                     Mistral 7B          
  Seq Length     KV Cache        Total     KV Cache        Total
----------------------------------------------------------------------------
       1,024       0.50 GB      12.84 GB       0.12 GB      14.90 GB
       4,096       2.00 GB      14.43 GB       0.50 GB      15.37 GB
       8,192       4.00 GB      16.56 GB       0.50 GB      15.49 GB
      16,384       8.00 GB      20.81 GB       0.50 GB      15.74 GB

The comparison reveals Mistral's memory efficiency advantage. At 4K tokens, Mistral's KV cache is already 4x smaller due to GQA alone. At 8K tokens and beyond, the sliding window provides additional savings by capping cache growth. The total memory difference becomes substantial for long-context applications.

At 8K tokens, Mistral's KV cache is 8x smaller than LLaMA 2's would be (due to 4x from GQA and 2x from the context being capped at window size). This translates directly to higher throughput, as the memory bandwidth required to read the KV cache during each generation step is proportionally reduced.

Out[35]:
Visualization
Stacked bar chart showing LLaMA 2 memory breakdown with large KV cache portion.
LLaMA 2 7B inference memory breakdown at 8K context. KV cache dominates at longer sequences due to full attention and per-head storage.
Stacked bar chart showing Mistral memory breakdown with much smaller KV cache.
Mistral 7B inference memory breakdown at 8K context. GQA and rolling buffer reduce KV cache by 8x, shifting the memory profile.

The memory breakdown visualizations show the impact of Mistral's design choices. While model weights are similar between the two architectures, the KV cache differs by nearly an order of magnitude. For LLaMA 2, the KV cache grows to dominate memory usage at long sequences. Mistral's combination of GQA (4x reduction) and rolling buffer (2x reduction at 8K context) keeps the KV cache manageable, freeing memory for larger batch sizes or longer generations.

Limitations and Considerations

While Mistral's architecture offers substantial efficiency gains, the design involves trade-offs worth understanding.

Sliding Window Limitations: The sliding window creates a hard attention cutoff. While information can propagate across layers to reach distant positions, this propagation is lossy. Information from early tokens becomes increasingly diffuse as it passes through more layers, potentially limiting performance on tasks requiring precise retrieval from early context. For tasks like long-document question answering where a specific fact from the beginning must be recalled exactly, the indirect propagation may be insufficient. In contrast, full attention can directly access any position with full fidelity.

GQA Trade-offs: Reducing KV heads from 32 to 8 means less independent key-value representations. While empirically this has minimal impact on most benchmarks, certain tasks that benefit from diverse attention patterns may see degradation. This trade-off matters most for tasks requiring fine-grained distinctions in how different query heads interpret the same input.

Pre-fill Chunking Overhead: Processing prompts in chunks requires sequential processing where parallel processing would otherwise be possible. For batch inference with many short prompts, this overhead can reduce throughput compared to processing all prompts simultaneously. The benefit only manifests for long individual sequences.

Rolling Buffer Complexity: The rolling buffer KV cache requires careful index management and can complicate certain inference optimizations. Speculative decoding and other advanced generation techniques may require modifications to work correctly with the circular buffer structure.

Despite these considerations, Mistral's design represents a compelling efficiency-capability trade-off that has proven effective across a wide range of applications. The architecture has influenced subsequent models and established sliding window attention as a viable approach for long-context language modeling.

Summary

Mistral 7B demonstrates that thoughtful architectural choices can achieve strong performance with fewer resources. The key innovations are:

  • Sliding Window Attention: Limits attention to the most recent w=4096w=4096 tokens, reducing complexity from O(n2)O(n^2) to O(nw)O(n \cdot w), where nn is sequence length and ww is the fixed window size. Information propagates across layers to effectively extend the receptive field to L×wL \times w tokens.

  • Rolling Buffer KV Cache: Stores only the most recent window of key-value pairs, keeping memory constant during generation regardless of sequence length. Combined with GQA (8 KV heads instead of 32), this reduces KV cache memory by 16x compared to full attention with standard MHA (4x from GQA, 4x from window vs full context for 16K sequences).

  • Grouped Query Attention: Four query heads share each KV head pair, reducing KV cache size by 4x while maintaining expressive capacity through independent query representations.

  • Pre-fill Chunking: Processes long prompts in window-sized chunks to bound peak memory, enabling longer context lengths on memory-constrained hardware.

These innovations work together to create a model that outperforms LLaMA 2 13B while using 44% fewer parameters. The efficiency gains don't come at the cost of capability. Instead, they free up resources that can be allocated to other components like a larger FFN layer.

The Mistral architecture has influenced the design of subsequent models and established that efficiency and performance can advance together. Sliding window attention, in particular, has become a common technique for extending context length without quadratic memory growth.

Key Parameters

The following parameters control Mistral's efficiency-capability trade-offs:

  • sliding_window (int, default=4096): The number of previous tokens each position can attend to. Larger values provide more direct context access but increase memory usage linearly. Mistral uses 4096, which balances efficiency with sufficient local context. For tasks requiring precise long-range retrieval, consider larger windows or full attention.

  • num_kv_heads (int, default=8): The number of key-value heads in grouped query attention. Must divide num_attention_heads evenly. Lower values reduce KV cache size but may limit representational diversity. Mistral's ratio of 32:8 (4:1) provides a good balance.

  • num_attention_heads (int, default=32): The number of query heads. Each group of num_attention_heads / num_kv_heads query heads shares one KV head pair. More heads enable finer-grained attention patterns.

  • hidden_size (int, default=4096): The model dimension. Determines the size of embeddings and the input/output dimension for each layer. Head dimension is computed as hidden_size / num_attention_heads.

  • intermediate_size (int, default=14336): The hidden dimension of the feed-forward network. Mistral uses a larger FFN (3.5x hidden_size) than LLaMA (2.7x), trading some of the memory savings from GQA for additional model capacity.

  • max_position_embeddings (int, default=8192): The maximum context length. With sliding window attention, sequences can technically extend beyond this, but RoPE positional embeddings are optimized for this range.

  • rope_theta (float, default=10000.0): The base frequency for rotary positional embeddings. Higher values extend the effective context length but may reduce precision for nearby positions.

Quiz

Ready to test your understanding? Take this quick quiz to reinforce what you've learned about Mistral's architectural innovations.

Loading component...
Track your reading progress

Sign in to mark chapters as read and track your learning journey

Sign in →

Comments

Reference

BIBTEXAcademic
@misc{mistralarchitectureslidingwindowattentionefficientllmdesign, author = {Michael Brenndoerfer}, title = {Mistral Architecture: Sliding Window Attention & Efficient LLM Design}, year = {2025}, url = {https://mbrenndoerfer.com/writing/mistral-architecture-sliding-window-attention}, organization = {mbrenndoerfer.com}, note = {Accessed: 2025-12-19} }
APAAcademic
Michael Brenndoerfer (2025). Mistral Architecture: Sliding Window Attention & Efficient LLM Design. Retrieved from https://mbrenndoerfer.com/writing/mistral-architecture-sliding-window-attention
MLAAcademic
Michael Brenndoerfer. "Mistral Architecture: Sliding Window Attention & Efficient LLM Design." 2025. Web. 12/19/2025. <https://mbrenndoerfer.com/writing/mistral-architecture-sliding-window-attention>.
CHICAGOAcademic
Michael Brenndoerfer. "Mistral Architecture: Sliding Window Attention & Efficient LLM Design." Accessed 12/19/2025. https://mbrenndoerfer.com/writing/mistral-architecture-sliding-window-attention.
HARVARDAcademic
Michael Brenndoerfer (2025) 'Mistral Architecture: Sliding Window Attention & Efficient LLM Design'. Available at: https://mbrenndoerfer.com/writing/mistral-architecture-sliding-window-attention (Accessed: 12/19/2025).
SimpleBasic
Michael Brenndoerfer (2025). Mistral Architecture: Sliding Window Attention & Efficient LLM Design. https://mbrenndoerfer.com/writing/mistral-architecture-sliding-window-attention
Michael Brenndoerfer

About the author: Michael Brenndoerfer

All opinions expressed here are my own and do not reflect the views of my employer.

Michael currently works as an Associate Director of Data Science at EQT Partners in Singapore, leading AI and data initiatives across private capital investments.

With over a decade of experience spanning private equity, management consulting, and software engineering, he specializes in building and scaling analytics capabilities from the ground up. He has published research in leading AI conferences and holds expertise in machine learning, natural language processing, and value creation through data.

Stay updated

Get notified when I publish new articles on data and AI, private equity, technology, and more.

No spam, unsubscribe anytime.

or

Create a free account to unlock exclusive features, track your progress, and join the conversation.

No popupsUnobstructed readingCommenting100% Free