Search

Search articles

Sparse Attention Patterns: Local, Strided & Block-Sparse Approaches

Michael BrenndoerferUpdated June 25, 202539 min read

Implement sparse attention patterns including local windows, strided attention, and block-sparse methods that reduce transformer complexity from quadratic to near-linear.

Track your reading progress

Sign in to mark chapters as read and track your learning journey

Sign in →
Reading Level

Choose your expertise level to adjust how many terms are explained. Beginners see more tooltips, experts see fewer to maintain reading flow. Hover over underlined terms for instant definitions.

Sparse Attention Patterns

Standard attention computes pairwise scores between all tokens, creating an n×nn \times n attention matrix where nn is the sequence length. This quadratic scaling becomes prohibitively expensive for long sequences. Sparse attention offers a principled solution: instead of attending to every position, each token attends only to a carefully chosen subset. The key insight is that most attention weights in practice are small, meaning many token pairs contribute little to the final representation. By restricting attention to positions that matter most, sparse patterns achieve near-linear complexity while preserving the model's ability to capture important relationships.

This chapter explores the fundamental sparse attention patterns that form the building blocks of efficient transformers. We'll implement local windowed attention, strided patterns, and block-sparse attention, then combine them into hybrid approaches used by models like Sparse Transformer, Longformer, and BigBird.

The Sparsity Principle

Before diving into specific patterns, let's understand why sparsity works. In natural language, most dependencies are local: a word is most strongly influenced by nearby words. Long-range dependencies exist but are relatively rare. Consider the sentence "The cat that the dog chased ran away." The verb "ran" primarily depends on "cat" (its subject), not on every intervening word. A sparse attention pattern that captures this key dependency while ignoring irrelevant pairs can achieve similar quality to full attention at a fraction of the cost.

Attention Sparsity

Sparse attention restricts each query to attend only to a subset of keys. If each query attends to kk keys instead of all nn, complexity drops from O(n2)O(n^2) to O(nk)O(nk), where nn is the sequence length and kk is the number of keys each query attends to. When kk is constant or grows slowly with nn, this achieves effective linear scaling.

Let's visualize how attention weights are distributed in practice to motivate sparsity.

In[2]:
Code
import numpy as np

np.random.seed(42)

# Simulate attention weights for a 64-token sequence
# In real transformers, weights tend to be concentrated on nearby tokens
seq_len = 64

# Create simulated attention pattern with local bias
# Each row sums to 1 (softmax output)
attention_weights = np.zeros((seq_len, seq_len))

for i in range(seq_len):
    # High attention to nearby positions, decaying with distance
    for j in range(seq_len):
        distance = abs(i - j)
        # Exponential decay with distance, plus small baseline
        attention_weights[i, j] = np.exp(-distance / 5) + 0.01

    # Normalize to sum to 1
    attention_weights[i] /= attention_weights[i].sum()
Out[3]:
Visualization
Heatmap showing attention weights concentrated along the diagonal, with values decaying as distance from diagonal increases.
Attention weight heatmap showing concentration along the diagonal. Darker colors indicate stronger attention, with most weight on nearby tokens.
Out[4]:
Visualization
Histogram of attention weights showing highly skewed distribution with most values near zero.
Distribution of attention weights across all position pairs. The red dashed line marks the 90th percentile, showing most weights are very small.
Out[5]:
Console
Attention sparsity analysis (threshold=0.01):
  Positions below threshold: 66.5%
  Coverage from top-10 positions per query: 64.8%
  Full attention computes: 4,096 scores
  Top-10 sparse attention: 640 scores (15.6%)

The analysis reveals a key insight: attention weights are highly concentrated. Most of the attention mass falls on a small number of positions per query, while the majority of positions receive negligible weight. This natural sparsity suggests we can skip computing many attention scores without significantly affecting the output.

Local Attention Windows

The most intuitive sparse pattern is local attention, where each token attends only to tokens within a fixed window around it. This pattern exploits the locality of language: consecutive words form phrases, sentences have local structure, and most grammatical dependencies span short distances.

Window Formulation

To formalize local attention, we need to answer a simple question: which positions should a token be allowed to attend to? The intuition is straightforward. Imagine you're reading position 10 in a sequence. Local attention says you can "look" at positions 8, 9, 10, 11, and 12 if the window size is 2, but nothing beyond that range. The window creates a neighborhood around each position.

Let's define this precisely. For a window size ww, each query at position ii attends to positions in the range [iw,i+w][i - w, i + w]. This range includes:

  • ww positions to the left (earlier in the sequence)
  • The current position ii itself
  • ww positions to the right (later in the sequence)

The total number of attended positions is 2w+12w + 1. Notice something crucial: this count is independent of sequence length nn. Whether your sequence has 100 tokens or 100,000 tokens, each position still attends to exactly 2w+12w + 1 neighbors. This independence is what gives local attention its O(n)O(n) complexity: nn positions, each computing 2w+12w + 1 attention scores, yields n(2w+1)n \cdot (2w + 1) total operations.

Now we need a mechanism to enforce this pattern. In transformer attention, we use an attention mask MM that modifies the attention scores before the softmax. The mask for local attention is:

Mij={0if ijwotherwiseM_{ij} = \begin{cases} 0 & \text{if } |i - j| \leq w \\ -\infty & \text{otherwise} \end{cases}

where:

  • MijM_{ij}: the mask value applied when query position ii attends to key position jj
  • ii: the query position (which token is asking "what should I attend to?")
  • jj: the key position (a candidate token that might be attended to)
  • ww: the window size (the radius of the local neighborhood)
  • ij|i - j|: the absolute distance between positions (how far apart are they?)

The mask works through the attention computation. Recall that standard attention computes:

Attention(Q,K,V)=softmax(QKTdk+M)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}} + M\right) V

When Mij=0M_{ij} = 0, the attention score passes through unchanged. When Mij=M_{ij} = -\infty, adding negative infinity to any finite score produces negative infinity. The softmax function then converts e=0e^{-\infty} = 0, effectively blocking query ii from attending to key jj. This elegant mechanism lets us selectively disable attention connections without changing the core attention computation.

In[6]:
Code
def create_local_attention_mask(seq_len, window_size):
    """
    Create a local attention mask where each position attends
    to window_size positions on each side.

    Args:
        seq_len: Length of the sequence
        window_size: Number of positions to attend on each side

    Returns:
        mask: (seq_len, seq_len) boolean mask (True = attend, False = block)
    """
    mask = np.zeros((seq_len, seq_len), dtype=bool)

    for i in range(seq_len):
        start = max(0, i - window_size)
        end = min(seq_len, i + window_size + 1)
        mask[i, start:end] = True

    return mask


# Create masks with different window sizes
seq_len = 32
small_window = create_local_attention_mask(seq_len, window_size=2)
medium_window = create_local_attention_mask(seq_len, window_size=4)
large_window = create_local_attention_mask(seq_len, window_size=8)
Out[7]:
Visualization
Narrow diagonal band in attention mask with window size 2.
Window size 2: Each position attends to 5 positions (2 left, self, 2 right). Very efficient but limited context.
Medium diagonal band in attention mask with window size 4.
Window size 4: Each position attends to 9 positions. Better context coverage with moderate overhead.
Wide diagonal band in attention mask with window size 8.
Window size 8: Each position attends to 17 positions. Rich local context approaching the density of short sequences.
Out[8]:
Console
Window 2: 154 attention scores (85.0% sparse), 5 positions per query
Window 4: 268 attention scores (73.8% sparse), 9 positions per query
Window 8: 472 attention scores (53.9% sparse), 17 positions per query

Local attention achieves significant sparsity even with generous window sizes. A window of 8 positions (17 attended positions per query) reduces computation by over 45% for a 32-token sequence. The savings grow with sequence length: for a 4,096-token sequence with window 256, sparsity exceeds 93%.

Choosing Window Size

The optimal window size depends on the task and sequence characteristics. Here are key considerations:

  • Linguistic dependencies: Most grammatical dependencies span fewer than 10-15 tokens. A window of 256-512 tokens covers most syntactic structures.
  • Task requirements: Sentiment analysis might need only local context, while question answering may require longer-range connections.
  • Computational budget: Larger windows provide more context but increase memory and compute proportionally.
  • Layer depth: Some architectures use smaller windows in early layers and larger windows (or global attention) in later layers.
In[9]:
Code
def compute_local_attention_complexity(seq_len, window_size):
    """
    Compute the number of attention scores for local attention.

    Returns:
        attended_pairs: Number of (query, key) pairs computed
        full_pairs: Number of pairs in full attention
        speedup: Ratio of full to local complexity
    """
    # Each of n positions attends to min(2w+1, n) positions
    # Edge positions attend to fewer
    attended = 0
    for i in range(seq_len):
        start = max(0, i - window_size)
        end = min(seq_len, i + window_size + 1)
        attended += end - start

    full = seq_len * seq_len
    return attended, full, full / attended


# Analyze complexity for various sequence lengths
seq_lens = [128, 512, 2048, 8192, 32768]
window_size = 256
Out[10]:
Console
Local attention speedup (window_size=256):

  Seq Length      Full Pairs     Local Pairs    Speedup
-------------------------------------------------------
         128          16,384          16,384        1.0x
         512         262,144         196,864        1.3x
       2,048       4,194,304         984,832        4.3x
       8,192      67,108,864       4,136,704       16.2x
      32,768   1,073,741,824      16,744,192       64.1x

The speedup from local attention grows linearly with sequence length. At 32,768 tokens, local attention with a 256-token window provides over 60x speedup compared to full attention. This scaling is what makes local attention essential for processing long documents.

Out[11]:
Visualization
Line plot showing exponential growth in speedup as sequence length increases from 128 to 32768 tokens.
Speedup of local attention over full attention as sequence length increases. The speedup grows linearly because local attention complexity is O(n) while full attention is O(n squared). Longer sequences benefit dramatically from sparse patterns.

The log-log plot reveals the linear relationship between sequence length and speedup. Doubling the sequence length approximately doubles the speedup, regardless of window size. Smaller windows provide greater speedups but capture less context.

Strided Attention Patterns

While local attention captures nearby dependencies, it cannot directly model long-range relationships. Strided attention addresses this by having each position attend to positions at regular intervals throughout the sequence. This creates "highways" for information to flow across long distances.

Stride Formulation

Local attention has a fundamental limitation: it cannot see beyond the window. Position 0 can never directly attend to position 1000, no matter how many attention scores we compute within that single layer. To bridge long distances, we need a different pattern.

The key insight behind strided attention is the concept of hub positions. Think of hubs like train stations in a transit network. Not every location has a direct connection to every other location, but major stations connect to many destinations. Similarly, strided attention designates certain positions as hubs that all other positions can access.

For a stride ss, we designate every ss-th position as a hub. With stride s=4s = 4, positions 0, 4, 8, 12, 16, ... become hubs. Each query position can attend to:

  1. All hub positions: Every position can reach the hubs, creating information highways across the sequence
  2. Itself: Self-attention is always preserved so each token can access its own representation

This creates a two-hop connectivity pattern. Consider positions 3 and 97 in a sequence. They cannot directly attend to each other (neither is a hub for typical stride values). But in layer 1, position 3 attends to hub 4, and position 97 attends to hub 96. In layer 2, both hubs can attend to each other (since hubs attend to all other hubs), and information flows between them. After just two layers, any two positions are connected.

The formal mask for strided attention is:

Mij={0if jmods=0 or j=iotherwiseM_{ij} = \begin{cases} 0 & \text{if } j \mod s = 0 \text{ or } j = i \\ -\infty & \text{otherwise} \end{cases}

where:

  • MijM_{ij}: the mask value when query ii considers attending to key jj
  • ss: the stride parameter (the spacing between hub positions)
  • jmodsj \mod s: the remainder when dividing jj by ss. When this equals 0, position jj is a hub
  • j=ij = i: the self-attention condition, ensuring diagonal connectivity

The first condition (jmods=0j \mod s = 0) produces the characteristic vertical stripes in the attention pattern. Every query attends to the same set of hub columns, creating shared access points. The second condition (j=ij = i) adds the diagonal, ensuring self-attention. Together, these conditions guarantee that information can propagate across any distance in two hops, while keeping the number of attended positions per query at approximately n/s+1n/s + 1, where nn is sequence length.

In[12]:
Code
def create_strided_attention_mask(seq_len, stride):
    """
    Create a strided attention mask.
    Each position attends to every stride-th position plus itself.

    Args:
        seq_len: Length of the sequence
        stride: Attend to every stride-th position

    Returns:
        mask: (seq_len, seq_len) boolean mask
    """
    mask = np.zeros((seq_len, seq_len), dtype=bool)

    for i in range(seq_len):
        # Attend to strided positions
        for j in range(0, seq_len, stride):
            mask[i, j] = True
        # Always attend to self
        mask[i, i] = True

    return mask


# Create strided masks
seq_len = 32
stride_4 = create_strided_attention_mask(seq_len, stride=4)
stride_8 = create_strided_attention_mask(seq_len, stride=8)
Out[13]:
Visualization
Grid pattern with vertical lines every 4 positions plus a diagonal.
Stride 4: Every 4th position (0, 4, 8, ...) receives attention from all queries, creating vertical stripes plus a diagonal for self-attention.
Grid pattern with vertical lines every 8 positions plus a diagonal.
Stride 8: Sparser pattern with fewer hub positions. Reduces computation but may miss intermediate positions.

Strided attention creates a distinctive pattern: vertical columns at hub positions (which all queries attend to) and a diagonal for self-attention. The hub positions act as relay points, allowing information to propagate across the sequence in two hops: any position can reach a hub, and from the hub, information flows to other positions in subsequent layers.

Information Flow in Strided Attention

A key property of strided attention is the maximum "hop distance" between any two positions. With stride ss:

  • Any position is at most s/2s/2 steps from a hub position
  • In one attention layer, information can travel to a hub
  • In the next layer, it can travel from the hub to any other position

This means any two positions are connected within 2 layers, regardless of their distance in the sequence. This is significantly better than local attention, which requires O(n/w)O(n/w) layers to connect distant positions, where nn is the sequence length and ww is the window size.

In[14]:
Code
def analyze_reachability(seq_len, stride, num_layers):
    """
    Analyze how many positions are reachable after num_layers of strided attention.

    Returns:
        reachability: (seq_len,) array of reachable position counts per starting position
    """
    # Start with identity (each position reaches itself)
    reachable = np.eye(seq_len, dtype=bool)

    # Create strided mask
    mask = create_strided_attention_mask(seq_len, stride)

    # Propagate through layers
    for _ in range(num_layers):
        # Reachability is transitive: if A reaches B and B reaches C, A reaches C
        reachable = reachable @ mask | reachable
        reachable = reachable > 0  # Convert back to bool

    return reachable.sum(axis=1)


seq_len = 64
stride = 8
reachability_by_layer = []
for layers in range(1, 5):
    reach = analyze_reachability(seq_len, stride, layers)
    reachability_by_layer.append(
        (layers, reach.min(), reach.max(), reach.mean())
    )
Out[15]:
Console
Reachability in strided attention (seq_len=64, stride=8):

  Layers    Min Reach    Max Reach   Mean Reach
------------------------------------------------
       1            8            9          8.9
       2            8            9          8.9
       3            8            9          8.9
       4            8            9          8.9

After just 2 layers of strided attention, every position can reach every other position. This demonstrates the power of strided patterns for long-range information flow, achieving global connectivity with sparse local computation.

Out[16]:
Visualization
Heatmap showing partial connectivity after one layer of strided attention.
After layer 1: Each position reaches itself plus all hub positions. The vertical stripes show hubs are universally accessible.
Fully connected heatmap after two layers of strided attention.
After layer 2: All positions are mutually reachable. Information can now flow between any pair of positions in the sequence.

The visualization shows how reachability expands with depth. After layer 1, we see the strided pattern: each position reaches hubs (vertical stripes) and itself (diagonal). After layer 2, the matrix is fully connected, demonstrating that any position can reach any other position through an intermediate hub.

Block-Sparse Attention

Block-sparse attention groups positions into blocks and defines attention patterns at the block level. This approach is particularly hardware-friendly because modern GPUs and TPUs are optimized for matrix operations on contiguous memory blocks.

Block Structure

The patterns we've seen so far, local and strided, define attention at the level of individual positions. Block-sparse attention takes a different approach: it groups positions into contiguous blocks and defines attention at the block level. This abstraction has a practical motivation. Modern GPUs are designed to process data in aligned, contiguous chunks. By organizing attention into blocks that match hardware execution units, we can achieve theoretical speedups in actual wall-clock time.

Imagine dividing a 1024-token sequence into blocks of size b=64b = 64. This creates 1024/64=161024/64 = 16 blocks. Instead of asking "which of 1024 positions can position ii attend to?", we ask "which of 16 blocks can block BiB_i attend to?" This coarser-grained question has fewer possible answers, and each answer involves a dense 64×6464 \times 64 matrix operation that GPUs handle efficiently.

The key parameters for block-sparse attention are:

  • bb: the block size (number of positions per block)
  • kk: the number of blocks each query block attends to

If we select kk blocks for each query block, the complexity analysis proceeds as follows:

  1. Number of query blocks: The sequence divides into n/bn/b blocks
  2. Attention per query block: Each block attends to kk other blocks
  3. Operations per block pair: Computing attention between two blocks of size bb requires b×b=b2b \times b = b^2 score computations

Multiplying these together:

Total operations=nb×k×b2=nkb\text{Total operations} = \frac{n}{b} \times k \times b^2 = n \cdot k \cdot b

When kk and bb are constants independent of nn, this expression is linear in sequence length. The quadratic term has vanished. Even better, each b×bb \times b block attention can use highly optimized dense matrix multiplication routines, achieving near-peak GPU utilization.

In[17]:
Code
def create_block_sparse_mask(seq_len, block_size, pattern="diagonal"):
    """
    Create a block-sparse attention mask.

    Args:
        seq_len: Length of sequence
        block_size: Size of each block
        pattern: Which block pairs to attend to
            - "diagonal": Each block attends to itself only
            - "tridiagonal": Each block attends to itself and adjacent blocks

    Returns:
        mask: (seq_len, seq_len) boolean mask
    """
    num_blocks = (seq_len + block_size - 1) // block_size
    mask = np.zeros((seq_len, seq_len), dtype=bool)

    for i in range(num_blocks):
        i_start = i * block_size
        i_end = min((i + 1) * block_size, seq_len)

        if pattern == "diagonal":
            # Attend only to same block
            blocks_to_attend = [i]
        elif pattern == "tridiagonal":
            # Attend to previous, current, and next block
            blocks_to_attend = [
                j for j in [i - 1, i, i + 1] if 0 <= j < num_blocks
            ]
        else:
            raise ValueError(f"Unknown pattern: {pattern}")

        for j in blocks_to_attend:
            j_start = j * block_size
            j_end = min((j + 1) * block_size, seq_len)
            mask[i_start:i_end, j_start:j_end] = True

    return mask


seq_len = 32
block_size = 4
diagonal_mask = create_block_sparse_mask(seq_len, block_size, "diagonal")
tridiagonal_mask = create_block_sparse_mask(seq_len, block_size, "tridiagonal")
Out[18]:
Visualization
Checkerboard pattern of square blocks along the diagonal.
Diagonal block-sparse: Each block attends only to itself. Creates isolated groups with no cross-block communication.
Wider band of square blocks along the diagonal extending one block in each direction.
Tridiagonal block-sparse: Each block attends to itself and neighbors. Allows local information flow across block boundaries.

Block-sparse attention provides a clean abstraction for hardware-efficient implementation. The regular block structure maps directly to GPU thread blocks, enabling efficient parallel execution with minimal memory overhead.

Hardware Efficiency

The advantage of block-sparse attention goes beyond theoretical complexity. Modern GPUs achieve peak performance when operating on aligned, contiguous memory blocks. Random sparse patterns, while mathematically equivalent, suffer from irregular memory access patterns that underutilize hardware.

In[19]:
Code
def estimate_memory_efficiency(seq_len, block_size, sparsity_ratio):
    """
    Estimate memory efficiency of block-sparse vs random sparse attention.

    Block-sparse can use dense matmul on selected blocks.
    Random sparse requires custom kernels with lower efficiency.
    """
    total_pairs = seq_len * seq_len
    sparse_pairs = int(total_pairs * (1 - sparsity_ratio))

    # Block-sparse: compute full dense blocks
    num_blocks = seq_len // block_size
    blocks_per_row = int(num_blocks * (1 - sparsity_ratio))
    block_sparse_ops = num_blocks * blocks_per_row * (block_size**2)

    # Memory coalescing factor (rough estimate)
    # Block-sparse achieves near 100% efficiency, random ~30%
    block_efficiency = 0.95
    random_efficiency = 0.30

    return {
        "block_sparse_pairs": block_sparse_ops,
        "random_sparse_pairs": sparse_pairs,
        "block_effective_throughput": block_sparse_ops * block_efficiency,
        "random_effective_throughput": sparse_pairs * random_efficiency,
    }


results = estimate_memory_efficiency(1024, 64, 0.9)
Out[20]:
Console
Memory efficiency comparison (90% sparsity, 1024 tokens, block size 64):

Block-sparse computed pairs:        65,536
Random-sparse computed pairs:      104,857

Effective throughput (accounting for memory coalescing):
Block-sparse:        62,259 (95% memory efficiency)
Random-sparse:       31,457 (30% memory efficiency)

Block-sparse advantage: 2.0x effective throughput

Block-sparse attention can be 3x or more efficient than random sparse patterns at the same sparsity level. This hardware awareness is crucial for practical implementations and explains why production models favor structured sparsity.

Combining Sparse Patterns

Real-world efficient attention mechanisms combine multiple patterns to balance local context, global reach, and computational efficiency. The Sparse Transformer paper introduced the idea of factorizing attention across multiple heads, with different heads using different patterns.

Local + Strided Combination

The most common combination pairs local attention for nearby tokens with strided attention for long-range connections. Together, they ensure every pair of positions can communicate within a small number of layers while maintaining overall sparsity.

In[21]:
Code
def create_combined_sparse_mask(seq_len, local_window, stride):
    """
    Combine local and strided attention patterns.

    Args:
        seq_len: Sequence length
        local_window: Window size for local attention
        stride: Stride for global attention

    Returns:
        mask: Combined boolean mask (OR of local and strided)
    """
    local_mask = create_local_attention_mask(seq_len, local_window)
    strided_mask = create_strided_attention_mask(seq_len, stride)

    # Combine with OR: attend if either pattern allows
    combined = local_mask | strided_mask

    return combined, local_mask, strided_mask


seq_len = 48
combined, local, strided = create_combined_sparse_mask(
    seq_len, local_window=3, stride=6
)
Out[22]:
Visualization
Diagonal band attention pattern.
Local attention with window 3: Creates a band-diagonal pattern capturing nearby context.
Vertical stripe pattern with evenly spaced lines.
Strided attention with stride 6: Vertical stripes at hub positions for long-range connectivity.
Diagonal band with vertical stripes overlaid.
Combined pattern: Union of local and strided patterns. Achieves both local context and global reach.
Out[23]:
Console
Pattern analysis (seq_len=48):

Pattern              Pairs     Sparsity
----------------------------------------
Full                 2,304         0.0%
Local (w=3)            324        85.9%
Strided (s=6)          424        81.6%
Combined               655        71.6%

Note: Combined has fewer pairs than sum of components due to overlap.

Multi-Head Factorization

An elegant approach from Sparse Transformer assigns different patterns to different attention heads. Half the heads might use local attention while the other half use strided attention. This factorization allows the model to learn which pattern is most useful for different types of dependencies.

In[24]:
Code
def create_factorized_attention(seq_len, num_heads, local_window, stride):
    """
    Create factorized attention patterns for multi-head attention.

    Half the heads use local attention, half use strided attention.

    Args:
        seq_len: Sequence length
        num_heads: Number of attention heads
        local_window: Window for local heads
        stride: Stride for strided heads

    Returns:
        masks: List of (seq_len, seq_len) masks, one per head
    """
    local_mask = create_local_attention_mask(seq_len, local_window)
    strided_mask = create_strided_attention_mask(seq_len, stride)

    masks = []
    for h in range(num_heads):
        if h < num_heads // 2:
            masks.append(local_mask)
        else:
            masks.append(strided_mask)

    return masks


# Create masks for 8 heads
num_heads = 8
head_masks = create_factorized_attention(
    seq_len, num_heads, local_window=3, stride=6
)
Out[25]:
Console
Factorized attention (8 heads):

  Heads 0-3: Local attention (window=3)
  Heads 4-7: Strided attention (stride=6)

Average pairs per head: 374
Full attention would be: 2,304
Overall sparsity: 83.8%

Factorized attention provides flexibility: local heads handle nearby dependencies while strided heads capture long-range patterns. The model learns to route information through the appropriate heads during training.

Implementing Sparse Attention

Having established the theory behind sparse patterns, let's implement a complete sparse attention module. This implementation prioritizes clarity over performance, demonstrating how the mask-based approach works in practice. Understanding this foundation will help you work with optimized libraries like xformers or Flash Attention later.

The Core Algorithm

Sparse attention follows the same computational structure as standard attention, with one addition: we apply a mask before the softmax to block certain attention connections. The algorithm proceeds in four steps:

  1. Compute raw attention scores: Multiply queries by keys to get similarity scores
  2. Apply the sparse mask: Set blocked positions to -\infty (we use 109-10^9 numerically)
  3. Softmax normalization: Convert scores to probability weights (blocked positions become zero)
  4. Weighted sum: Combine values according to the attention weights

Let's implement each step:

In[26]:
Code
def sparse_attention(query, key, value, mask, scale=None):
    """
    Compute sparse attention given a boolean mask.

    Args:
        query: (seq_len, d_k) query vectors
        key: (seq_len, d_k) key vectors
        value: (seq_len, d_v) value vectors
        mask: (seq_len, seq_len) boolean mask (True = attend, False = block)
        scale: Optional scaling factor (default: 1/sqrt(d_k))

    Returns:
        output: (seq_len, d_v) attention output
        weights: (seq_len, seq_len) attention weights (masked)
    """
    seq_len, d_k = query.shape

    # Step 0: Determine scaling factor
    # The 1/sqrt(d_k) scaling prevents attention scores from growing
    # too large as dimension increases, which would push softmax into
    # saturation regions with near-zero gradients
    if scale is None:
        scale = 1.0 / np.sqrt(d_k)

    # Step 1: Compute raw attention scores
    # Each score[i,j] measures similarity between query[i] and key[j]
    scores = query @ key.T * scale  # (seq_len, seq_len)

    # Step 2: Apply the sparse mask
    # Where mask is True, keep the score unchanged
    # Where mask is False, replace with a very large negative number
    # This ensures exp(-1e9) ≈ 0 after softmax
    masked_scores = np.where(mask, scores, -1e9)

    # Step 3: Softmax normalization
    # We subtract the max for numerical stability (prevents overflow)
    # This doesn't change the result since softmax is shift-invariant
    exp_scores = np.exp(
        masked_scores - masked_scores.max(axis=1, keepdims=True)
    )
    weights = exp_scores / exp_scores.sum(axis=1, keepdims=True)

    # Explicitly zero blocked positions
    # They're already near-zero from softmax, but this ensures exactness
    weights = np.where(mask, weights, 0.0)

    # Step 4: Weighted sum of values
    # Each output[i] is a weighted combination of all value vectors,
    # with weights determined by attention
    output = weights @ value

    return output, weights

Comparing Full and Sparse Attention

To validate our implementation and demonstrate the effectiveness of sparse patterns, let's compare the outputs of full attention versus a combined local + strided pattern:

In[27]:
Code
# Test with random inputs
np.random.seed(42)
seq_len = 16
d_model = 32

query = np.random.randn(seq_len, d_model)
key = np.random.randn(seq_len, d_model)
value = np.random.randn(seq_len, d_model)

# Compare full vs sparse attention
full_mask = np.ones((seq_len, seq_len), dtype=bool)
sparse_mask = create_combined_sparse_mask(seq_len, local_window=2, stride=4)[0]

full_output, full_weights = sparse_attention(query, key, value, full_mask)
sparse_output, sparse_weights = sparse_attention(query, key, value, sparse_mask)
Out[28]:
Console
Sparse vs Full Attention Comparison:

Output shape: (16, 32)
Mean absolute difference: 0.2533
Max absolute difference: 1.8285
Positions with difference > 0.1: 370 / 512

Sparse mask: 120 / 256 pairs attended (46.9%)

The difference between full and sparse attention outputs is relatively small, even though sparse attention computes far fewer pairs. This validates the core principle behind sparse attention: most attention scores contribute little to the final output, so we can skip computing them without significantly degrading the representation quality.

Out[29]:
Visualization
Heatmap showing small differences between full and sparse attention outputs across sequence positions and embedding dimensions.
Element-wise difference between full and sparse attention outputs. Most differences are small (light colors), with slightly larger differences for positions that lose access to high-attention neighbors under sparsity.

The heatmap reveals that differences are distributed across all positions, but remain small. This uniform distribution suggests that sparse attention doesn't systematically fail for any particular positions; rather, it introduces small approximation errors throughout the sequence.

Out[30]:
Visualization
Dense heatmap showing attention weights between all position pairs.
Full attention weights: All positions attend to all positions. Dense matrix shows no structural constraints.
Sparse heatmap with visible diagonal band and vertical stripes where attention is allowed.
Sparse attention weights: Only allowed positions have non-zero weights. The combined local+strided pattern is visible.

Complexity Analysis

Let's verify the computational savings from sparse attention empirically.

In[31]:
Code
import time


def benchmark_attention(seq_lens, patterns, num_trials=5):
    """
    Benchmark different attention patterns across sequence lengths.
    """
    results = []
    d_model = 64

    for n in seq_lens:
        query = np.random.randn(n, d_model)
        key = np.random.randn(n, d_model)
        value = np.random.randn(n, d_model)

        for pattern_name, create_mask_fn in patterns.items():
            mask = create_mask_fn(n)

            # Warm up
            _ = sparse_attention(query, key, value, mask)

            # Time multiple trials
            times = []
            for _ in range(num_trials):
                start = time.perf_counter()
                _ = sparse_attention(query, key, value, mask)
                times.append(time.perf_counter() - start)

            avg_time = np.mean(times)
            pairs = mask.sum()

            results.append(
                {
                    "seq_len": n,
                    "pattern": pattern_name,
                    "time_ms": avg_time * 1000,
                    "pairs": pairs,
                    "sparsity": 1 - pairs / (n * n),
                }
            )

    return results


# Define patterns
patterns = {
    "full": lambda n: np.ones((n, n), dtype=bool),
    "local_16": lambda n: create_local_attention_mask(n, 8),
    "sparse": lambda n: create_combined_sparse_mask(n, 4, 8)[0],
}

# Benchmark
seq_lens = [64, 128, 256, 512]
benchmark_results = benchmark_attention(seq_lens, patterns)
Out[32]:
Console
Attention Pattern Benchmark:

 Seq Len Pattern       Time (ms)      Pairs   Sparsity
-------------------------------------------------------
      64 full              0.079      4,096       0.0%
      64 local_16          0.071      1,016      75.2%
      64 sparse            0.073      1,000      75.6%
     128 full             12.605     16,384       0.0%
     128 local_16          9.871      2,104      87.2%
     128 sparse           29.957      3,040      81.4%
     256 full             21.160     65,536       0.0%
     256 local_16         15.178      4,280      93.5%
     256 sparse           18.806     10,192      84.4%
     512 full             50.061    262,144       0.0%
     512 local_16         31.049      8,632      96.7%
     512 sparse           31.225     36,784      86.0%
Out[33]:
Visualization
Line plot comparing execution time of full, local, and sparse attention patterns as sequence length increases.
Execution time vs sequence length for different attention patterns. Full attention shows quadratic growth while sparse patterns scale more favorably, with the gap widening at longer sequences.

The benchmark confirms the theoretical complexity analysis: sparse patterns scale better than full attention, with the advantage growing as sequence length increases. Note that our NumPy implementation doesn't fully exploit sparsity since it still computes the full matrix before masking. Production implementations using optimized sparse kernels would show even larger speedups.

Practical Considerations

When implementing sparse attention in practice, several factors influence the choice of pattern and parameters.

Pattern Selection Guidelines

Different tasks and sequence lengths call for different patterns:

  • Short sequences (less than 512 tokens): Full attention is often fast enough. Sparse patterns may not provide significant benefit and add implementation complexity.
  • Medium sequences (512 to 2048 tokens): Local attention with a window of 128-256 tokens works well for most tasks. Add strided or global attention for tasks requiring long-range dependencies.
  • Long sequences (more than 2048 tokens): Combined patterns are essential. Consider factorized attention across heads or hierarchical approaches.

Memory vs Compute Trade-offs

Sparse attention reduces both memory and compute, but the savings differ:

  • Compute: Scales directly with sparsity. 90% sparsity means 10x fewer FLOPs.
  • Memory: Storing the sparse mask adds overhead. For very sparse patterns, mask storage can dominate for short sequences.
In[34]:
Code
def analyze_memory_tradeoff(seq_len, sparsity):
    """
    Analyze memory usage for sparse vs dense attention.
    """
    # Dense attention matrix: n^2 float32 values
    dense_bytes = seq_len * seq_len * 4  # 4 bytes per float32

    # Sparse: store only non-zero values + indices
    num_nonzero = int(seq_len * seq_len * (1 - sparsity))
    # CSR format: values + column indices (int32) + row pointers
    sparse_bytes = num_nonzero * 4 + num_nonzero * 4 + (seq_len + 1) * 4

    return {
        "dense_mb": dense_bytes / 1e6,
        "sparse_mb": sparse_bytes / 1e6,
        "ratio": sparse_bytes / dense_bytes,
    }


memory_analysis = []
for n in [512, 1024, 2048, 4096, 8192]:
    for sparsity in [0.9, 0.95, 0.99]:
        result = analyze_memory_tradeoff(n, sparsity)
        memory_analysis.append(
            (
                n,
                sparsity,
                result["dense_mb"],
                result["sparse_mb"],
                result["ratio"],
            )
        )
Out[35]:
Console
Memory usage: Dense vs Sparse attention matrices

 Seq Len   Sparsity   Dense (MB)  Sparse (MB)    Ratio
-------------------------------------------------------
     512        90%         1.05         0.21     0.20
     512        95%         1.05         0.11     0.10
     512        99%         1.05         0.02     0.02
    1024        90%         4.19         0.84     0.20
    1024        95%         4.19         0.42     0.10
    1024        99%         4.19         0.09     0.02
    2048        90%        16.78         3.36     0.20
    2048        95%        16.78         1.69     0.10
    2048        99%        16.78         0.34     0.02
    4096        90%        67.11        13.44     0.20
    4096        95%        67.11         6.73     0.10
    4096        99%        67.11         1.36     0.02
    8192        90%       268.44        53.72     0.20
    8192        95%       268.44        26.88     0.10
    8192        99%       268.44         5.40     0.02

For high sparsity (95%+) and long sequences, sparse storage provides significant memory savings. However, for shorter sequences or moderate sparsity, the overhead of sparse formats can negate the benefits.

Gradient Flow

Sparse attention patterns can affect gradient flow during training. Positions that are never attended to receive no gradient signal through attention. This is typically acceptable when patterns ensure every position is reachable within a few layers, but can cause issues with overly aggressive sparsity.

The key insight is that multi-layer transformers compose attention patterns. Even if layer 1 uses sparse attention, the effective receptive field grows with depth. A position unreachable in one layer may become reachable through intermediate positions in subsequent layers.

Limitations and Impact

Sparse attention patterns represent a fundamental advance in efficient transformer design, enabling the processing of sequences that would be impossible with full attention. However, they come with trade-offs that practitioners must understand.

The most significant limitation is the potential loss of long-range dependencies. While patterns like local + strided ensure connectivity, they may not capture all important relationships as effectively as full attention. Tasks requiring precise long-range reasoning, such as mathematical proofs or complex code understanding, may suffer from aggressive sparsity. The solution often involves tuning the pattern parameters: larger windows, smaller strides, or additional global tokens to ensure critical positions remain connected.

Implementation complexity is another practical concern. While conceptually simple, efficient sparse attention requires careful memory management and often custom CUDA kernels to achieve the theoretical speedups. Libraries like xformers and Flash Attention provide optimized implementations, but integrating them into existing codebases requires effort. The block-sparse approach helps here by mapping to hardware-friendly operations, but still requires infrastructure beyond standard dense attention.

Despite these challenges, sparse attention patterns have enabled a new generation of long-context models. Longformer, BigBird, and their successors process documents, books, and codebases that were previously inaccessible to transformers. The efficiency gains are substantial: processing a 4,096-token document with sparse attention uses roughly the same resources as a 512-token document with full attention. This has opened new applications in document understanding, long-form generation, and multi-turn conversation.

Summary

Sparse attention patterns address the quadratic complexity bottleneck of standard attention by restricting each query to attend to a subset of keys. The key patterns are:

  • Local attention: Each position attends to a fixed window of nearby tokens, exploiting the locality of language. Complexity is O(nw)O(n \cdot w) where nn is the sequence length and ww is the window size.

  • Strided attention: Positions attend to regularly-spaced hub positions, enabling long-range information flow within two hops. Complexity is O(nn/s)O(n \cdot n/s) where ss is the stride, creating highways for information propagation across long sequences.

  • Block-sparse attention: Groups positions into blocks of size bb and defines attention at the block level. Complexity is O(nkb)O(n \cdot k \cdot b) where kk is the number of blocks attended to. Hardware-friendly due to regular memory access patterns.

  • Combined patterns: Real systems combine multiple patterns, often using factorized multi-head attention where different heads use different patterns.

The effectiveness of sparse attention rests on a key observation: most attention weights are small, concentrated on a few important positions. By carefully selecting which positions to attend to, sparse patterns preserve most of the representational power of full attention while dramatically reducing computational cost.

These building blocks form the foundation for efficient attention mechanisms like Longformer, BigBird, and Sparse Transformer. The next chapters explore specific architectures that combine sparse patterns with additional techniques like sliding windows and global tokens to achieve even better trade-offs between efficiency and expressiveness.

Key Parameters

When implementing sparse attention patterns, several parameters control the trade-off between efficiency and model quality:

  • window_size: The number of positions each query attends to on each side in local attention. Larger windows capture more context but increase computation. Typical values range from 64 to 512 tokens. Start with 256 for most tasks and adjust based on whether the model struggles with local dependencies.

  • stride: The interval between hub positions in strided attention. Smaller strides provide denser global connectivity but increase computation. Common values are 8 to 64. A stride of n\sqrt{n} (where nn is sequence length) balances coverage and efficiency.

  • block_size: The size of contiguous blocks in block-sparse attention. Must be chosen to align with GPU warp sizes (typically 32 or 64) for optimal hardware utilization. Larger blocks reduce indexing overhead but coarsen the sparsity pattern.

  • sparsity_ratio: The fraction of attention pairs that are masked out. Higher sparsity (0.9+) dramatically reduces computation but may degrade quality on tasks requiring dense interactions. Monitor validation loss when increasing sparsity.

  • num_heads: In factorized multi-head attention, determines how many heads use each pattern type. Splitting evenly between local and strided heads works well as a starting point. Models may benefit from more local heads for tasks with strong locality.

  • pattern: The combination strategy for multiple sparse patterns. Options include union (OR), which is most common, or alternating patterns across layers. The union approach ensures positions blocked by one pattern may still be reached through another.

Quiz

Ready to test your understanding? Take this quick quiz to reinforce what you've learned about sparse attention patterns and their role in efficient transformer design.

Loading component...
Track your reading progress

Sign in to mark chapters as read and track your learning journey

Sign in →

Comments

Reference

BIBTEXAcademic
@misc{sparseattentionpatternslocalstridedblocksparseapproaches, author = {Michael Brenndoerfer}, title = {Sparse Attention Patterns: Local, Strided & Block-Sparse Approaches}, year = {2025}, url = {https://mbrenndoerfer.com/writing/sparse-attention-patterns-efficient-transformers}, organization = {mbrenndoerfer.com}, note = {Accessed: 2025-12-19} }
APAAcademic
Michael Brenndoerfer (2025). Sparse Attention Patterns: Local, Strided & Block-Sparse Approaches. Retrieved from https://mbrenndoerfer.com/writing/sparse-attention-patterns-efficient-transformers
MLAAcademic
Michael Brenndoerfer. "Sparse Attention Patterns: Local, Strided & Block-Sparse Approaches." 2025. Web. 12/19/2025. <https://mbrenndoerfer.com/writing/sparse-attention-patterns-efficient-transformers>.
CHICAGOAcademic
Michael Brenndoerfer. "Sparse Attention Patterns: Local, Strided & Block-Sparse Approaches." Accessed 12/19/2025. https://mbrenndoerfer.com/writing/sparse-attention-patterns-efficient-transformers.
HARVARDAcademic
Michael Brenndoerfer (2025) 'Sparse Attention Patterns: Local, Strided & Block-Sparse Approaches'. Available at: https://mbrenndoerfer.com/writing/sparse-attention-patterns-efficient-transformers (Accessed: 12/19/2025).
SimpleBasic
Michael Brenndoerfer (2025). Sparse Attention Patterns: Local, Strided & Block-Sparse Approaches. https://mbrenndoerfer.com/writing/sparse-attention-patterns-efficient-transformers
Michael Brenndoerfer

About the author: Michael Brenndoerfer

All opinions expressed here are my own and do not reflect the views of my employer.

Michael currently works as an Associate Director of Data Science at EQT Partners in Singapore, leading AI and data initiatives across private capital investments.

With over a decade of experience spanning private equity, management consulting, and software engineering, he specializes in building and scaling analytics capabilities from the ground up. He has published research in leading AI conferences and holds expertise in machine learning, natural language processing, and value creation through data.

Stay updated

Get notified when I publish new articles on data and AI, private equity, technology, and more.

No spam, unsubscribe anytime.

or

Create a free account to unlock exclusive features, track your progress, and join the conversation.

No popupsUnobstructed readingCommenting100% Free