KV cache and fast inference.
Build it. Break it. Measure it.
Once the model trains, the next problem is making it generate fast enough to use. This cluster of Under The Hood walks through the small set of ideas that turn slow autoregressive decoding into the kind of inference real systems run: a key-value cache, a draft-and-verify scheme that skips most forward passes without changing the output distribution, and a paged allocator that lets one GPU serve many users at once.
What the KV cache actually does.
Lead with the answer. At inference time, the bottleneck is not the matrix multiplications themselves. It is the recomputation. A decoder-only transformer asked to produce token 501 from a 500-token prompt will, if you wrote the naive version, rebuild the keys and values for tokens 1 through 500 every decode step. That work was already done. The cache exists to stop doing it twice.
The shape of the savings is precise. Without a cache, generating T tokens from a prompt costs roughly 1 + 2 + 3 + ... + T units of attention work — quadratic in the sequence length. With a cache, each new step processes one new position against stored history, and the total drops to linear. That is the difference between "this chatbot stalls after a paragraph" and "this feels instant enough to use."
The cache itself is not subtle. For each transformer block, you keep two tensors: a growing stack of keys and a growing stack of values, one entry per past token, per head. When a new token arrives, the model computes its hidden state, projects out a fresh query, key, and value, appends the new key and value to that layer's stack, and runs attention with the new query against everything stored so far. Old queries are not cached because old queries are not used again — each step only needs the query for the one new position. Old keys and old values are cached because every future step still consults them.
Most explanations frame this as a compute win. It mostly is not. Once you stop recomputing the prefix, the real bottleneck moves to memory bandwidth — how fast you can read the cached K and V out of HBM at every decode step. Cache size grows linearly with context length, layer count, batch size, and head dimension. The phrase context window stops being an architectural choice and becomes a hardware statement: a model can only carry as much past as memory allows.
The projects this cluster covers.
Four projects, in the order the book teaches them. Project 13 builds the cache from scratch. Project 14 reduces the number of main-model forward passes you need. Project 16 extends usable context past the length the model was trained on. Project 17 wires the whole inference path into a serving stack that handles many users at once.
The mini example.
The whole trick fits in a short attention function. Here is what one decode step looks like once the cache is wired through the model — one new token's worth of hidden state goes in, the cache for this layer is appended to in place, and attention runs against the full stored history:
def decode_step(x_new, k_cache, v_cache, pos):
# x_new: [B, 1, d_model] — only the new token
q_new = q_proj(x_new)
k_new = k_proj(x_new)
v_new = v_proj(x_new)
q_new = split_heads(q_new) # [B, H, 1, d_head]
k_new = split_heads(k_new) # [B, H, 1, d_head]
v_new = split_heads(v_new) # [B, H, 1, d_head]
# Append into the preallocated cache at position `pos`.
k_cache[:, :, pos:pos+1, :] = k_new
v_cache[:, :, pos:pos+1, :] = v_new
k_all = k_cache[:, :, :pos+1, :]
v_all = v_cache[:, :, :pos+1, :]
scores = q_new @ k_all.transpose(-2, -1) / math.sqrt(q_new.size(-1))
att = scores.softmax(dim=-1)
return merge_heads(att @ v_all)
Notice there is no causal mask. The slice :pos+1 already enforces it — the cache only contains past and present, never future. That is one of the cleaner moments in the build. The mask becomes implicit in the cache length, which is what you want.
Why BREAK IT matters here.
The cache is one of the cleanest examples in the book of an optimization whose failure modes look nothing like its math. The math is bookkeeping. The failure modes are persistent state corruption, and they propagate in ways that feel spooky until you understand the shape.
"After about 100 generated tokens, silently overwrite one cached key vector in the middle of the sequence with random values. Then continue decoding as if nothing happened. You will usually get fluent-looking text for a few tokens, then increasing incoherence. The model loses track of referents, repeats odd phrases, switches topic, drifts into locally fluent garbage. The corruption is one bad note in the notepad — but every later token keeps consulting it."
The second break is subtler and worth doing right after. Keep the K and V tensor contents correct, but write them to the wrong slot — shift the write index by one. The cache now holds real values attached to wrong positions. The model attends to context as if the timeline were slightly scrambled, and output looks eerily close to correct at first before decaying into contradictions and strange repetitions. That proves something different from the first break. Cache integrity is not just tensor contents. It is alignment: the right values, in the right layer, for the right head, at the right position, for the right request. Anything less is broken inference, and "broken" can mean "still fluent for forty tokens."
This is also why one-user code fails at hundred-user load. Two requests cannot silently share a cache without producing nonsense. The weights are static; the cache is live, growing, per-request state. Half the engineering of any serving stack is keeping that distinction clean.
Related clusters and excerpts.
FAQ
How big does the KV cache get for long contexts?
Cache size grows linearly with sequence length, number of layers, number of heads, head dimension, batch size, and bytes per element. A rough estimate is 2 * N * B * T * d_model * bytes_per_element. For a twelve-layer model at d_model 768, FP16, batch size one, at 4,096 tokens you are storing about 144 MB just for the cache. Double the context, double the cache. Double the batch, double it again. That is why long-context generation is a memory problem before it is a compute problem.
Why doesn't the cache help training?
Training computes a forward pass over a fixed-length sequence and then computes a backward pass through it. There is no growing prefix to reuse. Every position is computed once and the gradients flow back through all of them. The KV cache only helps when you are doing repeated forward passes that share a prefix, which is the shape of autoregressive decoding, not the shape of a training step.
What is speculative decoding doing differently?
The KV cache cuts per-token cost by reusing stored state. Speculative decoding cuts the number of main-model forward passes you need. A small draft model proposes K tokens autoregressively, the main model verifies all K in a single parallel forward pass, and an exact acceptance-and-resample rule keeps the output distribution identical to the main model. The two techniques stack: speculative decoding still needs a KV cache for both the draft and the main model.
When does PagedAttention matter?
PagedAttention matters when you are serving many concurrent requests on one GPU. A single contiguous KV cache per request, sized to the maximum possible sequence length, wastes memory on the requests that finish early. Paging the cache into fixed-size blocks (typically 16 tokens) and managing them through a per-sequence page table cuts the worst case to twelve tokens of waste per request. For single-user inference, contiguous allocation is fine. For a serving stack handling dozens or hundreds of users on one card, paging is what makes the math work.
Does the KV cache change the output of the model?
It should not. With the same weights, the same prompt, the same temperature, and greedy decoding, naive generation and cached generation must produce the same token sequence. If they do not, the cache implementation has a bug — usually a wrong write position, a wrong slice, a mismatched positional index in RoPE, or accidentally including future positions in the attention. Equality on short deterministic runs is the first correctness test to add before chasing speed.
Open Chapter 13 tonight.
Four projects gets you from a slow naive decoder to a paged serving stack a single GPU can host dozens of users on. The book is on Leanpub with lifetime updates.