Open this lesson in your favourite AI. It'll walk you through the why, explain the demo, and quiz you on the try-it list.
To optimize serving you must know where the work is. A transformer forward pass spends its FLOPs on the attention computation and the big feed-forward (MLP) matmuls, scaled by model size and sequence length. Knowing the rough cost breakdown — that compute scales with parameters and (for attention) with sequence length squared — tells you why bigger models and longer contexts cost more, and where techniques like quantization (shrink the weights) and attention optimizations (tame the n²) pay off. This is the quantitative intuition behind every later optimization.
The demo estimates FLOPs per token (~2 × parameters for the matmuls) and shows attention's sequence-length dependence, giving you a back-of-envelope for why a 70B model is ~10× a 7B and why long contexts strain attention.
Use these three in order. Each builds on the one before.
In an LLM forward pass, where do the compute (FLOPs) and time actually go?
Explain the rough FLOP breakdown (matmuls ~2×params/token, attention growing with sequence length) and why model size and context length drive cost.
Using the FLOP breakdown, explain where quantization and attention optimizations pay off, and why decode is memory-bound while prefill is compute-bound despite both running the same weights.
def flops_per_token(params_billion, seq_len, hidden=4096, n_layers=32):
# dominant term: ~2 * params per token for the matmuls (fwd pass)
matmul = 2 * params_billion * 1e9
# attention adds a term that grows with sequence length (the n^2 intuition lives here)
attn = 2 * n_layers * seq_len * hidden
return {"matmul_flops": matmul, "attn_flops": attn,
"attn_share": round(attn / (matmul + attn), 4)}
print("7B @ 1k ctx:", flops_per_token(7, 1000))
print("7B @ 32k ctx:", flops_per_token(7, 32000)) # attention's share grows with contextpython3 main.py