Efficient Attention, Multi-Head Reasoning & Normalization

Section 3.5a

They asked me to scale attention to a million tokens. I implemented FlashAttention. They asked for ten million. I implemented Ring Attention. They asked for a hundred million. I implemented retirement.

AttnAttn, Memory-Bound AI Agent
Big Picture

This section continues from Section 3.5, which covered the three architectural families and positional encoding variants. Here we turn to the practical machinery that makes Transformers tractable at scale: efficient attention mechanisms (sparse, linear, FlashAttention, GQA/MQA, Differential Attention), why multiple attention heads matter, and where to place LayerNorm. Pre-Norm with RMSNorm is the configuration every modern frontier LLM uses, for reasons this section makes precise.

Prerequisites

This section continues from Section 3.5. Familiarity with the architectural families and positional encoding survey there is assumed, along with the attention mechanism from Section 3.1 and the from-scratch Transformer build in Section 3.3.

Having mapped the Transformer family tree and the positional-encoding design space in Section 3.5, we now turn to the components that determine whether a Transformer can actually run at long contexts and at scale: how attention is made efficient, why multiple heads matter, and where to place LayerNorm.

3.5.3 Efficient Attention Mechanisms

Standard attention has $O(T^{2})$ time and memory complexity, where $T$ is the sequence length. For $T = 128\text{K}$ (a common context window in 2024+ models), the naive attention matrix would be $128\text{K} \times 128\text{K} = 16$ billion entries per head. This section surveys the main approaches to making attention tractable at long sequences.

3.5.3.1 Sparse Attention

Instead of attending to all positions, sparse attention restricts each token to attend to a carefully chosen subset of positions. The challenge is choosing which positions to attend to while preserving the model's ability to capture long-range dependencies.

Note: Mistral's Sliding Window Attention

Mistral 7B uses a sliding window of 4096 tokens. Because information propagates through attention layers, a model with $N$ layers and window size $w$ can theoretically propagate information across $N \times w$ positions. With 32 layers and w=4096, that is 131,072 positions of effective reach.

Key Insight: The Four Canonical Sparse-Attention Patterns

The sparse-attention literature is large but converges on four primitive patterns; every real system (Longformer, BigBird, Sparse Transformer, Mistral, GPT-NeoX) is built by combining two or three of them. The ASCII sketch below shows the score matrix (rows = queries, columns = keys) for a 12-token sequence under each pattern; X marks a position the query is allowed to attend to.

   Full quadratic              Sliding window (w=3)         Dilated sliding (w=3, stride=2)       Global tokens (cols 0, 5)
   q\k 0 1 2 3 4 5 6 7 8 9     q\k 0 1 2 3 4 5 6 7 8 9      q\k 0 1 2 3 4 5 6 7 8 9              q\k 0 1 2 3 4 5 6 7 8 9
    0  X X X X X X X X X X      0  X X . . . . . . . .       0  X . X . X . . . . .               0  X . . . . X . . . .
    1  X X X X X X X X X X      1  X X X . . . . . . .       1  X X . X . X . . . .               1  X X . . . X . . . .
    2  X X X X X X X X X X      2  X X X X . . . . . .       2  X X X . X . X . . .               2  X . X . . X . . . .
    3  X X X X X X X X X X      3  . X X X X . . . . .       3  X X X X . X . X . .               3  X . . X . X . . . .
    4  X X X X X X X X X X      4  . . X X X X . . . .       4  . X X X X . X . X .               4  X . . . X X . . . .
    5  X X X X X X X X X X      5  . . . X X X X . . .       5  . . X X X X . X . X               5  X X X X X X X X X X
    6  X X X X X X X X X X      6  . . . . X X X X . .       6  . . . X X X X . X .               6  X . . . . X X . . .
    7  X X X X X X X X X X      7  . . . . . X X X X .       7  . . . . X X X X . X               7  X . . . . X . X . .

   Time/memory: O(T^2)         Time/memory: O(T * w)        Time/memory: O(T * w)                 Time: O(T * g) with g global anchors

Architectures combine these. Longformer (Beltagy et al., 2020) uses sliding window plus a handful of pre-designated global tokens (typically [CLS] for classification or question tokens for QA). BigBird (Zaheer et al., NeurIPS 2020) layers sliding window, global tokens, and a small number of random keys per query so that the resulting attention graph remains an expander on which a constant number of layers can route information between any two positions. Sparse Transformer (Child et al., 2019) factorizes attention into local plus strided heads. Mistral 7B (Jiang et al., 2023) uses pure sliding window and relies on layer-by-layer propagation (an $N$-layer stack with window $w$ has receptive field $N \cdot w$, as in the previous callout). The random-keys ingredient from BigBird is the least obvious one: it costs almost nothing computationally and provably restores the universality guarantees that pure local attention lacks, which is why most full-attention-replacement systems include a small random component as insurance.

3.5.3.2 Linear Attention

Linear attention replaces the softmax kernel with a decomposable kernel function, allowing the attention computation to be rewritten in O(T) time:

$$\begin{aligned}\text{Standard}: \operatorname{softmax}(\text{QK}^{T}) V \; [O(T^{2})] \\ \text{Linear}: \phi (Q) ( \phi (K)^{T} V) \; [O(T)]\end{aligned}$$

The trick is to compute $\phi (K)^{T}V$ first (a $d \times d$ matrix, independent of T), then multiply by $\phi (Q)$. The feature map $\phi$ can be the identity (giving a simple outer-product formulation), an exponential, or a random feature approximation. Linear attention has seen renewed interest through architectures like RWKV and RetNet (discussed below).

3.5.3.3 FlashAttention

FlashAttention (Dao et al., 2022) is not an approximation; it computes exact standard attention but with dramatically better hardware utilization. The key insight is that the standard attention implementation is memory-bound: it writes the full T × T attention matrix to GPU global memory (HBM), reads it back for the softmax, writes it again, and reads it for the value multiplication. FlashAttention fuses these operations into a single kernel that keeps the attention matrix in fast on-chip SRAM, never materializing the full T × T matrix in HBM.

The result: 2 to 4x wall-clock speedup and dramatically reduced memory usage (from $O(T^{2})$ to $O(T)$ HBM). FlashAttention-2 further optimizes the kernel with better work partitioning across GPU thread blocks. FlashAttention-3 (2024) leverages Hopper GPU features (warp specialization, FP8 tensor cores) for additional gains. We cover the algorithm in detail in Section 3.6.

FlashAttention: tile Q, K, V into SRAM and never materialize the T × T matrix in HBM
Figure 3.5.6: FlashAttention computes exact attention but reorders the work so the T × T score matrix is never written to HBM. Q, K, V are tiled into SRAM-sized blocks, the dot-product, softmax, and value combination happen in one fused kernel, and an online-softmax trick reconciles partial outputs across tiles. The result is identical to the textbook formula at 2 to 4× the throughput.
# Drop-in FlashAttention via PyTorch SDPA (selects the FlashAttention kernel
# automatically when the hardware and dtypes are compatible).  No math change,
# just a call into the fused kernel that the figure above describes.
import torch
import torch.nn.functional as F

torch.manual_seed(0)
B, H, T, D = 4, 16, 4096, 64                       # batch, heads, seq, head_dim
q = torch.randn(B, H, T, D, device="cuda", dtype=torch.bfloat16)
k = torch.randn_like(q)
v = torch.randn_like(q)

# Hint PyTorch to use the FlashAttention backend (PyTorch 2.2+).
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False,
                                    enable_mem_efficient=False):
    out = F.scaled_dot_product_attention(q, k, v, is_causal=True)

print("output shape:", out.shape)                  # (B, H, T, D)
print("peak GPU mem MB:",
      torch.cuda.max_memory_allocated() / 1e6)      # O(T) not O(T^2)

Code Fragment 3.5.2a: Calling FlashAttention through PyTorch's scaled_dot_product_attention with the Flash backend selected. The arithmetic result matches naive attention exactly; only the kernel changes. At T = 4096 this typically cuts peak attention memory from hundreds of megabytes to tens.

output shape: torch.Size([4, 16, 4096, 64])
peak GPU mem MB: 268.4
Real-World Scenario: Why FlashAttention Wins at Long Context

Consider attention on a single head with sequence length $T = 8192$ and head dimension $D = 64$ in FP16. The textbook implementation materializes the $T \times T$ attention-score matrix in HBM, costing $8192 \times 8192 \times 2$ bytes = 128 MB of HBM traffic per head per layer. A 32-layer, 32-head model therefore moves $128 \text{ MB} \times 32 \times 32 \approx 128 \text{ GB}$ of attention scores through HBM for a single forward pass. FlashAttention avoids the full materialization: tile sizes $B_r = B_c = 128$ in SRAM mean each Q tile sees each K, V tile exactly once, giving HBM traffic of roughly $O(T \cdot D)$ per head: just a few megabytes. The arithmetic is identical (numerically equivalent to the naive softmax thanks to the online-softmax trick), but the GPU spends time computing instead of waiting on HBM, which is why the wall-clock speedup is 2 to 4x in practice and grows with sequence length.

3.5.3.4 Multi-Query Attention (MQA) and Grouped-Query Attention (GQA)

During auto-regressive inference, the keys and values from all previous positions are cached (the "KV cache"), a technique we examine in detail in Section 9.3: Memory Optimization. With standard multi-head attention and many heads, this cache becomes enormous.

Multi-Query Attention: eight query heads each project the input into their own Q vector, but all eight heads share a single K and V projection; per-token KV cache shrinks by the number of query heads.
Figure 3.5.7: Multi-Query Attention. Each of the $h$ query heads keeps its own projection $Q_i$, but all heads broadcast against a single shared $(K, V)$ pair. Per-token KV cache memory therefore shrinks by exactly the number of query heads, which is why PaLM and Falcon adopted MQA before the GQA compromise emerged. The same kernel implements MHA, GQA, and MQA by varying only $n_{kv}$ (Algorithm 3.5.2 above).

In Hugging Face transformers, MQA is just GQA with num_key_value_heads=1, so the same config switch flips a model between the three regimes:

from transformers import AutoConfig, AutoModelForCausalLM

# Falcon-7B ships with MQA: 71 query heads share a single (K, V) pair.
cfg = AutoConfig.from_pretrained("tiiuae/falcon-7b")
print(cfg.num_attention_heads, cfg.num_kv_heads)
# 71 1   -> Multi-Query Attention

# Switch the same architecture to GQA (8 KV heads) or MHA (n_kv == n_q)
# by editing one line of the config before instantiation.
cfg.num_kv_heads = 1            # MQA  (Falcon, PaLM default)
# cfg.num_kv_heads = 8          # GQA  (Llama-2 70B, Mistral)
# cfg.num_kv_heads = cfg.num_attention_heads  # full MHA

model = AutoModelForCausalLM.from_config(cfg)

Code Fragment 3.5.2b: MQA, GQA, and MHA are exposed as a single integer in modern HuggingFace configs. Falcon-7B is the canonical MQA reference: 71 query heads but only one shared $(K, V)$ head, which cuts inference-time KV cache by 71x compared to full multi-head attention at the cost of about one perplexity point.

71 1
Algorithm 3.5.2: MHA, GQA, and MQA as a Single Parameter Family
Algorithm: Generalized grouped attention (parameterized by g)
Input:  X in R^{B x T x d_model}, n_q query heads, n_kv key/value heads with
        n_q mod n_kv == 0 and group size g = n_q / n_kv
Output: Y in R^{B x T x d_model}

  // Projections: queries have full rank, KV is shared in groups
  Q := X @ W_Q                                  // (B, T, n_q,  d_k)
  K := X @ W_K                                  // (B, T, n_kv, d_k)
  V := X @ W_V                                  // (B, T, n_kv, d_k)

  // Broadcast each KV head over its group of g queries
  K_full := repeat_interleave(K, g, axis = 'head')   // (B, T, n_q, d_k)
  V_full := repeat_interleave(V, g, axis = 'head')   // (B, T, n_q, d_k)

  // Per-head scaled-dot-product attention (Algorithm 3.1.1) using K_full, V_full
  Y := concat_heads( SoftmaxAttention(Q_h, K_full_h, V_full_h) for h in 1..n_q ) @ W_O
  Return Y

Specializations:
    n_kv = n_q  (g = 1)        => MHA   (standard multi-head attention)
    n_kv = n_q / g  (1 < g)    => GQA   (LLaMA-2 70B uses n_q=64, n_kv=8, g=8)
    n_kv = 1     (g = n_q)     => MQA   (PaLM, Falcon)

KV cache per token (FP16): 2 * n_kv * d_k * 2 bytes
    MHA  (n_q=32, d_k=128): 16,384 bytes/token   (baseline)
    GQA  (n_kv=8):            4,096 bytes/token   (4x reduction)
    MQA  (n_kv=1):              512 bytes/token   (32x reduction)

Sources: Shazeer, "Fast Transformer Decoding: One Write-Head Is All You Need" (arXiv:1911.02150, 2019) for MQA; Ainslie et al., "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints" (arXiv:2305.13245, 2023) for GQA. The same kernel implements all three by varying a single parameter (n_kv), which is why production inference stacks treat them as one fused op.

from torch import nn
import torch
# Grouped-Query Attention (GQA): multiple query heads share fewer KV heads.
# Reduces KV cache size while retaining most of multi-head attention quality.
class GroupedQueryAttention(nn.Module):
    """GQA: groups of query heads share KV heads."""
    def __init__(self, d_model, n_heads, n_kv_heads):
        super().__init__()
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
        self.n_rep = n_heads // n_kv_heads # how many Q heads per KV head
        self.d_k = d_model // n_heads
        self.W_q = nn.Linear(d_model, n_heads * self.d_k, bias=False)
        self.W_k = nn.Linear(d_model, n_kv_heads * self.d_k, bias=False)
        self.W_v = nn.Linear(d_model, n_kv_heads * self.d_k, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)
        # Forward pass: define computation graph
    def forward(self, x, mask=None):
        B, T, _ = x.shape
        q = self.W_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        k = self.W_k(x).view(B, T, self.n_kv_heads, self.d_k).transpose(1, 2)
        v = self.W_v(x).view(B, T, self.n_kv_heads, self.d_k).transpose(1, 2)
        # Repeat KV heads to match the number of query heads
        # (B, n_kv_heads, T, d_k) -> (B, n_heads, T, d_k)
        k = k.repeat_interleave(self.n_rep, dim=1)
        v = v.repeat_interleave(self.n_rep, dim=1)
        scores = (q @ k.transpose(-2, -1)) / (self.d_k ** 0.5)
        if mask is not None:
            # Mask future positions with -inf so softmax ignores them
            scores = scores.masked_fill(mask == 0, float('-inf'))
            # Convert scores to attention weights (probabilities summing to 1)
            attn = torch.softmax(scores, dim=-1)
            out = (attn @ v).transpose(1, 2).contiguous().view(B, T, -1)
            return self.W_o(out)
Code Fragment 3.5.2c: Repeat KV heads to match the number of query heads.

All the attention variants we have seen so far modify how attention scores are computed or how many heads share parameters. But they all share a common limitation: softmax attention always assigns some weight to every position, even irrelevant ones. The next approach tackles this "attention noise" problem directly.

3.5.3.5 Differential Attention (DIFF Transformer)

A more recent innovation in attention design is Differential Attention, introduced by Ye et al. (2024) in the DIFF Transformer paper. The core idea is that standard softmax attention allocates non-trivial weight to irrelevant context tokens (a phenomenon sometimes called "attention noise"). Even with a clear focus token, the long tail of softmax probabilities still assigns small but non-zero attention to unrelated positions, diluting the signal and contributing to hallucination.

Differential Attention addresses this by computing attention as the difference of two separate softmax attention maps. Each attention head is split into two sub-heads that compute independent attention distributions over the same keys, and the final attention weights are the element-wise subtraction of one from the other. Positions that receive similar attention from both sub-heads cancel out (analogous to noise cancellation), while positions that one sub-head strongly attends to but the other does not are amplified. The result is a sparser, more focused attention pattern that better distinguishes signal from noise.

import torch.nn.functional as F
import math
# Simplified Differential Attention (conceptual)
def diff_attention(Q1, K1, Q2, K2, V):
 """Two sub-attention maps; subtract to cancel noise."""
attn_1 = F.softmax(Q1 @ K1.T / math.sqrt(d_k), dim=-1)
attn_2 = F.softmax(Q2 @ K2.T / math.sqrt(d_k), dim=-1)
# Noise cancellation: positions attended equally by both cancel out
diff_weights = attn_1 - attn_2
return diff_weights @ V
Code Fragment 3.5.3a: Differential attention (diff_attention): two parallel softmax-attention heads whose weighted difference cancels the shared "noise" component, sharpening the signal head.

Empirically, DIFF Transformers show improved performance on long-context tasks, key-value retrieval, and hallucination benchmarks, with particular gains on tasks requiring precise information extraction from noisy contexts. The approach adds minimal overhead (two sets of query/key projections per head instead of one) and is compatible with FlashAttention and GQA. Microsoft has explored integrating Differential Attention into production models, and the technique represents a promising direction for making attention more robust without fundamentally changing the Transformer architecture.

3.5.4 Multi-Head Attention: Why Multiple Heads?

We established the mechanics of multi-head attention in Section 3.1 and implemented it in Section 3.3. Now let us dig into the why: what does having multiple heads actually buy us?

The short answer is that different heads learn to attend to fundamentally different kinds of relationships, simultaneously. A single attention head produces one attention distribution over the sequence for each query position: it is forced to aggregate all of its relational reasoning into a single scalar weight per (query, key) pair. Multi-head attention runs n_heads independent attention computations in parallel, each in a lower-dimensional subspace, and then concatenates the results. This lets the model represent multiple relational modes at once.

Empirical evidence for this comes from interpretability work. Clark et al. (2019) analyzed BERT's attention patterns and found striking specialization across heads:

Four types of attention head behavior: syntactic, local, induction, positional
Figure 3.5.8: Four types of attention head behavior observed in trained models (Clark et al., 2019). These specializations emerge from gradient descent alone; they are not explicitly programmed. Multiple heads allow the model to simultaneously represent syntactic structure, local context, global routing, and coreference, all within a single forward pass.

This specialization is an emergent property: no training objective explicitly requires a head to learn coreference. It arises because tracking coreference is useful for next-token prediction, and the architecture provides enough capacity (enough separate heads) for the model to allocate a head to this purpose without sacrificing others. Single-head attention cannot do this: with only one attention distribution, the model is forced to trade off between all these relational signals at every position.

For a discussion of how Grouped-Query Attention (GQA) reduces the number of key-value heads while retaining most of this multi-head representational power, see Section 9.3.

Fun Fact

Nobody asked the heads to specialize. No loss term says "head 7, you handle coreference; head 12, you do syntax." Yet when researchers crack open a trained BERT, they consistently find a head devoted to pronouns, another to direct objects, another that just looks one token to the left like a paranoid neighbor. It is the closest thing in ML to spontaneous division of labor in an ant colony, and it happens because gradient descent is opportunistic: if a head can specialize and reduce loss by 0.0001, it will, every time, forever.

3.5.5 Pre-Norm vs. Post-Norm Layer Normalization

Layer normalization is present in every Transformer block, but where you apply it changes both training stability and model quality in ways that matter at scale. Two placement strategies have competed since 2018, and a simplified normalization variant has emerged as a third option.

3.5.5.1 Post-Norm (Original Vaswani 2017)

The original "Attention Is All You Need" paper placed LayerNorm after the residual add:

$$\text{PostNorm}(x) = \text{LayerNorm}(x + \text{SubLayer}(x))$$

The intuition is clean: normalize the output of each sublayer before passing it forward. However, this placement creates a training instability problem at scale. During backpropagation, gradients flowing back to early layers must pass through the LayerNorm operations at every block. The norm statistics depend on the current activations and can produce high-variance gradients in early training before the model has settled. This forced practitioners to use very careful learning-rate warmup schedules (often thousands of steps at very low learning rate) to stabilize Post-LN training.

3.5.5.2 Pre-Norm (GPT-2 onward, all modern LLMs)

Pre-Norm applies LayerNorm before the sublayer:

$$\text{PreNorm}(x) = x + \text{SubLayer}(\text{LayerNorm}(x))$$

The key difference: the residual path x + is not passed through LayerNorm. Gradients flowing back through the residual connection reach all earlier layers directly, without being modified by any normalization step. This produces much more stable gradient magnitudes throughout training and essentially eliminates the need for warmup warmup, or at least reduces the sensitivity to its length. Virtually every LLM trained since GPT-2 (Radford et al., 2019) uses Pre-LN. In our Section 3.3 implementation, the two lines x = x + self.attn(self.ln1(x)) and x = x + self.ffn(self.ln2(x)) are both Pre-LN.

Post-Norm vs Pre-Norm: Pre-Norm gives clean gradient path, stable at scale
Figure 3.5.9a: Post-Norm (left) places LayerNorm after the residual add; gradients must pass through the norm on the backward pass, causing instability at scale. Pre-Norm (right) places LayerNorm before the sublayer; the raw residual signal bypasses the norm entirely, giving stable gradients throughout training.

3.5.5.3 RMSNorm (Zhang & Sennrich, 2019)

Standard LayerNorm normalizes by subtracting the mean and dividing by the standard deviation, then scales and shifts:

$$\text{LayerNorm}(x) = \frac{x - \mu}{\sqrt{\sigma^{2} + \epsilon}} \cdot \gamma + \beta, \quad \mu = \frac{1}{d}\sum x_i, \quad \sigma = \sqrt{\frac{1}{d}\sum(x_i - \mu)^2}$$

RMSNorm simplifies this by removing the mean-centering step entirely and using only the root-mean-square for scaling:

$$\text{RMSNorm}(x) = \frac{x}{\text{RMS}(x)} \cdot \gamma, \quad \text{RMS}(x) = \sqrt{\frac{1}{d}\sum x_i^2}$$

The motivation: the mean-centering in standard LayerNorm re-centers activations at zero, but empirically the centering provides little or no quality benefit while adding computation. RMSNorm drops both the mean subtraction and the shift parameter $\beta$, reducing the operation to a single division and element-wise scale. Zhang and Sennrich (2019) showed that RMSNorm matches LayerNorm quality with roughly 7-12% less computation on the normalization step.

Modern LLMs have adopted RMSNorm widely. LLaMA 1 through 3, Mistral, Gemma, and Qwen all use RMSNorm in the Pre-LN position. The combination of Pre-LN placement and RMSNorm is now the de facto standard for frontier decoder-only models.

Tip: Which Norm to Use in New Projects

Unless you have a specific reason to deviate, use Pre-LN with RMSNorm. It is the most stable configuration for training deep Transformers, it requires minimal warmup, and it closely matches the configuration of every major open-source LLM. In PyTorch: torch.nn.RMSNorm(d_model) is available from PyTorch 2.4+ as a built-in module. For older versions, a single-line implementation is: x / x.pow(2).mean(-1, keepdim=True).add(1e-8).sqrt() * weight.

See Also

For how KV cache memory savings flow into production serving, see Section 9.3 (KV cache and GQA in Practice). For how attention variants appear in modern model families like Llama and Mistral, see Section 7.3 (Modern Architectures).

Research Frontier

Attention itself keeps evolving. Differential Attention (Ye et al., 2024) reduces attention noise by subtracting two softmax distributions, improving precise retrieval and hallucination resistance. FlashAttention-3 (2024) exploits Hopper warp specialization and FP8 tensor cores for further speedups on long contexts. Ring Attention (Liu et al., 2023) distributes attention across devices to enable million-token contexts. Positional encoding research continues with NTK-aware RoPE scaling, YaRN, and learned-extension tricks pushing trained models well beyond their original context windows. See Section 3.8 for the parallel frontier in non-attention paradigms (SSMs, MoE, MLA).

Key Takeaways
Self-Check
1. How does FlashAttention achieve speedup without approximating attention?
Show Answer
FlashAttention fuses the attention computation into a single GPU kernel that keeps intermediate results (the attention matrix) in fast on-chip SRAM rather than writing them to slow HBM (GPU global memory). By tiling the computation and using online softmax (computing softmax incrementally), it avoids materializing the full T x T attention matrix in HBM.
2. Why does Pre-Norm produce more stable gradients than Post-Norm?
Show Answer
In Post-Norm, gradients flowing backward through a block must pass through the LayerNorm operation, which modifies their magnitude based on the current activation statistics. Early in training, these statistics are unstable, causing gradient variance to spike. In Pre-Norm, the residual path (the "x +" term) bypasses both the LayerNorm and the sublayer entirely, so gradient magnitude is preserved cleanly back to all earlier layers. This stabilizes training and removes the need for long warmup schedules.

Exercises

Exercise 3.5.1: KV-Cache Memory for GQA Calculation

A 70B-class model has 80 layers, d_model=8192, 64 query heads (so d_k=128), and uses GQA with 8 KV-head groups. Compute the KV cache size at sequence length 32K in FP16, and compare against (a) full multi-head attention (64 KV heads), and (b) MQA (1 KV head).

Answer Sketch

KV cache stores K and V for each layer: 2 (K+V) × seq_len × n_kv_heads × d_k × bytes. With FP16 (2 bytes): per layer = 2 × 32768 × n_kv × 128 × 2. For GQA (n_kv=8): 2 × 32768 × 8 × 128 × 2 = 134.2 MB per layer; × 80 layers = 10.74 GB. (a) MHA (n_kv=64): 8x larger = 85.9 GB. (b) MQA (n_kv=1): 8x smaller than GQA = 1.34 GB. So GQA's "8 groups" choice is a sweet spot: it cuts KV memory by 8x compared to MHA (making 32K context fit on a 24 GB GPU alongside the model) while preserving most of the quality benefit of multiple distinct KV representations. MQA pushes another 8x but the literature (Ainslie et al., 2023) shows it costs 1-2 perplexity points.

Exercise 3.5.2: Pre-LN, Post-LN, and Warmup Conceptual

The original Transformer paper used Post-LN with a 4000-step warmup schedule. Modern Pre-LN models often use 100-1000 step warmup or even none. (a) Walk through gradient flow in Post-LN vs. Pre-LN to explain why warmup is more important for Post-LN. (b) Why might you still use a small warmup with Pre-LN? (c) Could you go even further and remove LayerNorm entirely? What goes wrong?

Answer Sketch

(a) In Post-LN, the residual $x$ is added to SubLayer(x), and the sum is normalized: LN(x + SubLayer(x)). The backward gradient passes through this LN, whose Jacobian depends on the (initially noisy) activation statistics; gradients spike or vanish unpredictably in early training. Warmup gradually ramps the learning rate so these unstable gradients cause smaller weight updates while statistics settle. In Pre-LN, the residual gradient bypasses LN entirely, so gradient magnitudes are stable from step 0. (b) Even Pre-LN benefits from a short warmup (say, 100 steps) because Adam's running estimates of gradient first/second moments are uninitialized and produce overly large updates initially. The warmup gives Adam time to build accurate moment estimates. (c) Removing LayerNorm entirely: the residual stream's variance grows linearly with depth (each block adds independent contributions), causing later layers to operate on inputs with very large magnitude that saturate softmax and ReLU/SiLU. Training becomes unstable past ~8-10 layers. Recent variants like NormFormer and weight standardization aim to reduce reliance on LayerNorm but still need some normalization to scale.

What's Next?

In the next section, Section 3.6: GPU Fundamentals & Systems, we examine GPU fundamentals and systems-level concepts that determine real-world training and inference performance.

Further Reading

Attention Interpretability

Clark, K., Khandelwal, U., Levy, O., & Manning, C. D. (2019). "What Does BERT Look At? An Analysis of BERT's Attention." BlackboxNLP, ACL 2019. A systematic empirical study of BERT's 144 attention heads (12 layers, 12 heads each). Finds that certain heads specialize in syntactic roles (direct objects, coreferent mentions), local context, and special token routing. The head visualization methodology in this paper inspired the diagrams in this section. Highly readable and a good introduction to mechanistic interpretability.

Normalization

Zhang, B. & Sennrich, R. (2019). "Root Mean Square Layer Normalization." NeurIPS 2019. Shows that mean-centering in standard LayerNorm contributes little to quality while adding computation. RMSNorm, which normalizes by root mean square only, matches LayerNorm performance with roughly 10% lower cost. Now standard in LLaMA, Mistral, Gemma, and most frontier decoder-only models. The paper is concise and the derivation is accessible to anyone familiar with LayerNorm.

Efficient Attention

Dao, T. et al. (2022). "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness." NeurIPS 2022. The paper that made long-context Transformers practical by fusing attention computation into a single GPU kernel, eliminating the O(T^2) memory bottleneck. A must-read for understanding how hardware-aware algorithms changed the game for attention efficiency.
Ainslie, J. et al. (2023). "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints." EMNLP 2023. Introduces Grouped-Query Attention, the KV cache reduction technique used in LLaMA 2, Mistral, and most modern LLMs. Useful for understanding how production models balance quality and inference speed.