GPU Kernel Programming for LLM Optimization

Section 9.9

The fastest code is the code the GPU never has to wait for.

QuantQuant, Kernel Obsessed AI Agent
Big Picture

The performance of LLM inference is ultimately determined by how efficiently we use GPU hardware. High-level frameworks like PyTorch compose operations from a library of pre-written CUDA kernels, but this composition introduces overhead: each kernel launch reads from and writes to global memory, even when the output of one operation feeds directly into the next. Custom kernels eliminate this overhead by fusing multiple operations into a single GPU program. FlashAttention, the single most impactful LLM optimization of recent years, is fundamentally a custom kernel that tiles the attention computation to fit in fast on-chip SRAM instead of slow off-chip HBM. This section teaches you how to write, profile, and reason about custom GPU kernels using Triton, OpenAI's Python-based kernel programming language.

Prerequisites

This section assumes familiarity with the transformer architecture from Section 3.3, the attention mechanism in particular. The quantization and memory optimization techniques from Section 9.1 and Section 9.3 provide important context for understanding why custom kernels are needed. Basic familiarity with GPU hardware concepts (threads, blocks, shared memory) is helpful but not required.

9.9.1 Why Custom Kernels Matter

A standard PyTorch attention implementation performs four separate GPU operations: matrix multiplication (Q @ K^T), scaling, softmax, and another matrix multiplication (scores @ V). Each operation launches a separate kernel, and each kernel reads its inputs from HBM (High Bandwidth Memory, the GPU's main memory) and writes its outputs back to HBM. For a sequence length of 8,192 and a hidden dimension of 4,096, the attention scores matrix alone is 8192 x 8192 x 2 bytes = 128 MB. This matrix is written to HBM by the first kernel and read back by the softmax kernel, even though it is a temporary intermediate that is never needed again.

Fun Fact: Kernels Are the Engine Room

On a cruise ship, passengers see the dining room and the deck; almost no one visits the engine room where the heat and the noise live. GPU kernels are the engine room of LLM inference: ugly C++ and PTX that pushes data through warps and shared memory while the Python layer above sips coffee on deck. A kernel that is even ten percent faster shows up as a measurable change in your inference bill.

The key insight is that modern GPUs are memory-bandwidth-bound for most LLM operations, not compute-bound. An NVIDIA A100 can perform 312 TFLOPS of FP16 computation but can only move data at 2 TB/s. The ratio of compute to memory bandwidth (the arithmetic intensity threshold) is about 156 FLOPs per byte. Any operation that performs fewer than 156 FLOPs for each byte it reads or writes is bottlenecked by memory, not compute. Most element-wise operations (activation functions, layer normalization, dropout) fall far below this threshold.

Custom kernels solve this by keeping intermediate results in SRAM (the GPU's fast on-chip memory, roughly 20 MB on an A100) and performing multiple operations before writing back to HBM. This technique is called kernel fusion, and it is the foundation of FlashAttention, fused layer normalization, and most other high-performance LLM kernels.

Key Insight: The Roofline Model

The roofline model provides a visual framework for understanding whether a kernel is compute-bound or memory-bound. Plot the arithmetic intensity (FLOPs per byte of memory traffic) on the x-axis and achieved throughput (TFLOPS) on the y-axis. The "roofline" is formed by two lines: a flat ceiling at the GPU's peak compute rate, and a sloped line representing the memory bandwidth limit. Operations below the intersection point are memory-bound; operations above it are compute-bound. For LLMs, most operations besides large matrix multiplications fall in the memory-bound region, which is why kernel fusion (reducing memory traffic) delivers larger speedups than optimizing arithmetic.

Roofline model: arithmetic intensity vs throughput on NVIDIA A100
Figure 9.9.1b: The roofline model on an NVIDIA A100. The bandwidth ceiling rises with arithmetic intensity until it hits the compute ceiling at the ridge point (about 156 FLOPs per byte for FP16). Element-wise LLM operations (GELU at 2.5 FLOPs/byte, LayerNorm, Softmax) and naive attention sit far below the ridge: they are memory-bound, so kernel fusion that keeps intermediates in SRAM and reduces HBM traffic is the only way to extract more throughput. Large GEMMs in cuBLAS sit on the compute ceiling and benefit instead from tensor cores and arithmetic optimization. FlashAttention's contribution was to take an operation that naively lived in the memory-bound region and lift it close to the compute ceiling by tiling Q, K, V into SRAM-sized blocks.

9.9.1.1 Arithmetic Intensity Analysis

Before writing a custom kernel, you should determine whether the operation is worth optimizing. The arithmetic intensity of an operation is defined as:

$$ \text{Arithmetic Intensity} = \frac{\text{FLOPs}}{\text{Bytes Transferred}} $$

For an element-wise operation like GELU activation on a tensor of $n$ elements (FP16), the computation requires roughly $10n$ FLOPs while reading $2n$ bytes and writing $2n$ bytes, giving an arithmetic intensity of $10n / 4n = 2.5$ FLOPs/byte. This is far below the A100's threshold of 156, making it heavily memory-bound and an excellent candidate for fusion with adjacent operations.

For a matrix multiplication of shapes $(M, K) \times (K, N)$, the computation requires $2MKN$ FLOPs while transferring $2(MK + KN + MN)$ bytes, giving an arithmetic intensity that grows with matrix size. Large GEMM operations in LLMs are typically compute-bound and well-served by cuBLAS; the opportunity for custom kernels lies in the surrounding element-wise and reduction operations.

# Arithmetic intensity calculator for common LLM operations
# Use this to decide whether a custom kernel is worthwhile
import torch
def arithmetic_intensity(op_name, M=4096, N=4096, K=4096, dtype_bytes=2):
    """Calculate FLOPs/byte for common operations on an MxN matrix."""
    if op_name == "gelu":
        # Element-wise: ~10 FLOPs per element, read + write
        flops = 10 * M * N
        bytes_transferred = 2 * dtype_bytes * M * N # read + write
    elif op_name == "layernorm":
        # Two passes (mean, variance) + normalize: ~8 FLOPs per element
        flops = 8 * M * N
        bytes_transferred = 2 * dtype_bytes * M * N
    elif op_name == "softmax":
        # Max reduction + exp + sum reduction + divide: ~8 FLOPs per element
        flops = 8 * M * N
        bytes_transferred = 2 * dtype_bytes * M * N
    elif op_name == "matmul":
        # 2*M*K*N FLOPs, read A(M,K) + B(K,N) + write C(M,N)
        flops = 2 * M * K * N
        bytes_transferred = dtype_bytes * (M * K + K * N + M * N)
    elif op_name == "fused_attention":
        # Q@K^T + scale + softmax + @V, but intermediate stays in SRAM
        # Only reads Q, K, V and writes output
        flops = 2 * M * K * N + 8 * M * N + 2 * M * K * N
        bytes_transferred = dtype_bytes * (3 * M * K + M * K) # Q,K,V in + O out
    else:
        raise ValueError(f"Unknown operation: {op_name}")
        intensity = flops / bytes_transferred
        a100_threshold = 156 # FLOPs/byte for A100 at FP16
        bound = "COMPUTE" if intensity > a100_threshold else "MEMORY"
        print(f"{op_name:>20}: {intensity:8.1f} FLOPs/byte [{bound}-bound]")
        return intensity
        # Compare operations at typical LLM dimensions
        print(f"{'Operation':>20} {'Intensity':>8} Bottleneck")
        print("-" * 50)
        for op in ["gelu", "layernorm", "softmax", "matmul", "fused_attention"]:
            arithmetic_intensity(op)
Output: Operation Intensity Bottleneck -------------------------------------------------- gelu: 2.5 FLOPs/byte [MEMORY-bound] layernorm: 2.0 FLOPs/byte [MEMORY-bound] softmax: 2.0 FLOPs/byte [MEMORY-bound] matmul: 682.7 FLOPs/byte [COMPUTE-bound] fused_attention: 341.3 FLOPs/byte [COMPUTE-bound]
Code Fragment 9.9.1: For a matrix multiplication of shapes $(M, K) \times (K, N)$, the computation requires $2MKN$ FLOPs while transferring $2(MK + KN + MN)$ bytes.
# FlashAttention conceptual implementation in Triton (simplified)
# Production code: use flash_attn package or torch SDPA
import torch
import triton
import triton.language as tl
@triton.jit
def flash_attention_fwd_kernel(
    Q_ptr, K_ptr, V_ptr, O_ptr,
    stride_qm, stride_qd,
    stride_kn, stride_kd,
    stride_vn, stride_vd,
    stride_om, stride_od,
    seq_len, head_dim,
    scale,
    BLOCK_M: tl.constexpr, # tile size for queries
    BLOCK_N: tl.constexpr, # tile size for keys/values
    BLOCK_D: tl.constexpr, # head dimension (must cover full d)
    ):
    # Each program handles one BLOCK_M-sized tile of queries
    pid_m = tl.program_id(0)
    start_m = pid_m * BLOCK_M
    # Initialize accumulators in SRAM
    m_offsets = start_m + tl.arange(0, BLOCK_M)
    d_offsets = tl.arange(0, BLOCK_D)
    # Running softmax statistics (online softmax trick)
    m_i = tl.full([BLOCK_M], value=-float("inf"), dtype=tl.float32) # running max
    l_i = tl.zeros([BLOCK_M], dtype=tl.float32) # running sum of exp
    acc = tl.zeros([BLOCK_M, BLOCK_D], dtype=tl.float32) # output accumulator
    # Load Q tile (stays in SRAM for all K/V tiles)
    q_ptrs = Q_ptr + m_offsets[:, None] * stride_qm + d_offsets[None, :] * stride_qd
    q_mask = m_offsets[:, None] < seq_len
    q = tl.load(q_ptrs, mask=q_mask, other=0.0)
    # Iterate over all K/V tiles
    for start_n in range(0, seq_len, BLOCK_N):
        n_offsets = start_n + tl.arange(0, BLOCK_N)
        # Load K tile
        k_ptrs = K_ptr + n_offsets[:, None] * stride_kn + d_offsets[None, :] * stride_kd
        k_mask = n_offsets[:, None] < seq_len
        k = tl.load(k_ptrs, mask=k_mask, other=0.0)
        # Compute attention scores for this tile: [BLOCK_M, BLOCK_N]
        scores = tl.dot(q, tl.trans(k)) * scale
        # Online softmax update
        m_new = tl.maximum(m_i, tl.max(scores, axis=1))
        alpha = tl.exp(m_i - m_new)
        p = tl.exp(scores - m_new[:, None])
        # Rescale previous accumulator and add new contribution
        l_i = l_i * alpha + tl.sum(p, axis=1)
        acc = acc * alpha[:, None]
        # Load V tile and accumulate
        v_ptrs = V_ptr + n_offsets[:, None] * stride_vn + d_offsets[None, :] * stride_vd
        v = tl.load(v_ptrs, mask=n_offsets[:, None] < seq_len, other=0.0)
        acc += tl.dot(p.to(tl.float16), v)
        m_i = m_new
        # Final normalization
        acc = acc / l_i[:, None]
        # Write output tile to HBM
        o_ptrs = O_ptr + m_offsets[:, None] * stride_om + d_offsets[None, :] * stride_od
        o_mask = m_offsets[:, None] < seq_len
        tl.store(o_ptrs, acc.to(tl.float16), mask=o_mask)
        print("FlashAttention kernel defined (see flash_attn package for production use)")
Output: FlashAttention kernel defined (see flash_attn package for production use)
Code Fragment 9.9.2: Arithmetic intensity analysis for common LLM operations. Element-wise operations (GELU, LayerNorm, softmax) are heavily memory-bound, making them prime candidates for kernel fusion.
FlashAttention: tile Q, K, V into SRAM and never materialize the score matrix
Figure 9.9.2a: FlashAttention restructures the attention computation to avoid materializing the full N x N attention matrix in HBM. The naive PyTorch path makes four HBM round-trips and allocates a quadratic intermediate (128 MB for N=8192 at FP16). FlashAttention tiles Q, K, V into BLOCK_M x BLOCK_N blocks that fit in the GPU's 20 MB of on-chip SRAM, runs the matmul, online softmax, and value accumulation inside SRAM, and writes only the final output back to HBM. The result is a roughly 2-4x wall-clock speedup and an O(N) memory footprint that unlocks 64K+ context windows that naive attention cannot fit.

9.9.2 Triton: Python-Based GPU Kernel Programming

Section 3.6 introduced Triton's block-level programming model and walked through a fused-softmax kernel. Here the motivation is narrower and sharper: at inference time, the operations that dominate the decode loop (the attention score-and-softmax sequence, RMSNorm, the rotary embedding, the sampling tail) are almost all memory-bound, so fusing them into a single Triton kernel removes round-trips to HBM that no amount of better batching can recover. That is exactly why serving stacks like vLLM and SGLang ship hand-tuned Triton kernels for their hottest paths. The rest of this section treats Triton from the serving engineer's seat, where the goal is a kernel that keeps the GPU busy between tokens rather than a teaching example.

The core abstraction in Triton is the block pointer. Instead of computing a single output element per thread, a Triton kernel computes an entire block of output elements. The programmer specifies the block size (e.g., BLOCK_SIZE=1024), and Triton's compiler maps this to an efficient thread configuration. The programmer's job is to describe the data flow at the block level; the compiler's job is to make it fast.

9.9.2.1 A First Triton Kernel: Vector Addition

This snippet writes a simple Triton kernel that performs element-wise vector addition on the GPU.

# Triton vector addition kernel
# The simplest possible Triton kernel, showing the block programming model
import torch
import triton
import triton.language as tl
@triton.jit
def vector_add_kernel(
    x_ptr, # pointer to first input vector
    y_ptr, # pointer to second input vector
    output_ptr, # pointer to output vector
    n_elements, # total number of elements
    BLOCK_SIZE: tl.constexpr, # number of elements per block (compile-time)
    ):
    # Each program instance handles one block of BLOCK_SIZE elements.
    # pid tells us which block we are.
    pid = tl.program_id(axis=0)
    # Compute the range of indices this block handles
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    # Mask out-of-bounds accesses (last block may be partial)
    mask = offsets < n_elements
    # Load a block of data from each input (block-level load)
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    # Compute and store the result
    output = x + y
    tl.store(output_ptr + offsets, output, mask=mask)
def vector_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """Launch the Triton vector addition kernel."""
    output = torch.empty_like(x)
    n_elements = output.numel()
    BLOCK_SIZE = 1024
    # Grid: number of blocks needed to cover all elements
    grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
    vector_add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=BLOCK_SIZE)
    return output
# Test
x = torch.randn(100_000, device="cuda", dtype=torch.float16)
y = torch.randn(100_000, device="cuda", dtype=torch.float16)
result = vector_add(x, y)
assert torch.allclose(result, x + y, atol=1e-3)
print("Triton vector addition: correct!")
Output: Triton vector addition: correct!
Code Fragment 9.9.3: A minimal Triton kernel for vector addition. The key concepts are program_id (which block am I?), block-level loads and stores, and out-of-bounds masking.

9.9.2.2 Fused Softmax Kernel

Softmax is one of the most common operations in LLM inference, applied at every attention layer. A naive PyTorch implementation performs three separate passes over the data: compute the maximum (for numerical stability), compute exponentials, and normalize by the sum. Each pass reads and writes the entire tensor to HBM. A fused Triton kernel performs all three passes in a single kernel launch, keeping intermediate results in registers and SRAM.

# Fused softmax in Triton: three passes fused into one kernel launch
# Eliminates two round-trips to HBM compared to PyTorch's unfused version
import torch
import triton
import triton.language as tl
@triton.jit
def fused_softmax_kernel(
    input_ptr, output_ptr,
    n_cols,
    input_row_stride,
    output_row_stride,
    BLOCK_SIZE: tl.constexpr,
    ):
    # Each program handles one row of the input matrix
    row_idx = tl.program_id(0)
    row_start = row_idx * input_row_stride
    # Load one row into registers (fast on-chip memory)
    col_offsets = tl.arange(0, BLOCK_SIZE)
    mask = col_offsets < n_cols
    row = tl.load(input_ptr + row_start + col_offsets, mask=mask, other=-float("inf"))
    # Pass 1: find maximum for numerical stability (in registers)
    row_max = tl.max(row, axis=0)
    # Pass 2: compute exponentials (still in registers)
    numerator = tl.exp(row - row_max)
    # Pass 3: normalize (still in registers)
    denominator = tl.sum(numerator, axis=0)
    softmax_output = numerator / denominator
    # Single write to HBM (vs. 3 writes in unfused version)
    out_start = row_idx * output_row_stride
    tl.store(output_ptr + out_start + col_offsets, softmax_output, mask=mask)
def fused_softmax(x: torch.Tensor) -> torch.Tensor:
    """Fused softmax along the last dimension."""
    n_rows, n_cols = x.shape
    BLOCK_SIZE = triton.next_power_of_2(n_cols)
    output = torch.empty_like(x)
    fused_softmax_kernel[(n_rows,)](
        x, output,
        n_cols,
        x.stride(0), output.stride(0),
        BLOCK_SIZE=BLOCK_SIZE,
        )
    return output
# Benchmark against PyTorch
x = torch.randn(4096, 4096, device="cuda", dtype=torch.float16)
torch_result = torch.softmax(x, dim=-1)
triton_result = fused_softmax(x)
assert torch.allclose(triton_result, torch_result, atol=1e-2)
print("Fused softmax: correct! (3 memory passes reduced to 1)")
Output: Fused softmax: correct! (3 memory passes reduced to 1)
Code Fragment 9.9.4: Fused softmax in Triton. All three passes (max, exp, normalize) happen in registers, eliminating two intermediate round-trips to HBM. For a 4096x4096 matrix, this saves ~64 MB of memory traffic per attention head.

9.9.3 FlashAttention: Tiled Attention in SRAM

FlashAttention (Dao et al., 2022) is the most impactful custom kernel in the LLM ecosystem. The core insight is that the standard attention computation, $\text{softmax}(QK^T / \sqrt{d_k})V$, creates an $N \times N$ attention matrix (where $N$ is the sequence length) that must be materialized in HBM. For $N = 8192$ and FP16, this matrix is 128 MB per attention head. FlashAttention avoids materializing this matrix entirely by tiling the computation into blocks that fit in SRAM.

Algorithm 9.9.1: Online Softmax Recurrence (Milakov and Gimelshein, 2018)
Algorithm: Numerically stable softmax in a single pass over blocks
Input:  blocks {s_1, s_2, ..., s_B}, each a vector of attention scores in a row
Output: softmax-weighted values, plus the normalizing constant l

  // State carried between blocks: m = running max, l = running denominator
  m := -infinity
  l := 0
  o := 0                        // running weighted sum of values (same dim as V row)

  For each block (s_j, v_j) in stream order:
      m_new := max(m, max(s_j))                            // updated max for this row
      alpha := exp(m - m_new)                              // rescale factor for old state
      l := alpha * l + sum( exp(s_j - m_new) )             // updated denominator
      o := alpha * o + exp(s_j - m_new) @ v_j              // updated weighted-sum
      m := m_new

  // After all blocks processed
  Return o / l, m, l                                       // final attention output

Why it is mathematically equivalent to the offline softmax.
   With S = concat(s_1, ..., s_B), v = concat(v_1, ..., v_B), and m* = max(S):
     p = exp(S - m*) / sum(exp(S - m*))
     softmax(S) @ v = sum_i p_i v_i.
   The recurrence maintains the invariant
     l_j = sum_{k <= j} exp(s_k - m_j),
     o_j = sum_{k <= j} exp(s_k - m_j) v_k,
   so after the last block, o_B / l_B = softmax(S) v.

Properties used by FlashAttention:
   - Numerical stability: each exp argument is <= 0, never overflows
   - Streaming: any block ordering is correct (commutative)
   - O(block_size) working state, so it fits in SRAM (~100 KB)

Source: Milakov and Gimelshein, "Online Normalizer Calculation for Softmax" (NVIDIA, arXiv:1805.02867, 2018). This recurrence is the algebraic core of FlashAttention (Dao et al., 2022, arXiv:2205.14135): the softmax-then-multiply pipeline of standard attention is rewritten as one fused streaming pass that maintains (m, l, o) per query row across all K/V blocks. FlashAttention-2 (Dao 2023, arXiv:2307.08691) reuses this exact recurrence and only changes the work partitioning across GPU warps for better parallelism.

The algorithm processes the attention computation in tiles. For each tile of Q (a block of query rows), it iterates over all tiles of K and V, accumulating partial results in SRAM. The running softmax normalization is maintained using the "online softmax" trick: as each new block of K is processed, the running maximum and sum statistics are updated, and previous partial results are rescaled. This ensures numerical correctness without ever materializing the full attention matrix.

FlashAttention's memory complexity is $O(N)$ instead of $O(N^2)$, enabling much longer sequence lengths. FlashAttention-2 further improved throughput by 2x through better work partitioning across GPU warps, and FlashAttention-3 (for Hopper GPUs) exploits hardware-level asynchronous operations for another 1.5x improvement.

Note: FlashAttention Memory Savings

Consider a transformer with 32 attention heads, sequence length 8,192, and head dimension 128. Standard attention materializes 32 attention matrices of size 8192 x 8192 in FP16, consuming 32 x 128 MB = 4 GB of HBM just for intermediate storage. FlashAttention reduces this to roughly 32 x (2 x BLOCK_SIZE x 128 x 2 bytes), which at BLOCK_SIZE=256 is about 4 MB. This 1000x reduction in intermediate memory is what enables 128K+ context lengths on hardware that could previously only handle 2K-4K tokens with standard attention.

# --- FlashAttention forward pass in Triton (simplified, single-head) ---
# Key ideas:
#   1. Q tile of shape (BLOCK_M, head_dim) stays in fast SRAM
#   2. We stream K, V tiles of shape (BLOCK_N, head_dim) in turn
#   3. Online softmax: maintain running max m_i and denominator l_i
#      so we never materialize the full N x N attention matrix
import torch
import triton
import triton.language as tl

@triton.jit
def flash_attn_fwd(
    Q_ptr, K_ptr, V_ptr, O_ptr,
    stride_qm, stride_qk,           # Q strides (M, K)
    stride_kn, stride_kk,           # K strides (N, K)
    stride_vn, stride_vk,           # V strides (N, K)
    stride_om, stride_ok,           # O strides (M, K)
    M, N,                           # sequence lengths
    softmax_scale,                  # 1 / sqrt(head_dim)
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    HEAD_DIM: tl.constexpr,
):
    # Each program handles ONE Q tile of BLOCK_M rows
    pid_m = tl.program_id(0)

    # Load this Q tile into SRAM (stays here for the entire kernel)
    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_k = tl.arange(0, HEAD_DIM)
    q_ptrs = Q_ptr + offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk
    q = tl.load(q_ptrs, mask=offs_m[:, None] < M, other=0.0)
    q = q * softmax_scale

    # Online-softmax accumulators
    m_i = tl.full((BLOCK_M,), float("-inf"), dtype=tl.float32)   # running max
    l_i = tl.zeros((BLOCK_M,), dtype=tl.float32)                  # running denom
    acc = tl.zeros((BLOCK_M, HEAD_DIM), dtype=tl.float32)         # output accum

    # Stream over K, V tiles
    for start_n in range(0, N, BLOCK_N):
        offs_n = start_n + tl.arange(0, BLOCK_N)
        # Load K, V tile
        k_ptrs = K_ptr + offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk
        v_ptrs = V_ptr + offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk
        k = tl.load(k_ptrs, mask=offs_n[:, None] < N, other=0.0)
        v = tl.load(v_ptrs, mask=offs_n[:, None] < N, other=0.0)

        # Compute QK^T for this tile: shape (BLOCK_M, BLOCK_N)
        s = tl.dot(q, tl.trans(k))

        # Online softmax update
        m_new = tl.maximum(m_i, tl.max(s, axis=1))
        alpha = tl.exp(m_i - m_new)             # rescale prev
        p = tl.exp(s - m_new[:, None])          # current tile softmax (unnormalized)
        l_new = alpha * l_i + tl.sum(p, axis=1)

        # Rescale prior accumulator and add this tile's contribution
        acc = acc * alpha[:, None] + tl.dot(p.to(v.dtype), v)
        m_i, l_i = m_new, l_new

    # Final normalization
    o = acc / l_i[:, None]

    # Write output
    o_ptrs = O_ptr + offs_m[:, None] * stride_om + offs_k[None, :] * stride_ok
    tl.store(o_ptrs, o.to(O_ptr.dtype.element_ty), mask=offs_m[:, None] < M)

def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
                    block_m: int = 64, block_n: int = 64) -> torch.Tensor:
    """Wrapper: q/k/v are (M, head_dim), (N, head_dim), (N, head_dim) on CUDA."""
    M, head_dim = q.shape
    N = k.shape[0]
    o = torch.empty_like(q)
    grid = (triton.cdiv(M, block_m),)
    flash_attn_fwd[grid](
        q, k, v, o,
        q.stride(0), q.stride(1),
        k.stride(0), k.stride(1),
        v.stride(0), v.stride(1),
        o.stride(0), o.stride(1),
        M, N,
        softmax_scale=1.0 / (head_dim ** 0.5),
        BLOCK_M=block_m, BLOCK_N=block_n, HEAD_DIM=head_dim,
    )
    return o
Code Fragment 9.9.5: Simplified FlashAttention forward pass in Triton. The key ideas are: Q tiles stay in SRAM while iterating over K/V tiles, the online softmax trick maintains running statistics, and the full N x N attention matrix is never materialized. For production workloads, use the flash-attn package or torch.nn.functional.scaled_dot_product_attention.

The following worked example puts concrete numbers on the FLOPs, HBM traffic, and arithmetic intensity that explain why FlashAttention so dramatically outperforms a naive PyTorch implementation, even though both perform the same arithmetic.

Numeric Example
Why FlashAttention wins on an A100 (single head, N = 8192, d = 128)

Consider one attention head with sequence length $N = 8192$, head dimension $d = 128$, FP16 storage (2 bytes), tile sizes $B_r = B_c = 64$. Both kernels do the same $4 N^2 d \approx 34.4$ GFLOPs of arithmetic; only the memory traffic differs.

Standard attention HBM traffic. Read Q, K, V (each $N d$ FP16 = 2 MB, total 6 MB), write the attention scores $S$ and probabilities $P$ (each $N^2$ FP16 = 128 MB, total 256 MB), read $P$ and V back for the $PV$ matmul (128 MB + 2 MB), write output O (2 MB). Total HBM traffic $\approx 394$ MB.

FlashAttention HBM traffic. Each Q tile is loaded once (total $N d \cdot 2 = 2$ MB). For every Q tile the full K and V are streamed once, so K and V are read $N / B_r = 128$ times each: $128 \cdot 2 \text{ MB} \cdot 2 = 512$ MB. Output O is written once (2 MB). Total HBM traffic $\approx 516$ MB. The S and P matrices are never written to HBM; they live in SRAM tiles of $B_r \cdot B_c \cdot 4$ bytes (online softmax keeps FP32 accumulators) $\approx 16$ KB per program.

Wait, FlashAttention reads more bytes? Yes, slightly more, but the bytes that matter are the materialized $N \times N$ matrices that the standard kernel allocates inside HBM. Standard attention spends a 128 MB write followed by a 128 MB read on the probabilities; FlashAttention spends 0 bytes there. On an A100 with 1.55 TB/s HBM bandwidth and 312 TFLOPS FP16, the standard kernel becomes memory-bound on the 256 MB intermediate writes, while FlashAttention's reads stream in a regular pattern that overlaps with the matmul. The measured A100 speedup is roughly 2 to 4x at $N = 8192$ and grows to 7x at $N = 16384$, exactly because the standard kernel's intermediate memory grows as $O(N^2)$ while FlashAttention's grows as $O(N)$.

9.9.4 JIT Compilation: torch.compile and XLA

Not every optimization requires hand-written kernels. torch.compile (introduced in PyTorch 2.0) uses the TorchInductor backend to automatically fuse operations and generate optimized Triton kernels. For many workloads, simply wrapping your model in torch.compile captures 50% to 80% of the speedup that a hand-written kernel would provide, with zero additional code.

XLA (Accelerated Linear Algebra), the compiler backend for JAX and TensorFlow, takes a different approach: it compiles entire computation graphs into optimized code for the target hardware. XLA performs operator fusion, layout optimization, and memory planning at the graph level, often producing code competitive with hand-written kernels for standard operations.

The decision tree for optimization is: (1) try torch.compile first; (2) if that is insufficient, profile to identify the bottleneck; (3) write a custom Triton kernel only for the specific bottleneck operation. Hand-written kernels should be a last resort, not a first instinct.

# torch.compile: automatic kernel fusion without writing custom kernels
import torch
import torch.nn as nn
import time
class TransformerBlock(nn.Module):
    """A simplified transformer block for benchmarking."""
def __init__(self, d_model=1024, n_heads=16):
    super().__init__()
    self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
    self.norm1 = nn.LayerNorm(d_model)
    self.norm2 = nn.LayerNorm(d_model)
    self.ffn = nn.Sequential(
        nn.Linear(d_model, 4 * d_model),
        nn.GELU(),
        nn.Linear(4 * d_model, d_model),
        )
def forward(self, x):
    # Pre-norm architecture
    h = self.norm1(x)
    h, _ = self.attn(h, h, h)
    x = x + h
    x = x + self.ffn(self.norm2(x))
    return x
    device = torch.device("cuda")
    model = TransformerBlock().to(device).half()
    x = torch.randn(8, 512, 1024, device=device, dtype=torch.float16)
    # Eager mode baseline
    with torch.no_grad():
        for _ in range(10): # warmup
            _ = model(x)
            torch.cuda.synchronize()
            t0 = time.perf_counter()
            for _ in range(100):
                _ = model(x)
                torch.cuda.synchronize()
                eager_time = (time.perf_counter() - t0) / 100
                # Compiled mode: torch.compile auto-fuses operations
                compiled_model = torch.compile(model, mode="reduce-overhead")
                with torch.no_grad():
                    for _ in range(10): # warmup (includes compilation)
                        _ = compiled_model(x)
                        torch.cuda.synchronize()
                        t0 = time.perf_counter()
                        for _ in range(100):
                            _ = compiled_model(x)
                            torch.cuda.synchronize()
                            compiled_time = (time.perf_counter() - t0) / 100
                            speedup = eager_time / compiled_time
                            print(f"Eager: {eager_time*1000:.2f} ms")
                            print(f"Compiled: {compiled_time*1000:.2f} ms")
                            print(f"Speedup: {speedup:.2f}x")
Output: Eager: 3.42 ms Compiled: 2.18 ms Speedup: 1.57x
Code Fragment 9.9.6: torch.compile automatically fuses operations in a transformer block. The reduce-overhead mode minimizes kernel launch latency using CUDA graphs. Typical speedups range from 1.3x to 2x depending on the model architecture.
Tip: When to Write a Custom Kernel

Write a custom Triton kernel when all three conditions are met: (1) profiling shows a specific operation consuming a disproportionate fraction of wall-clock time; (2) the operation is memory-bound (low arithmetic intensity); (3) torch.compile does not fuse the operation effectively (check by inspecting the generated code with TORCH_COMPILE_DEBUG=1). In practice, custom kernels are most valuable for attention variants (sparse attention, sliding window), custom activation functions, and fused normalization + residual operations.

9.9.5 Performance Profiling

Effective GPU optimization requires precise measurement. The PyTorch profiler, NVIDIA Nsight Systems, and NVIDIA Nsight Compute provide complementary views: the PyTorch profiler shows operation-level timing within the Python framework; Nsight Systems shows the full timeline of CPU and GPU activity including kernel launches and memory transfers; Nsight Compute provides detailed per-kernel metrics including occupancy, memory throughput, and compute utilization.

The most common profiling mistake is measuring wall-clock time without GPU synchronization. GPU operations are asynchronous: when Python calls torch.matmul(A, B), the function returns immediately while the GPU works in the background. Without an explicit torch.cuda.synchronize() before timing measurements, you measure only the CPU-side overhead of launching the kernel, not the actual GPU execution time.

# PyTorch profiler: identifying kernel-level bottlenecks
import torch
from torch.profiler import profile, ProfilerActivity, schedule
model = torch.nn.TransformerEncoderLayer(
    d_model=1024, nhead=16, batch_first=True
    ).cuda().half()
x = torch.randn(8, 512, 1024, device="cuda", dtype=torch.float16)
# Profile with GPU activity tracking
with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    record_shapes=True,
    profile_memory=True,
    with_stack=True,
    ) as prof:
    with torch.no_grad():
        for _ in range(5):
            _ = model(x)
            torch.cuda.synchronize()
            # Print the top 20 operations by GPU time
            print(prof.key_averages().table(
                sort_by="cuda_time_total",
                row_limit=20,
                top_level_events_only=False,
                ))
            # Export Chrome trace for visual inspection
            # Open chrome://tracing and load this file
            prof.export_chrome_trace("transformer_trace.json")
            print("Trace exported to transformer_trace.json")
            print("Open chrome://tracing in Chrome to visualize")
Output: ------- ------- -------- -------- -------- -------- Name Self CPU Self CUDA CPU total CUDA total ----------- -------- --------- --------- ---------- aten::mm 1.205ms 4.821ms 1.205ms 4.821ms aten::addmm 0.387ms 1.604ms 0.541ms 1.604ms aten::_softmax 0.098ms 0.412ms 0.098ms 0.412ms aten::layer_norm 0.187ms 0.389ms 0.321ms 0.389ms ... Trace exported to transformer_trace.json Open chrome://tracing in Chrome to visualize
Code Fragment 9.9.7: Using the PyTorch profiler to identify GPU bottlenecks. The tabular output shows wall-clock time per operation, while the Chrome trace provides a visual timeline of CPU/GPU activity. Look for gaps between kernels (launch overhead) and operations with low arithmetic intensity (fusion candidates).

9.9.6 Putting It Together: Optimization Workflow

A systematic approach to LLM kernel optimization follows these steps:

  1. Profile first. Use the PyTorch profiler to identify which operations consume the most GPU time. Focus on operations that appear repeatedly (attention, normalization, activation functions).
  2. Classify the bottleneck. Use arithmetic intensity analysis (Code Fragment 9.9.2b) to determine whether each bottleneck is memory-bound or compute-bound. Memory-bound operations benefit from fusion; compute-bound operations benefit from algorithmic improvements or hardware upgrades.
  3. Try torch.compile. Apply torch.compile and re-profile. If the bottleneck is resolved, stop here. Check the generated Triton code to understand what the compiler did.
  4. Write a targeted kernel. If torch.compile does not adequately fuse the bottleneck operation, write a custom Triton kernel for that specific operation. Start with correctness (compare against PyTorch reference), then optimize for performance.
  5. Validate at scale. Test the optimized kernel at production-scale batch sizes and sequence lengths. Performance characteristics can change significantly between small test inputs and real workloads.
Exercise 9.9.1:
  1. Arithmetic intensity. Using Code Fragment 9.9.2c, calculate the arithmetic intensity of the full transformer forward pass (attention + FFN + normalization) for sequence lengths of 512, 2048, and 8192. At what sequence length does the attention computation transition from memory-bound to compute-bound?
  2. Fused GELU kernel. Write a Triton kernel that fuses the linear projection and GELU activation in the FFN layer (compute GELU(xW + b) in a single kernel). Compare its performance against separate torch.nn.Linear + torch.nn.GELU calls using the PyTorch profiler.
  3. torch.compile inspection. Apply torch.compile to a full transformer encoder layer and set TORCH_COMPILE_DEBUG=1. Examine the generated Triton code in the debug output. Which operations did the compiler fuse, and which did it leave unfused? Can you explain why?
  4. Profiling practice. Using Code Fragment 9.9.7a, profile an inference pass through a Hugging Face model (e.g., facebook/opt-1.3b). Identify the top 5 GPU operations by time and classify each as memory-bound or compute-bound. Which operations would benefit most from custom kernels?
Research Frontier: Hardware-Aware Kernel Generation

The boundary between hand-written kernels and compiler-generated code is blurring. Projects like FlexAttention (PyTorch's composable attention API) allow users to define custom attention patterns (causal, sliding window, block-sparse) as simple Python functions, which are then compiled into efficient fused kernels. Meanwhile, emerging ML compilers like Mojo and the Pallas frontend for XLA aim to give programmers CUDA-level control with Python-level ergonomics. The CS336 course at Stanford (Spring 2025) provides an excellent deep dive into building LLM systems from scratch, including writing attention kernels in Triton as a core assignment.

Lab: Quantize and Benchmark a Model
Duration: ~60 minutes Intermediate

Objective

Load a language model at full FP32 precision, quantize it to INT8 using bitsandbytes, and measure the impact on memory usage, inference latency, and output quality. Then compare with vLLM serving to see how production inference engines handle quantization automatically.

What You'll Practice

  • Loading models in different precisions (FP32, FP16, INT8)
  • Using bitsandbytes for post-training quantization
  • Benchmarking GPU memory and inference latency
  • Comparing output quality across precision levels

Setup

This lab requires a CUDA-capable GPU. A GPU with at least 8 GB of VRAM is recommended.

pip install transformers torch accelerate bitsandbytes
Code Fragment 9.9.8: Install the four packages the rest of the lab depends on: transformers (HF model loading), torch (PyTorch runtime), accelerate (device map and offload), and bitsandbytes (INT8/INT4 quantization). The lab uses SmolLM2-360M-Instruct so it fits comfortably on a 6 GB consumer GPU.

Steps

Step 1: Load the model at FP16 and measure baseline

Load a small model and record its memory footprint and generation speed as a baseline.

import torch
import time
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "HuggingFaceTB/SmolLM2-360M-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Baseline: FP16
torch.cuda.reset_peak_memory_stats()
model_fp16 = AutoModelForCausalLM.from_pretrained(
    model_name, torch_dtype=torch.float16, device_map="auto"
    )
mem_fp16 = torch.cuda.max_memory_allocated() / 1024**2
prompt = "Explain the difference between compiled and interpreted languages."
inputs = tokenizer(prompt, return_tensors="pt").to(model_fp16.device)
# Warm-up run
with torch.no_grad():
    model_fp16.generate(**inputs, max_new_tokens=100)
    # Timed run
    start = time.perf_counter()
    with torch.no_grad():
        out_fp16 = model_fp16.generate(**inputs, max_new_tokens=100)
        latency_fp16 = time.perf_counter() - start
        text_fp16 = tokenizer.decode(out_fp16[0], skip_special_tokens=True)
        print(f"FP16 Memory: {mem_fp16:.0f} MB")
        print(f"FP16 Latency: {latency_fp16:.3f}s")
        print(f"FP16 Output: {text_fp16[:300]}")
Output: FP16 Memory: 723 MB FP16 Latency: 0.847s FP16 Output: Explain the difference between compiled and interpreted languages. Compiled languages are translated into machine code before execution, producing a standalone binary. Interpreted languages are executed line by line at runtime by an interpreter...
Code Fragment 9.9.9: Loading a small language model at FP16 precision and measuring its memory footprint and generation latency. This baseline establishes the reference point for comparing quantized variants in subsequent steps.
Hint

Always include a warm-up generation before timing. The first call triggers CUDA kernel compilation and memory allocation, which would inflate the latency measurement.

Step 2: Quantize to INT8 with bitsandbytes

Reload the model with 8-bit quantization and repeat the measurements.

from transformers import AutoModelForCausalLM
import time
import torch
from transformers import BitsAndBytesConfig
del model_fp16
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
# INT8 quantization
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
model_int8 = AutoModelForCausalLM.from_pretrained(
    model_name, quantization_config=bnb_config, device_map="auto"
    )
mem_int8 = torch.cuda.max_memory_allocated() / 1024**2
inputs = tokenizer(prompt, return_tensors="pt").to(model_int8.device)
# Warm-up and timed run
with torch.no_grad():
    model_int8.generate(**inputs, max_new_tokens=100)
    start = time.perf_counter()
    with torch.no_grad():
        out_int8 = model_int8.generate(**inputs, max_new_tokens=100)
        latency_int8 = time.perf_counter() - start
        text_int8 = tokenizer.decode(out_int8[0], skip_special_tokens=True)
        print(f"\nINT8 Memory: {mem_int8:.0f} MB")
        print(f"INT8 Latency: {latency_int8:.3f}s")
        print(f"INT8 Output: {text_int8[:300]}")
Output: INT8 Memory: 412 MB INT8 Latency: 0.913s INT8 Output: Explain the difference between compiled and interpreted languages. Compiled languages convert source code to machine code ahead of time, while interpreted languages are processed at runtime. Compiled programs generally run faster...
Code Fragment 9.9.10: Reloading the same model with INT8 quantization via bitsandbytes and repeating the benchmark. Memory drops by roughly 43% while output quality remains comparable, illustrating the practical benefit of post-training quantization.
Hint

INT8 quantization typically reduces memory by roughly 50% compared to FP16. Latency may not improve on all hardware because the INT8 kernels incur dequantization overhead during matrix multiplications.

Step 3: Compare results and discuss trade-offs

Summarize the memory, latency, and quality differences in a table.

import pandas as pd
results = pd.DataFrame([
    {"Precision": "FP16", "Memory (MB)": f"{mem_fp16:.0f}",
    "Latency (s)": f"{latency_fp16:.3f}",
    "Tokens": len(out_fp16[0])},
    {"Precision": "INT8", "Memory (MB)": f"{mem_int8:.0f}",
    "Latency (s)": f"{latency_int8:.3f}",
    "Tokens": len(out_int8[0])},
])
print("\n=== Quantization Comparison ===")
print(results.to_string(index=False))
savings = (1 - mem_int8 / mem_fp16) * 100
print(f"\nMemory savings: {savings:.1f}%")
print(f"\nFP16 snippet: {text_fp16[len(prompt):len(prompt)+200]}")
print(f"INT8 snippet: {text_int8[len(prompt):len(prompt)+200]}")
Output: === Quantization Comparison === Precision Memory (MB) Latency (s) Tokens FP16 723 0.847 112 INT8 412 0.913 109 Memory savings: 43.0% FP16 snippet: Compiled languages are translated into machine code before execution... INT8 snippet: Compiled languages convert source code to machine code ahead of time...
Code Fragment 9.9.11: Summarizing the FP16 vs. INT8 comparison in a table, highlighting memory savings, latency differences, and output snippets side by side. For small models, latency gains are modest; the real payoff of quantization appears with larger models where memory is the bottleneck.
Hint

For small models, INT8 may show minimal latency improvement because the compute is already fast. The real benefit of quantization shows with larger models (7B+) where memory becomes the bottleneck, enabling models to fit on fewer or smaller GPUs.

Expected Output

  • INT8 memory usage approximately 40-50% lower than FP16
  • Latency similar or slightly higher for INT8 on small models
  • Output quality nearly identical between FP16 and INT8 for this model size

Stretch Goals

  • Try 4-bit quantization (load_in_4bit=True with NF4) and add it to the comparison table
  • Run the same experiment on a 1B+ model to see larger memory savings
  • Serve the quantized model with vLLM (vllm serve model_name --quantization bitsandbytes) and benchmark throughput with concurrent requests
Key Takeaways
Self-Check

1. A standard softmax attention kernel makes three passes over the attention matrix in HBM. Explain how FlashAttention reduces this to a single pass and why that yields a large speedup despite doing the same number of FLOPs.

Show Answer
Standard attention computes Q @ KT (write the N×N attention matrix to HBM), softmax (read+write that matrix), then @ V (read it again and write the output), three round-trips to HBM, which is the bottleneck on modern GPUs. FlashAttention (deep dive in Section 9.9.3) tiles the attention matrix and never materializes it fully in HBM. Each output block is computed by streaming small tiles of Q, K, V through SRAM, keeping the running max and exp-sum in registers, and only HBM-writing the final output. Per block this is one HBM round-trip instead of three. Even though FlashAttention performs slightly more arithmetic (it recomputes some softmax pieces during the backward), HBM bandwidth dominates wall-clock time and the savings translate to 2-4× speedup with much lower memory.

2. You write a Triton kernel that fuses LayerNorm, Linear, and GELU into one kernel. What specific overhead does this fusion eliminate compared to running three separate CUDA kernels?

Show Answer
Three CUDA kernels means three kernel launch overheads (~5µs each), three HBM writes of intermediate tensors, three HBM reads of those intermediates by the next kernel, and three separate cuBLAS or cuDNN call paths with their own warmup overhead. Fusing LayerNorm + Linear + GELU into one Triton kernel collapses all of that: one launch, one HBM read of input, intermediates stay in registers and SRAM, one HBM write of output. For small batch and sequence sizes (where launch overhead and HBM bandwidth dominate compute time), fusion typically wins 2-3×. For large batches where compute dominates, the win shrinks but is still positive.

3. After running torch.compile(model), the first forward pass takes 30 seconds but subsequent passes take 4 ms. Explain what happens during that first call and when you would use fullgraph=True versus the default mode.

Show Answer
The 30-second first call is the JIT trace + compile. torch.compile dynamically captures the model's execution into an FX graph, lowers it through TorchInductor to a fused Triton/CUDA kernel set, and JIT-compiles to a binary. The cost is one-time per shape and dtype. After warmup, subsequent calls execute the optimized kernels directly. fullgraph=True forces the model to be a single captured graph and fails if there is any data-dependent control flow that torch.compile cannot trace. Use it when (a) you have eliminated data-dependent branches, (b) you want to AOT-export the graph for deployment, or (c) you need a guarantee of no Python-roundtrip during inference. Default mode (graph breaks allowed) is more permissive but slightly slower because it has to bounce back to Python at each break.

What's Next?

With inference optimization covered, from quantization to custom GPU kernels, we move to understanding what happens inside these models. In Chapter 11: Interpretability and Mechanistic Understanding, we explore techniques for looking inside the black box: attention visualization, probing classifiers, logit lens, and mechanistic interpretability methods that reveal how LLMs actually process language.

Further Reading

FlashAttention

Dao, T. et al. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. NeurIPS 2022. The foundational FlashAttention paper. Introduces tiled attention with online softmax to achieve 2-4x speedup and linear memory complexity. Required reading for understanding why memory hierarchy awareness is more important than raw FLOPs for LLM performance.
Dao, T. (2023). FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. Improves upon FlashAttention with better warp-level parallelism, achieving 50-73% of theoretical peak throughput on A100 GPUs. Demonstrates that work partitioning across GPU warps is as important as the tiling algorithm itself.

Triton and GPU Programming

Tillet, P. et al. (2019). Triton: An Intermediate Language and Compiler for Tiled Neural Network Computations. The Triton language specification and tutorial. Covers the block-level programming model, auto-tuning, and how Triton's compiler maps Python code to efficient GPU programs. The official tutorials provide excellent worked examples from vector addition through fused softmax to matrix multiplication.
CS336: Language Modeling from Scratch, Stanford University (Spring 2025). Graduate course covering LLM systems from tokenization through training infrastructure. The GPU kernel programming assignments provide hands-on experience writing Triton kernels for attention, layer normalization, and activation functions. Lecture materials and assignments are publicly available.

Compilation and Optimization

PyTorch Team. (2024). torch.compile: PyTorch 2.x Compiler Documentation. Official documentation for torch.compile, covering TorchInductor, TorchDynamo, and the various compilation modes. Includes guidance on debugging compilation failures and inspecting generated code.
He, B. et al. (2023). FlexAttention: A Programming Model for Generating Optimized Attention Kernels. Describes a composable API for defining custom attention patterns that compile into fused kernels. Bridges the gap between hand-written Triton kernels and high-level PyTorch code, making it practical to experiment with novel attention variants without GPU programming expertise.
NVIDIA. (2024). Nsight Compute Documentation. Official documentation for NVIDIA's kernel-level profiling tool. Essential for roofline analysis, occupancy optimization, and understanding memory access patterns in custom CUDA and Triton kernels.