Open this lesson in your favourite AI. It'll walk you through the why, explain the demo, and quiz you on the try-it list.
The model's raw output isn't a token — it's a vector of logits, one score per vocabulary entry. Turning those into the next token is sampling, and the sampling parameters (temperature, top-p, top-k) are serving-relevant: they affect output quality, determinism, and even throughput. Understanding sampling demystifies why the same prompt gives different answers, why temperature 0 isn't fully deterministic in batched serving, and how to expose the right knobs to users. Sampling is the bridge between the model's probabilities and the text you serve.
The demo implements the sampling pipeline: logits → softmax (with temperature) → top-k/top-p filtering → sample. Each parameter visibly reshapes the distribution before a token is drawn.
Use these three in order. Each builds on the one before.
What are logits, and how does sampling turn them into the next token? What do temperature, top-p, and top-k do?
Walk me through the sampling pipeline (softmax with temperature, then top-k/top-p filtering) and how each parameter reshapes the distribution.
Why isn't temperature 0 perfectly deterministic in batched GPU serving, and what are the implications for reproducibility and caching of generated outputs?
import numpy as np
def sample(logits, temperature=1.0, top_p=1.0, top_k=0):
logits = np.array(logits, dtype=float)
if temperature > 0:
logits = logits / temperature # lower temp -> sharper distribution
probs = np.exp(logits - logits.max()); probs /= probs.sum() # softmax
order = probs.argsort()[::-1]
if top_k: # keep only the k most likely
order = order[:top_k]
if top_p < 1.0: # nucleus: smallest set summing to p
cum = np.cumsum(probs[order]); order = order[cum <= top_p] if (cum <= top_p).any() else order[:1]
keep = np.zeros_like(probs); keep[order] = probs[order]; keep /= keep.sum()
return int(np.random.choice(len(keep), p=keep))
logits = [2.0, 1.0, 0.5, 0.1, -1.0]
print("temp 0.2:", sample(logits, 0.2), " temp 1.5:", sample(logits, 1.5, top_p=0.9))python3 main.py