Open this lesson in your favourite AI. It'll walk you through the why, explain the demo, and quiz you on the try-it list.
Most neural network bugs are shape bugs. You'll spend more time reading shapes than reading code. A scalar is rank 0, a vector is rank 1, a matrix is rank 2, and anything with more axes we just call a tensor — the word is not reserved for physics. Once you can look at [batch=32, seq=256, hidden=768] and know what each axis means, half the challenge of reading ML code is gone.
Below, each language creates a 3D tensor with shape (2, 3, 4) — imagine 2 "sheets" of a 3×4 matrix. We print the shape, the total element count, and one indexed element. You'll use this mental model thousands of times.
The canonical LLM activation shape is (batch, sequence_length, hidden_dim). Memorize that order — PyTorch, JAX, and MLX all use it by default (though some libraries transpose; always check).
(batch=4, seq=8, hidden=16) tensor in your chosen language; compute expected memory in bytes (fp32) and check it matches.Use these three in order. Each builds on the one before.
Define scalar, vector, matrix, and tensor. Give the rank and shape of each, and one example from LLMs (not physics).
Explain how a multidimensional tensor is laid out in contiguous memory with strides. Walk through how `x[i, j, k]` translates into a flat offset for row-major and column-major layouts.
For a transformer activation (batch=32, seq=2048, hidden=4096), why does fp16 matter more than batch size for fitting into GPU memory during training? Do the arithmetic for weights vs activations vs gradients vs optimizer state.
# main.py — tensor shapes in NumPy
import numpy as np
x = np.arange(24).reshape(2, 3, 4).astype(np.float32)
print("shape:", x.shape) # (2, 3, 4)
print("elements:", x.size) # 24
print("dtype:", x.dtype) # float32
print("x[0, 1, 2] =", x[0, 1, 2]) # 6
# The canonical LLM activation: (batch, seq, hidden)
activations = np.zeros((32, 256, 768), dtype=np.float32)
print("activation bytes:", activations.nbytes) # 24_117_248 (~23 MiB)python3 main.py