Optimizers — SGD, momentum, Adam, AdamW
Every optimizer answers one question: given the gradient \(g_t = \nabla_\theta \mathcal{L}\) at the current parameters, how far and in what direction do we step? The answers form a short, important lineage. Stochastic gradient descent is the bare minimum — step downhill by a fixed multiple of the gradient on a mini-batch:
The first fix is momentum: accumulate an exponentially-decaying running average of past gradients (a velocity \(v_t\)) and step along that instead. Consistent directions reinforce; oscillating ones cancel.
The second fix is adaptivity: give each parameter its own effective learning rate, scaled down where gradients have been large. Adam combines this with momentum. It maintains a first moment \(m_t\) (the momentum-like mean of gradients) and a second moment \(v_t\) (a mean of squared gradients), bias-corrects both, and divides the step by the root of the second moment:
AdamW is the variant you should actually reach for. The issue it fixes is subtle: classical weight decay was implemented as an L2 penalty added to the loss, so its gradient \(\lambda\theta\) flows through Adam's adaptive denominator and gets rescaled per-parameter — coupling the regularization strength to each coordinate's gradient history. Loshchilov & Hutter showed that decoupling the decay — applying it directly to the weights, outside the adaptive step — restores the intended behavior and consistently generalizes better:
| Optimizer | State per parameter | Strength | Weakness |
|---|---|---|---|
| SGD | none | Cheapest; flat minima; strong final accuracy on vision with a good schedule | Slow on ill-conditioned loss surfaces; very LR-sensitive |
| SGD + momentum | 1 (velocity) | Accelerates persistent directions, damps oscillation; the CNN workhorse | Can overshoot; still one global \(\eta\) |
| Adam | 2 (\(m\), \(v\)) | Per-parameter adaptive; robust across layer types; fast early progress | 2× optimizer memory; L2 decay misbehaves |
| AdamW | 2 (\(m\), \(v\)) | Adam with correct weight decay; default for transformers | Same memory cost; still needs a schedule |
Adam's two extra moments cost real memory: at fp32 they add 8 bytes per parameter on top of the 4-byte weight and 4-byte gradient — the "16 bytes/param" rule that sizes training clusters (and the reason 8-bit optimizers and ZeRO sharding exist). The contested point worth flagging: on some vision benchmarks well-tuned SGD+momentum still generalizes slightly better than Adam, so "Adam always wins" is folklore, not law — it wins on convenience and on transformers, where SGD struggles.
# SGD vs momentum vs Adam on an ill-conditioned 2D quadratic
# Loss = 0.5*(a*x^2 + b*y^2); steep in x (a=20), flat in y (b=1).
import numpy as np
a, b = 20.0, 1.0
def grad(p): return np.array([a*p[0], b*p[1]]) # gradient of the quadratic
def loss(p): return 0.5*(a*p[0]**2 + b*p[1]**2)
def run(kind, lr, steps=300):
p = np.array([1.0, 1.0]); m = np.zeros(2); v = np.zeros(2)
for t in range(1, steps+1):
g = grad(p)
if kind == "sgd":
p = p - lr*g
elif kind == "mom":
m = 0.9*m + g; p = p - lr*m
else: # adam
m = 0.9*m + 0.1*g; v = 0.999*v + 0.001*g*g
mh = m/(1-0.9**t); vh = v/(1-0.999**t)
p = p - lr*mh/(np.sqrt(vh)+1e-8)
return loss(p)
# Each optimizer gets its own near-best stable lr (the fair way to compare them)
for kind, lr in [("sgd", 0.04), ("mom", 0.02), ("adam", 0.20)]:
print(f"{kind:5s} (lr={lr:.2f}) final loss after 300 steps: {run(kind, lr):.2e}")
print("\nAdam reaches the lowest loss: it scales x and y independently, so the")
print("steep x-direction and flat y-direction converge at the same rate -- the")
print("single global step size that hobbles SGD on this surface is gone.")
Learning-rate schedules — warmup, cosine, cyclical
The single learning rate \(\eta\) is the most consequential hyperparameter in deep learning, and the best value is not constant over a run. Two facts shape the schedule: early on, weights are random and gradients are large and chaotic, so a big step can blow up; late on, you want small steps to settle into a minimum. The modern default answers both with a warmup followed by a cosine decay.
Why a cosine rather than a straight line or exponential? Empirically the cosine's slow start (it lingers near \(\eta_{\max}\)) buys more exploration before annealing, and its slow finish lets the model fine-settle — and it consistently beats step decay on large language and vision models. The cyclical / warm-restart family (SGDR) takes the idea further, resetting the schedule periodically so the rate jumps back up; each restart can knock the model out of a mediocre basin into a better one, and the snapshots make a cheap ensemble. The contested part: with a good cosine, restarts rarely help large single-run pretraining, so they have fallen out of fashion for frontier models while remaining useful for smaller budgets.
# Warmup + cosine learning-rate schedule (EQ N7.5): build and inspect it
import numpy as np
T, Tw = 1000, 50 # total steps, warmup steps (5%)
eta_max, eta_min = 1e-3, 0.0
def lr_at(t):
if t < Tw: # linear warmup
return eta_max * t / Tw
prog = (t - Tw) / (T - Tw) # 0..1 through the decay
return eta_min + 0.5*(eta_max - eta_min)*(1 + np.cos(np.pi*prog))
ts = np.arange(T)
eta = np.array([lr_at(t) for t in ts])
print("step 0:", f"{lr_at(0):.2e} (warmup starts at 0)")
print("step 50:", f"{lr_at(50):.2e} (peak = eta_max at end of warmup)")
print("step 525:", f"{lr_at(525):.2e} (~midpoint of decay, steepest part)")
print("step 999:", f"{lr_at(999):.2e} (decayed to eta_min)")
print(f"\npeak step is {ts[eta.argmax()]} -> rate peaks exactly at warmup's end")
plot_xy(ts, eta) # the classic ramp-then-cosine shape
Regularization & early stopping
A network with millions of parameters can memorize its training set outright. Regularization is the set of pressures that push it to generalize instead — to fit the signal, not the noise. The deep-learning toolkit is small and well-understood.
- Weight decay (the \(\lambda\theta\) term of EQ N7.4). Shrinks weights toward zero each step, favoring simpler, smaller-norm solutions. Use the decoupled form via AdamW; exclude biases and norm scales.
- Dropout. During training, zero each activation independently with probability \(p\) and rescale the survivors by \(1/(1-p)\) (so the expected activation is unchanged). This prevents co-adaptation — no neuron can rely on any specific other — and approximates training an ensemble of subnetworks. At inference, dropout is off. Transformers use light dropout (\(p \approx 0.0\!-\!0.1\)); large-data pretraining often sets it to zero.
- Data augmentation. The cheapest regularizer: expand the effective dataset with label-preserving transforms (crops, flips, mixup/cutmix for vision; token masking for text). More data beats every other trick.
- Label smoothing. Replace one-hot targets with \((1-\varepsilon)\) on the true class and \(\varepsilon/K\) elsewhere, discouraging the model from becoming over-confident and improving calibration.
- Early stopping. Track a held-out validation loss; keep the checkpoint at its minimum and stop once it has stopped improving for a patience window. It is regularization by when you quit.
The signature of overfitting is a validation loss that bottoms out and then rises while the training loss keeps falling — the model is now learning the training set's idiosyncrasies. Underfitting is the opposite: both losses sit high and flat, the model lacks the capacity, the right features, or enough training. Early stopping catches the first; more capacity, better features, or longer training fixes the second.
Mixed precision & numerical stability
Modern GPUs run dramatically faster in 16-bit than in 32-bit, and 16-bit tensors halve memory. Mixed-precision training captures both wins while keeping fp32 where precision is non-negotiable. The catch is dynamic range: the older float16 format has only ~5 exponent bits, so its largest representable value is about \(65{,}504\) and small gradients underflow to zero. The fix is loss scaling.
inf/NaN appears.Three practices keep mixed precision numerically safe:
- Keep an fp32 master copy of the weights. Updates are tiny relative to the weights; adding a small fp16 step to an fp16 weight rounds to nothing. The optimizer updates the fp32 master, then casts to fp16 for the next forward pass.
- Run reductions in fp32. Softmax, layer-norm statistics, and loss accumulation sum many terms; do them in fp32 to avoid catastrophic cancellation, even when the matmuls run in 16-bit.
- Prefer
bfloat16when the hardware has it. bf16 keeps fp32's 8 exponent bits (same ~\(10^{38}\) range) at the cost of mantissa precision, so it almost never overflows and usually needs no loss scaling — the reason it is the default for large-model training on recent accelerators. fp8 pushes further still and is now used for the heaviest matmuls, with per-tensor scaling.
Most "my loss went to NaN" failures are numeric, not algorithmic. The usual suspects: fp16 gradient overflow (use loss scaling or switch to bf16); a learning rate high enough to send weights to inf in a few steps; \(\log(0)\) or \(0/0\) in a hand-written loss (add an \(\epsilon\), use the log-sum-exp trick); and un-clipped gradients on a spiky batch. Gradient clipping — rescale the gradient so \(\lVert g\rVert \le c\) (typically \(c = 1.0\)) — is cheap insurance against the last one and is standard in transformer recipes.
float16 format's largest representable finite value — the overflow ceiling that motivates loss scaling — is which number? (It is \((2 - 2^{-10})\times 2^{15}\).)inf in fp16, which is exactly why loss scaling — and, better, bf16's fp32-sized exponent — exist.# Why loss scaling exists: fp16 underflow, and how scaling rescues gradients
import numpy as np
FP16_MAX = 65504.0 # largest finite fp16; above this -> inf (overflow)
# A batch of tiny gradients, the kind deep nets produce late in training.
# fp16's smallest positive value is ~6e-8, so anything well below that vanishes.
g = np.array([1e-3, 2e-5, 5e-7, 4e-8, 9e-9])
# Cast to fp16 with NO scaling -> the smallest entries flush to zero (underflow)
g_fp16 = g.astype(np.float16)
lost = int(np.sum((g != 0) & (g_fp16 == 0)))
print("raw gradients :", g)
print("naive fp16 :", g_fp16.astype(np.float32))
print(f"-> {lost} of {g.size} gradients underflowed to exactly 0\n")
# Loss scaling: multiply by S before fp16, divide back after (EQ N7.7)
S = 2**15
scaled = g * S
g_scaled = scaled.astype(np.float16).astype(np.float32) / S
recovered = int(np.sum((g_fp16 == 0) & (g_scaled != 0)))
overflow = bool(np.any(np.abs(scaled) > FP16_MAX))
print(f"with loss scale S={S}:", g_scaled)
print(f"-> {recovered} previously-lost gradient(s) recovered; overflow? {overflow}")
print("\nScaling lifts tiny gradients above fp16's underflow floor, then")
print("unscales them after backprop -- same math, full dynamic range recovered.")
A practical recipe & debugging
Theory converges; in practice the failures are mundane and repetitive. Here is a default that survives contact with reality for most supervised deep-learning tasks, followed by the debugging loop that finds the bug when it does not.
# Defaults that work for most from-scratch deep-net training
optimizer: AdamW · β1=0.9 · β2=0.999 (0.95 for big transformers) · ε=1e-8
weight_decay: 0.1 on weights · 0.0 on biases & norm/scale params
lr: tune η_max first (it dominates); 3e-4 is a sane transformer start
schedule: linear warmup 1–5% of steps → cosine decay to ~0
batch: as large as memory allows; raise lr with batch (lin/sqrt rule)
precision: bf16 if available (no loss scaling); else fp16 + dynamic scaling
grad_clip: global-norm clip at 1.0 — cheap insurance against spikes
regularize: dropout 0.0–0.1 · augmentation · early-stop on val loss
init: scaled init (He/Xavier or per-arch); verify activations don't explode
When a run misbehaves, work the ladder from cheapest check to most expensive — most bugs are caught in the first three rungs:
- Overfit one batch. Before anything else, train on a single mini-batch until the loss hits (near) zero. If it cannot, the bug is in the model, the loss, or the data pipeline — not the hyperparameters. This one test catches a remarkable fraction of failures.
- Sanity-check the initial loss. For \(K\)-class classification with random weights, cross-entropy should start near \(\ln K\). If it starts far off, your labels, logits, or loss are wired wrong.
- Read the loss curve (Instrument N7.3). NaN/spike → lower LR, clip gradients, check for fp16 overflow. Flat-and-high → underfit: more capacity/LR/steps. Val turns up → overfit: regularize or early-stop.
- Do an LR sweep. The learning rate dominates every other knob. Sweep it over a few orders of magnitude (or use an LR-range test) before touching architecture.
- Watch gradient and activation norms. Exploding norms → clip, lower LR, check init/normalization. Vanishing norms → check residual connections, normalization placement, and activation functions.
You can now train a network that fits a fixed dataset; the next volume removes the dataset. Reinforcement learning replaces "minimize a loss on labeled examples" with "maximize a reward signal an agent must discover by acting" — a setting where the data is generated by the very policy you are optimizing. RL · 01 opens with the formalism that makes that tractable: the Markov decision process, states, actions, rewards, and the discounting that ties a future payoff to a present choice.
References
- Kingma, D. P. & Ba, J. (2014). Adam: A Method for Stochastic Optimization.
- Loshchilov, I. & Hutter, F. (2017). Decoupled Weight Decay Regularization.
- Micikevicius, P. et al. (2017). Mixed Precision Training.
- Loshchilov, I. & Hutter, F. (2016). SGDR: Stochastic Gradient Descent with Warm Restarts.
- Srivastava, N. et al. (2014). Dropout: A Simple Way to Prevent Neural Networks from Overfitting.
- Keskar, N. S. et al. (2016). On Large-Batch Training for Deep Learning: Generalization Gap and Sharp Minima.
- Goodfellow, I., Bengio, Y. & Courville, A. (2016). Deep Learning