Chapter 18 · Full excerpt
Mixture of experts.
A mixture-of-experts layer replaces one general-purpose feedforward network with many specialist feedforward networks, then puts a learned router in front of them that picks the top k experts for each token. The model holds the full capacity of all the experts on disk, but for any single token only k of N experts actually run. That is the trick: more total parameters, sparse active compute per token, and a router that has to learn who to call.
This chapter builds an MoE layer from scratch — one expert at a time, then the router, then top-2 selection, then the weighted sum — trains it on a small corpus inside nanochat, and then deliberately breaks the router. You will watch routing collapse, see two of four experts go dead, and learn why a perfectly balanced router can still be a broken one.
This is Chapter 18 of Under The Hood — Build Every Layer of a Large Language Model from Scratch. The full 35-project book is on Leanpub. Code companion at github.com/mechramc/Under-the-hood.
The concept
Start with a normal transformer block. After attention, each token goes through a feedforward network, usually called the FFN. You can think of the FFN as a workshop that every token visits alone. Attention decides what context matters, the FFN turns that context into something useful. It is where a lot of the model's factual and pattern-matching behavior lives.
In a dense model, every token enters the same workshop. The token "def", the token "mitochondria", and the token "therefore" all use the same set of weights. The workshop has to be decent at everything because it serves everyone. Every token pays for every bit of knowledge the FFN has compressed into its weights, even when most of that knowledge is irrelevant to the token in front of it.
Mixture of Experts changes the workshop into a building with multiple rooms, each with its own specialist team. One room might get better at code-like patterns. Another might drift toward names and entities. Another becomes good at punctuation-heavy structure. Nobody assigns these roles by hand. The model discovers them during training, sometimes in ways that look nothing like the clean taxonomy you imagined. The important design change is simple: each token does not go through all rooms. A router looks at the token's hidden state, scores the experts, picks the top few, and sends the token only there.
I originally assumed expert specialization would look like recognizable domains. Code expert here, prose expert there. After staring at routing histograms for a few weeks during a long MoE-LoRA ablation, I gave that up. The categories the router invents are weirder, narrower, and sometimes embarrassingly token-level. One expert in one of our runs became, basically, the "open-paren" expert. That is not a joke.
If dense FFN is one huge call center where every question hits the same staff, MoE is a dispatch system with specialists. The dispatcher sends the question to the top one or two departments. The company employs far more expertise overall without putting the full staff on every call.
Now the confusion that usually shows up. If only some experts handle each token, doesn't that make the model fragile? And doesn't the router just learn to send everything to one expert? Yes, on both counts. If one expert is great, why not let it do all the work? Because then the other experts stop learning, and your fancy specialist building turns back into one crowded room with extra rent. That failure has a name: router collapse. The router starts favoring one or two experts for most tokens. Those experts get more gradient updates, so they improve faster. Because they improve faster, the router prefers them even more. Positive feedback kicks in. The rich get richer. The neglected experts become dead weight.
This is why MoE is not "add more FFNs." The router is now part of the learning problem. Most beginner explanations of MoE are honestly terrible because they describe the experts and treat the router as a footnote. The router is the project. Get the router wrong and the experts may as well not exist.
What the router actually computes
The hidden state of a token is a list of numbers representing what the model currently "thinks" about that token in context. Call it x. The router is just a linear layer that takes x and produces one score per expert. If you have 4 experts, the router turns one token representation into 4 scores. Those scores are not yet probabilities. They are raw preferences. Apply softmax, and the scores become numbers between 0 and 1 that add up to 1 — a routing distribution over experts.
If expert 2 gets 0.55, expert 0 gets 0.30, expert 1 gets 0.10, and expert 3 gets 0.05, then for top-2 routing you keep experts 2 and 0, run the token through those two experts, and combine their outputs with weights based on the routing scores. This token mostly wants expert 2, also wants some help from expert 0, and the other experts do not get involved. The router behaves like a dispatcher who only learned the job by watching which calls did not result in a complaint. There is no manager telling it the right call. There is only the loss yelling at it after the fact.
Why it matters
Without MoE, growing model capacity usually means growing per-token compute — every token pays the full bill. Double the width of dense feedforward layers and you double a big chunk of the work done for every token at training and inference time. MoE breaks that link. You can increase total parameters by adding experts while keeping only top-1 or top-2 active, which means more storage of learned behavior without making every token expensive in the same way a dense expansion would.
It also creates a new failure surface that dense models do not have. A dense FFN cannot suffer routing collapse because there is nowhere to route. MoE can fail in ways dense models cannot: one expert absorbs nearly all traffic, some experts become undertrained, top-1 routing becomes brittle, expert assignments drift wildly during training, and load-balancing loss fights task loss too hard and hurts quality. You do not really understand MoE until you watch it fail.
And it makes "parameter count" slippery. You now need at least three numbers to describe an MoE model: total parameters in the checkpoint, active parameters per token, and how many experts are actually doing useful work. A sparse 8×7B model does not behave like "56B dense." It stores more knowledge than a 7B dense model, but each token still only activates a small slice, and the active path is what matters.
Strong opinion: most published MoE explanations spend their pages on the experts. The router is one paragraph. That ratio is exactly inverted from where the engineering actually lives.
The build
You are starting from nanochat, where a transformer block already contains attention, normalization, a feedforward network, and residual connections. Your job is to replace one dense feedforward layer with an MoE version. Do not convert the whole model at once yet. Change one block first. That keeps the experiment legible. If the model gets worse, you can tell whether your MoE block is helping or hurting. If you change every block at once, you lose the ability to reason about the failure. I learned this the painful way — my first attempt swapped every block at once because I was excited, the loss curve looked plausible, and I could not tell which block was doing the work or whether any of them were broken. I threw the run out and started over with one block.
Step 1 — Isolate the FFN boundary
Before you write any MoE code, find the place in the transformer block where the dense feedforward network runs. In most GPT-style code, the pattern looks like this:
x = x + self.attn(self.norm1(x))
x = x + self.ffn(self.norm2(x))
That second line is the seam. The FFN gets a tensor shaped like B x T x d (batch, sequence length, hidden size) and returns the same shape. That interface is the contract your MoE layer must preserve. If your dense FFN maps B x T x d -> B x T x d, your MoE layer must do the same. The rest of the block should not need to care whether the work came from one dense FFN or several experts. The cleaner that interface, the easier it is to run dense versus MoE comparisons without rewriting the rest of the model.
Step 2 — Build one expert
An expert is just an FFN. Do not make it exotic on the first pass. Use the same structure as your dense feedforward network:
class Expert(nn.Module):
def __init__(self, d_model: int, d_ff: int):
super().__init__()
self.net = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Linear(d_ff, d_model),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
If your dense FFN already uses SwiGLU or another gated design, you can mirror that later. For the first working version, plain FFN shape parity is enough. Each expert has its own weights. They are not shared copies. That is the whole point.
Step 3 — Add multiple experts
Now build a list of experts. For a first experiment, use 4 experts. That is enough to show specialization and collapse behavior, but small enough that you can inspect it by hand.
self.experts = nn.ModuleList([
Expert(d_model, d_ff) for _ in range(num_experts)
])
At this point you have four separate workshops, but nothing decides who goes where yet.
Step 4 — Add the router
The router is a linear projection from hidden size d to number of experts N. No activation, no bias on the first pass:
self.router = nn.Linear(d_model, num_experts, bias=False)
Its input shape is B x T x d and its output shape is B x T x N. For each token, you now have N expert scores telling you which experts the router prefers. Nothing about this layer looks special — it is exactly the same primitive as any other linear projection in the block. The thing that makes it different is what its output gets used for downstream.
Step 5 — Convert scores into routing probabilities
Take the router logits and apply softmax over experts:
router_logits = self.router(x) # (B, T, N)
router_probs = torch.softmax(router_logits, dim=-1)
Each token now has a probability distribution over experts. For a given token you might get [0.05, 0.62, 0.28, 0.05]. That means expert 1 is the main pick and expert 2 is the backup. Order matters here too: softmax over the expert dimension, not over the batch or the sequence. A bug I have seen more than once is to softmax over the wrong axis, which produces shapes that look right but distribute probability across the wrong dimension. Print the shape and a row of probabilities the first time you wire this up. Verify each row sums to 1.
Step 6 — Select the top-k experts
Start with top-2 routing:
topk_probs, topk_idx = torch.topk(router_probs, k=2, dim=-1)
For every token, this returns the top 2 routing probabilities and the indices of the winning experts. One practical detail matters here: those top-2 probabilities no longer sum to 1 because you dropped the rest, so renormalize them.
topk_probs = topk_probs / topk_probs.sum(dim=-1, keepdim=True)
Now the two chosen experts split 100% of the routing mass between them. This is a small detail with large consequences. If you skip renormalization, tokens that had a confident top-2 may get larger expert output magnitudes than tokens whose probability mass was spread more evenly, which changes the scale of the MoE output in a way you did not intend. You will see this in your downstream activations as a slow drift in norms that has no obvious origin.
Step 7 — Run only the chosen experts
There are two ways to do this. The clean teaching version runs all experts on all tokens, masks out the unused ones, and sums the selected outputs. The efficient production version gathers tokens per expert, runs only those tokens through that expert, and scatters results back to original positions. For this chapter, start with the clean version. It wastes compute, but it makes the mechanics obvious. I spent hours trying to debug a scatter-gather implementation in a fusion run before realizing one tensor was being appended at the wrong position index and only some tokens were getting routed correctly. The bug was invisible until I went back and reimplemented the layer in the dumb-but-obvious form, then watched the histograms agree.
A simple teaching implementation:
expert_outputs = torch.stack(
[expert(x) for expert in self.experts], dim=2
)
# shape: (B, T, N, d)
Now you have every expert's output for every token. Next, build a sparse mask for the selected experts:
mask = torch.zeros_like(router_probs)
mask.scatter_(-1, topk_idx, topk_probs)
# shape: (B, T, N)
This mask contains nonzero weights only at the selected experts. Then combine:
y = (expert_outputs * mask.unsqueeze(-1)).sum(dim=2)
# shape: (B, T, d)
That is your MoE output — only the top-2 experts contribute to each token, their outputs are mixed in proportion to the (renormalized) router probabilities. The shape matches the dense FFN exactly. You can swap it in without touching the surrounding block.
Step 8 — Wire it into one block
Now wire the MoE layer into a single transformer block. Your block still looks like:
x = x + self.attn(self.norm1(x))
x = x + self.moe(self.norm2(x))
From outside, nothing changed except the implementation behind the second sublayer. A clean swap means dense versus MoE is a one-line config change, not a structural rewrite.
Step 9 — Track expert utilization
Do not train yet without instrumentation. If you cannot see where tokens are going, you are flying blind. Expert utilization means how often each expert was selected. For top-2 routing, every token contributes two selections. Count the selected expert indices and build a histogram:
with torch.no_grad():
flat_idx = topk_idx.reshape(-1)
counts = torch.bincount(flat_idx, minlength=num_experts)
Store these counts across batches and compute fractions. If 4 experts are used evenly under top-2 routing, you expect each expert to receive about 25% of all selections over time. Not exactly, but roughly. If you see Expert 0: 72%, Expert 1: 18%, Expert 2: 7%, Expert 3: 3%, you have a collapse problem. That histogram is one of the main metrics for this project, and it deserves to be treated as important as loss.
Step 10 — Add the load-balancing auxiliary loss
The main training objective is still next-token prediction, but you add an extra term that punishes uneven expert usage. Think of it as a tax on sending too much traffic to the same room. You want two things to stay near uniform: the router's average probability mass across experts (importance) and the actual top-k assignment counts (load):
importance = router_probs.mean(dim=(0, 1)) # (N,)
load = mask.sum(dim=(0, 1)) / mask.sum() # (N,)
aux_loss = num_experts * torch.sum(importance * load)
loss = lm_loss + alpha * aux_loss
For 4 experts, both vectors should look near [0.25, 0.25, 0.25, 0.25] when routing is healthy. alpha is small — 1e-2 or 1e-3. Readers often ask: "Why are we hurting the model's freedom to choose the best expert?" Because without this pressure, the router may choose the same few experts almost all the time, which prevents the unused experts from ever becoming worth choosing later. You are not forcing perfect equality. You are preventing a positive-feedback loop from running away.
BREAK IT
Building MoE is useful. Breaking it proves what the pieces are actually doing. There are two failures worth forcing on purpose: freeze two of four experts and watch the remaining pair absorb all the traffic, then look at a different breakage entirely — the case where the router balances perfectly but the experts have nothing inside them. The first teaches you how robust specialization actually is. The second teaches you that a clean histogram is not the same thing as a healthy layer.
Freeze two experts and force all routing onto the rest
Pick experts 2 and 3 from your trained 4-expert checkpoint. You will stop their weights from updating, stop their outputs from contributing, and stop the router from sending tokens to them. Do all three. If you freeze them but still let the router send traffic there, you are creating a dead route where tokens disappear into silent rooms and your measurement gets muddied. Mask them at routing time:
dead = torch.tensor([0, 0, 1, 1], device=x.device, dtype=torch.bool)
router_logits = self.router(x)
router_logits[..., dead] = -1e9
Then freeze their parameters so the optimizer cannot keep updating them anyway:
for p in self.experts[2].parameters():
p.requires_grad = False
for p in self.experts[3].parameters():
p.requires_grad = False
Continue training briefly. What you should observe is that quality degrades but does not collapse. The remaining experts still form valid computation paths. The router can still send every token somewhere useful, and the model loses diversity of specialization but not the whole FFN function. Expect validation loss to get worse, training to recover part of that loss if you continue, utilization of experts 0 and 1 to rise toward 50/50, and tokens that previously leaned on experts 2 and 3 to suffer first. If you inspect samples, you may notice more repetitive phrasing, more confusion on rarer pattern types, and weaker handling of the specific token types that likely had specialized support before.
What this proves is that MoE stores useful knowledge in multiple semi-independent specialist paths, but those paths are not brittle glass. Losing some experts hurts, yet the remaining experts can absorb the traffic and keep the system working. Robustness with degradation, not catastrophic failure. People often imagine expert specialization as if each expert contains one irreplaceable organ. That is usually too dramatic. Experts overlap, and the router plus the residual path give the model room to adapt.
Balanced but dead — the failure that fools the histogram
This is the breakage I had braced for as router collapse but actually got first in my own MoE-LoRA ablations. The setup is subtle: you train with a strong auxiliary load-balancing loss and a weak language modeling signal — short runs, small batches, or a balancing coefficient cranked too high. The utilization histogram looks beautiful. Each of the four experts gets close to 25% of all selections. Every metric you look at says the routing is healthy.
And the loss is terrible. Sample generations are mush. The experts are not specialized; they are interchangeable averages. The router fanned traffic out almost perfectly evenly across all experts, which sounds healthy, except none of them were learning anything useful because no token had committed to any of them. Balanced is not the same as alive. That was a humbling Tuesday.
To see it on purpose, set alpha to something aggressive (0.5 or 1.0 instead of 1e-2), train for a short window, and watch:
# strong balancing pressure
loss = lm_loss + 1.0 * aux_loss
The diagnosis is not in the histogram. It is in the comparison between the histogram and the loss. A healthy MoE layer has a histogram that is roughly balanced — not perfectly flat — and a loss that is improving. A balanced-but-dead MoE layer has a histogram that is suspiciously flat and a loss that has plateaued early. The fix is to weaken the balancing pressure, let the router commit to preferences, and accept some skew as evidence the experts are learning to specialize. This is also why measuring expert utilization alone is not enough. You need utilization plus validation loss plus, ideally, a look at sample generations from the model.
The deeper lesson behind both breaks is that MoE is a routing problem under load, not a collection of FFNs. The router can fail by concentrating too much (Break 1's pre-conditions) or by spreading too evenly while learning nothing (Break 2). Both look fine on some metrics and broken on others. You only catch them by reading multiple instruments together. A working loss curve does not mean your experts are healthy. A working utilization histogram does not mean your experts are healthy either. Both, together, with a real downstream evaluation, are the closest you get to ground truth.
Questions to answer
- Before you added the auxiliary loss, which expert in your 4-expert run got overloaded first? Did that overload grow steadily across thousands of steps, or did it appear suddenly within a few hundred?
- After you added load balancing with
alpha = 1e-2, did validation loss improve, stay flat, or get slightly worse? What does that say about the tension between next-token quality and healthy traffic distribution? - In your runs, how large was the quality gap between top-1 and top-2 routing on the same data and same expert count? Was it smaller than you expected, and what does that tell you about the value of the second expert?
- When you froze experts 2 and 3, what failed first: validation loss, sample quality, or routing balance among the remaining experts? Why is the order informative?
- If your utilization histogram ends a run looking like
[0.25, 0.25, 0.25, 0.25], what additional measurement would you make before concluding that the MoE layer is healthy?
Go further
- Shazeer et al., Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer. The paper that put sparse routing into modern deep learning. Read it for the original load-balancing loss and noise-injection tricks.
- Fedus et al., Switch Transformers. Top-1 routing at scale, with the cleanest writeup of capacity factors and expert dropout. Pair it with the Mixtral paper to see how top-2 plays out a few years later.
- Project 13 — Fast Inference: The KV Cache. Sister chapter on the inference side: MoE changes parameter-vs-compute economics; KV caching changes generation latency. Both are about which work you can skip without hurting quality.
- Project 32 — Fusing Independently Trained Specialists. What if your experts came from independently trained models instead of joint MoE training? Different routing problem, same underlying question about how specialists combine.
What you now know
You can write a sparse MoE layer from a blank file. You can name what each piece prevents: the router exists because a dense FFN forces every token to pay for every parameter; top-k selection exists to make active compute sublinear in total parameters; renormalization of the top-k probabilities exists to keep the output scale predictable; the load-balancing auxiliary loss exists to stop one expert from eating the whole training signal. You have watched a four-expert layer collapse onto two when you forced it to, and you have seen the more insidious failure where a perfectly balanced histogram hides experts that never learned anything in the first place. The next time you read a sparse-LLM paper that quotes "47B parameters, 12B active," you will know exactly which router, which top-k rule, and which balancing loss are doing the work behind that number — and which silent failures they are paid to prevent.
This was Chapter 18 of 35.
The full book is 934 pages and 256,587 words. 35 hands-on projects from autograd to fused specialists. PDF and EPUB on Leanpub, lifetime free updates.