Chapter 13 · Full excerpt

Fast inference: the KV cache.

At inference time, the KV cache stores the keys and values of every previous token so the model never has to recompute them. Each new token only needs a fresh query, which it compares against the saved keys and uses to pull from the saved values. That turns per-token attention from O(n²) over a growing sequence into roughly O(n) against stored history, and it is the single most consequential optimization in modern LLM inference.

This chapter builds the cache from scratch on a small GPT. Naive autoregressive generation first, then per-layer cache tensors, then a decode-only attention path. You measure the speedup as prompt length grows, profile the memory the cache costs you, and then break it on purpose — invalidate one cached key, then shift a write index by one slot — to see exactly what the cache is holding and how generation degrades when those bytes are wrong.

This is Chapter 13 of Under The Hood — Build Every Layer of a Large Language Model from Scratch. The full 35-project book is on Leanpub. Code companion at github.com/mechramc/Under-the-hood.

The concept

Start with the wrong mental model, because a lot of people have it. Attention is not memory in the durable sense. It does not keep a little brain-state from one decoding step to the next. It recomputes relationships between tokens inside the current input. If you call the model again on a longer sequence, it will happily recalculate the same keys and values for the old tokens unless you explicitly save them. That saved state is the KV cache. The name is dry. The idea is not.

Think of a courtroom transcriptionist. When a new sentence is spoken, the judge does not ask for the entire hearing to be retyped from the beginning. The old transcript is already there. The new sentence gets appended. That is KV caching.

The model has already turned past tokens into two useful things: keys — what past tokens contain, and values — what past tokens contribute when selected. For a new token, the model mainly needs a fresh query: what the new token is looking for. It compares that query against the saved keys, pulls information from the saved values, and moves on. No need to rebuild old keys and values every step.

For old tokens, the keys and values do not change during inference. If token 17 already produced a key and a value at layer 6, then tokens 18, 19, and 400 can keep reusing them. Recomputing them is wasted work. That is why the cache stores K and V, not Q. The new token's query changes every step; the older tokens' keys and values are stable for the current prompt. The most consequential inference optimization in modern LLMs is "do not recompute the things that did not change."

The math, briefly

Suppose your prompt length is T. In naive generation, step 1 runs on length 1, step 2 on length 2, and so on until length T. Total work is 1 + 2 + ... + T, which grows like . With a KV cache, each new step mainly processes one new position against stored history. You stop recomputing all earlier K and V projections from scratch. In practice, that is the difference between "why is this chatbot stalling?" and "this feels instant enough to use."

The cache also makes the real bottleneck visible. Once you stop wasting compute, memory becomes the wall. The cache grows linearly with context length — every new token adds more stored keys and values for every layer, head, and batch item. The question changes from "can the model in principle look back 32K tokens?" to "can I afford to store 32K tokens of keys and values for every layer, head, and batch item?" Context length is not just an architectural idea. It is a memory budget.

Why it matters

Without a KV cache, long-context generation is slow enough to feel broken. Ask a tiny model to generate 256 tokens from a 512-token prompt. Naive decoding reruns almost the entire prefix at every step, so later steps cost much more than early ones and throughput falls off as the sequence gets longer. The user does not care why — the user sees lag.

If you run a model for chat, coding, or agents, you are not doing one forward pass over a fixed sequence. You are doing incremental decoding over and over. Your runtime lives or dies by how cheaply it can carry the past forward, and the KV cache turns repeated work into stored state.

Production inference cares about p99 latency, not throughput averaged over a benchmark. A cache that drops a tail of slow requests by 50% is worth more than one that lifts the mean by 10%.

The cache also exposes a real memory trade. A model might fit in VRAM and still fail under long contexts because the cache grows until memory runs out. For one layer, the cache stores keys of shape [B, H, T, d_head] and values of shape [B, H, T, d_head]. Double the context, roughly double the cache. Double the batch size, double it again. Double the number of layers, double it again — each layer has its own cache.

And if the cache becomes corrupted, generation collapses in a way that feels spooky until you understand why. One bad cached key in the middle of the sequence can poison every later attention lookup that touches it. You are damaging a persistent reference point that future tokens may keep consulting, and a bad cache does not heal itself just because later math is correct.

The build

Build the cache from the outside in. Start with naive generation. Add per-layer storage. Change attention to decode one token at a time. Then thread the cache through the model, benchmark the speedup, profile the memory, and add a sliding window to see the trade.

Step 1 — Start with naive autoregressive generation

Before you optimize anything, build the slow version on purpose. Otherwise the cache feels like extra machinery rather than a fix for a measurable problem. Your naive generator should do what a first implementation would do: take the current sequence, run the full model on it, keep only the last position's logits, sample, append, repeat.

@torch.no_grad()
def generate_naive(model, idx, max_new_tokens, temperature=1.0):
    # idx: [B, T]
    for _ in range(max_new_tokens):
        logits = model(idx)              # full forward pass on full sequence
        last = logits[:, -1, :] / temperature
        next_id = torch.argmax(last, dim=-1, keepdim=True)
        idx = torch.cat([idx, next_id], dim=1)
    return idx

This is correct. It is also wasteful. At step 100, it reruns the model over the first 99 generated positions even though their internal K and V tensors were already computed in previous steps. Time it. Do not trust your intuition. Vary prompt length across 32, 128, 512, and 1024 and you should see throughput get worse as the prompt grows. That worsening is the point.

Step 2 — Identify where K and V are created

Now look inside one attention layer. In a decoder-only transformer, the block computes q = q_proj(x), k = k_proj(x), v = v_proj(x) and then reshapes into heads. The important insight is not the formula — you already know attention from Chapter 4. The important insight is lifecycle. For a prefix token, once k and v are computed during inference, they can be saved. So the attention module needs a second mode: full-sequence mode for training or naive inference, and incremental mode for cached decoding.

In incremental mode, the layer takes an x_new for the new position, plus k_cache and v_cache, plus a cache_pos telling it where to append. Conceptually a decode step looks like: receive the new token embedding, compute q_new, k_new, v_new, append k_new and v_new to cache, attend using q_new against the full cached K and V. The output is only for the new position. That is what you want.

Step 3 — Add a cache object per layer

You need storage — not vague "memory," but concrete tensors. A simple cache per attention layer is two tensors, k_cache and v_cache. If your model has N transformer blocks, you need one pair per block.

class KVCache:
    def __init__(self, n_layers):
        self.k = [None] * n_layers
        self.v = [None] * n_layers

In practice, on GPU, you usually preallocate the full max context length and write into it by slice assignment. Repeatedly concatenating tensors becomes its own bottleneck. Preallocation gives the kernel a fixed buffer to work with.

k_cache[layer]  # [B, H, T_max, d_head]
v_cache[layer]  # [B, H, T_max, d_head]

# write the new K/V at the current position
k_cache[layer][:, :, pos:pos+1, :] = k_new
v_cache[layer][:, :, pos:pos+1, :] = v_new

Then pos advances by one. There is a large lesson hiding in this small implementation detail. The cache is part of the runtime state of the model. The weights are static. The cache is live, growing, per-session state. That is why two requests in a server cannot silently share the same cache unless you want nonsense. The weights live in one pool. The per-request live state has to live somewhere you can swap in and out per inference. Skip that separation and the system mixes one user's notes into another user's reply.

Step 4 — Change attention to decode one token at a time

In training, attention sees a whole sequence: q, k, v all have shape [B, H, T, d_head]. In cached decoding at one new step, q_new has shape [B, H, 1, d_head] and the cached k_all and v_all have shape [B, H, T_cache, d_head]. Then attention scores have shape [B, H, 1, T_cache], meaning the new token asks one question against all past tokens.

The attention law is unchanged:

attention(q_new, k_all, v_all) = softmax(q_new @ k_all^T / sqrt(d_head)) @ v_all

Earn every symbol. q_new is the new token's query. k_all is every cached key up to the current position. v_all is every cached value up to the current position. sqrt(d_head) is the scaling factor from Chapter 4 that stops dot products from blowing up. Softmax turns scores into weights that sum to one. Nothing about the law changed. Only the shapes changed. KV caching is not a new kind of attention — it is the same attention with old K and V reused instead of rebuilt.

def decode_step(x_new, k_cache, v_cache, pos):
    q_new = q_proj(x_new)
    k_new = k_proj(x_new)
    v_new = v_proj(x_new)

    q_new = split_heads(q_new)  # [B, H, 1, d]
    k_new = split_heads(k_new)  # [B, H, 1, d]
    v_new = split_heads(v_new)  # [B, H, 1, d]

    k_cache[:, :, pos:pos+1, :] = k_new
    v_cache[:, :, pos:pos+1, :] = v_new

    k_all = k_cache[:, :, :pos+1, :]
    v_all = v_cache[:, :, :pos+1, :]

    scores = q_new @ k_all.transpose(-2, -1) / math.sqrt(q_new.size(-1))
    att = scores.softmax(dim=-1)
    out = att @ v_all
    return merge_heads(out)

Notice there is no explicit causal mask. The slice :pos+1 already means "only attend to past and present." That is one of the nicer parts of decode mode: the mask becomes implicit in the cache length. The cache is the mask.

Step 5 — Thread the cache through the full model

Each transformer block needs access to its own cache tensors. There are two ways to handle the prompt itself. Option one: run full-sequence mode once on the prompt while writing all K and V into cache. Option two: feed the prompt token-by-token to populate the cache incrementally. Option one is faster for long prompts. Serving systems call that the prefill phase, and incremental decoding starts after the prompt is loaded.

The split matters because serving has two phases, and they stress hardware differently. Prefill has lots of parallel work across prompt positions and looks compute-bound. Decode has tiny per-step work with lots of cache reads and looks memory-bound. Even if you do not fully optimize prefill yet, keep the mental split clear. Production systems care about both.

logits, cache = model.prefill(prompt_ids, cache=None)
next_logits, cache = model.decode(next_token, cache)

The exact API matters less than the state transition. The cache must persist across decode steps. Pass it through, return it back, and never let one request's cache leak into another's.

Step 6 — Benchmark the speedup honestly

Rerun the tokens-per-second benchmark with cached decoding. Use the same model, prompt, number of generated tokens, device, dtype, and decoding rule. Only change the inference path. The shape you should see is that short prompts give modest gains and long prompts give dramatic gains — small at 32 tokens, larger at 128, much larger at 512, much larger again at 1024. That is the right result because the cache removes repeated prefix work, and repeated prefix work hurts more as the prefix grows. Most benchmark tables quote a single number and hide the prompt-length dependence. Always plot the curve, not the point.

Step 7 — Profile cache memory

Now measure the price you paid for speed. Each token leaves behind footprints in every layer. Two of them: a key and a value. And not one global copy — one per head, per layer, per batch item. A rough byte estimate for a standard multi-head cache is:

bytes ≈ 2 * N * B * T * H * d_head * bytes_per_element

The leading 2 is for K plus V. Because H * d_head = d_model, you can think of it as bytes ≈ 2 * N * B * T * d_model * bytes_per_element. For N = 12, B = 1, T = 4096, d_model = 768, FP16 (so 2 bytes per element), that works out to about 151 million bytes, or roughly 144 MB. Just the KV cache. Not the weights, not the activations during prefill, not framework overhead. That is why longer contexts eventually hit a wall.

Measure it empirically as well. On CUDA, log torch.cuda.memory_allocated() and torch.cuda.max_memory_allocated() at increasing context lengths. Keep increasing T until you hit out-of-memory. That is your local context wall for this model, dtype, and hardware. The failure arrives because runtime state grew, not because the checkpoint got bigger.

Step 8 — Add a sliding window

Once you see the wall, the next question is obvious: what if I stop remembering everything? Pick a window size W. If the cache grows past W, drop the oldest entries and keep only the most recent W.

if cache_len > window:
    k_cache = k_cache[:, :, -window:, :]
    v_cache = v_cache[:, :, -window:, :]

Memory is now bounded by W instead of growing forever. The trade is clear: you are choosing not to let the model attend to older context. Measure quality, not just speed. Run evaluation with windows like full, 2048, 1024, 512, and 256, and compare validation perplexity. Large windows usually cause almost no damage; medium windows degrade long-range dependencies; tiny windows produce repetition, forgotten setup, wrong references, and broken code completions. The deeper lesson: context length is an agreement between model behavior and runtime storage. The model may know how to use long context. Your runtime may choose not to keep it.

Step 9 — Keep generation semantics identical

A good cache implementation should not change outputs when decoding deterministically. That is a test, not a wish. Same prompt, same weights, same temperature, same seed, greedy decoding — naive generation and cached generation should produce the same token sequence. If they do not, your cache path is wrong. Common bugs: writing to the wrong position, reading too much or too little of the cache, mixing batch items, mixing layers, using a wrong positional index during RoPE or learned position embeddings, forgetting that decode handles one token rather than a full sequence, or accidentally including future positions in the cache slice. Before chasing speed, assert equality on short deterministic runs.

BREAK IT

Building the cache is useful. Breaking it proves what the cache is actually holding and why it has to be exactly right. There are two failures to force on purpose: corrupt one cached key vector, and shift the cache write index by one position. Each failure teaches a different lesson. The first teaches you that the cache is a persistent reference point, not a one-time computation. The second teaches you that cache integrity is alignment, not just tensor contents.

Break 1

Corrupt one cached key vector

Generate a baseline completion from a fixed prompt using greedy decoding. Save the output. Reset and rerun from the same prompt. After the cache contains at least 100 tokens, silently overwrite one cached key vector in the middle of the sequence with random values — one layer, one head, one position. Then continue decoding as if nothing happened.

# Corrupt layer 3, head 0, position 57
with torch.no_grad():
    cache.k[3][:, 0, 57, :] = torch.randn_like(cache.k[3][:, 0, 57, :])

What you should observe is not a crash — that would be easy. You will usually get fluent-looking text for a few tokens, then increasing incoherence. The model may lose track of referents, repeat odd phrases, switch topic abruptly, write syntactically plausible but semantically wrong text, or drift into locally fluent garbage. The outputs match exactly until the corruption matters. After that point, the corrupted run diverges permanently. The model does not recover, because every later step depends on the evolving corrupted history.

Why does one corrupted key matter so much? Because that key is not a one-time computation. It is a persistent lookup target inside the remembered past. Every later token that asks a query might compare itself against that corrupted key. Some queries will ignore it. Some will not. When they do not, the attention weights shift, the model reads the wrong value pattern from the past, and then the hidden state for the new token becomes slightly wrong. That wrong hidden state creates later wrong queries, keys, and values. The corruption propagates forward. Not magic — just a bad note in the notepad that keeps being consulted.

The first time I saw this in the wild was a buffer alignment bug, not a math bug. We were swapping IOSurface-backed KV tensors between two ANE-compiled graphs and one surface had a stride the consumer did not expect. Generation looked sane for about 40 tokens, then drifted, then went incoherent. The math was correct. The cache was read from a slightly wrong offset. What it proves: the cache is part of the model's live computational state. Correctness in inference systems is not only "does it run?" but "does the saved state stay valid?"

Break 2

Shift the cache write index by one

Keep K and V values correct, but write them to the wrong slot. Replace the correct write:

k_cache[layer][:, :, pos:pos+1, :]   = k_new
v_cache[layer][:, :, pos:pos+1, :]   = v_new

With an off-by-one variant:

k_cache[layer][:, :, pos+1:pos+2, :] = k_new
v_cache[layer][:, :, pos+1:pos+2, :] = v_new

Now the cache contents are real — they are just attached to the wrong positions. The model attends to context as if the timeline were slightly scrambled. The token that should have been at position 17 is now sitting at position 18. The token at position 18 lands at 19. Each query reads keys that belong one step away from where it thinks they belong.

This often produces output that looks eerily close to correct at first, then decays into contradictions and strange repetitions as the cache grows and the alignment error compounds. The model may address a name that was never introduced, or complete code referencing a variable defined for the next function instead of the current one.

That proves something different from Break 1. Cache integrity is not just tensor contents — it is alignment. You need the right values, in the right layer, for the right head, at the right position, for the right request. The strongest test of a KV cache implementation is not "does generation finish?" but "does cached greedy decoding produce the exact same token sequence as naive greedy decoding?" If the answer is no, find the misalignment before you ship.

The deeper lesson

The KV cache feels like a clean optimization until you break it, and then you see what it really is: live, mutable, per-session model state. The weights are static and shareable. The cache is none of those things. It is grown, indexed, written, and read on every decode step, and every later token's output is a function of every earlier byte in those tensors. That is why production inference engines spend so much engineering on paged attention, continuous batching, eviction policies, and quantized K and V. Those are not separate inventions — they are the same idea: the cache is the system, and the system has to be exactly right.

Questions to answer

  1. How did tokens per second change as prompt length increased for naive decoding versus cached decoding, and at what prompt length did cached decoding stop feeling like an optimization and start feeling required?
  2. When you measured cache memory, what grew linearly with sequence length and what stayed fixed — and how did your measured numbers compare to the analytic estimate 2 * N * B * T * d_model * bytes_per_element?
  3. In the sliding-window experiment, what failed first as you shrank the window: exact recall, topic consistency, code correctness, or local fluency?
  4. When you corrupted one cached key vector, how many tokens passed before the output clearly diverged from the clean run — and why did the model not recover?
  5. What tradeoff is the KV cache protecting: compute, memory, latency, or all three at once? Which of those becomes the new bottleneck after caching works?

Go further

What you now know

The KV cache stores old keys and values so each new token only needs one attention step instead of recomputing the entire history. That changes generation speed by an order of magnitude on long prompts, and it creates the main failure surface in modern inference: corrupted or misaligned cache state silently degrades every token that follows. You can build it from a blank file, write into it by slice assignment, decode one token at a time using an implicit causal mask via the slice itself, profile the memory cost layer by layer, and watch generation collapse when one cached key is wrong. Context length is ultimately a memory problem, and your runtime decides what past survives and at what cost. The next time you read about paged attention, continuous batching, or grouped-query attention, each one looks like a direct response to the cache pressure you just measured.

Chapter 4 — Foundation

Attention From Scratch

The Q/K/V split that makes a cache possible. Build attention from a blank file, then break it on purpose.

Chapter 8 — Systems

Flash Attention and Tiled Kernels

Once the cache turns decode into a memory-bandwidth problem, IO-aware kernels are the next move.

Continue the build

This was Chapter 13 of 35.

The full book is 934 pages and 256,587 words. 35 hands-on projects from autograd to fused specialists. PDF and EPUB on Leanpub, lifetime free updates.