GPU Fundamentals & Systems

Section 3.6

My GPU utilization hit 100% and for one beautiful moment, memory bandwidth and compute were perfectly balanced. Then I loaded the next batch.

NormNorm, Bandwidth-Bottlenecked AI Agent
Big Picture

Modern GPUs are fundamentally throughput machines designed to keep thousands of cores busy simultaneously. The main challenge is not compute; it is feeding data to those cores fast enough. Most Transformer operations are memory-bound, meaning they spend more time moving data than computing with it. The PyTorch tensor foundations from Section 0.3 map directly onto these GPU primitives.

3.6.1 Why GPU Architecture Matters

Key Insight
Why: FlashAttention is faster despite computing the same math

FlashAttention's speedup comes from a tiling reorder, not from any algebraic trick: standard attention reads the N × N score matrix to and from HBM twice (once to compute softmax, once to apply it to V). Dao et al. (2022) showed you can fuse softmax and the V-multiply into a single pass over tiles small enough to fit in SRAM, using an online-softmax recurrence (Milakov and Gimelshein, 2018) that updates the running max and sum incrementally. The wall-clock win is purely the HBM-vs-SRAM bandwidth ratio (typically 10-20x), and the math is bit-for-bit identical. This is a microcosm of modern GPU optimization: most "kernel improvements" are memory-hierarchy reorderings that the autograd framework cannot do for you.

Algorithm 3.6.1: Algorithm: Under the Hood: Why FlashAttention Doesn't Materialize N×N

Standard attention computes Q Kᵀ first (an N×N matrix), softmaxes the rows, then multiplies by V. The intermediate N×N matrix dominates memory at long contexts (1M tokens × 1M tokens of FP16 = 2 TB, impossible). FlashAttention never materializes the full attention matrix. It tiles Q, K, V into blocks that fit in GPU SRAM (~100 KB), computes one block at a time, and maintains a running "online softmax" using the log-sum-exp trick:

$$m_{\text{new}} = \max(m, \text{row\_max}(S)), \quad \ell_{\text{new}} = e^{m - m_{\text{new}}} \ell + \sum e^{S - m_{\text{new}}}$$

where m is the running maximum (for numerical stability) and ℓ is the running denominator. After processing all K/V blocks, the output is rescaled by 1/ℓ. Peak memory: O(N × block_size), not O(N²). Verification: the output matches standard attention up to numerical noise (< 10⁻⁵ max diff).

Here is a number that should surprise you: a modern GPU can perform 1,000 trillion floating-point operations per second, but it can only read about 3 trillion bytes from its main memory in that same second. The math does not lie: for many operations, the GPU finishes computing before the data even arrives. A Transformer's theoretical FLOP count tells only half the story. Two implementations with identical FLOP counts can differ by 10x in wall-clock time depending on how well they utilize the GPU's memory hierarchy. Understanding GPU architecture is not optional for anyone who wants to train or serve LLMs efficiently. This section provides the mental model you need to reason about performance bottlenecks and understand optimizations like FlashAttention.

3.6.2 GPU Architecture Overview

The 300x compute-to-memory-bandwidth gap on a modern GPU is not a bug. It is the result of decades of architectural choices that prioritize throughput over latency. To reason about which Transformer operations will be bottlenecked and which will fly, we first need a vocabulary for the hardware: how the chip is organized into Streaming Multiprocessors, what memory tiers each SM can reach, and how thread scheduling hides latency.

3.6.2.1 Streaming Multiprocessors (SMs)

A GPU is organized as an array of Streaming Multiprocessors (SMs). Each SM contains:

3.6.2.2 Memory Hierarchy

GPU memory hierarchy from registers (fastest) to HBM (largest)
Figure 3.6.1: GPU memory hierarchy. Each level is faster but smaller than the one below. The key optimization challenge is keeping data in the fast upper levels.
An aerial view of a city representing GPU memory hierarchy, with fast SRAM neighborhoods near the center and slower HBM districts on the outskirts, connected by data highways of varying widths
Figure 3.6.2: GPU memory as a city: the fastest SRAM is a small neighborhood at the center, surrounded by larger but slower HBM districts. Data must travel along bandwidth-limited highways between them, making memory access patterns critical for performance.

Prerequisites

This section assumes familiarity with the Transformer architecture from Section 3.1, especially the matrix operations involved in attention computation. The PyTorch foundations from Section 0.3 (tensors, device placement) are helpful. The hardware concepts here connect directly to the inference-optimization and distributed-training themes covered later in the book.

The bandwidth gap between shared memory (~19 TB/s) and HBM (~3.35 TB/s on H100) is roughly 6x. The gap between registers and HBM is even larger. This is why memory access patterns dominate GPU performance. An operation that reads the same data from HBM three times (like naive attention) can be 3x slower than one that reads it once and keeps it in shared memory (FlashAttention).

3.6.3 Compute-Bound vs. Memory-Bound Operations

With the memory hierarchy in hand, we can now classify every Transformer operation along a single axis: does it limit out on math or on data movement? The roofline model gives us a clean visual answer, plotting peak performance against arithmetic intensity so that any kernel can be placed on the chart and immediately diagnosed.

3.6.3.1 The Roofline Model

The roofline model characterizes each operation by its arithmetic intensity: the ratio of FLOPs to bytes transferred from memory. If an operation does many FLOPs per byte loaded, it is compute-bound (limited by the number of arithmetic units). If it does few FLOPs per byte, it is memory-bound (limited by memory bandwidth).

$$\text{Arithmetic Intensity} = \text{FLOPs} / \text{Bytes transferred}$$
Table 3.6.1a: Operation Comparison (as of 2026).
OperationFLOPsMemoryIntensityBound
Matrix multiply (large) 2MNK 2(MK + KN + MN) bytes High (scales with dims) Compute
LayerNorm ~5T per element Read + write all elements ~2.5 Memory
Softmax ~5T per element Read + write all elements ~2.5 Memory
Dropout ~1 per element Read + write all elements ~0.5 Memory
Element-wise add (residual) 1 per element Read 2 + write 1 ~0.33 Memory
Attention ($QK^{T}$ softmax V) O($T^{2}$d) O($T^{2}$) for attention matrix Depends on T, d Usually memory
Key Insight: Most Operations Are Memory-Bound

In a typical Transformer forward pass, the large matrix multiplications (QKV projections, FFN layers, output projection) are compute-bound and keep the GPU busy. But everything else (LayerNorm, softmax, dropout, residual adds, attention score computation for moderate sequence lengths) is memory-bound. This is why kernel fusion (combining multiple memory-bound operations into a single kernel) is so effective.

3.6.4 The FlashAttention Algorithm

Fun Fact

FlashAttention computes the exact same result as naive attention but 2 to 4 times faster, simply by being smarter about memory access patterns. It is the computational equivalent of realizing you can grocery shop faster by organizing your list by aisle rather than alphabetically.

FlashAttention is the single most important GPU optimization for Transformers. It computes exact standard attention while reducing HBM reads/writes from O($T^{2}$) to O($T^{2}$d/M), where M is the size of on-chip SRAM. For typical values, this is a 5 to 10x reduction in memory traffic.

Warning: Common Misconception

FlashAttention is often mistaken for an approximation (like Linear, Performer, or Sparse Attention) that trades quality for speed. It is not. It performs the same FLOPs as standard scaled dot-product attention and produces bitwise-comparable outputs (within FP16 noise of a non-tiled reference). The speedup is entirely from reordering memory access; the algebra is identical. This matters for production: swapping in FlashAttention is loss-neutral, but swapping in a sparse approximation is not.

3.6.4.1 The Problem with Naive Attention

The standard attention implementation performs these steps, each reading from and writing to HBM:

  1. Compute S = QKT / √d, write S to HBM. Size: O(T2).
  2. Read S from HBM, apply mask, compute P = softmax(S), write P to HBM. Size: O(T2).
  3. Apply dropout to P, write back to HBM.
  4. Read P from HBM, compute O = PV, write O to HBM.

The T × T matrices S and P are the bottleneck. For T=4096, d=128, each matrix is 64 MB in FP32 per head per batch element. With 32 heads and batch size 4, that is 8 GB just for the intermediate attention matrices.

3.6.4.2 The Tiling Approach

FlashAttention processes the attention computation in tiles. Instead of computing the full T × T attention matrix at once, it processes blocks of size $B_{r}$ × $B_{c}$ that fit in SRAM. The challenge is computing the softmax correctly when you only see part of each row at a time.

Online Softmax

The critical algorithmic trick is online softmax: computing the softmax incrementally as new blocks arrive. For a row of attention scores being computed in blocks, the algorithm maintains running statistics (the current maximum and the running sum of exponentials) and rescales previous partial results as new maxima are discovered:

# Pseudocode: Online softmax for FlashAttention
# Processing one row of the attention matrix in blocks
max_so_far = -infinity
sum_exp = 0
output_accumulator = zeros(d_v)
for each block j of keys/values:
# Compute attention scores for this block
scores_j = query @ keys_block_j.T / sqrt(d_k)
# Update running max
new_max = max(max_so_far, max(scores_j))
# Rescale previous accumulator with correction factor
correction = exp(max_so_far - new_max)
sum_exp = sum_exp * correction + sum(exp(scores_j - new_max))
output_accumulator = output_accumulator * correction + exp(scores_j - new_max) @ values_block_j
max_so_far = new_max
# Final normalization
output = output_accumulator / sum_exp
Code Fragment 3.6.1b: The critical algorithmic trick is online softmax: computing the softmax incrementally as new blocks arrive.
Input: Q, K, V matrices in HBM; tile sizes Br, Bc fitting in SRAM
Output: O = softmax(QKT / dk) V, written to HBM
// Partition Q into T/Br row blocks, K and V into T/Bc column blocks
for each Q block Qi (rows i*Br to (i+1)*Br):
    Load Qi from HBM to SRAM
    Initialize: Oi = 0, li = 0, mi = −∞
    // l = running sum of exponentials, m = running row-wise max
    for each K,V block (Kj, Vj):
        Load Kj, Vj from HBM to SRAM
        Sij = Qi KjT / dk // computed in SRAM
        mnew = max(mi, rowmax(Sij))
        Pij = exp(Sij  mnew)
        // Rescale previous accumulator for updated max
        α = exp(mi  mnew)
        li = α  li + rowsum(Pij)
        Oi = α  Oi + Pij Vj
        mi = mnew
        Oi = Oi / li // final normalization
        Write Oi to HBM
        return O
Code Fragment 3.6.2a: Pseudocode: Online softmax for FlashAttention.

This is numerically equivalent to computing the full softmax but requires only O(B) SRAM at any point, rather than the full O(T) row. The correction factor ensures that as we discover larger values, all previous exponentials are rescaled consistently.

# Triton fused softmax kernel: compute softmax in a single GPU pass
# without materializing the full attention matrix in HBM.
import torch
@triton.jit

def softmax_kernel(
output_ptr, input_ptr,
input_row_stride, output_row_stride,
n_cols,
BLOCK_SIZE: tl.constexpr,
):
 """Fused softmax: one kernel, one pass through HBM per row."""
# Each program handles one row
row_idx = tl.program_id(0)
row_start_ptr = input_ptr + row_idx * input_row_stride
# Load the row into SRAM
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
row = tl.load(row_start_ptr + col_offsets, mask=mask, other=float('-inf'))
# Compute softmax entirely in SRAM
# Step 1: Subtract max for numerical stability
row_max = tl.max(row, axis=0)
row = row - row_max
# Step 2: Exponentiate
numerator = tl.exp(row)
# Step 3: Normalize
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
# Write back to HBM
output_row_start_ptr = output_ptr + row_idx * output_row_stride
tl.store(output_row_start_ptr + col_offsets, softmax_output, mask=mask)

def fused_softmax(x):
 """Apply softmax to each row of x using a fused Triton kernel."""
n_rows, n_cols = x.shape
# BLOCK_SIZE must be a power of 2 >= n_cols
BLOCK_SIZE = triton.next_power_of_2(n_cols)
output = torch.empty_like(x)
softmax_kernel[(n_rows,)](
output, x,
x.stride(0), output.stride(0),
n_cols,
BLOCK_SIZE=BLOCK_SIZE,
)
return output
Output: Input shape: (1024, 768), dtype: torch.float32 Fused softmax output shape: (1024, 768) Row sums (should be ≈ 1.0): tensor([1.0000, 1.0000, 1.0000, ..., 1.0000]) Max abs diff vs. torch.softmax: 4.77e-07
Code Fragment 3.6.3: Triton fused softmax kernel that performs softmax in a single GPU pass without materializing the full attention matrix in HBM. The fused_softmax wrapper launches one program per row at a power-of-two block size.

The key insight: Q, K, V are each read from HBM once per outer loop iteration, and the T × T attention matrix is never materialized in HBM. Total HBM access is O($T^{2}$d2/M) instead of O($T^{2}$ + Td), where M is SRAM size. This is where the 2 to 4x speedup comes from.

FlashAttention tiles computation into SRAM blocks, never materializing full attention matrix
Figure 3.6.3a: FlashAttention tiles the attention computation into blocks that fit in on-chip SRAM. The full T x T attention matrix is never materialized in HBM.

3.6.5 Introduction to Triton

Writing GPU kernels in CUDA requires managing threads, warps, shared memory, synchronization, and memory coalescing at a low level. Triton (developed at OpenAI) provides a higher-level abstraction: you write kernels in a Python-like language that operates on blocks of data rather than individual threads. Triton handles the complex details of thread mapping, shared memory management, and memory coalescing automatically.

3.6.5.1 A Simple Example: Vector Addition

This Triton kernel performs element-wise vector addition, illustrating the block-based programming model.

import torch
# implement add_kernel, triton_add
# See inline comments for step-by-step details.
import triton
import triton.language as tl
@triton.jit
def add_kernel(
    x_ptr, y_ptr, output_ptr,
    n_elements,
    BLOCK_SIZE: tl.constexpr,
    ):
    # Each program instance handles BLOCK_SIZE elements
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    # Mask to handle the case where n_elements is not a multiple of BLOCK_SIZE
    mask = offsets < n_elements
    # Load, compute, store
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    output = x + y
    tl.store(output_ptr + offsets, output, mask=mask)
    def triton_add(x, y):
        output = torch.empty_like(x)
        n_elements = x.numel()
        BLOCK_SIZE = 1024
        # Launch kernel with enough program instances
        grid = (triton.cdiv(n_elements, BLOCK_SIZE),)
        add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE)
        return output
Code Fragment 3.6.4: Triton element-wise vector addition kernel: each program instance handles BLOCK_SIZE elements with no manual thread management.

Notice: no thread management, no shared memory allocation, no warp-level operations. Triton programs are specified in terms of program instances that each process a block of elements. The compiler handles the mapping to GPU hardware.

3.6.5.2 Lab: Fused Softmax in Triton

Lab: Fused Softmax in Triton

Objective

Implementing softmax as a fused kernel eliminates the need for multiple HBM round-trips. In a naive implementation, computing softmax requires three passes over the data: one to find the max, one to compute exponentials and their sum, and one to normalize. A fused kernel does all three in a single pass through the data in SRAM.

Setup

You need a CUDA-capable GPU with the Triton toolchain available (Triton ships with recent PyTorch builds). The reference implementation in Code Fragment 3.6.3 above is the starting point; clone it into your environment and confirm the baseline torch.softmax runs against tensors of the Transformer shapes you care about (batch x heads x seqlen).

Steps

This fused kernel reads each row from HBM exactly once, performs the entire softmax in SRAM, and writes the result back once. The naive PyTorch implementation (torch.softmax) may require multiple reads and writes for the intermediate max and sum computations. The implementation walk-through is in Code Fragment 3.6.3 above; reproduce it end-to-end, benchmark against torch.softmax, and report the speedup on the GPU available to you.

Expected Output

On an A100, a fused softmax kernel can be 2 to 4x faster than the PyTorch default for typical Transformer shapes. The numerical output should match torch.softmax within standard FP16 or BF16 tolerance; report both the wall-clock speedup and the maximum absolute difference against the PyTorch reference.

3.6.6 Key GPU Metrics for LLM Practitioners

Table 3.6.2b: Key GPU Metrics for LLM Practitioners Comparison (as of 2026).
GPUHBM CapacityHBM BandwidthFP16 TFLOPsTF32 TFLOPs
A100 (80GB)80 GB2.0 TB/s312156
H100 SXM80 GB3.35 TB/s990495
H200141 GB4.8 TB/s990495
B200192 GB8.0 TB/s22501125
Note: Model Flops Utilization (MFU)

MFU measures what fraction of the GPU's theoretical peak FLOPS your training run actually achieves. Good Transformer training typically reaches 40% to 60% MFU. Reaching above 60% is excellent. Values below 30% usually indicate a bottleneck (memory-bound operations, communication overhead, or poor batch size selection). The Chinchilla paper reported ~57% MFU for their training runs.

3.6.7 Practical Considerations

The theory is now sufficient to predict performance on paper. Turning that prediction into a real training run that hits 50%+ MFU requires three workflow-level choices: which precision format to compute in, how to budget device memory across parameters and activations, and how to debug when reality drifts from the roofline. We address each in turn.

3.6.7.1 Mixed Precision Training

Modern LLM training uses mixed precision: forward and backward passes use FP16 or BF16, while the master weights and optimizer states are kept in FP32. BF16 is preferred because it has the same exponent range as FP32 (avoiding overflow) but lower precision (8 vs. 23 mantissa bits). BF16 Tensor Core operations are natively supported on A100 and later GPUs.

3.6.7.2 Memory Budgeting

Training a Transformer model requires storing: (1) model parameters, (2) optimizer states (Adam stores momentum and variance, tripling the parameter memory), (3) gradients, and (4) activations (for the backward pass). The dominant cost for large models is often optimizer states. For a 7B parameter model in mixed precision:

This is why training a 7B model requires multiple GPUs even when the model parameters alone would fit on a single 80 GB GPU. Techniques like ZeRO (distributed optimizer states), gradient checkpointing (recomputing activations instead of storing them), and offloading help manage this memory pressure. We cover distributed training strategies in Section 6.6 and inference-time memory optimization (including quantization) in Chapter 9.

Real-World Scenario
Profiling a Transformer Inference Pipeline

A team running Llama-2 7B inference on an A100 GPU noticed throughput was lower than expected. Using torch.profiler, they found that 68% of GPU time was spent on memory-bound operations (layer normalization, residual additions, and softmax), not the matrix multiplications. By switching to FlashAttention and fusing the layer norm with the subsequent linear projection using a custom Triton kernel, they reduced per-token latency from 12 ms to 7 ms, a 42% improvement, without changing the model itself. The key lesson: profiling the roofline characteristics of each operation, rather than optimizing blindly, directs engineering effort to where it matters most.

Tip: Match Layer Norm Placement to Your Reference

Pre-norm (LayerNorm before attention) and post-norm (after attention) give different training dynamics. Most modern LLMs use pre-norm for stability. If you are reproducing a paper, check which variant they used; swapping them silently degrades performance.

Key Insight
The Memory Wall as a Fundamental Physical Constraint

The dominance of memory bandwidth over compute capacity in LLM inference is not an engineering oversight; it reflects a fundamental physical constraint. The speed of light limits how fast data can travel between memory chips and processing units, and the energy required to move a bit of data across a chip is now 10 to 100 times greater than the energy to perform a floating-point operation on it. This "memory wall" was predicted by Wulf and McKee in 1995, and it has only widened since. FlashAttention's brilliance lies in recognizing that the bottleneck is data movement, not computation, and restructuring the algorithm to minimize it. This principle generalizes far beyond attention: across all of computing, from databases to scientific simulations to neural networks, the dominant cost is increasingly data movement rather than arithmetic. Understanding this physical reality explains why algorithmic improvements (like FlashAttention's tiling strategy) can yield 2 to 4x speedups without changing the mathematical result, and why hardware innovations like high-bandwidth memory, near-memory computing, and in-memory processing are central to the future of AI acceleration.

Research Frontier

Hardware-software co-design is accelerating. NVIDIA's Blackwell (B200/GB200) GPUs introduce a second-generation Transformer Engine with FP4 support, doubling effective throughput for inference. Google's TPU v5p and AMD's MI300X offer competitive alternatives. On the software side, Triton (OpenAI, 2024) enables researchers to write custom GPU kernels in Python, dramatically lowering the barrier to hardware optimization. ThunderKittens (Stanford, 2024) provides even higher-level abstractions for attention kernels. Disaggregated inference architectures (separating prefill and decode across different GPU pools) are emerging as a key pattern for cost-efficient LLM serving, as explored in Section 9.1.

Key Takeaways
Self-Check
1. Why is standard attention considered memory-bound rather than compute-bound?
Show Answer
The T x T attention matrix must be written to and read from HBM multiple times (for the QK^T product, masking, softmax, dropout, and the final multiplication with V). The total bytes transferred (O(T^2) for the attention matrix) exceed what the memory bandwidth can deliver before the compute units would finish their work. FlashAttention addresses this by keeping the attention matrix in fast on-chip SRAM.
2. What is the "online softmax" trick and why is it needed for FlashAttention?
Show Answer
Online softmax computes the softmax incrementally as new blocks of the attention row arrive. It maintains running statistics (current max, running sum of exponentials) and rescales previous partial results when a new maximum is found. This is needed because FlashAttention processes the attention matrix in tiles; it never has the complete row available at once, so it cannot compute the global max and sum in advance.
3. Why does mixed precision training use BF16 rather than FP16?
Show Answer
BF16 has the same exponent range (8 bits) as FP32, which prevents overflow and underflow during training. FP16 has only 5 exponent bits, giving a much smaller dynamic range that can cause training instabilities (loss scaling is required). BF16 trades mantissa precision (8 bits vs. FP16's 10 bits) for this exponent range, which is a good tradeoff since gradient and activation values need wide range more than high precision.
4. A GPU has 3.35 TB/s HBM bandwidth and 990 TFLOPS of FP16 compute. What is the arithmetic intensity threshold for an operation to be compute-bound?
Show Answer
The threshold is 990 TFLOPS / 3.35 TB/s = ~295 FLOPs per byte. Operations with arithmetic intensity above 295 FLOPs/byte are compute-bound; below are memory-bound. For reference, large matrix multiplies can easily reach 1000+ FLOPs/byte, while element-wise operations have ~0.5 to 5 FLOPs/byte.
5. Why does training a 7B parameter model require much more than 14 GB of GPU memory?
Show Answer
Beyond the 14 GB for parameters (in BF16), training also requires: optimizer states (AdamW stores two FP32 states per parameter, adding ~56 GB), gradients (~14 GB in BF16), and activations (variable, can be many GB depending on batch size and sequence length). The total minimum is roughly 84 GB before activations, which is why multi-GPU training and memory optimization techniques like ZeRO and gradient checkpointing are essential.

Exercises

Exercise 4.4.1: Arithmetic Intensity Classification Calculation

For each of the following, compute the arithmetic intensity in FLOPs/byte (assume FP16, 2 bytes) and classify as compute-bound or memory-bound on a GPU with a 295 FLOPs/byte threshold. (a) Multiplying two 4096x4096 matrices to produce a 4096x4096 output. (b) Element-wise softmax over a 4096-element row. (c) LayerNorm over a vector of length 4096.

Answer Sketch

(a) FLOPs = 2 × 4096³ ≈ 1.37 × 10¹¹. Bytes = 2 × (3 × 4096²) = 100.7 MB (read A, read B, write C). Intensity = 1366 FLOPs/byte. Compute-bound (well above 295). (b) Softmax FLOPs ≈ 5n = 20480 (max, subtract, exp, sum, divide). Bytes = 2 × 2 × 4096 = 16384 (read row, write row). Intensity = 1.25 FLOPs/byte. Strongly memory-bound. (c) LayerNorm FLOPs ≈ 5n = 20480 (mean, subtract, square, mean, divide, scale, shift). Bytes ≈ 16384. Intensity ≈ 1.25 FLOPs/byte. Memory-bound. Conclusion: matrix multiplication is the only one that saturates compute; everything else is bandwidth-limited, which is exactly why kernel fusion (FlashAttention, fused LayerNorm) is so impactful.

Exercise 4.4.2: FlashAttention Memory Reduction Predictive

Compute the HBM I/O for standard attention vs. FlashAttention at T=8192, d=128, 32 heads, batch=4 (FP16). (a) Standard attention reads/writes the attention matrix multiple times; estimate total HBM bytes moved through it. (b) FlashAttention reduces HBM I/O by what factor approximately? (c) Why does this translate to a wall-clock speedup of "only" 2-4x rather than the larger ratio of byte reduction?

Answer Sketch

(a) The T x T attention matrix per head is 8192² × 2 bytes = 128 MB. Standard attention reads/writes it ~4 times (compute scores, mask, softmax, multiply by V), so 4 × 128 = 512 MB per head per batch. Total: 512 × 32 × 4 = 64 GB of HBM traffic just for the attention matrix. (b) FlashAttention never materializes the full matrix in HBM, only keeps a single Q and a single K,V block in SRAM, reducing HBM bytes to O(T × d × heads × batch) = roughly 8192 × 128 × 32 × 4 × 2 = 256 MB, a 256x reduction in HBM traffic for the attention matrix. (c) Wall-clock speedup is "only" 2-4x because (i) FlashAttention is the same total FLOPs (the QV matmul still dominates), and the FLOPs themselves take real time; (ii) other parts of the layer (FFN, projections) are unaffected; (iii) FlashAttention's SRAM-aware kernel has its own overhead (synchronization between blocks). The lesson: HBM I/O reduction is only useful when HBM I/O is the bottleneck, which it is for attention but not for everything.

Exercise 4.4.3: Online Softmax Numerical Sketch Calculation

Walk through the online softmax algorithm with two blocks of attention scores: block A = [1, 3, 2], block B = [5, 0]. Compute (a) the running max, sum_exp, and rescaling factor after each block, and (b) verify the final normalized probabilities equal the direct softmax over [1, 3, 2, 5, 0].

Answer Sketch

After block A: max_so_far = 3, sum_exp = e^(1-3) + e^(3-3) + e^(2-3) = 0.135 + 1 + 0.368 = 1.503. Partial output accumulator = sum of (e^(score-3) × value); we just track sum_exp here. Processing block B: new_max = max(3, 5) = 5. Correction = e^(3-5) = e^(-2) = 0.135. Rescale: sum_exp = 1.503 × 0.135 + e^(5-5) + e^(0-5) = 0.203 + 1 + 0.0067 = 1.210. (b) Direct softmax: numerators e^(x-5) = [0.0183, 0.135, 0.0498, 1, 0.0067]; sum = 1.210 (matches!). Probabilities: divide each by 1.210, giving [0.0151, 0.112, 0.0411, 0.826, 0.0055]. The trick of always subtracting the running max keeps all exponentials between 0 and 1 (numerically safe), and the correction factor e^(old_max - new_max) compensates exactly for the regrouping.

Exercise 4.4.4: Memory Budget for 7B Training Calculation

You want to train a 7B-parameter model on a single A100 80GB GPU. Compute the memory for (a) BF16 model weights, (b) FP32 AdamW optimizer states, (c) BF16 gradients, and (d) sum. (e) What is the budget left for activations and how would you proceed if it's insufficient?

Answer Sketch

(a) 7B × 2 bytes = 14 GB. (b) AdamW stores 2 FP32 moments per parameter, so 7B × 2 × 4 = 56 GB. (Some implementations also keep a FP32 copy of weights: another 28 GB.) (c) 7B × 2 bytes = 14 GB. (d) Total without FP32 master weights: 84 GB. With FP32 master copy: 112 GB. The A100's 80 GB cannot hold these. (e) Budget left is roughly 80 - 84 = -4 GB (negative, infeasible). Mitigations: (i) ZeRO Stage 2/3 shards optimizer states and gradients across multiple GPUs, dropping per-GPU memory by 4-8x at the cost of inter-GPU communication; (ii) Gradient checkpointing trades 30% extra compute for 5-10x less activation memory; (iii) Use 8-bit AdamW which stores moments in INT8 (saves 28 GB at minor convergence cost); (iv) Use LoRA or QLoRA fine-tuning, which freezes the base model and only trains a few percent of parameters. In practice, full pretraining of 7B requires at least 8x A100s with FSDP/ZeRO.

Exercise 4.4.5: When FlashAttention Does NOT Help Analysis

FlashAttention is "free speed" for most workloads, but there are settings where it provides little benefit. Describe two scenarios where FlashAttention's speedup over standard attention is < 20%, and explain why.

Answer Sketch

Scenario 1: Very short sequences (T < 512). At T=128, the attention matrix is 128² = 16K entries per head, only 32 KB in FP16. This fits comfortably in L2 cache without needing FlashAttention's tiling; HBM is barely involved. Standard attention is already fast, and FlashAttention's kernel-launch overhead may even make it slightly slower. Empirically the breakeven is around T = 1024-2048. Scenario 2: KV cache decoding with batch size 1 (autoregressive generation). At each step, we process only a single query token attending over T cached key positions. The Q matrix is (1, d), the K matrix is (T, d). Total HBM bytes are O(T × d), the same with or without FlashAttention (no T² matrix is ever needed). The bottleneck is loading the KV cache from HBM, not the attention computation. FlashAttention's tiling cannot help. Newer "FlashAttention-Decoder" or "PagedAttention" variants address this with different techniques (paged KV cache, speculative decoding).

Exercise 4.4.6: BF16 vs. FP16 Failure Mode Predictive

You inherit a training codebase that uses FP16 with loss scaling. You try to train a new architecture and observe NaN losses after ~500 steps. Diagnose this in terms of FP16's specific limitation, predict whether switching to BF16 would fix it, and describe what additional change (besides switching dtypes) you might still need to make.

Answer Sketch

Diagnosis: FP16 has a 5-bit exponent (max value ~65,504; min normal value ~6 × 10⁻⁵). If any forward activation or gradient exceeds 65,504, it overflows to infinity; if it underflows below ~10⁻⁵, it rounds to zero. Loss scaling multiplies the loss by a large constant (e.g., 2¹⁵) so gradients stay in the representable range, then divides back. NaN appears when gradients in some layer still overflow (typical with deeper or wider models). Switching to BF16: BF16 has the same 8-bit exponent as FP32 (range ~10⁻³⁸ to 10³⁸), so overflow/underflow vanishes essentially regardless of model size; training is stable without loss scaling. Additional change needed: BF16 has only 7 mantissa bits vs. FP16's 10, which can cause small accuracy losses in (i) the LayerNorm/RMSNorm computation when summing large numbers of small contributions, and (ii) the final loss aggregation. Standard fix: keep LayerNorm and the final loss in FP32 (mixed precision), which most frameworks do automatically. So the typical migration is: remove loss scaling, switch all matmuls to BF16, keep norms and loss in FP32.

What's Next?

In the next section, Section 3.7: Transformer Expressiveness Theory, we explore the theoretical expressiveness of Transformers, understanding what these models can and cannot compute.

Further Reading

Foundational Papers

Dao, T., Fu, D. Y., Ermon, S., Rudra, A., & Re, C. (2022). "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness." NeurIPS 2022. arXiv:2205.14135. The original FlashAttention paper; the canonical reference for IO-aware attention and the tiling trick described in this section.
Dao, T. (2023). "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning." arXiv:2307.08691. The second-generation FlashAttention with improved work partitioning across warps; the version shipped in PyTorch SDPA.
Shah, J., Bikshandi, G., Zhang, Y., Thakkar, V., Ramani, P., & Dao, T. (2024). "FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-Precision." arXiv:2407.08608. The H100-targeted third generation that exploits TMA and FP8; the current state of the art for production attention kernels.
Milakov, M., & Gimelshein, N. (2018). "Online Normalizer Calculation for Softmax." arXiv:1805.02867. The online softmax recurrence that FlashAttention builds on; essential reading for understanding the streaming softmax inside the kernel.

Kernels and Compilers

Tillet, P., Kung, H.T., Cox, D. (2019). "Triton: An Intermediate Language and Compiler for Tiled Neural Network Computations." MLSys 2019. The Triton compiler paper from OpenAI; the foundation for the Triton kernels written in this section.
Ivanov, A., Dryden, N., Ben-Nun, T., Li, S., Hoefler, T. (2021). "Data Movement Is All You Need: A Case Study on Optimizing Transformers." MLSys 2021. Quantifies how much Transformer wall-clock time is spent moving data rather than computing; the empirical motivation for IO-aware kernels.