AI // ENCYCLOPEDIA / VOL II / 03 / ATTENTION INDEX NEXT: PRE-TRAINING →
CHAPTER 03 / 10

Attention

Attention performs a differentiable soft lookup. Every position publishes what it holds (keys, values) and what it wants (queries), and information flows wherever query meets key. This chapter covers the mechanism exactly, then the production variants (multi-head, MQA, GQA, MLA, sliding-window, FlashAttention) and the KV cache that dominates inference memory.

READING TIME≈ 30 MIN BUILDS ONCH 02 INSTRUMENTSHEATMAP · KV CALC · GQA
3.1

Scaled dot-product attention

From the normalized residual stream \(X \in \mathbb{R}^{T \times d}\), three learned projections produce queries, keys and values: \(Q = XW_Q\), \(K = XW_K\), \(V = XW_V\). Every query is compared against every key by dot product; the resulting scores, softmaxed, become mixing weights over the values:

EQ 3.1 — THE EQUATION OF THE DECADE $$ \mathrm{Attention}(Q, K, V) \;=\; \mathrm{softmax}\!\left( \frac{Q K^{\top}}{\sqrt{d_k}} + M \right) V $$
\(QK^\top\) is a \(T \times T\) matrix of relevance scores. \(M\) is the causal mask (§3.4). Each output row is a convex combination of value vectors — attention never invents content, it routes and blends what positions already offer. Computational cost: \(O(T^2 d)\) — the quadratic that drives an entire sub-industry of optimizations.
A query and key give a raw dot product \( q \cdot k = 12 \), and the head dimension is \( d_k = 16 \). What is the scaled attention score \( \dfrac{q \cdot k}{\sqrt{d_k}} \)?
\( \sqrt{d_k} = \sqrt{16} = 4 \), so the scaled score \( = 12 / 4 = \) 3.
PYTHON · RUNNABLE IN-BROWSER
# EQ 3.1, complete: scaled dot-product attention with a causal mask
import numpy as np
rng = np.random.default_rng(0)
T, dk = 6, 8
Q, K, V = rng.normal(0, 1, (3, T, dk))

S = Q @ K.T / np.sqrt(dk)                      # T x T relevance scores
S += np.triu(np.full((T, T), -np.inf), 1)      # causal: futures unreachable
A = np.exp(S - S.max(-1, keepdims=True))
A /= A.sum(-1, keepdims=True)                  # softmax, row by row
out = A @ V                                    # each row: a blend of values

np.set_printoptions(precision=2, suppress=True)
print("attention weights A (rows = queries, cols = keys):")
print(A)
print("\nrow sums:", A.sum(1).round(6), " <- every row a convex blend")
print("row 0 can only see itself, so out[0] == v_0 exactly:",
      np.allclose(out[0], V[0]))
edits are live — break it on purpose
INSTRUMENT 3.1 — ATTENTION INSPECTORHOVER TOKENS · CAUSAL · 1 HEAD
Hover a token to set the query row. Try “it” — the head resolves the pronoun to “ball” and “robot”. Lower the temperature and watch softmax sharpen toward a hard lookup; raise it and attention diffuses toward a uniform average. Upper triangle is the causal mask: futures are unreachable.
3.2

Why √d̄ — and what softmax is doing

The \(\sqrt{d_k}\) is not cosmetic. If query and key components are independent with zero mean and unit variance, their dot product over \(d_k\) dimensions has variance \(d_k\):

EQ 3.2 — SCORE VARIANCE $$ \mathrm{Var}\!\left( q \cdot k \right) = \sum_{i=1}^{d_k} \mathrm{Var}(q_i k_i) = d_k \quad\Longrightarrow\quad \mathrm{Var}\!\left( \frac{q \cdot k}{\sqrt{d_k}} \right) = 1 $$
Unscaled, scores grow like \(\sqrt{d_k}\) in magnitude, the softmax saturates to near one-hot, and gradients through it vanish. Dividing by \(\sqrt{d_k}\) keeps the score distribution in softmax's responsive regime — the same role temperature plays at sampling time.
Two keys produce scaled scores \( (1,\ 0) \). What softmax weight does the first key receive? (Use \( e^1 = 2.718,\ e^0 = 1 \).)
\( \dfrac{e^1}{e^1 + e^0} = \dfrac{2.718}{2.718 + 1} = \dfrac{2.718}{3.718} = \) 0.731 — comfortably inside softmax's responsive regime, unlike the saturated unscaled case.
PYTHON · RUNNABLE IN-BROWSER
# the sqrt(d_k) experiment: score variance vs head width (EQ 3.2)
import numpy as np
rng = np.random.default_rng(0)

def topw(sigma):                  # 2-way softmax of a typical (+1 sigma) score vs 0
    return 1 / (1 + np.exp(-sigma))

print(" d_k   var(q.k)   var(scaled)   softmax raw   softmax scaled")
for dk in (4, 64, 1024):
    q, k = rng.normal(0, 1, (2, 4000, dk))
    dots = (q * k).sum(1)
    raw, scaled = dots.var(), (dots / np.sqrt(dk)).var()
    print(f"{dk:4d} {raw:10.1f} {scaled:13.3f} {topw(np.sqrt(raw)):13.5f}"
          f" {topw(np.sqrt(scaled)):15.3f}")

print("\nvar(q.k) = d_k, as EQ 3.2 predicts; dividing by sqrt(d_k) pins it at 1.")
print("at d_k=1024 the unscaled softmax reads 1.00000 -- saturated, zero gradient;")
print("the scaled column sits near 0.73 at every width: always trainable.")
edits are live — break it on purpose

Softmax with temperature \(\tau\) interpolates between two regimes you just explored in the instrument above: \(\tau \to 0\) recovers a hard \(\arg\max\) lookup (a dictionary); \(\tau \to \infty\) gives uniform averaging (a bag of words). Trained attention lives between — sharp enough to bind, soft enough to be differentiable.

3.3

Multi-head attention

One softmax produces one mixing pattern per position — but a token may simultaneously need its syntactic head, an earlier coreferent, and the previous token. Multi-head attention runs \(h\) attentions in parallel in subspaces of size \(d_k = d/h\), then concatenates and projects:

EQ 3.3 — MULTI-HEAD $$ \mathrm{head}_i = \mathrm{Attention}\!\left(XW_Q^{(i)},\, XW_K^{(i)},\, XW_V^{(i)}\right), \qquad \mathrm{MHA}(X) = \big[\mathrm{head}_1; \cdots; \mathrm{head}_h\big]\, W_O $$
Same total FLOPs as one full-width head — the work is sliced, not multiplied. Trained heads specialize into recognizable roles: previous-token heads, syntactic heads, induction heads (find an earlier occurrence of the current pattern and copy what followed it — the circuit behind in-context learning), and many that resist naming.
3.4

Causal masking

A language model must not see its own future — position \(t\) may only attend to positions \(\le t\). This is enforced before the softmax with an additive mask:

EQ 3.4 — CAUSAL MASK $$ M_{ij} = \begin{cases} 0 & j \le i \\ -\infty & j > i \end{cases} $$
\(e^{-\infty} = 0\): masked positions receive exactly zero weight after softmax. The same trick implements padding masks and (with a band pattern) sliding-window attention. The triangle of grey cells in Instrument 02 is \(M\) made visible.

Masking is also why training parallelizes (Chapter 01): with the triangle in place, all \(T\) positions can be predicted simultaneously from one forward pass without information leaking backward from labels.

3.5

The KV cache: inference's real currency

During generation, step \(t\) needs the keys and values of all previous positions. Recomputing them every step would cost \(O(T^2)\) redundant work — so they are cached. The price is memory, and it grows linearly with everything:

EQ 3.5 — KV-CACHE SIZE $$ \mathrm{bytes} \;=\; 2 \times L \times h_{kv} \times d_k \times T \times b \times (\text{bytes/elem}) $$
2 for K and V; \(L\) layers; \(h_{kv}\) key-value heads; \(d_k\) head dim; \(T\) sequence length; \(b\) batch size. This buffer — not the weights — is what limits how many concurrent users fit on a GPU, which is exactly why §3.6 exists.
Using EQ 3.5 per single token (\( T = 1 \), \( b = 1 \)): \( L = 32 \), \( h_{kv} = 8 \), \( d_k = 128 \), FP16 (2 bytes/element). How many KB of KV cache does one token need?
Elements \( = 2 \times 32 \times 8 \times 128 = 65{,}536 \). Bytes \( = 65{,}536 \times 2 = 131{,}072 \). In KB: \( 131{,}072 / 1024 = \) 128 KB per token.
PYTHON · RUNNABLE IN-BROWSER
# kv_cache_gb: EQ 3.5 as a function -- the number that sizes serving fleets
def kv_cache_gb(L, h_kv, d_k, T, batch=1, bytes_per=2):
    return 2 * L * h_kv * d_k * T * batch * bytes_per / 1e9

llama2_70b = dict(L=80, h_kv=8, d_k=128)              # GQA-8, FP16
per_tok = 2 * 80 * 8 * 128 * 2
print(f"Llama-2-70B: {per_tok:,} bytes of cache per token, every token")
for T in (4096, 8192, 32768, 131072):
    print(f"  T = {T:>7,}: {kv_cache_gb(T=T, **llama2_70b):8.2f} GB per user")

full_mha = kv_cache_gb(T=8192, L=80, h_kv=64, d_k=128)
print(f"\nsame model, full MHA (h_kv = 64) at 8K: {full_mha:.1f} GB -- 8x worse;")
print("four such users fill an 80 GB H100 before one weight is loaded.")

monster = kv_cache_gb(T=1_000_000, **llama2_70b)
fleet = kv_cache_gb(T=1_000_000, batch=32, **llama2_70b)
print(f"\n1M-token context: {monster:,.0f} GB for ONE user (4+ H100s of pure cache);")
print(f"a batch of 32 such users: {fleet/1000:,.1f} TB. This is why §3.6 exists.")
edits are live — break it on purpose
INSTRUMENT 3.2 — KV-CACHE CALCULATOREQ 3.5 · LIVE
TOTAL KV CACHE
PER TOKEN (ALL LAYERS)
FOOTPRINT
Defaults ≈ Llama-2-70B with GQA-8. Set KV heads to 64 to feel why pure MHA died at long context — then drop precision to FP8 and watch serving capacity double.
3.6

Shrinking the cache: MQA → GQA → MLA

Queries are free at decode time — only K and V are cached. So the variants attack \(h_{kv}\):

  • Multi-Query Attention (MQA). All \(h\) query heads share one K/V head: \(h_{kv} = 1\), a \(h\times\) cache reduction. Fast but measurably lossy at scale.
  • Grouped-Query Attention (GQA). The production compromise: \(h_{kv} = h/g\) groups, with each group of query heads sharing one K/V pair. Llama-3 uses 128 query heads against 8 KV heads — a 16× reduction at near-zero quality cost.
  • Multi-head Latent Attention (MLA). DeepSeek's reformulation: instead of caching K and V at all, cache a single low-rank latent \(c_t\) per position and reconstruct keys and values from it on the fly.
EQ 3.6 — MLA: CACHE A LATENT, NOT K AND V $$ c_t = W_{DKV}\, h_t \in \mathbb{R}^{d_c}, \qquad k_t^{(i)} = W_{UK}^{(i)} c_t, \quad v_t^{(i)} = W_{UV}^{(i)} c_t \qquad (d_c \ll h \cdot d_k) $$
DeepSeek-V3: \(d_c = 512\) versus \(h \cdot d_k = 16{,}384\) — a ~32× compression that outperformed full MHA in their ablations, because the up-projections \(W_{UK}, W_{UV}\) can be absorbed into neighboring matrices at inference. A decoupled RoPE component rides alongside the latent to preserve relative position.
A model has \( h = 64 \) query heads but uses GQA with only \( h_{kv} = 8 \) cached KV heads. What fraction of the full-MHA KV cache does it keep? (Answer as a decimal: \( h_{kv}/h \).)
Cache scales with KV heads, so the fraction is \( h_{kv}/h = 8/64 = 1/8 = \) 0.125 — an 8× reduction at near-zero quality cost.
VariantKV heads cachedCache vs MHAUsed by
MHAh (= 32–128)GPT-2/3 era
MQA11/hPaLM, Falcon
GQAh/g (= 8 typical)g/h (e.g. 1/16)Llama 2/3, Mistral, Qwen
MLA1 latent (d_c)≈ 1/30DeepSeek V2/V3/R1
INSTRUMENT 3.3 — HEAD SHARING32 QUERY HEADS · MHA → GQA → MQA
QUERY HEADS (COLOR = SHARED KV GROUP)
CACHED KV HEADS
REGIME
CACHE REDUCTION
KV @ 8K CTX (70B-CLASS, FP16)
Slide left to MQA (one KV head serving all 32 queries) and right back to full MHA. The middle — GQA-8 — is where nearly every model since 2023 has landed.
3.7

FlashAttention, sliding windows, sparsity

FlashAttention changed nothing mathematically and everything practically. The insight: attention is bottlenecked not by FLOPs but by reading and writing the \(T \times T\) score matrix to GPU main memory (HBM). FlashAttention never materializes that matrix — it processes K/V in tiles resident in fast on-chip SRAM, maintaining a running softmax via the online softmax identities:

EQ 3.7 — ONLINE SOFTMAX (PER TILE UPDATE) $$ m^{\text{new}} = \max(m, \tilde{m}), \qquad \ell^{\text{new}} = e^{\,m - m^{\text{new}}}\,\ell + e^{\,\tilde{m} - m^{\text{new}}}\,\tilde{\ell}, \qquad O^{\text{new}} = \frac{e^{\,m - m^{\text{new}}}\,\ell\, O + e^{\,\tilde{m} - m^{\text{new}}}\,\tilde{\ell}\,\tilde{O}}{\ell^{\text{new}}} $$
Running max \(m\), normalizer \(\ell\), and output \(O\) are corrected as each new tile \((\tilde{m}, \tilde{\ell}, \tilde{O})\) arrives — the exact softmax, computed without ever holding all scores at once. Memory drops from \(O(T^2)\) to \(O(T)\); wall-clock speedups of 2–4× and the backward pass recomputes rather than stores. FlashAttention-2/3 refine parallelism and exploit FP8 on Hopper.

Restricting the pattern

  • Sliding-window attention: attend only to the last \(w\) positions (Mistral: \(w = 4096\)). Cost becomes \(O(Tw)\); stacked layers extend effective reach to \(L \times w\). Often interleaved — e.g. 3 local layers : 1 global — in recent models (Gemma, GPT-OSS pattern).
  • Attention sinks: keep the first few tokens always visible; their removal destabilizes streaming generation because softmax needs somewhere to park probability mass.
  • Sparse / native sparse attention: learned or structured subsets of the full pattern (block-sparse, DeepSeek's NSA), trading exactness for near-linear scaling — increasingly important at million-token contexts (Chapter 09).
  • Linear attention & kernel methods: replace softmax with feature maps so attention becomes associative and \(O(T)\) — historically a quality trade-off, now resurfacing inside hybrid architectures (Chapter 09).
NEXT

Architecture is settled; now it must learn. Chapter 04: the data, the scaling laws that tell you how big to build, the optimizer, and the art of spreading one training run across tens of thousands of GPUs.

§

Further reading

  • Vaswani et al. (2017). Attention Is All You Need. — scaled dot-product and multi-head attention as defined here, including the √d scaling.
  • Bahdanau, Cho & Bengio (2015). Neural Machine Translation by Jointly Learning to Align and Translate. — the original additive attention that the dot-product form streamlined.
  • Shazeer (2019). Fast Transformer Decoding: One Write-Head is All You Need. — multi-query attention, the first big cut to KV-cache size.
  • Ainslie et al. (2023). GQA: Training Generalized Multi-Query Transformer Models. — grouped-query attention, the production middle ground.
  • DeepSeek-AI (2024). DeepSeek-V2. — multi-head latent attention (MLA), compressing the KV cache via low-rank projection.
  • Dao, Fu, Ermon, Rudra & Ré (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. — the IO-aware kernel that made long-context attention practical.
  • Beltagy, Peters & Cohan (2020). Longformer: The Long-Document Transformer. — sliding-window / sparse attention for long sequences.