Scaled dot-product attention
From the normalized residual stream \(X \in \mathbb{R}^{T \times d}\), three learned projections produce queries, keys and values: \(Q = XW_Q\), \(K = XW_K\), \(V = XW_V\). Every query is compared against every key by dot product; the resulting scores, softmaxed, become mixing weights over the values:
# EQ 3.1, complete: scaled dot-product attention with a causal mask
import numpy as np
rng = np.random.default_rng(0)
T, dk = 6, 8
Q, K, V = rng.normal(0, 1, (3, T, dk))
S = Q @ K.T / np.sqrt(dk) # T x T relevance scores
S += np.triu(np.full((T, T), -np.inf), 1) # causal: futures unreachable
A = np.exp(S - S.max(-1, keepdims=True))
A /= A.sum(-1, keepdims=True) # softmax, row by row
out = A @ V # each row: a blend of values
np.set_printoptions(precision=2, suppress=True)
print("attention weights A (rows = queries, cols = keys):")
print(A)
print("\nrow sums:", A.sum(1).round(6), " <- every row a convex blend")
print("row 0 can only see itself, so out[0] == v_0 exactly:",
np.allclose(out[0], V[0]))
Why √d̄ — and what softmax is doing
The \(\sqrt{d_k}\) is not cosmetic. If query and key components are independent with zero mean and unit variance, their dot product over \(d_k\) dimensions has variance \(d_k\):
# the sqrt(d_k) experiment: score variance vs head width (EQ 3.2)
import numpy as np
rng = np.random.default_rng(0)
def topw(sigma): # 2-way softmax of a typical (+1 sigma) score vs 0
return 1 / (1 + np.exp(-sigma))
print(" d_k var(q.k) var(scaled) softmax raw softmax scaled")
for dk in (4, 64, 1024):
q, k = rng.normal(0, 1, (2, 4000, dk))
dots = (q * k).sum(1)
raw, scaled = dots.var(), (dots / np.sqrt(dk)).var()
print(f"{dk:4d} {raw:10.1f} {scaled:13.3f} {topw(np.sqrt(raw)):13.5f}"
f" {topw(np.sqrt(scaled)):15.3f}")
print("\nvar(q.k) = d_k, as EQ 3.2 predicts; dividing by sqrt(d_k) pins it at 1.")
print("at d_k=1024 the unscaled softmax reads 1.00000 -- saturated, zero gradient;")
print("the scaled column sits near 0.73 at every width: always trainable.")
Softmax with temperature \(\tau\) interpolates between two regimes you just explored in the instrument above: \(\tau \to 0\) recovers a hard \(\arg\max\) lookup (a dictionary); \(\tau \to \infty\) gives uniform averaging (a bag of words). Trained attention lives between — sharp enough to bind, soft enough to be differentiable.
Multi-head attention
One softmax produces one mixing pattern per position — but a token may simultaneously need its syntactic head, an earlier coreferent, and the previous token. Multi-head attention runs \(h\) attentions in parallel in subspaces of size \(d_k = d/h\), then concatenates and projects:
Causal masking
A language model must not see its own future — position \(t\) may only attend to positions \(\le t\). This is enforced before the softmax with an additive mask:
Masking is also why training parallelizes (Chapter 01): with the triangle in place, all \(T\) positions can be predicted simultaneously from one forward pass without information leaking backward from labels.
The KV cache: inference's real currency
During generation, step \(t\) needs the keys and values of all previous positions. Recomputing them every step would cost \(O(T^2)\) redundant work — so they are cached. The price is memory, and it grows linearly with everything:
# kv_cache_gb: EQ 3.5 as a function -- the number that sizes serving fleets
def kv_cache_gb(L, h_kv, d_k, T, batch=1, bytes_per=2):
return 2 * L * h_kv * d_k * T * batch * bytes_per / 1e9
llama2_70b = dict(L=80, h_kv=8, d_k=128) # GQA-8, FP16
per_tok = 2 * 80 * 8 * 128 * 2
print(f"Llama-2-70B: {per_tok:,} bytes of cache per token, every token")
for T in (4096, 8192, 32768, 131072):
print(f" T = {T:>7,}: {kv_cache_gb(T=T, **llama2_70b):8.2f} GB per user")
full_mha = kv_cache_gb(T=8192, L=80, h_kv=64, d_k=128)
print(f"\nsame model, full MHA (h_kv = 64) at 8K: {full_mha:.1f} GB -- 8x worse;")
print("four such users fill an 80 GB H100 before one weight is loaded.")
monster = kv_cache_gb(T=1_000_000, **llama2_70b)
fleet = kv_cache_gb(T=1_000_000, batch=32, **llama2_70b)
print(f"\n1M-token context: {monster:,.0f} GB for ONE user (4+ H100s of pure cache);")
print(f"a batch of 32 such users: {fleet/1000:,.1f} TB. This is why §3.6 exists.")
Shrinking the cache: MQA → GQA → MLA
Queries are free at decode time — only K and V are cached. So the variants attack \(h_{kv}\):
- Multi-Query Attention (MQA). All \(h\) query heads share one K/V head: \(h_{kv} = 1\), a \(h\times\) cache reduction. Fast but measurably lossy at scale.
- Grouped-Query Attention (GQA). The production compromise: \(h_{kv} = h/g\) groups, with each group of query heads sharing one K/V pair. Llama-3 uses 128 query heads against 8 KV heads — a 16× reduction at near-zero quality cost.
- Multi-head Latent Attention (MLA). DeepSeek's reformulation: instead of caching K and V at all, cache a single low-rank latent \(c_t\) per position and reconstruct keys and values from it on the fly.
| Variant | KV heads cached | Cache vs MHA | Used by |
|---|---|---|---|
| MHA | h (= 32–128) | 1× | GPT-2/3 era |
| MQA | 1 | 1/h | PaLM, Falcon |
| GQA | h/g (= 8 typical) | g/h (e.g. 1/16) | Llama 2/3, Mistral, Qwen |
| MLA | 1 latent (d_c) | ≈ 1/30 | DeepSeek V2/V3/R1 |
FlashAttention, sliding windows, sparsity
FlashAttention changed nothing mathematically and everything practically. The insight: attention is bottlenecked not by FLOPs but by reading and writing the \(T \times T\) score matrix to GPU main memory (HBM). FlashAttention never materializes that matrix — it processes K/V in tiles resident in fast on-chip SRAM, maintaining a running softmax via the online softmax identities:
Restricting the pattern
- Sliding-window attention: attend only to the last \(w\) positions (Mistral: \(w = 4096\)). Cost becomes \(O(Tw)\); stacked layers extend effective reach to \(L \times w\). Often interleaved — e.g. 3 local layers : 1 global — in recent models (Gemma, GPT-OSS pattern).
- Attention sinks: keep the first few tokens always visible; their removal destabilizes streaming generation because softmax needs somewhere to park probability mass.
- Sparse / native sparse attention: learned or structured subsets of the full pattern (block-sparse, DeepSeek's NSA), trading exactness for near-linear scaling — increasingly important at million-token contexts (Chapter 09).
- Linear attention & kernel methods: replace softmax with feature maps so attention becomes associative and \(O(T)\) — historically a quality trade-off, now resurfacing inside hybrid architectures (Chapter 09).
Architecture is settled; now it must learn. Chapter 04: the data, the scaling laws that tell you how big to build, the optimizer, and the art of spreading one training run across tens of thousands of GPUs.
Further reading
- Vaswani et al. (2017). Attention Is All You Need. — scaled dot-product and multi-head attention as defined here, including the √d scaling.
- Bahdanau, Cho & Bengio (2015). Neural Machine Translation by Jointly Learning to Align and Translate. — the original additive attention that the dot-product form streamlined.
- Shazeer (2019). Fast Transformer Decoding: One Write-Head is All You Need. — multi-query attention, the first big cut to KV-cache size.
- Ainslie et al. (2023). GQA: Training Generalized Multi-Query Transformer Models. — grouped-query attention, the production middle ground.
- DeepSeek-AI (2024). DeepSeek-V2. — multi-head latent attention (MLA), compressing the KV cache via low-rank projection.
- Dao, Fu, Ermon, Rudra & Ré (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. — the IO-aware kernel that made long-context attention practical.
- Beltagy, Peters & Cohan (2020). Longformer: The Long-Document Transformer. — sliding-window / sparse attention for long sequences.