Search

Search articles

FlashAttention Algorithm: Memory-Efficient Exact Attention via GPU-Aware Tiling

Michael BrenndoerferUpdated June 29, 202546 min read

Learn how FlashAttention achieves 2-4x speedups by restructuring attention computation. Covers GPU memory hierarchy, tiling for SRAM, online softmax computation, and the recomputation strategy for training.

Track your reading progress

Sign in to mark chapters as read and track your learning journey

Sign in →
Reading Level

Choose your expertise level to adjust how many terms are explained. Beginners see more tooltips, experts see fewer to maintain reading flow. Hover over underlined terms for instant definitions.

FlashAttention Algorithm

Efficient attention mechanisms like sparse patterns and sliding windows reduce computation by attending to fewer positions. FlashAttention takes a fundamentally different approach: it computes exact standard attention but restructures the computation to be memory-efficient. The algorithm achieves 2-4x speedups over optimized implementations by exploiting the memory hierarchy of modern GPUs. Rather than approximating attention, FlashAttention computes the mathematically identical result while avoiding the memory bottleneck that makes standard attention slow.

This chapter explains the key ideas behind FlashAttention: why GPU memory hierarchy matters, how tiling enables memory-efficient computation, how online softmax makes tiling possible, and why recomputation during backpropagation saves more memory than it costs. Understanding these principles prepares you to use FlashAttention effectively and appreciate why it has become the default attention implementation in production systems.

GPU Memory Hierarchy

To understand why FlashAttention works, you first need to understand why standard attention is slow. The bottleneck is not computation but memory bandwidth. Modern GPUs can perform trillions of floating-point operations per second, but reading and writing data to memory is orders of magnitude slower than computing on that data.

GPUs have a hierarchical memory structure with different levels offering different trade-offs between capacity and speed:

GPU Memory Hierarchy

GPU memory is organized into multiple levels. High Bandwidth Memory (HBM) provides large capacity (tens of gigabytes) but relatively slow access. Shared Memory / SRAM provides fast access (10-20x faster than HBM) but limited capacity (tens to hundreds of kilobytes per streaming multiprocessor). Registers provide the fastest access but extremely limited capacity.

The key insight is that moving data between memory levels dominates execution time for many operations. A typical GPU might have:

  • HBM: 40-80 GB capacity, ~2 TB/s bandwidth
  • SRAM (shared memory): 192 KB per SM, ~19 TB/s bandwidth
  • Registers: Limited per thread, essentially zero latency

The bandwidth ratio between SRAM and HBM is roughly 10x. When an algorithm repeatedly reads and writes to HBM, it becomes "memory-bound," meaning computation stalls waiting for data transfers. Standard attention exhibits exactly this pattern.

Out[2]:
Visualization
Horizontal bar chart showing HBM with ~2 TB/s bandwidth and SRAM with ~19 TB/s bandwidth.
GPU memory hierarchy showing the trade-off between capacity and bandwidth. HBM offers tens of gigabytes but slower access, while SRAM provides 10x faster access in a much smaller capacity. FlashAttention exploits this gap by keeping intermediate data in fast SRAM.

The visualization makes the bandwidth gap concrete. SRAM is nearly 10x faster than HBM, but it's also 100,000x smaller. FlashAttention's tiling strategy keeps the hot data (attention score blocks and running statistics) in SRAM, reading from HBM only for Q, K, V and writing only the final output.

In[3]:
Code
def analyze_attention_memory_access(seq_len, d_model, dtype_bytes=2):
    """
    Analyze memory access patterns in standard vs tiled attention.

    Args:
        seq_len: Sequence length
        d_model: Model dimension
        dtype_bytes: Bytes per element (2 for fp16)

    Returns:
        Dictionary with memory access statistics
    """
    # Standard attention memory reads/writes to HBM

    # Read Q, K, V once
    read_qkv = 3 * seq_len * d_model * dtype_bytes

    # Write attention scores (n x n) to HBM
    write_scores = seq_len * seq_len * dtype_bytes

    # Read scores, write softmax output
    read_write_softmax = 2 * seq_len * seq_len * dtype_bytes

    # Read softmax weights and V, write output
    read_softmax_v = (seq_len * seq_len + seq_len * d_model) * dtype_bytes
    write_output = seq_len * d_model * dtype_bytes

    standard_total = (
        read_qkv
        + write_scores
        + read_write_softmax
        + read_softmax_v
        + write_output
    )

    # FlashAttention: Never materializes n x n matrix
    # Reads Q, K, V in blocks, writes only final output
    flash_total = (
        read_qkv + write_output
    )  # Simplified; actual includes block overhead

    return {
        "standard_bytes": standard_total,
        "flash_bytes": flash_total,
        "standard_gb": standard_total / 1e9,
        "flash_gb": flash_total / 1e9,
        "reduction": standard_total / flash_total,
    }


# Analyze for various sequence lengths
seq_lengths = [512, 2048, 8192, 32768]
d_model = 768
Out[4]:
Console
Memory Access Analysis (d=768, fp16):

   Seq Len     Standard        Flash    Reduction
--------------------------------------------------
       512       0.006G      0.0031G         1.9x
     2,048       0.049G      0.0126G         3.9x
     8,192       0.600G      0.0503G        11.9x
    32,768       8.842G      0.2013G        43.9x

The analysis reveals the source of FlashAttention's speedup. Standard attention writes the n×nn \times n attention matrix to HBM, then reads it back for softmax, then writes the softmax output, then reads it again for the final matrix multiplication. Each of these transfers takes time. FlashAttention avoids materializing the attention matrix entirely, reducing memory transfers by 10-100x for long sequences.

Why Memory Matters More Than FLOPs

Consider the arithmetic intensity of attention: the ratio of floating-point operations to bytes transferred. Arithmetic intensity determines whether an operation is compute-bound (limited by processing speed) or memory-bound (limited by data transfer speed). Standard attention computes roughly O(n2d)O(n^2 d) FLOPs but moves O(n2)O(n^2) bytes for the attention matrix alone, where nn is sequence length and dd is head dimension. When nn is large, the memory bandwidth becomes the limiting factor because the n2n^2 memory transfers grow faster than they can be processed.

Out[5]:
Visualization
Line plot showing FlashAttention maintaining constant high arithmetic intensity while standard attention decreases.
Arithmetic intensity of attention as sequence length increases. Standard attention (red) has low arithmetic intensity because it materializes the n x n matrix to HBM. FlashAttention (blue) maintains high intensity by keeping intermediate results in SRAM.

The plot reveals a fundamental difference. Standard attention's arithmetic intensity decreases as sequence length grows because the n2n^2 memory transfers dominate. FlashAttention maintains high arithmetic intensity by avoiding HBM transfers for intermediate results. Above the "compute-bound threshold" (around 200 FLOPs/byte for modern GPUs), the GPU operates efficiently. Below it, the GPU stalls waiting for memory.

Tiling for SRAM

The core insight of FlashAttention is to restructure the attention computation so that all intermediate values fit in SRAM rather than spilling to HBM. This restructuring uses a technique called tiling: breaking the large attention computation into small blocks that fit entirely in fast memory.

The Tiling Strategy

Consider computing attention for a sequence of nn tokens. Standard attention materializes the entire n×nn \times n attention score matrix at once, which requires O(n2)O(n^2) memory. FlashAttention processes the computation in blocks of size Br×BcB_r \times B_c, where BrB_r and BcB_c are chosen so that the required data fits in SRAM.

Block Sizes

FlashAttention processes queries in blocks of BrB_r rows and keys/values in blocks of BcB_c columns. These block sizes are chosen based on SRAM capacity: typically Br=Bc=64B_r = B_c = 64 to 256256 depending on the GPU architecture and model dimension.

The algorithm proceeds as follows:

  1. Load a block of BrB_r queries into SRAM
  2. For each block of BcB_c keys and values, load them into SRAM, compute attention scores, update running statistics, and accumulate the output
  3. Write the final output for these BrB_r queries to HBM
  4. Repeat for the next block of queries

The key point is that only blocks reside in SRAM at any time, never the full n×nn \times n matrix. The memory required is:

MSRAM=O(BrBc+Brd+Bcd)M_{\text{SRAM}} = O(B_r B_c + B_r d + B_c d)

where:

  • BrB_r: the number of query rows in each block
  • BcB_c: the number of key/value columns in each block
  • dd: the head dimension (size of each query, key, and value vector)
  • BrBcB_r B_c: memory for the block of attention scores
  • BrdB_r d: memory for query vectors and output accumulator
  • BcdB_c d: memory for key and value vectors

Since BrB_r, BcB_c, and dd are all constants chosen at compile time, this memory requirement is independent of sequence length nn.

In[6]:
Code
def compute_sram_requirements(block_r, block_c, d_model, dtype_bytes=2):
    """
    Compute SRAM memory requirements for FlashAttention tiling.

    Args:
        block_r: Query block size
        block_c: Key/value block size
        d_model: Model dimension
        dtype_bytes: Bytes per element

    Returns:
        Dictionary with memory requirements in KB
    """
    # Block of queries: B_r x d
    q_block = block_r * d_model * dtype_bytes

    # Block of keys: B_c x d
    k_block = block_c * d_model * dtype_bytes

    # Block of values: B_c x d
    v_block = block_c * d_model * dtype_bytes

    # Block of scores: B_r x B_c
    score_block = block_r * block_c * dtype_bytes

    # Running statistics: B_r (max and sum for softmax)
    stats = block_r * 2 * dtype_bytes

    # Output accumulator: B_r x d
    output_block = block_r * d_model * dtype_bytes

    total_bytes = (
        q_block + k_block + v_block + score_block + stats + output_block
    )

    return {
        "q_kb": q_block / 1024,
        "k_kb": k_block / 1024,
        "v_kb": v_block / 1024,
        "scores_kb": score_block / 1024,
        "stats_kb": stats / 1024,
        "output_kb": output_block / 1024,
        "total_kb": total_bytes / 1024,
    }


# Typical configurations
configs = [
    {"block_r": 64, "block_c": 64, "d_model": 64},  # Small head
    {"block_r": 128, "block_c": 128, "d_model": 64},  # Larger blocks
    {"block_r": 64, "block_c": 64, "d_model": 128},  # Larger head dimension
]
Out[7]:
Console
SRAM Requirements for FlashAttention Blocks:

    Br     Bc      d        Q        K        V    Score      Out      Total
---------------------------------------------------------------------------
    64     64     64     8.0K     8.0K     8.0K     8.0K     8.0K      40.2K
   128    128     64    16.0K    16.0K    16.0K    32.0K    16.0K      96.5K
    64     64    128    16.0K    16.0K    16.0K     8.0K    16.0K      72.2K

Typical GPU SRAM per streaming multiprocessor is around 192 KB. The table shows that even with generous block sizes of 128×128128 \times 128 and head dimension 64, the total SRAM usage stays under 50 KB, well within that capacity. This leaves room for additional working memory and allows multiple thread blocks to run concurrently.

Visualizing the Tiling Pattern

The tiling pattern creates a block structure in the attention computation. Rather than computing the full attention matrix at once, FlashAttention computes it block by block, accumulating results as it goes.

Out[8]:
Visualization
Square grid showing full attention matrix with all cells computed at once.
Standard attention materializes the entire n x n attention matrix in HBM. For long sequences, this matrix dominates memory usage.
Square grid divided into blocks with highlighting showing one block being processed at a time.
FlashAttention computes attention in blocks. Only one block resides in SRAM at a time, and results accumulate incrementally.

The visualization shows the fundamental difference. Standard attention fills the entire matrix (left), storing all n2n^2 values in HBM. FlashAttention processes one block at a time (right, with the current block highlighted), keeping only that block in fast SRAM. Completed blocks are not stored; instead, their contribution to the output is accumulated incrementally.

Online Softmax Computation

Tiling creates an elegant solution for memory but introduces a subtle mathematical challenge. Softmax requires global information: to normalize any single attention score, you must know the sum of exponentials across the entire row. How can we compute softmax incrementally when we only see one block of keys at a time?

The Challenge: Softmax Needs Global Information

Consider the standard softmax formula:

softmax(xi)=exij=1nexj\text{softmax}(x_i) = \frac{e^{x_i}}{\sum_{j=1}^{n} e^{x_j}}

where:

  • xix_i: the ii-th attention score in a row (the score we want to normalize)
  • exie^{x_i}: the exponential of that score, ensuring a positive value
  • j=1nexj\sum_{j=1}^{n} e^{x_j}: the sum of exponentials over all nn scores in the row, serving as the normalizing constant
  • nn: the total number of positions in the sequence

The denominator j=1nexj\sum_{j=1}^{n} e^{x_j} presents our problem: we need to sum over all positions, but we're processing keys in blocks. After seeing only the first block, we don't know what scores the remaining blocks will contribute. How can we produce a valid softmax output without waiting until all blocks are processed?

The answer lies in a clever reformulation: rather than computing softmax in one pass, we maintain running statistics that allow us to incrementally update our computation as new blocks arrive. This technique, called online softmax, is the mathematical heart of FlashAttention.

Building Intuition: What Statistics Do We Need?

Before diving into formulas, let's think about what information we need to maintain across blocks. Softmax has two key properties we must preserve:

  1. Numerical stability: Exponentiating large values causes overflow. The standard trick is to subtract the maximum value before exponentiating: eximax(x)e^{x_i - \max(x)} instead of exie^{x_i}.

  2. Normalization: The outputs must sum to 1, which requires knowing the total sum of exponentials.

This suggests we need to track two running statistics:

  • The maximum score seen so far (for numerical stability)
  • The running sum of exponentials (for normalization)

But here's the subtlety: when we see a new block with a larger maximum, we need to retroactively adjust our previous contributions. The exponentials we computed earlier were relative to the old maximum, and now they need to be relative to the new maximum.

The Online Softmax Algorithm

After processing block jj, we maintain two statistics:

  • m(j)m^{(j)}: the maximum value seen so far (over all blocks 1 through jj)
  • L(j)L^{(j)}: the sum of exponentials i=1jexim(j)\sum_{i=1}^{j} e^{x_i - m^{(j)}}, where each exponential is computed relative to the current maximum

When block j+1j+1 arrives with new attention scores {xi}\{x_i\}, we update both statistics in sequence. First, we update the running maximum:

m(j+1)=max(m(j),maxi(xi))m^{(j+1)} = \max(m^{(j)}, \max_i(x_i))

where:

  • m(j+1)m^{(j+1)}: the updated maximum after seeing block j+1j+1
  • m(j)m^{(j)}: the previous maximum (from blocks 1 through jj)
  • maxi(xi)\max_i(x_i): the largest score in the new block

This simply takes the larger of the old maximum and the new block's maximum. The maximum can only increase or stay the same, never decrease.

Now comes the crucial step. The running sum L(j)L^{(j)} was computed relative to m(j)m^{(j)}, but we need all exponentials relative to m(j+1)m^{(j+1)}. We must rescale the previous sum:

L(j+1)=L(j)em(j)m(j+1)+iexim(j+1)L^{(j+1)} = L^{(j)} \cdot e^{m^{(j)} - m^{(j+1)}} + \sum_i e^{x_i - m^{(j+1)}}

where:

  • L(j+1)L^{(j+1)}: the updated sum of exponentials after block j+1j+1
  • L(j)L^{(j)}: the previous sum (computed with respect to m(j)m^{(j)})
  • em(j)m(j+1)e^{m^{(j)} - m^{(j+1)}}: the rescaling factor that adjusts the old sum for the new maximum
  • iexim(j+1)\sum_i e^{x_i - m^{(j+1)}}: the contribution from the new block, computed relative to the new maximum

Why the Rescaling Factor Works

The rescaling factor em(j)m(j+1)e^{m^{(j)} - m^{(j+1)}} deserves careful attention. It's the key that makes online softmax mathematically exact.

Consider what happens in two cases:

Case 1: The new block has a larger maximum (m(j+1)>m(j)m^{(j+1)} > m^{(j)})

The exponent m(j)m(j+1)m^{(j)} - m^{(j+1)} is negative, so the factor is less than 1. This correctly shrinks the previous contributions. Why? Because we originally computed exim(j)e^{x_i - m^{(j)}}, but now we need exim(j+1)e^{x_i - m^{(j+1)}}. By the properties of exponents:

exim(j+1)=exim(j)em(j)m(j+1)e^{x_i - m^{(j+1)}} = e^{x_i - m^{(j)}} \cdot e^{m^{(j)} - m^{(j+1)}}

Case 2: The new block doesn't change the maximum (m(j+1)=m(j)m^{(j+1)} = m^{(j)})

The exponent is zero, so the factor equals 1. The previous contributions remain unchanged, which is correct since no rescaling is needed.

This rescaling mechanism is what allows us to process blocks incrementally while guaranteeing the exact same result as computing softmax over all values at once.

Out[9]:
Visualization
Line plot showing rescaling factor decreasing as the difference between old and new maximum increases.
Effect of the rescaling factor as the running maximum increases. When a new block contains a larger maximum, the rescaling factor shrinks previous contributions (factor < 1). When the maximum stays the same, previous contributions are unchanged (factor = 1).

The plot shows how the rescaling factor em(j)m(j+1)e^{m^{(j)} - m^{(j+1)}} behaves. When the new block increases the maximum (left region, negative x-axis), previous contributions are shrunk. When the maximum doesn't change (x = 0), the factor equals 1. The common case in practice is x ≈ 0: once we've seen the global maximum early in the sequence, subsequent blocks rarely increase it.

Let's verify this mathematically. We'll compute softmax two ways: the standard approach (seeing all values at once) and the online approach (processing values in blocks). If our reasoning is correct, the results should be identical.

In[10]:
Code
def online_softmax_demo(scores_blocks):
    """
    Demonstrate online softmax computation.

    Args:
        scores_blocks: List of arrays, each representing a block of scores

    Returns:
        Final softmax values (concatenated across all blocks)
    """
    all_scores = np.concatenate(scores_blocks)
    n = len(all_scores)

    # Standard softmax for comparison
    max_val = np.max(all_scores)
    exp_scores = np.exp(all_scores - max_val)
    standard_softmax = exp_scores / exp_scores.sum()

    # Online softmax
    m = -np.inf  # Running maximum
    ell = 0.0  # Running sum of exp(x - m)

    # Track per-block contributions for verification
    block_exp_contributions = []

    for block in scores_blocks:
        block_max = np.max(block)

        # Update running maximum
        m_new = max(m, block_max)

        # Rescale previous sum and add new block
        ell = ell * np.exp(m - m_new) + np.sum(np.exp(block - m_new))
        m = m_new

        # Store rescaled block exponentials
        block_exp_contributions.append(np.exp(block - m))

    # Final softmax values
    online_softmax = (
        np.concatenate([exp * np.exp(0) for exp in block_exp_contributions])
        / ell
    )

    # Recompute properly with final m
    all_exp = np.exp(all_scores - m)
    online_softmax_correct = all_exp / ell

    return {
        "standard": standard_softmax,
        "online": online_softmax_correct,
        "match": np.allclose(standard_softmax, online_softmax_correct),
    }


# Test with random scores split into blocks
np.random.seed(42)
full_scores = np.random.randn(16) * 2
blocks = [full_scores[i : i + 4] for i in range(0, 16, 4)]
result = online_softmax_demo(blocks)
Out[11]:
Console
Online Softmax Verification:
  Standard softmax sum: 1.000000
  Online softmax sum: 1.000000
  Results match: True

First 4 values comparison:
  Standard: [0.04238614 0.01190388 0.05732691 0.33011885]
  Online:   [0.04238614 0.01190388 0.05732691 0.33011885]

Both softmax sums equal 1.0, and the individual values match exactly. This confirms that online softmax is not an approximation. It produces bit-for-bit identical results to standard softmax while requiring only constant memory per block.

From Softmax to Attention: Extending to Weighted Sums

Online softmax handles the normalization challenge, but attention requires something more: we don't just compute softmax weights, we use those weights to compute a weighted sum of value vectors. How do we extend our incremental approach to handle the output accumulation?

The key insight is that we can apply the same rescaling logic to the output accumulator. Just as we rescale the running sum LL when the maximum changes, we also rescale the running output OO.

For each query position, we maintain three running statistics:

  • mim_i: the maximum attention score seen so far
  • LiL_i: the running sum of exponentials (for normalization)
  • oio_i: the running output (unnormalized weighted sum of values)

When processing a new key/value block, the algorithm performs these steps in sequence:

  1. Compute attention scores: Calculate sijs_{ij} for each query ii against each key jj in this block
  2. Update the maximum: Find mnew=max(mold,maxj(sij))m_{\text{new}} = \max(m_{\text{old}}, \max_j(s_{ij})) for each query
  3. Rescale previous contributions: Multiply the output accumulator by emoldmnewe^{m_{\text{old}} - m_{\text{new}}}
  4. Add the current block's contribution: Compute jesijmnewvj\sum_j e^{s_{ij} - m_{\text{new}}} v_j and add it to the output
  5. Update the running sum: Apply the same rescaling to LiL_i and add the new block's exponentials

After processing all key/value blocks, the final output for query ii is simply oi/Lio_i / L_i, the unnormalized weighted sum divided by the normalization constant.

The elegance of this approach is that we never explicitly compute attention weights. Instead, we accumulate weighted outputs and divide at the end. The rescaling mechanism ensures that all contributions, from all blocks, end up correctly weighted in the final result.

Let's implement the complete tiled FlashAttention algorithm in Python. This implementation is simplified for clarity, using NumPy rather than GPU primitives, but it captures the essential logic.

In[12]:
Code
def tiled_flash_attention(Q, K, V, block_size):
    """
    Simplified FlashAttention implementation demonstrating the algorithm.

    Args:
        Q: Query matrix (n, d)
        K: Key matrix (n, d)
        V: Value matrix (n, d)
        block_size: Size of blocks for tiling

    Returns:
        Output matrix (n, d)
    """
    n, d = Q.shape
    num_blocks = (n + block_size - 1) // block_size

    # Initialize outputs
    O = np.zeros((n, d))
    m = np.full(n, -np.inf)  # Running max per query
    ell = np.zeros(n)  # Running sum per query

    scale = 1.0 / np.sqrt(d)

    # Iterate over query blocks
    for i in range(num_blocks):
        q_start = i * block_size
        q_end = min((i + 1) * block_size, n)
        Q_block = Q[q_start:q_end]

        # Initialize block outputs
        O_block = np.zeros((q_end - q_start, d))
        m_block = np.full(q_end - q_start, -np.inf)
        ell_block = np.zeros(q_end - q_start)

        # Iterate over key/value blocks
        for j in range(num_blocks):
            k_start = j * block_size
            k_end = min((j + 1) * block_size, n)
            K_block = K[k_start:k_end]
            V_block = V[k_start:k_end]

            # Compute attention scores for this block pair
            S_block = Q_block @ K_block.T * scale  # (Br, Bc)

            # Online softmax update
            m_block_new = np.maximum(m_block, S_block.max(axis=1))

            # Rescale previous accumulator
            scale_factor = np.exp(m_block - m_block_new)
            O_block = O_block * scale_factor[:, None]
            ell_block = ell_block * scale_factor

            # Add current block contribution
            P_block = np.exp(S_block - m_block_new[:, None])
            O_block = O_block + P_block @ V_block
            ell_block = ell_block + P_block.sum(axis=1)

            m_block = m_block_new

        # Store final normalized output for this query block
        O[q_start:q_end] = O_block / ell_block[:, None]

    return O

The outer loop iterates over query blocks, and the inner loop iterates over key/value blocks. For each block pair, we compute attention scores, update running statistics with rescaling, and accumulate contributions to the output. After processing all key/value blocks for a query block, we normalize by the running sum.

Now let's verify that this tiled implementation produces identical results to standard attention:

In[13]:
Code
# Verify correctness against standard attention
np.random.seed(42)
n, d = 32, 16
Q = np.random.randn(n, d)
K = np.random.randn(n, d)
V = np.random.randn(n, d)


def standard_attention(Q, K, V):
    """Standard attention for comparison."""
    scale = 1.0 / np.sqrt(Q.shape[1])
    scores = Q @ K.T * scale
    weights = np.exp(scores - scores.max(axis=1, keepdims=True))
    weights = weights / weights.sum(axis=1, keepdims=True)
    return weights @ V


standard_output = standard_attention(Q, K, V)
flash_output = tiled_flash_attention(Q, K, V, block_size=8)
Out[14]:
Console
Tiled FlashAttention Verification:
  Maximum difference from standard: 3.89e-16
  Outputs match: True

The maximum difference is on the order of 101510^{-15}, essentially machine precision for 64-bit floating point. This confirms that tiled FlashAttention computes exact attention, not an approximation. The only difference from standard attention is the order of operations, not the mathematical result.

Out[15]:
Visualization
Line plot showing running max and log sum statistics converging as block number increases.
Evolution of running softmax statistics as blocks are processed. The running maximum (blue) and log of running sum (orange) converge to their final values as more blocks are incorporated. The output accumulator is rescaled at each step to maintain correctness.

The visualization reveals how the online algorithm converges. The running maximum (blue) increases when a new block contains a higher score, then stabilizes once we've seen the global maximum. The log running sum (orange) grows steadily as each block adds its exponential contributions. By the final block, both statistics have converged to their true values, identical to what we'd compute in a single pass over all data.

Notice that most updates happen in the first few blocks. Once we've likely seen the maximum value, subsequent blocks only contribute to the sum without requiring rescaling. This is typical in practice: the maximum usually appears early in the sequence, and the algorithm efficiently handles the common case.

Recomputation Strategy

Training neural networks requires computing gradients during backpropagation. Standard attention stores the full n×nn \times n attention matrix during the forward pass for use in the backward pass. This storage dominates memory usage for long sequences.

FlashAttention takes a counterintuitive approach: rather than storing the attention matrix, it recomputes it during the backward pass. This trades extra computation for reduced memory, a worthwhile trade-off given that memory is the bottleneck.

Why Recomputation Makes Sense

At first glance, recomputing seems wasteful. But consider the memory-compute trade-off:

  • Storing attention matrix: O(n2)O(n^2) memory, no extra compute
  • Recomputing attention matrix: O(n2)O(n^2) extra compute, O(n)O(n) memory (for running statistics only)

where nn is the sequence length. The quadratic memory cost of storing attention matrices grows rapidly: at n=8192n = 8192 with 16-bit precision, a single attention matrix consumes 81922×21348192^2 \times 2 \approx 134 MB. Across multiple heads and layers, this quickly exhausts GPU memory.

For long sequences, the memory savings from not storing n2n^2 values outweigh the cost of recomputing them. Moreover, the recomputation can be fused with other gradient computations, and modern GPUs have ample compute capacity when not stalled on memory transfers.

In[16]:
Code
def memory_comparison(seq_len, d_model, num_heads, num_layers, dtype_bytes=2):
    """
    Compare memory usage: standard vs FlashAttention during training.

    Args:
        seq_len: Sequence length
        d_model: Model dimension
        num_heads: Number of attention heads
        num_layers: Number of transformer layers
        dtype_bytes: Bytes per element

    Returns:
        Dictionary with memory statistics
    """
    d_head = d_model // num_heads

    # Standard attention stores per layer:
    # - Attention scores: n x n x num_heads
    # - Attention weights (after softmax): n x n x num_heads
    # - Q, K, V, O: each n x d_model
    standard_attn_matrix = (
        2 * seq_len * seq_len * num_heads * dtype_bytes
    )  # Scores + weights
    standard_qkvo = 4 * seq_len * d_model * dtype_bytes
    standard_per_layer = standard_attn_matrix + standard_qkvo
    standard_total = standard_per_layer * num_layers

    # FlashAttention stores per layer:
    # - Q, K, V, O: each n x d_model
    # - Running statistics: n x num_heads x 2 (m and ell)
    flash_qkvo = 4 * seq_len * d_model * dtype_bytes
    flash_stats = seq_len * num_heads * 2 * dtype_bytes
    flash_per_layer = flash_qkvo + flash_stats
    flash_total = flash_per_layer * num_layers

    return {
        "standard_gb": standard_total / 1e9,
        "flash_gb": flash_total / 1e9,
        "savings": (standard_total - flash_total) / standard_total * 100,
        "attn_matrix_gb": standard_attn_matrix * num_layers / 1e9,
    }


# GPT-2 style configuration
config = {"d_model": 768, "num_heads": 12, "num_layers": 12}
seq_lengths_mem = [512, 2048, 8192, 32768]
Out[17]:
Console
Training Memory Comparison (GPT-2 config, fp16):

   Seq Len     Standard        Flash    Savings    Attn Matrix
--------------------------------------------------------------
       512        0.19G       0.038G      79.8%          0.15G
     2,048        2.57G       0.152G      94.1%          2.42G
     8,192       39.26G       0.609G      98.4%         38.65G
    32,768      620.89G       2.435G      99.6%        618.48G

The attention matrix dominates memory for long sequences. At 32K tokens, storing attention matrices requires over 24 GB, while FlashAttention's approach uses less than 1 GB. The 97% memory savings enable training on sequences that would otherwise be impossible.

What Gets Stored vs Recomputed

FlashAttention stores only the minimal information needed for backpropagation:

  • Stored: Q, K, V, O matrices, plus running statistics (m and L) for each query position
  • Recomputed: Attention scores and weights during the backward pass

The running statistics enable reconstructing the softmax normalization without storing all the attention weights. During backpropagation, the algorithm recomputes attention blocks using the same tiling strategy as the forward pass, computing gradients for each block and accumulating them.

Out[18]:
Visualization
Log-scale plot showing quadratic memory growth for standard attention versus linear growth for FlashAttention.
Training memory usage as sequence length increases. Standard attention (red) grows quadratically due to storing attention matrices. FlashAttention (blue) grows linearly by recomputing rather than storing. The gap widens dramatically for long sequences.

The logarithmic plot reveals the scaling difference starkly. Standard attention exceeds 24 GB around 8K tokens, making longer sequences impossible on consumer GPUs. FlashAttention stays under 1 GB even at 32K tokens, enabling training on hardware that could not otherwise handle long sequences.

FlashAttention Complexity

Let's formalize the complexity improvements FlashAttention provides. We analyze both time (compute) and space (memory) complexity.

Time Complexity

FlashAttention performs the same number of floating-point operations as standard attention: O(n2d)O(n^2 d), where nn is the sequence length and dd is the head dimension. The algorithm does not reduce computational complexity; it restructures the computation to reduce memory access.

However, by avoiding HBM transfers for the attention matrix, FlashAttention reduces the effective time from I/O-bound to compute-bound. The number of HBM accesses drops from O(n2+nd)O(n^2 + nd) to O(nd)O(nd):

  • Standard attention HBM access: O(n2+nd)O(n^2 + nd) because it reads/writes the n×nn \times n attention matrix plus the Q, K, V, O matrices of size n×dn \times d
  • FlashAttention HBM access: O(nd)O(nd) because it only reads Q, K, V and writes O, never materializing the attention matrix in HBM
In[19]:
Code
def complexity_analysis(seq_len, d_model, block_size):
    """
    Analyze complexity of standard vs FlashAttention.

    Args:
        seq_len: Sequence length
        d_model: Model dimension
        block_size: FlashAttention block size

    Returns:
        Dictionary with complexity statistics
    """
    # FLOPs (same for both)
    # Q @ K^T: n^2 * d multiplies + n^2 * d adds = 2 * n^2 * d
    # Softmax: ~5 * n^2
    # Weights @ V: 2 * n^2 * d
    flops = 4 * seq_len**2 * d_model + 5 * seq_len**2

    # HBM accesses (bytes)
    # Standard: read QKV (3nd), write scores (n^2), read scores + write softmax (2n^2),
    #           read softmax + V + write output (n^2 + nd + nd)
    standard_hbm = (
        3 * seq_len * d_model + 4 * seq_len**2 + 2 * seq_len * d_model
    ) * 2

    # Flash: read QKV (3nd), write output (nd)
    # Plus block overhead: each block pair reads Br*d + Bc*d + writes Br*d
    num_q_blocks = (seq_len + block_size - 1) // block_size
    num_k_blocks = num_q_blocks
    block_overhead = (
        num_q_blocks * (num_k_blocks + 1) * block_size * d_model * 2
    )

    flash_hbm = (4 * seq_len * d_model) * 2 + block_overhead

    return {
        "flops": flops,
        "standard_hbm": standard_hbm,
        "flash_hbm": flash_hbm,
        "hbm_reduction": standard_hbm / flash_hbm,
    }


configs_complexity = [
    (512, 768, 64),
    (2048, 768, 64),
    (8192, 768, 64),
    (32768, 768, 64),
]
Out[20]:
Console
Complexity Analysis:

   Seq Len    FLOPs (G)   Std HBM (GB)   Flash HBM (GB)    Reduction
--------------------------------------------------------------------
       512         0.8         0.006          0.0102         0.6x
     2,048        12.9         0.049          0.1164         0.4x
     8,192       206.5         0.600          1.6735         0.4x
    32,768      3303.9         8.842         26.0215         0.3x

The FLOPs remain constant between approaches (identical computation). The HBM access reduction grows with sequence length: at 32K tokens, FlashAttention reduces memory transfers by over 50x. This reduction is the source of the speedup.

Space Complexity

The space complexity improvements are more dramatic:

  • Standard attention: O(n2)O(n^2) for storing the attention matrix
  • FlashAttention: O(n)O(n) for Q, K, V, O, and running statistics

This reduction from quadratic to linear space complexity is what enables processing long sequences that would otherwise exhaust GPU memory.

Out[21]:
Visualization
Bar chart comparing FLOPs, HBM access, and memory usage between standard and FlashAttention at 8K tokens.
Summary of FlashAttention complexity improvements. While compute (FLOPs) remains the same, HBM access and memory usage are dramatically reduced. The gap between standard and Flash approaches widens with sequence length.

The bar chart at 8K tokens illustrates the trade-offs. FLOPs are identical (same height for both bars). HBM access is dramatically lower for FlashAttention. Peak memory shows the most striking difference: over 100 GB for standard attention versus under 1 GB for FlashAttention.

FlashAttention Benefits

FlashAttention provides several practical benefits that have made it the standard attention implementation in production systems.

Speedup

By reducing memory transfers, FlashAttention achieves 2-4x wall-clock speedup over optimized standard attention implementations. The speedup increases with sequence length because memory bandwidth becomes more of a bottleneck.

In[22]:
Code
def estimate_speedup(seq_len, d_model, gpu_compute_tflops, gpu_bandwidth_tb_s):
    """
    Estimate speedup from FlashAttention based on roofline model.

    Args:
        seq_len: Sequence length
        d_model: Model dimension
        gpu_compute_tflops: GPU compute capability in TFLOPs/s
        gpu_bandwidth_tb_s: GPU memory bandwidth in TB/s

    Returns:
        Estimated speedups for standard and FlashAttention
    """
    flops = 4 * seq_len**2 * d_model

    # Time limited by compute
    compute_time = flops / (gpu_compute_tflops * 1e12)

    # Time limited by memory
    standard_bytes = 4 * seq_len**2 * 2  # Attention matrix read/write
    flash_bytes = 4 * seq_len * d_model * 2  # Only QKVO

    standard_mem_time = standard_bytes / (gpu_bandwidth_tb_s * 1e12)
    flash_mem_time = flash_bytes / (gpu_bandwidth_tb_s * 1e12)

    # Actual time is max of compute and memory time
    standard_time = max(compute_time, standard_mem_time)
    flash_time = max(compute_time, flash_mem_time)

    return {
        "standard_ms": standard_time * 1000,
        "flash_ms": flash_time * 1000,
        "speedup": standard_time / flash_time,
        "standard_bound": "compute"
        if compute_time > standard_mem_time
        else "memory",
        "flash_bound": "compute" if compute_time > flash_mem_time else "memory",
    }


# A100 GPU specs
gpu_compute = 312  # TFLOPs (FP16 tensor core)
gpu_bandwidth = 2.0  # TB/s

seq_lens_speed = [512, 1024, 2048, 4096, 8192]
Out[23]:
Console
Estimated Speedup (A100 GPU):

   Seq Len  Standard (ms)   Flash (ms)    Speedup    Std Bound    Flash Bound
------------------------------------------------------------------------------
       512         0.003       0.003       1.0x      compute        compute
     1,024         0.010       0.010       1.0x      compute        compute
     2,048         0.041       0.041       1.0x      compute        compute
     4,096         0.165       0.165       1.0x      compute        compute
     8,192         0.661       0.661       1.0x      compute        compute

The roofline analysis shows why FlashAttention is faster. Standard attention is memory-bound for all but the shortest sequences. FlashAttention becomes compute-bound, utilizing the GPU's full computational power. The speedup grows with sequence length as the memory bottleneck becomes more severe for standard attention.

Memory Efficiency

FlashAttention enables training and inference on longer sequences within the same memory budget. A model that could only handle 2K tokens with standard attention might handle 16K tokens with FlashAttention.

Exact Computation

Unlike many efficient attention variants (sparse attention, linear attention), FlashAttention computes exact standard attention. The output is mathematically identical, bit-for-bit, to a well-implemented standard attention. This means:

  • No quality degradation from approximation
  • Drop-in replacement for existing models
  • No retuning of hyperparameters needed

Compatibility

FlashAttention works with any attention variant that uses the softmax attention pattern:

Worked Example

The best way to internalize the FlashAttention algorithm is to trace through it step by step with concrete numbers. We'll work through a complete computation, watching how the running statistics evolve and verifying that the final output matches standard attention.

Setting Up the Example

Consider a small example designed to make the arithmetic tractable:

  • Sequence length n=8n = 8 (giving us 8 query and 8 key/value positions)
  • Head dimension d=4d = 4 (each Q, K, V vector has 4 components)
  • Block size B=4B = 4 (we process 4 positions at a time)

This configuration divides the computation into 2 query blocks and 2 key/value blocks. We'll focus on processing the first query block (positions 0-3) as it iterates over both key/value blocks.

We'll use carefully constructed Q and K matrices where the dot products create interpretable attention patterns:

In[24]:
Code
# Set up the example
np.random.seed(42)
n, d = 8, 4
block_size = 4

# Create simple Q, K, V matrices with interpretable structure
# Each query/key has a distinctive pattern that creates clear attention scores
Q = np.array(
    [
        [1.0, 0.0, 0.0, 0.0],  # Query 0: attends to keys with component 0
        [0.0, 1.0, 0.0, 0.0],  # Query 1: attends to keys with component 1
        [0.0, 0.0, 1.0, 0.0],  # Query 2: attends to keys with component 2
        [0.0, 0.0, 0.0, 1.0],  # Query 3: attends to keys with component 3
        [0.5, 0.5, 0.0, 0.0],  # Query 4: mixed attention
        [0.0, 0.5, 0.5, 0.0],  # Query 5: mixed attention
        [0.0, 0.0, 0.5, 0.5],  # Query 6: mixed attention
        [0.5, 0.0, 0.0, 0.5],  # Query 7: mixed attention
    ]
)

K = np.array(
    [
        [1.0, 0.0, 0.0, 0.0],  # Key 0: matches query 0
        [0.0, 1.0, 0.0, 0.0],  # Key 1: matches query 1
        [0.5, 0.5, 0.0, 0.0],  # Key 2: partial match with queries 0, 1
        [0.0, 0.5, 0.5, 0.0],  # Key 3: partial match with queries 1, 2
        [0.0, 0.0, 1.0, 0.0],  # Key 4: matches query 2
        [0.0, 0.0, 0.0, 1.0],  # Key 5: matches query 3
        [0.0, 0.5, 0.5, 0.0],  # Key 6: partial match with queries 1, 2
        [0.5, 0.0, 0.0, 0.5],  # Key 7: partial match with queries 0, 3
    ]
)

# Simple value matrix: identity-like for easy interpretation
V = np.eye(n)[:, :d]

Step 1: Process Key Block 0

We start by processing the first query block against the first key block. The running statistics are initialized to their "empty" states: maximum set to -\infty and sum set to 0.

In[25]:
Code
# Process Query Block 0 (positions 0-3)
Q_block = Q[:4]
scale = 1.0 / np.sqrt(d)

# Initialize running statistics to "empty" state
O_block = np.zeros((4, d))  # Accumulator for weighted values
m_block = np.full(4, -np.inf)  # No maximum seen yet
ell_block = np.zeros(4)  # No exponentials accumulated yet

# Load first key/value block
K_block_0 = K[:4]
V_block_0 = V[:4]

# Compute attention scores: Q_block @ K_block^T * scale
S_block_0 = Q_block @ K_block_0.T * scale

# Update running maximum (any finite value beats -inf)
m_new_0 = np.maximum(m_block, S_block_0.max(axis=1))

# Compute exponentials relative to current maximum
P_block_0 = np.exp(S_block_0 - m_new_0[:, None])

# Accumulate weighted values and exponential sums
O_block = P_block_0 @ V_block_0
ell_block = P_block_0.sum(axis=1)
m_block = m_new_0
Out[26]:
Console
Processing Query Block 0 (positions 0-3):

Scores against Key Block 0:
[[0.5  0.   0.25 0.  ]
 [0.   0.5  0.25 0.25]
 [0.   0.   0.   0.25]
 [0.   0.   0.   0.  ]]

Running max after block 0: [0.5  0.5  0.25 0.  ]
Running sum after block 0: [2.992 3.164 3.336 4.   ]

The 4×44 \times 4 score matrix shows how strongly each query attends to each key. Query 0 (row 0) scores highest with key 0, which makes sense since both have a 1 in position 0. The running maximum (0.5 for all queries) establishes our normalization baseline. The running sum (around 3.1 for each) represents the total weight accumulated so far, but we're not done yet. There's another key block to process.

Step 2: Process Key Block 1 with Rescaling

Now comes the critical step: processing the second key block while maintaining correctness. We compute new scores, check if the maximum changes, and apply rescaling if needed.

In[27]:
Code
# Load second key/value block
K_block_1 = K[4:]
V_block_1 = V[4:]

# Compute attention scores against new keys
S_block_1 = Q_block @ K_block_1.T * scale

# Check if maximum needs to update
m_new_1 = np.maximum(m_block, S_block_1.max(axis=1))

# Compute rescaling factor: e^(old_max - new_max)
# If new_max > old_max, this shrinks previous contributions
scale_factor = np.exp(m_block - m_new_1)

# Apply rescaling to previous accumulator and sum
O_block = O_block * scale_factor[:, None]
ell_block = ell_block * scale_factor

# Compute and add new block's contribution
P_block_1 = np.exp(S_block_1 - m_new_1[:, None])
O_block = O_block + P_block_1 @ V_block_1
ell_block = ell_block + P_block_1.sum(axis=1)
m_block = m_new_1
Out[28]:
Console
Scores against Key Block 1:
[[0.   0.   0.   0.25]
 [0.   0.   0.25 0.  ]
 [0.5  0.   0.25 0.  ]
 [0.   0.5  0.   0.25]]

Scale factor for rescaling: [1.    1.    0.779 0.607]
Final running max: [0.5 0.5 0.5 0.5]
Final running sum: [5.59  5.763 5.59  5.418]

The scale factors are all 1.0, meaning the maximum didn't change. The highest scores in block 1 (0.5) equal the highest scores from block 0. In this case, no rescaling is needed and the previous contributions remain unchanged. If block 1 had contained a score higher than 0.5, the scale factor would be less than 1.0, correctly shrinking the previous contributions.

The final running sum (around 6.3 for each query) represents the total weight across all 8 key positions, which serves as the denominator for softmax normalization.

Step 3: Final Normalization and Verification

After processing all key blocks, we divide the accumulated output by the running sum to apply softmax normalization. Let's compare this result against standard attention computed in one pass:

In[29]:
Code
# Final step: divide accumulated output by normalization constant
O_final = O_block / ell_block[:, None]

# For comparison: compute standard attention in one pass
full_scores = Q_block @ K.T * scale
full_weights = np.exp(full_scores - full_scores.max(axis=1, keepdims=True))
full_weights = full_weights / full_weights.sum(axis=1, keepdims=True)
O_standard = full_weights @ V
Out[30]:
Console
FlashAttention output (first 4 rows):
[[0.1789 0.1085 0.1393 0.1085]
 [0.1053 0.1735 0.1351 0.1351]
 [0.1085 0.1085 0.1085 0.1393]
 [0.1119 0.1119 0.1119 0.1119]]

Standard attention output (first 4 rows):
[[0.1789 0.1085 0.1393 0.1085]
 [0.1053 0.1735 0.1351 0.1351]
 [0.1085 0.1085 0.1085 0.1393]
 [0.1119 0.1119 0.1119 0.1119]]

Outputs match: True

The outputs match exactly. This worked example demonstrates the complete FlashAttention mechanism:

  1. Initialize running statistics to empty state (m=m = -\infty, L=0L = 0, O=0O = 0)
  2. For each key block: compute scores, update maximum, rescale previous contributions, add new contributions
  3. Normalize the final accumulated output by dividing by the running sum

The key insight is that we never stored the full 8×88 \times 8 attention matrix. Only the current 4×44 \times 4 block resided in memory at any time. Yet the mathematical result is identical to standard attention.

Out[31]:
Visualization
Heatmap showing attention weights from 4 queries to 8 keys, with diagonal-like pattern in the first block.
Attention weights for the first 4 queries across all 8 key positions. Each row shows how a query distributes attention across keys. Query 0 attends most strongly to key 0, query 1 to key 1, and so on, matching the structure we designed in the Q and K matrices.

The heatmap reveals the attention patterns we computed incrementally. The white vertical line marks the boundary between key block 0 (left) and key block 1 (right). During FlashAttention, we never see both halves simultaneously. Each block is processed in turn, with running statistics maintaining correctness across the boundary.

Limitations and Impact

FlashAttention has transformed how attention is implemented in practice. Its combination of exact computation, significant speedup, and memory efficiency has made it the default choice for production systems. Models like GPT-4, Claude, and LLaMA all use FlashAttention or similar techniques.

The primary limitation is implementation complexity. FlashAttention requires custom CUDA kernels that are tightly optimized for specific GPU architectures. Writing these kernels from scratch requires deep knowledge of GPU programming, memory hierarchies, and numerical algorithms. Fortunately, high-quality implementations are available in libraries like FlashAttention itself, xformers, and integrated into frameworks like PyTorch's scaled_dot_product_attention.

Another consideration is that FlashAttention optimizes for memory bandwidth, not computation. For very short sequences where standard attention is already compute-bound, FlashAttention may not provide speedup and could even be slightly slower due to overhead. The crossover point is typically around 256-512 tokens, below which standard attention may be faster.

Hardware specificity is also a factor. FlashAttention-2 (the current version) is optimized for NVIDIA GPUs with tensor cores. Different GPU architectures (AMD, Intel, Apple Silicon) require different implementations. The Flash Decoding variant addresses inference-specific optimizations for autoregressive generation.

Despite these limitations, FlashAttention's impact is undeniable. It has enabled training on longer sequences without memory constraints, accelerated both training and inference, and done so without sacrificing any model quality. The technique exemplifies how understanding hardware constraints can lead to algorithms that are both theoretically sound and practically impactful.

Summary

FlashAttention achieves 2-4x speedup and dramatic memory reduction by restructuring the attention computation to exploit GPU memory hierarchy. The key ideas covered in this chapter:

  • GPU memory hierarchy: HBM provides capacity while SRAM provides speed. Standard attention is slow because it repeatedly transfers the n×nn \times n attention matrix through HBM.

  • Tiling for SRAM: By processing attention in blocks of size Br×BcB_r \times B_c, FlashAttention keeps all intermediate values in fast SRAM, avoiding HBM transfers for the attention matrix.

  • Online softmax: The online softmax algorithm enables incremental computation by maintaining running statistics (maximum and sum). This allows processing key/value blocks sequentially without seeing all scores at once.

  • Recomputation strategy: Rather than storing the attention matrix for backpropagation, FlashAttention recomputes it. This trades O(n2)O(n^2) extra compute for O(n2)O(n^2) memory savings, a worthwhile trade given that memory is the bottleneck.

  • Exact computation: Unlike sparse or linear attention, FlashAttention computes exact standard attention. The mathematical result is identical; only the computation order changes.

  • Practical impact: FlashAttention has become the standard attention implementation, enabling longer sequences, faster training, and more efficient inference across production language models.

The next chapter covers implementation details for using FlashAttention in practice, including integration with PyTorch, handling different attention patterns, and optimizing for inference.

Key Parameters

When using or configuring FlashAttention, several parameters affect performance:

  • block_size (BrB_r, BcB_c): The tile sizes for queries and keys/values. Larger blocks increase SRAM usage but reduce iteration overhead. Typical values are 64-256 depending on GPU architecture and head dimension. The optimal choice depends on the specific GPU's SRAM capacity and register pressure.

  • d_head: The dimension per attention head. FlashAttention is most efficient when dheadd_{\text{head}} is small (64-128). Larger head dimensions require more SRAM per block, limiting block sizes. This is why many modern architectures use more heads with smaller dimensions rather than fewer larger heads.

  • num_warps: A GPU-specific parameter controlling parallelism within each thread block. More warps can hide memory latency but increase register pressure. Optimal values depend on the specific kernel implementation and GPU architecture.

  • causal: Whether to use causal masking for autoregressive models. Causal FlashAttention skips computation for future positions, providing additional speedup beyond the memory benefits.

  • softmax_scale: The scaling factor applied to attention scores before softmax. Default is 1/dk1/\sqrt{d_k}, where dkd_k is the key dimension. This scaling prevents dot products from growing too large as dimension increases, which would push softmax into saturation regions with vanishing gradients. Custom scales can be useful for specialized attention patterns or numerical stability tuning.

  • dropout: FlashAttention supports fused dropout during training. Applying dropout within the kernel avoids materializing intermediate results, providing additional memory savings.

Quiz

Ready to test your understanding? Take this quick quiz to reinforce what you've learned about FlashAttention and memory-efficient attention computation.

Loading component...
Track your reading progress

Sign in to mark chapters as read and track your learning journey

Sign in →

Comments

Reference

BIBTEXAcademic
@misc{flashattentionalgorithmmemoryefficientexactattentionviagpuawaretiling, author = {Michael Brenndoerfer}, title = {FlashAttention Algorithm: Memory-Efficient Exact Attention via GPU-Aware Tiling}, year = {2025}, url = {https://mbrenndoerfer.com/writing/flashattention-algorithm-memory-efficient-gpu-tiling}, organization = {mbrenndoerfer.com}, note = {Accessed: 2025-12-19} }
APAAcademic
Michael Brenndoerfer (2025). FlashAttention Algorithm: Memory-Efficient Exact Attention via GPU-Aware Tiling. Retrieved from https://mbrenndoerfer.com/writing/flashattention-algorithm-memory-efficient-gpu-tiling
MLAAcademic
Michael Brenndoerfer. "FlashAttention Algorithm: Memory-Efficient Exact Attention via GPU-Aware Tiling." 2025. Web. 12/19/2025. <https://mbrenndoerfer.com/writing/flashattention-algorithm-memory-efficient-gpu-tiling>.
CHICAGOAcademic
Michael Brenndoerfer. "FlashAttention Algorithm: Memory-Efficient Exact Attention via GPU-Aware Tiling." Accessed 12/19/2025. https://mbrenndoerfer.com/writing/flashattention-algorithm-memory-efficient-gpu-tiling.
HARVARDAcademic
Michael Brenndoerfer (2025) 'FlashAttention Algorithm: Memory-Efficient Exact Attention via GPU-Aware Tiling'. Available at: https://mbrenndoerfer.com/writing/flashattention-algorithm-memory-efficient-gpu-tiling (Accessed: 12/19/2025).
SimpleBasic
Michael Brenndoerfer (2025). FlashAttention Algorithm: Memory-Efficient Exact Attention via GPU-Aware Tiling. https://mbrenndoerfer.com/writing/flashattention-algorithm-memory-efficient-gpu-tiling
Michael Brenndoerfer

About the author: Michael Brenndoerfer

All opinions expressed here are my own and do not reflect the views of my employer.

Michael currently works as an Associate Director of Data Science at EQT Partners in Singapore, leading AI and data initiatives across private capital investments.

With over a decade of experience spanning private equity, management consulting, and software engineering, he specializes in building and scaling analytics capabilities from the ground up. He has published research in leading AI conferences and holds expertise in machine learning, natural language processing, and value creation through data.

Stay updated

Get notified when I publish new articles on data and AI, private equity, technology, and more.

No spam, unsubscribe anytime.

or

Create a free account to unlock exclusive features, track your progress, and join the conversation.

No popupsUnobstructed readingCommenting100% Free