Chapter 8 · Full excerpt
Flash Attention and tiled kernels.
Flash Attention computes the same attention output as the standard kernel without ever materialising the full T × T score matrix. It works the score table in small tiles, accumulates the softmax through a streaming algorithm that maintains a running maximum and a running sum, and folds each tile's contribution into a partial output before throwing the tile away. The math is identical, bit for bit in fp32. Only the order of operations changes — and that change is what trades a little recomputation for a quadratic-to-linear reduction in memory.
This chapter builds tile-based attention from scratch in pure PyTorch, verifies it against the naive version, watches peak memory fall off the quadratic curve, and then breaks the tiling on purpose by removing the running-maximum subtraction so you can see exactly which line of the streaming softmax keeps the entire kernel numerically alive. By the end you can read the production Triton and CUDA versions in flash-attn and recognise the same shape underneath.
This is Chapter 8 of Under The Hood — Build Every Layer of a Large Language Model from Scratch. The full 35-project book is on Leanpub. Code companion at github.com/mechramc/Under-the-hood/tree/main/projects/08-flash-attention.
The concept
Picture a chef with a one-cup pan working a recipe that asks for twenty cups mixed and heated together. The options are: buy a bigger pan, or rewrite the recipe to work a cup at a time with a running total in a separate bowl, never holding more than a cup in the pan at once. Flash Attention is the second option. The pan is your GPU's fast on-chip memory. The bowl is the accumulator. The cup at a time is a tile.
Attention asks the model, for every position in a sequence, to compare itself against every other position and decide who matters. With a sequence of length N, that comparison produces an N × N table of scores. At N = 4096, that table holds about 16.8 million entries per head, and the model has several heads, several layers, and a batch of sequences. Materialising it — actually allocating that N × N tensor in memory and writing every score into it — is the move that breaks at long context.
Flash Attention's central idea is that you never have to materialise the table. The final output of an attention layer is a weighted sum of value vectors, with weights from a softmax over the score table. Compute that output in pieces: take a small block of rows from the query, a small block of columns from the key, compute the partial scores for just that tile, fold them into a running output, throw the tile away, and move on. The full table is never written. The result is identical.
The trick that makes this possible is the online softmax, also called a streaming softmax. A standard softmax over N numbers needs to see all N before it can produce any output. It computes the maximum (for numerical stability), subtracts that maximum from every number, exponentiates, sums, and divides. Every one of those steps depends on all N numbers being available at once. That is the dependency we want to break.
The online version watches the numbers arrive one block at a time and keeps two running statistics: the running maximum seen so far, and the running sum of exponentials rescaled to that maximum. When a new block arrives, compute its local maximum and local exponential sum, compare against the running maximum, and rescale the older statistics if the new block's maximum is larger. The output you accumulate uses the same rescaling. By the end of the last block, the running statistics give you exactly the same answer as the all-at-once version, to the bit.
Two terms worth defining before they appear in code. A tile is a small block of a larger tensor, sized to fit in a fast memory region — on a real GPU, the on-chip SRAM, perhaps 100 KB per streaming multiprocessor. A fused kernel combines what would otherwise be several separate passes through memory — scores, softmax, value multiply — into one computation done on one tile while the tile is still in fast memory. The win is not only less memory. It is fewer round trips between fast on-chip memory and slow off-chip memory, which is where most of a GPU's wall-clock time goes during attention.
The original Flash Attention paper, by Tri Dao and colleagues in 2022, calls this property IO-aware: the algorithm is designed around the cost of moving data between memory tiers, not around the cost of arithmetic. Attention is bandwidth-bound at long context lengths, not compute-bound. Reducing memory traffic is the actual lever.
Why it matters
Every chapter after this one assumes long context as a baseline. Without tiled attention, the rest of the book runs into a wall at sequence length somewhere between 2K and 4K on most consumer GPUs. With tiled attention, the wall moves out by a factor of ten or more, and the projects that depend on long context become reachable on the same hardware.
The hand-written attention from Chapter 4 allocates a tensor of shape (batch, heads, seq, seq) to hold the scores. At sequence length 1024 with 8 heads and batch 8 in fp16, that tensor is roughly 128 MB per layer. At 4096 it balloons to 2 GB per layer. A six-layer model wants 12 GB for nothing but score tensors, before parameters, activations, gradients, or optimizer state. On an 8 GB card the run dies before it begins. Parameters fit. Activations fit. The single thing that does not fit is the materialised score table.
There is a forward-pointing reason too. Chapter 13 introduces the inference-time KV cache, where each new token computes attention against the single new query row against the full cached key and value blocks. That decoding-time computation is itself a tiled attention: one query row tile against a long key-value column block. Once you have built the training-time tiled forward here, the inference-time variant in Chapter 13 is a small specialisation of the same kernel, not a new mystery.
Most explanations of Flash Attention open with the CUDA kernel and leave the algorithm implicit. That order is backwards. Once you understand the algorithm, the kernel is straightforward engineering. Without the algorithm, the kernel reads like incantations.
The build
Extend the attention from Chapter 4 in seven steps. Measure where the naive memory footprint dies. Implement a tiled forward in pure PyTorch — no Triton, no CUDA, no kernel-fusion tricks, just plain tensor operations. Add the streaming softmax, which is the piece that makes the tiling correct. Verify against the naive version, measure memory, and watch the curve flatten. Then wire it into the working GPT from Chapter 5 and benchmark wall-clock.
Step 1 — Measure where the naive version dies
Start from the attention you wrote in Chapter 4. The relevant block is a textbook materialised attention with no causal masking shown for brevity:
def naive_attention(q, k, v):
# q, k, v: (batch, heads, seq, dim)
d = q.shape[-1]
scores = (q @ k.transpose(-2, -1)) / (d ** 0.5) # (B, H, N, N)
weights = torch.softmax(scores, dim=-1) # (B, H, N, N)
return weights @ v # (B, H, N, D)
The two intermediate tensors of shape (B, H, N, N) are the entire problem. With B=2, H=8, N=4096 in fp16, that is 1 GB per tensor, and PyTorch keeps both scores and weights alive during forward. Write a small probe that measures peak memory at several sequence lengths:
def measure_naive(seq_lens, batch=2, heads=8, dim=64,
device="cuda", dtype=torch.float16):
results = {}
for n in seq_lens:
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
q = torch.randn(batch, heads, n, dim, device=device, dtype=dtype)
k = torch.randn_like(q)
v = torch.randn_like(q)
try:
out = naive_attention(q, k, v)
torch.cuda.synchronize()
results[n] = torch.cuda.max_memory_allocated() / 1e9
except torch.cuda.OutOfMemoryError:
results[n] = None
return results
On an 8 GB card you will see peak memory roughly quadruple every time you double N, and None appears at 4096 or 8192. That is the wall. Save this dictionary; the tiled version reuses the same probe and produces a strikingly different curve.
Step 2 — Decompose the computation into tiles
The tiled forward replaces the all-at-once score matrix with a loop over column tiles, processing one row tile of the query at a time. Define two block sizes upfront — one for query rows, one for key and value columns:
BR = 64 # query row tile size
BC = 64 # key/value column tile size
These two numbers are the analogue of pan size. Sixty-four is small enough to fit anywhere and large enough to amortise the loop overhead. Production CUDA tuning sweeps these values per GPU; here we pick a value that works and move on.
The query tensor gets split along the sequence dimension into N / BR row tiles of shape (B, H, BR, D). The key and value tensors get split into N / BC column tiles of shape (B, H, BC, D). For each query row tile, we walk every key-value column tile, accumulate partial outputs and partial softmax statistics, and emit one output row tile at the end. The outer skeleton:
def tiled_attention_skeleton(q, k, v, BR=64, BC=64):
B, H, N, D = q.shape
out = torch.zeros_like(q)
for i in range(0, N, BR):
q_tile = q[:, :, i:i+BR, :] # (B, H, BR, D)
# running statistics for this query row tile, set up in Step 3
for j in range(0, N, BC):
k_tile = k[:, :, j:j+BC, :] # (B, H, BC, D)
v_tile = v[:, :, j:j+BC, :] # (B, H, BC, D)
scores_tile = q_tile @ k_tile.transpose(-2, -1) # (B, H, BR, BC)
# update running statistics and partial output using scores_tile
...
out[:, :, i:i+BR, :] = ...
return out
The tensor scores_tile is shape (B, H, BR, BC), never (B, H, N, N). At BR = BC = 64, this tile is the same size regardless of whether N is 1024 or 32768. That is the entire memory win. The score matrix never exists; only one small tile of it exists at a time, and the tile gets overwritten on the next iteration.
Step 3 — Build the streaming softmax
Standard softmax over a row of N scores produces exp(s_j - m) / sum_k exp(s_k - m), where m is the maximum of the row, included for numerical stability. The streaming version computes the same final answer but processes scores one column tile at a time, maintaining three pieces of state per query row: the running maximum m_running, the running rescaled sum l_running, and the running partial output o_running. When a new tile arrives with scores s_tile, compute three things from the tile alone:
m_tile = s_tile.max(dim=-1, keepdim=True).values # local max of the tile
p_tile = torch.exp(s_tile - m_tile) # local stable exponentials
l_tile = p_tile.sum(dim=-1, keepdim=True) # local sum of exponentials
Combine local statistics with running statistics. The new running max is the elementwise max of the old running max and the tile's local max. If the new max is larger, the old running sum and output were computed against a now-too-small scaling factor and need to be rescaled by exp(m_running - m_new), which is at most 1.
m_new = torch.maximum(m_running, m_tile)
alpha = torch.exp(m_running - m_new) # rescaling for old stats
beta = torch.exp(m_tile - m_new) # rescaling for tile stats
l_new = alpha * l_running + beta * l_tile
o_new = alpha * o_running + beta * (p_tile @ v_tile)
m_running, l_running, o_running = m_new, l_new, o_new
At the end of all column tiles for this query row tile, divide the running output by the running sum: o_final = o_running / l_running. alpha rescales the old running statistics down to the new common scale; beta rescales the current tile's local statistics down to the same scale. After the rescaling, the running sum and running output are exactly what you would have got by computing softmax over all scores seen so far. Put it together into a working tiled attention. The scale factor 1 / sqrt(D) and the optional causal mask go in the obvious places:
def tiled_attention(q, k, v, BR=64, BC=64, causal=False):
B, H, N, D = q.shape
scale = D ** -0.5
out = torch.zeros_like(q)
for i in range(0, N, BR):
q_tile = q[:, :, i:i+BR, :]
m = torch.full((B, H, q_tile.shape[2], 1), float("-inf"),
device=q.device, dtype=q.dtype)
l = torch.zeros((B, H, q_tile.shape[2], 1),
device=q.device, dtype=q.dtype)
o = torch.zeros_like(q_tile)
for j in range(0, N, BC):
k_tile = k[:, :, j:j+BC, :]
v_tile = v[:, :, j:j+BC, :]
s = (q_tile @ k_tile.transpose(-2, -1)) * scale
if causal:
row_idx = torch.arange(i, i + q_tile.shape[2], device=q.device)[:, None]
col_idx = torch.arange(j, j + k_tile.shape[2], device=q.device)[None, :]
s = s.masked_fill(col_idx > row_idx, float("-inf"))
m_tile = s.max(dim=-1, keepdim=True).values
p = torch.exp(s - m_tile)
l_tile = p.sum(dim=-1, keepdim=True)
m_new = torch.maximum(m, m_tile)
alpha = torch.exp(m - m_new)
beta = torch.exp(m_tile - m_new)
l = alpha * l + beta * l_tile
o = alpha * o + beta * (p @ v_tile)
m = m_new
out[:, :, i:i+BR, :] = o / l
return out
Read the inner loop with the streaming-softmax derivation right next to you the first time. Every line corresponds to one piece of the math. The score tile s is (B, H, BR, BC). The running statistics m, l, o only ever exist for one query row tile at a time. There is no (B, H, N, N) tensor anywhere in this function.
Step 4 — Verify against the naive version
Before benchmarking, prove the tiled version computes the same answer. Run both at small sizes where the naive version still fits comfortably and compare:
torch.manual_seed(0)
B, H, N, D = 2, 4, 256, 32
q = torch.randn(B, H, N, D, device="cuda")
k = torch.randn(B, H, N, D, device="cuda")
v = torch.randn(B, H, N, D, device="cuda")
ref = naive_attention(q, k, v)
out = tiled_attention(q, k, v, BR=32, BC=32)
diff = (ref - out).abs().max().item()
print(f"max absolute difference: {diff:.6e}")
For fp32 you should see a difference on the order of 1e-6 — floating-point round-off noise from doing the additions in a different order. For fp16 the difference will be larger, perhaps 1e-3, but still well below anything that affects model behaviour. If you see 1e-1 or larger, the streaming softmax is wrong. Most often the rescaling factor alpha is applied in the wrong place, or the running maximum is initialised to zero instead of -inf. The -inf initialisation matters: the first tile's local maximum must always become the running maximum, which only happens if the running maximum starts below every possible score.
A second sanity check is to vary the block sizes. Run at BR=BC=16, 32, 64, 128. The output should be identical to within floating-point noise. If it changes meaningfully with block size, the streaming softmax is wrong.
Step 5 — Measure memory versus the naive version
Rerun the memory probe from Step 1, this time on the tiled forward. Keep the same dtype, batch, heads, and head dimension. Sweep the same set of sequence lengths, but extend it further now that you can:
print(measure_tiled([1024, 2048, 4096, 8192, 16384, 32768]))
Peak memory now grows roughly linearly with sequence length. The slope is set by the input tensors themselves — q, k, v, and out are all (B, H, N, D). The score tile and the running statistics are fixed-size and contribute a negligible constant overhead. Plot the two curves on the same axes with a log-scaled y-axis. The naive line bends sharply upward and disappears off the top of the chart between 2K and 4K. The tiled line is almost straight. That visual is the single most important picture in this chapter.
Step 6 — Wire it into the GPT and train briefly
Open the model from Chapter 5. Find the attention module. Replace the call to your naive attention with a call to the tiled version. The interface is identical: same inputs, same output shape. Nothing else in the model changes. Train for 200 steps on the same dataset, with the same hyperparameters and random seed as a successful naive run. The two loss curves should be visually indistinguishable, to within the floating-point noise we characterised in Step 4. If your loss curve diverges meaningfully, the most likely cause is a bug in the causal mask. The mask must zero out positions where the key's absolute index exceeds the query's absolute index, where "absolute" means relative to the full sequence, not relative to the tile. The snippet in Step 3 computes row_idx = torch.arange(i, ...) and col_idx = torch.arange(j, ...) precisely to handle this. If you write arange starting from 0 instead of from i or j, the mask is wrong at every tile after the first one.
Step 7 — Benchmark wall-clock
Pure-PyTorch tiled attention is more memory-efficient than naive attention at every sequence length, but it is only faster at long sequences, and the speedup is modest compared to the production CUDA version. Time both versions at 1K, 2K, and 4K for forward only, with warmup and synchronisation handled properly. At 1K, the naive version often wins by a small margin — its single large matmul is well-optimised by cuBLAS, and the tiled version pays a Python loop overhead. At 2K, the two are comparable. At 4K, the naive version may not even run, in which case the tiled version wins by infinity. At 8K and beyond, only the tiled version runs.
The point of the benchmark is not to celebrate a speedup. It is to internalise when tiling helps with wall-clock and when it does not. Tiling buys you memory at every sequence length. Tiling buys you speed only when you are bandwidth-bound, which is at long sequences, large heads, or both. Below the crossover length, the naive version is fine for wall-clock; you might still want the tiled version for memory headroom, but you should not expect it to be faster.
BREAK IT
The break for this chapter targets the one trick the streaming softmax depends on for numerical stability: the running maximum. Remove it, watch the exponentials overflow, and confirm that the failure mode shows up at relatively short sequence lengths. There is also a second, subtler break worth running — keep the maximum but drop the rescaling — which fails for a completely different reason.
Remove the running-maximum subtraction
Find the streaming softmax block inside tiled_attention. Currently the per-tile maximum is subtracted before exponentiation. Remove the subtraction:
# m_tile = s.max(dim=-1, keepdim=True).values # removed
p = torch.exp(s) # no stability shift
l_tile = p.sum(dim=-1, keepdim=True)
Then strip out the rescaling that used to bring old and new statistics onto the same scale:
# m_new = torch.maximum(m, m_tile)
# alpha = torch.exp(m - m_new)
# beta = torch.exp(m_tile - m_new)
# l = alpha * l + beta * l_tile
# o = alpha * o + beta * (p @ v_tile)
# m = m_new
l = l + l_tile
o = o + (p @ v_tile)
The forward pass still runs. Shapes are still correct. No assertion fires. The dashboard, if you had one, would show nothing wrong for the first few seconds.
Write down what you expect to see, with numbers, before pressing run. In fp32, the maximum representable magnitude is about 3.4e38, and exp(89) is already 4e38. Real attention scores at the start of training have magnitudes in the single digits, so the first tile or two will look fine. As soon as a value vector has a moderately large component and a query has a moderately aligned component, you will see a score in the tens, and exp(s) will be in the millions. By the time the sum runs across several tiles, l overflows to inf, o overflows to inf, and the final divide produces nan.
In fp16, the maximum representable magnitude is only 65504, and exp(12) is already 162754. Overflow happens almost immediately. At sequence length 256 in fp16, you should see nan within the first tile or two for any query that is moderately aligned with any key. Print the max of s and the max of p per tile during the broken forward. Watch p cross 1e30 in fp32 or 1e4 in fp16. Watch l follow. Watch o / l become nan shortly after.
The intact version of the kernel, with the running-maximum subtraction in place, holds p in a tight range. Every value is between exp(-large) (close to zero) and exp(0) = 1, regardless of the absolute scale of the scores. The subtraction is the cheap trick that keeps the entire pipeline numerically alive. Without it, the algorithm's "no approximation" property goes from being a virtue to being a death sentence. Exact computation of an overflowing value is still an overflow.
Keep the maximum, drop the rescaling
Restore the per-tile maximum subtraction from Break 1 so each tile's exponentials are stable on their own. But comment out the part that rescales the previous running statistics when a new, larger tile maximum arrives:
m_tile = s.max(dim=-1, keepdim=True).values
p = torch.exp(s - m_tile)
l_tile = p.sum(dim=-1, keepdim=True)
# m_new = torch.maximum(m, m_tile)
# alpha = torch.exp(m - m_new)
# beta = torch.exp(m_tile - m_new)
# instead, just accumulate raw:
l = l + l_tile
o = o + (p @ v_tile)
m = m_tile
Nothing overflows this time. The tile-local exponentials are well-behaved because the subtraction is still there. But the running output and the running sum are now being added on incompatible scales. Tile 0 contributes exp(s - m_0). Tile 1 contributes exp(s - m_1). If m_0 and m_1 differ, those exponentials are not on the same denominator. Adding them is meaningless. The final divide produces a finite number — no nan, no inf — but the number is wrong.
Run the verification from Step 4. The max absolute difference against the naive version will be 0.1 to 0.5, not 1e-6. The result looks like attention but is not attention. Wire this version into the GPT and the loss curve will trend in the right direction but converge to a clearly worse solution. This is the harder failure to debug: the kernel runs, the model "trains," and only a careful diff against the reference reveals that something is off.
The deeper lesson
The streaming softmax does two things at once. It rearranges the order of operations so that the full vector never lives in memory at the same time, and it preserves the numerical-stability trick from the standard softmax inside every tile. Both pieces are load-bearing. Remove the rescaling but keep the subtraction and you get wrong answers, because the running statistics are on incompatible scales. Remove the subtraction but keep the rescaling and you get nan, because the raw exponentials overflow before they ever reach the rescaling step. Correctness depends on both moves working together — local stable exponentials inside each tile, and rescaling that brings older partial results onto the same scale as newer ones.
Restore both pieces. Run again. Watch nan disappear and the max difference fall back to 1e-6. That confidence — knowing exactly which two lines keep the kernel alive, and what each one prevents — is the lesson. It is the same shape of lesson as Chapter 4's attention BREAK IT: every piece in the formula exists because something breaks without it. With the standard kernel, the breakages live in the math. With the tiled kernel, they live in the running statistics. The discipline is identical.
Questions to answer
- At what sequence length did the naive attention first fail to run on your hardware? At what sequence length did the tiled attention use less than one tenth the memory of the naive version at that same length?
- When you removed the running-maximum subtraction in Break 1, at which tile did the overflow first appear? Did it appear earlier in fp16 than in fp32, and by roughly how much?
- Below the crossover sequence length, the naive version is faster than the tiled version in pure PyTorch. Why? Be specific about which operation in the naive version is faster, and what overhead in the tiled version costs you the speed parity.
- The streaming softmax maintains three running pieces of state per query row tile: the running maximum, the running sum, and the running output. What goes wrong if you only maintain two of the three — say, the running sum and the running output, but not the running maximum?
- If a future project asks you to implement attention with a custom score modifier (say, ALiBi linear biases or a non-standard masking pattern), where in the tiled implementation does the modifier go, and why does the location matter for the streaming softmax to stay correct?
Go further
- Dao-AILab/flash-attention. The production library. Start with the README to map the algorithm sections, then open the Triton implementation in
flash_attn/flash_attn_triton.pyrather than the CUDA one. The Triton version reads more like Python and the algorithmic structure is closer to what you built. Look for the running-maximum subtraction, the rescaling factor, and the partial output accumulator. The names will be slightly different, but the structure is the same. - Dao, Fu, Ermon, Rudra, Ré, FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness (2022). The original paper. Read the IO-awareness analysis carefully — it is the foundation for why this approach scales the way it does.
- Milakov and Gimelshein, Online normalizer calculation for softmax (2018). The streaming-softmax algorithm in isolation, predating Flash Attention by several years. A clean short read and a useful reference for anyone who needs to fuse softmax into other tiled computations.
- Paged attention (vLLM). The same tiled idea applied to the inference-time KV cache: the cache is split into fixed-size pages and attended to one page at a time. This is the bridge to Chapter 13.
What you now know
The N × N attention score matrix does not need to exist. It can be computed in tiles, accumulated through a running maximum and a running sum, and written back as a final weighted sum of value vectors, all without ever holding the full matrix in memory. You have built that algorithm in pure PyTorch, verified it against the naive version, watched the memory curve flatten from quadratic to linear, and seen what happens when the stability subtraction is removed and what happens when the rescaling is removed instead. You also know what to look for in the production CUDA kernel: the Triton and CUDA implementations are not different algorithms — they are the algorithm you just wrote, fused into a single kernel with block sizes tuned to the host GPU's memory hierarchy and the backward pass set up to recompute softmax weights on the fly. If a tiled attention kernel produces nan at moderate sequence length in fp16, check whether the running-maximum subtraction is present and correctly applied. That diagnostic alone will save more debugging hours than any other single piece of knowledge from this chapter.
This was Chapter 8 of 35.
The full book is 934 pages and 256,587 words. 35 hands-on projects from autograd to fused specialists. PDF and EPUB on Leanpub, lifetime free updates.