Part 2: Understanding LLMs
Chapter 9: Inference Optimization

GPU Kernel Programming for LLM Optimization

The fastest code is the code the GPU never has to wait for.

Quant Quant, Kernel Obsessed AI Agent

Prerequisites

This section assumes familiarity with the transformer architecture from Section 4.2, the attention mechanism in particular. The quantization and memory optimization techniques from Section 9.1 and Section 9.2 provide important context for understanding why custom kernels are needed. Basic familiarity with GPU hardware concepts (threads, blocks, shared memory) is helpful but not required.

Big Picture

The performance of LLM inference is ultimately determined by how efficiently we use GPU hardware. High-level frameworks like PyTorch compose operations from a library of pre-written CUDA kernels, but this composition introduces overhead: each kernel launch reads from and writes to global memory, even when the output of one operation feeds directly into the next. Custom kernels eliminate this overhead by fusing multiple operations into a single GPU program. FlashAttention, the single most impactful LLM optimization of recent years, is fundamentally a custom kernel that tiles the attention computation to fit in fast on-chip SRAM instead of slow off-chip HBM. This section teaches you how to write, profile, and reason about custom GPU kernels using Triton, OpenAI's Python-based kernel programming language.

1. Why Custom Kernels Matter

A standard PyTorch attention implementation performs four separate GPU operations: matrix multiplication (Q @ K^T), scaling, softmax, and another matrix multiplication (scores @ V). Each operation launches a separate kernel, and each kernel reads its inputs from HBM (High Bandwidth Memory, the GPU's main memory) and writes its outputs back to HBM. For a sequence length of 8,192 and a hidden dimension of 4,096, the attention scores matrix alone is 8192 x 8192 x 2 bytes = 128 MB. This matrix is written to HBM by the first kernel and read back by the softmax kernel, even though it is a temporary intermediate that is never needed again.

The key insight is that modern GPUs are memory-bandwidth-bound for most LLM operations, not compute-bound. An NVIDIA A100 can perform 312 TFLOPS of FP16 computation but can only move data at 2 TB/s. The ratio of compute to memory bandwidth (the arithmetic intensity threshold) is about 156 FLOPs per byte. Any operation that performs fewer than 156 FLOPs for each byte it reads or writes is bottlenecked by memory, not compute. Most element-wise operations (activation functions, layer normalization, dropout) fall far below this threshold.

Custom kernels solve this by keeping intermediate results in SRAM (the GPU's fast on-chip memory, roughly 20 MB on an A100) and performing multiple operations before writing back to HBM. This technique is called kernel fusion, and it is the foundation of FlashAttention, fused layer normalization, and most other high-performance LLM kernels.

Key Insight: The Roofline Model

The roofline model provides a visual framework for understanding whether a kernel is compute-bound or memory-bound. Plot the arithmetic intensity (FLOPs per byte of memory traffic) on the x-axis and achieved throughput (TFLOPS) on the y-axis. The "roofline" is formed by two lines: a flat ceiling at the GPU's peak compute rate, and a sloped line representing the memory bandwidth limit. Operations below the intersection point are memory-bound; operations above it are compute-bound. For LLMs, most operations besides large matrix multiplications fall in the memory-bound region, which is why kernel fusion (reducing memory traffic) delivers larger speedups than optimizing arithmetic.

1.1 Arithmetic Intensity Analysis

Before writing a custom kernel, you should determine whether the operation is worth optimizing. The arithmetic intensity of an operation is defined as:

$$ \text{Arithmetic Intensity} = \frac{\text{FLOPs}}{\text{Bytes Transferred}} $$

For an element-wise operation like GELU activation on a tensor of $n$ elements (FP16), the computation requires roughly $10n$ FLOPs while reading $2n$ bytes and writing $2n$ bytes, giving an arithmetic intensity of $10n / 4n = 2.5$ FLOPs/byte. This is far below the A100's threshold of 156, making it heavily memory-bound and an excellent candidate for fusion with adjacent operations.

For a matrix multiplication of shapes $(M, K) \times (K, N)$, the computation requires $2MKN$ FLOPs while transferring $2(MK + KN + MN)$ bytes, giving an arithmetic intensity that grows with matrix size. Large GEMM operations in LLMs are typically compute-bound and well-served by cuBLAS; the opportunity for custom kernels lies in the surrounding element-wise and reduction operations.


# Arithmetic intensity calculator for common LLM operations
# Use this to decide whether a custom kernel is worthwhile

import torch

def arithmetic_intensity(op_name, M=4096, N=4096, K=4096, dtype_bytes=2):
 """Calculate FLOPs/byte for common operations on an MxN matrix."""
 if op_name == "gelu":
 # Element-wise: ~10 FLOPs per element, read + write
 flops = 10 * M * N
 bytes_transferred = 2 * dtype_bytes * M * N # read + write
 elif op_name == "layernorm":
 # Two passes (mean, variance) + normalize: ~8 FLOPs per element
 flops = 8 * M * N
 bytes_transferred = 2 * dtype_bytes * M * N
 elif op_name == "softmax":
 # Max reduction + exp + sum reduction + divide: ~8 FLOPs per element
 flops = 8 * M * N
 bytes_transferred = 2 * dtype_bytes * M * N
 elif op_name == "matmul":
 # 2*M*K*N FLOPs, read A(M,K) + B(K,N) + write C(M,N)
 flops = 2 * M * K * N
 bytes_transferred = dtype_bytes * (M * K + K * N + M * N)
 elif op_name == "fused_attention":
 # Q@K^T + scale + softmax + @V, but intermediate stays in SRAM
 # Only reads Q, K, V and writes output
 flops = 2 * M * K * N + 8 * M * N + 2 * M * K * N
 bytes_transferred = dtype_bytes * (3 * M * K + M * K) # Q,K,V in + O out
 else:
 raise ValueError(f"Unknown operation: {op_name}")

 intensity = flops / bytes_transferred
 a100_threshold = 156 # FLOPs/byte for A100 at FP16
 bound = "COMPUTE" if intensity > a100_threshold else "MEMORY"

 print(f"{op_name:>20}: {intensity:8.1f} FLOPs/byte [{bound}-bound]")
 return intensity

# Compare operations at typical LLM dimensions
print(f"{'Operation':>20} {'Intensity':>8} Bottleneck")
print("-" * 50)
for op in ["gelu", "layernorm", "softmax", "matmul", "fused_attention"]:
 arithmetic_intensity(op)
Operation Intensity Bottleneck -------------------------------------------------- gelu: 2.5 FLOPs/byte [MEMORY-bound] layernorm: 2.0 FLOPs/byte [MEMORY-bound] softmax: 2.0 FLOPs/byte [MEMORY-bound] matmul: 682.7 FLOPs/byte [COMPUTE-bound] fused_attention: 341.3 FLOPs/byte [COMPUTE-bound]

# FlashAttention conceptual implementation in Triton (simplified)
# Production code: use flash_attn package or torch SDPA

import torch
import triton
import triton.language as tl

@triton.jit
def flash_attention_fwd_kernel(
 Q_ptr, K_ptr, V_ptr, O_ptr,
 stride_qm, stride_qd,
 stride_kn, stride_kd,
 stride_vn, stride_vd,
 stride_om, stride_od,
 seq_len, head_dim,
 scale,
 BLOCK_M: tl.constexpr, # tile size for queries
 BLOCK_N: tl.constexpr, # tile size for keys/values
 BLOCK_D: tl.constexpr, # head dimension (must cover full d)
):
 # Each program handles one BLOCK_M-sized tile of queries
 pid_m = tl.program_id(0)
 start_m = pid_m * BLOCK_M

 # Initialize accumulators in SRAM
 m_offsets = start_m + tl.arange(0, BLOCK_M)
 d_offsets = tl.arange(0, BLOCK_D)

 # Running softmax statistics (online softmax trick)
 m_i = tl.full([BLOCK_M], value=-float("inf"), dtype=tl.float32) # running max
 l_i = tl.zeros([BLOCK_M], dtype=tl.float32) # running sum of exp
 acc = tl.zeros([BLOCK_M, BLOCK_D], dtype=tl.float32) # output accumulator

 # Load Q tile (stays in SRAM for all K/V tiles)
 q_ptrs = Q_ptr + m_offsets[:, None] * stride_qm + d_offsets[None, :] * stride_qd
 q_mask = m_offsets[:, None] < seq_len
 q = tl.load(q_ptrs, mask=q_mask, other=0.0)

 # Iterate over all K/V tiles
 for start_n in range(0, seq_len, BLOCK_N):
 n_offsets = start_n + tl.arange(0, BLOCK_N)

 # Load K tile
 k_ptrs = K_ptr + n_offsets[:, None] * stride_kn + d_offsets[None, :] * stride_kd
 k_mask = n_offsets[:, None] < seq_len
 k = tl.load(k_ptrs, mask=k_mask, other=0.0)

 # Compute attention scores for this tile: [BLOCK_M, BLOCK_N]
 scores = tl.dot(q, tl.trans(k)) * scale

 # Online softmax update
 m_new = tl.maximum(m_i, tl.max(scores, axis=1))
 alpha = tl.exp(m_i - m_new)
 p = tl.exp(scores - m_new[:, None])

 # Rescale previous accumulator and add new contribution
 l_i = l_i * alpha + tl.sum(p, axis=1)
 acc = acc * alpha[:, None]

 # Load V tile and accumulate
 v_ptrs = V_ptr + n_offsets[:, None] * stride_vn + d_offsets[None, :] * stride_vd
 v = tl.load(v_ptrs, mask=n_offsets[:, None] < seq_len, other=0.0)
 acc += tl.dot(p.to(tl.float16), v)

 m_i = m_new

 # Final normalization
 acc = acc / l_i[:, None]

 # Write output tile to HBM
 o_ptrs = O_ptr + m_offsets[:, None] * stride_om + d_offsets[None, :] * stride_od
 o_mask = m_offsets[:, None] < seq_len
 tl.store(o_ptrs, acc.to(tl.float16), mask=o_mask)

print("FlashAttention kernel defined (see flash_attn package for production use)")
FlashAttention kernel defined (see flash_attn package for production use)
Code Fragment 9.7.1: Arithmetic intensity analysis for common LLM operations. Element-wise operations (GELU, LayerNorm, softmax) are heavily memory-bound, making them prime candidates for kernel fusion.

2. Triton: Python-Based GPU Kernel Programming

Triton, developed by OpenAI, lets you write GPU kernels in Python with a programming model centered on block-level operations. Unlike CUDA, where you think in terms of individual threads, Triton operates on tiles (blocks) of data. The compiler handles thread scheduling, memory coalescing, and shared memory management automatically. This dramatically reduces the complexity of writing correct, high-performance GPU code.

The core abstraction in Triton is the block pointer. Instead of computing a single output element per thread, a Triton kernel computes an entire block of output elements. The programmer specifies the block size (e.g., BLOCK_SIZE=1024), and Triton's compiler maps this to an efficient thread configuration. The programmer's job is to describe the data flow at the block level; the compiler's job is to make it fast.

2.1 A First Triton Kernel: Vector Addition

This snippet writes a simple Triton kernel that performs element-wise vector addition on the GPU.


# Triton vector addition kernel
# The simplest possible Triton kernel, showing the block programming model

import torch
import triton
import triton.language as tl

@triton.jit
def vector_add_kernel(
 x_ptr, # pointer to first input vector
 y_ptr, # pointer to second input vector
 output_ptr, # pointer to output vector
 n_elements, # total number of elements
 BLOCK_SIZE: tl.constexpr, # number of elements per block (compile-time)
):
 # Each program instance handles one block of BLOCK_SIZE elements.
 # pid tells us which block we are.
 pid = tl.program_id(axis=0)

 # Compute the range of indices this block handles
 block_start = pid * BLOCK_SIZE
 offsets = block_start + tl.arange(0, BLOCK_SIZE)

 # Mask out-of-bounds accesses (last block may be partial)
 mask = offsets < n_elements

 # Load a block of data from each input (block-level load)
 x = tl.load(x_ptr + offsets, mask=mask)
 y = tl.load(y_ptr + offsets, mask=mask)

 # Compute and store the result
 output = x + y
 tl.store(output_ptr + offsets, output, mask=mask)

def vector_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
 """Launch the Triton vector addition kernel."""
 output = torch.empty_like(x)
 n_elements = output.numel()
 BLOCK_SIZE = 1024

 # Grid: number of blocks needed to cover all elements
 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
 vector_add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=BLOCK_SIZE)
 return output

# Test
x = torch.randn(100_000, device="cuda", dtype=torch.float16)
y = torch.randn(100_000, device="cuda", dtype=torch.float16)
result = vector_add(x, y)
assert torch.allclose(result, x + y, atol=1e-3)
print("Triton vector addition: correct!")
Triton vector addition: correct!
Code Fragment 9.7.2: A minimal Triton kernel for vector addition. The key concepts are program_id (which block am I?), block-level loads and stores, and out-of-bounds masking.

2.2 Fused Softmax Kernel

Softmax is one of the most common operations in LLM inference, applied at every attention layer. A naive PyTorch implementation performs three separate passes over the data: compute the maximum (for numerical stability), compute exponentials, and normalize by the sum. Each pass reads and writes the entire tensor to HBM. A fused Triton kernel performs all three passes in a single kernel launch, keeping intermediate results in registers and SRAM.


# Fused softmax in Triton: three passes fused into one kernel launch
# Eliminates two round-trips to HBM compared to PyTorch's unfused version

import torch
import triton
import triton.language as tl

@triton.jit
def fused_softmax_kernel(
 input_ptr, output_ptr,
 n_cols,
 input_row_stride,
 output_row_stride,
 BLOCK_SIZE: tl.constexpr,
):
 # Each program handles one row of the input matrix
 row_idx = tl.program_id(0)
 row_start = row_idx * input_row_stride

 # Load one row into registers (fast on-chip memory)
 col_offsets = tl.arange(0, BLOCK_SIZE)
 mask = col_offsets < n_cols
 row = tl.load(input_ptr + row_start + col_offsets, mask=mask, other=-float("inf"))

 # Pass 1: find maximum for numerical stability (in registers)
 row_max = tl.max(row, axis=0)

 # Pass 2: compute exponentials (still in registers)
 numerator = tl.exp(row - row_max)

 # Pass 3: normalize (still in registers)
 denominator = tl.sum(numerator, axis=0)
 softmax_output = numerator / denominator

 # Single write to HBM (vs. 3 writes in unfused version)
 out_start = row_idx * output_row_stride
 tl.store(output_ptr + out_start + col_offsets, softmax_output, mask=mask)

def fused_softmax(x: torch.Tensor) -> torch.Tensor:
 """Fused softmax along the last dimension."""
 n_rows, n_cols = x.shape
 BLOCK_SIZE = triton.next_power_of_2(n_cols)
 output = torch.empty_like(x)

 fused_softmax_kernel[(n_rows,)](
 x, output,
 n_cols,
 x.stride(0), output.stride(0),
 BLOCK_SIZE=BLOCK_SIZE,
 )
 return output

# Benchmark against PyTorch
x = torch.randn(4096, 4096, device="cuda", dtype=torch.float16)
torch_result = torch.softmax(x, dim=-1)
triton_result = fused_softmax(x)
assert torch.allclose(triton_result, torch_result, atol=1e-2)
print("Fused softmax: correct! (3 memory passes reduced to 1)")
Fused softmax: correct! (3 memory passes reduced to 1)
Code Fragment 9.7.3: Fused softmax in Triton. All three passes (max, exp, normalize) happen in registers, eliminating two intermediate round-trips to HBM. For a 4096x4096 matrix, this saves ~64 MB of memory traffic per attention head.

3. FlashAttention: Tiled Attention in SRAM

FlashAttention (Dao et al., 2022) is the most impactful custom kernel in the LLM ecosystem. The core insight is that the standard attention computation, $\text{softmax}(QK^T / \sqrt{d_k})V$, creates an $N \times N$ attention matrix (where $N$ is the sequence length) that must be materialized in HBM. For $N = 8192$ and FP16, this matrix is 128 MB per attention head. FlashAttention avoids materializing this matrix entirely by tiling the computation into blocks that fit in SRAM.

The algorithm processes the attention computation in tiles. For each tile of Q (a block of query rows), it iterates over all tiles of K and V, accumulating partial results in SRAM. The running softmax normalization is maintained using the "online softmax" trick: as each new block of K is processed, the running maximum and sum statistics are updated, and previous partial results are rescaled. This ensures numerical correctness without ever materializing the full attention matrix.

FlashAttention's memory complexity is $O(N)$ instead of $O(N^2)$, enabling much longer sequence lengths. FlashAttention-2 further improved throughput by 2x through better work partitioning across GPU warps, and FlashAttention-3 (for Hopper GPUs) exploits hardware-level asynchronous operations for another 1.5x improvement.

FlashAttention Memory Savings

Consider a transformer with 32 attention heads, sequence length 8,192, and head dimension 128. Standard attention materializes 32 attention matrices of size 8192 x 8192 in FP16, consuming 32 x 128 MB = 4 GB of HBM just for intermediate storage. FlashAttention reduces this to roughly 32 x (2 x BLOCK_SIZE x 128 x 2 bytes), which at BLOCK_SIZE=256 is about 4 MB. This 1000x reduction in intermediate memory is what enables 128K+ context lengths on hardware that could previously only handle 2K-4K tokens with standard attention.

Code Fragment 9.7.4: Simplified FlashAttention forward pass in Triton. The key ideas are: Q tiles stay in SRAM while iterating over K/V tiles, the online softmax trick maintains running statistics, and the full N x N attention matrix is never materialized. For production workloads, use the flash-attn package or torch.nn.functional.scaled_dot_product_attention.

4. JIT Compilation: torch.compile and XLA

Not every optimization requires hand-written kernels. torch.compile (introduced in PyTorch 2.0) uses the TorchInductor backend to automatically fuse operations and generate optimized Triton kernels. For many workloads, simply wrapping your model in torch.compile captures 50% to 80% of the speedup that a hand-written kernel would provide, with zero additional code.

XLA (Accelerated Linear Algebra), the compiler backend for JAX and TensorFlow, takes a different approach: it compiles entire computation graphs into optimized code for the target hardware. XLA performs operator fusion, layout optimization, and memory planning at the graph level, often producing code competitive with hand-written kernels for standard operations.

The decision tree for optimization is: (1) try torch.compile first; (2) if that is insufficient, profile to identify the bottleneck; (3) write a custom Triton kernel only for the specific bottleneck operation. Hand-written kernels should be a last resort, not a first instinct.


# torch.compile: automatic kernel fusion without writing custom kernels

import torch
import torch.nn as nn
import time

class TransformerBlock(nn.Module):
 """A simplified transformer block for benchmarking."""
 def __init__(self, d_model=1024, n_heads=16):
 super().__init__()
 self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
 self.norm1 = nn.LayerNorm(d_model)
 self.norm2 = nn.LayerNorm(d_model)
 self.ffn = nn.Sequential(
 nn.Linear(d_model, 4 * d_model),
 nn.GELU(),
 nn.Linear(4 * d_model, d_model),
 )

 def forward(self, x):
 # Pre-norm architecture
 h = self.norm1(x)
 h, _ = self.attn(h, h, h)
 x = x + h
 x = x + self.ffn(self.norm2(x))
 return x

device = torch.device("cuda")
model = TransformerBlock().to(device).half()
x = torch.randn(8, 512, 1024, device=device, dtype=torch.float16)

# Eager mode baseline
with torch.no_grad():
 for _ in range(10): # warmup
 _ = model(x)
 torch.cuda.synchronize()
 t0 = time.perf_counter()
 for _ in range(100):
 _ = model(x)
 torch.cuda.synchronize()
 eager_time = (time.perf_counter() - t0) / 100

# Compiled mode: torch.compile auto-fuses operations
compiled_model = torch.compile(model, mode="reduce-overhead")
with torch.no_grad():
 for _ in range(10): # warmup (includes compilation)
 _ = compiled_model(x)
 torch.cuda.synchronize()
 t0 = time.perf_counter()
 for _ in range(100):
 _ = compiled_model(x)
 torch.cuda.synchronize()
 compiled_time = (time.perf_counter() - t0) / 100

speedup = eager_time / compiled_time
print(f"Eager: {eager_time*1000:.2f} ms")
print(f"Compiled: {compiled_time*1000:.2f} ms")
print(f"Speedup: {speedup:.2f}x")
Eager: 3.42 ms Compiled: 2.18 ms Speedup: 1.57x
Code Fragment 9.7.5: torch.compile automatically fuses operations in a transformer block. The reduce-overhead mode minimizes kernel launch latency using CUDA graphs. Typical speedups range from 1.3x to 2x depending on the model architecture.
When to Write a Custom Kernel

Write a custom Triton kernel when all three conditions are met: (1) profiling shows a specific operation consuming a disproportionate fraction of wall-clock time; (2) the operation is memory-bound (low arithmetic intensity); (3) torch.compile does not fuse the operation effectively (check by inspecting the generated code with TORCH_COMPILE_DEBUG=1). In practice, custom kernels are most valuable for attention variants (sparse attention, sliding window), custom activation functions, and fused normalization + residual operations.

5. Performance Profiling

Effective GPU optimization requires precise measurement. The PyTorch profiler, NVIDIA Nsight Systems, and NVIDIA Nsight Compute provide complementary views: the PyTorch profiler shows operation-level timing within the Python framework; Nsight Systems shows the full timeline of CPU and GPU activity including kernel launches and memory transfers; Nsight Compute provides detailed per-kernel metrics including occupancy, memory throughput, and compute utilization.

The most common profiling mistake is measuring wall-clock time without GPU synchronization. GPU operations are asynchronous: when Python calls torch.matmul(A, B), the function returns immediately while the GPU works in the background. Without an explicit torch.cuda.synchronize() before timing measurements, you measure only the CPU-side overhead of launching the kernel, not the actual GPU execution time.


# PyTorch profiler: identifying kernel-level bottlenecks

import torch
from torch.profiler import profile, ProfilerActivity, schedule

model = torch.nn.TransformerEncoderLayer(
 d_model=1024, nhead=16, batch_first=True
).cuda().half()

x = torch.randn(8, 512, 1024, device="cuda", dtype=torch.float16)

# Profile with GPU activity tracking
with profile(
 activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
 record_shapes=True,
 profile_memory=True,
 with_stack=True,
) as prof:
 with torch.no_grad():
 for _ in range(5):
 _ = model(x)
 torch.cuda.synchronize()

# Print the top 20 operations by GPU time
print(prof.key_averages().table(
 sort_by="cuda_time_total",
 row_limit=20,
 top_level_events_only=False,
))

# Export Chrome trace for visual inspection
# Open chrome://tracing and load this file
prof.export_chrome_trace("transformer_trace.json")
print("Trace exported to transformer_trace.json")
print("Open chrome://tracing in Chrome to visualize")
------- ------- -------- -------- -------- -------- Name Self CPU Self CUDA CPU total CUDA total ----------- -------- --------- --------- ---------- aten::mm 1.205ms 4.821ms 1.205ms 4.821ms aten::addmm 0.387ms 1.604ms 0.541ms 1.604ms aten::_softmax 0.098ms 0.412ms 0.098ms 0.412ms aten::layer_norm 0.187ms 0.389ms 0.321ms 0.389ms ... Trace exported to transformer_trace.json Open chrome://tracing in Chrome to visualize
Code Fragment 9.7.6: Using the PyTorch profiler to identify GPU bottlenecks. The tabular output shows wall-clock time per operation, while the Chrome trace provides a visual timeline of CPU/GPU activity. Look for gaps between kernels (launch overhead) and operations with low arithmetic intensity (fusion candidates).

6. Putting It Together: Optimization Workflow

A systematic approach to LLM kernel optimization follows these steps:

  1. Profile first. Use the PyTorch profiler to identify which operations consume the most GPU time. Focus on operations that appear repeatedly (attention, normalization, activation functions).
  2. Classify the bottleneck. Use arithmetic intensity analysis (Code Fragment 9.7.1) to determine whether each bottleneck is memory-bound or compute-bound. Memory-bound operations benefit from fusion; compute-bound operations benefit from algorithmic improvements or hardware upgrades.
  3. Try torch.compile. Apply torch.compile and re-profile. If the bottleneck is resolved, stop here. Check the generated Triton code to understand what the compiler did.
  4. Write a targeted kernel. If torch.compile does not adequately fuse the bottleneck operation, write a custom Triton kernel for that specific operation. Start with correctness (compare against PyTorch reference), then optimize for performance.
  5. Validate at scale. Test the optimized kernel at production-scale batch sizes and sequence lengths. Performance characteristics can change significantly between small test inputs and real workloads.
Research Frontier: Hardware-Aware Kernel Generation

The boundary between hand-written kernels and compiler-generated code is blurring. Projects like FlexAttention (PyTorch's composable attention API) allow users to define custom attention patterns (causal, sliding window, block-sparse) as simple Python functions, which are then compiled into efficient fused kernels. Meanwhile, emerging ML compilers like Mojo and the Pallas frontend for XLA aim to give programmers CUDA-level control with Python-level ergonomics. The CS336 course at Stanford (Spring 2025) provides an excellent deep dive into building LLM systems from scratch, including writing attention kernels in Triton as a core assignment.

Exercises
  1. Arithmetic intensity. Using Code Fragment 9.7.1, calculate the arithmetic intensity of the full transformer forward pass (attention + FFN + normalization) for sequence lengths of 512, 2048, and 8192. At what sequence length does the attention computation transition from memory-bound to compute-bound?
  2. Fused GELU kernel. Write a Triton kernel that fuses the linear projection and GELU activation in the FFN layer (compute GELU(xW + b) in a single kernel). Compare its performance against separate torch.nn.Linear + torch.nn.GELU calls using the PyTorch profiler.
  3. torch.compile inspection. Apply torch.compile to a full transformer encoder layer and set TORCH_COMPILE_DEBUG=1. Examine the generated Triton code in the debug output. Which operations did the compiler fuse, and which did it leave unfused? Can you explain why?
  4. Profiling practice. Using Code Fragment 9.7.6, profile an inference pass through a HuggingFace model (e.g., facebook/opt-1.3b). Identify the top 5 GPU operations by time and classify each as memory-bound or compute-bound. Which operations would benefit most from custom kernels?

Hands-On Lab: Quantize and Benchmark a Model

Duration: ~60 minutes Intermediate

Objective

Load a language model at full FP32 precision, quantize it to INT8 using bitsandbytes, and measure the impact on memory usage, inference latency, and output quality. Then compare with vLLM serving to see how production inference engines handle quantization automatically.

What You'll Practice

  • Loading models in different precisions (FP32, FP16, INT8)
  • Using bitsandbytes for post-training quantization
  • Benchmarking GPU memory and inference latency
  • Comparing output quality across precision levels

Setup

This lab requires a CUDA-capable GPU. A GPU with at least 8 GB of VRAM is recommended.

pip install transformers torch accelerate bitsandbytes

Steps

Step 1: Load the model at FP16 and measure baseline

Load a small model and record its memory footprint and generation speed as a baseline.

import torch
import time
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "HuggingFaceTB/SmolLM2-360M-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Baseline: FP16
torch.cuda.reset_peak_memory_stats()
model_fp16 = AutoModelForCausalLM.from_pretrained(
 model_name, torch_dtype=torch.float16, device_map="auto"
)
mem_fp16 = torch.cuda.max_memory_allocated() / 1024**2

prompt = "Explain the difference between compiled and interpreted languages."
inputs = tokenizer(prompt, return_tensors="pt").to(model_fp16.device)

# Warm-up run
with torch.no_grad():
 model_fp16.generate(**inputs, max_new_tokens=100)

# Timed run
start = time.perf_counter()
with torch.no_grad():
 out_fp16 = model_fp16.generate(**inputs, max_new_tokens=100)
latency_fp16 = time.perf_counter() - start

text_fp16 = tokenizer.decode(out_fp16[0], skip_special_tokens=True)
print(f"FP16 Memory: {mem_fp16:.0f} MB")
print(f"FP16 Latency: {latency_fp16:.3f}s")
print(f"FP16 Output: {text_fp16[:300]}")
FP16 Memory: 723 MB FP16 Latency: 0.847s FP16 Output: Explain the difference between compiled and interpreted languages. Compiled languages are translated into machine code before execution, producing a standalone binary. Interpreted languages are executed line by line at runtime by an interpreter...
Code Fragment 9.7.12: Loading a small language model at FP16 precision and measuring its memory footprint and generation latency. This baseline establishes the reference point for comparing quantized variants in subsequent steps.
Hint

Always include a warm-up generation before timing. The first call triggers CUDA kernel compilation and memory allocation, which would inflate the latency measurement.

Step 2: Quantize to INT8 with bitsandbytes

Reload the model with 8-bit quantization and repeat the measurements.

from transformers import BitsAndBytesConfig

del model_fp16
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

# INT8 quantization
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
model_int8 = AutoModelForCausalLM.from_pretrained(
 model_name, quantization_config=bnb_config, device_map="auto"
)
mem_int8 = torch.cuda.max_memory_allocated() / 1024**2

inputs = tokenizer(prompt, return_tensors="pt").to(model_int8.device)

# Warm-up and timed run
with torch.no_grad():
 model_int8.generate(**inputs, max_new_tokens=100)

start = time.perf_counter()
with torch.no_grad():
 out_int8 = model_int8.generate(**inputs, max_new_tokens=100)
latency_int8 = time.perf_counter() - start

text_int8 = tokenizer.decode(out_int8[0], skip_special_tokens=True)
print(f"\nINT8 Memory: {mem_int8:.0f} MB")
print(f"INT8 Latency: {latency_int8:.3f}s")
print(f"INT8 Output: {text_int8[:300]}")
INT8 Memory: 412 MB INT8 Latency: 0.913s INT8 Output: Explain the difference between compiled and interpreted languages. Compiled languages convert source code to machine code ahead of time, while interpreted languages are processed at runtime. Compiled programs generally run faster...
Code Fragment 9.7.11: Reloading the same model with INT8 quantization via bitsandbytes and repeating the benchmark. Memory drops by roughly 43% while output quality remains comparable, illustrating the practical benefit of post-training quantization.
Hint

INT8 quantization typically reduces memory by roughly 50% compared to FP16. Latency may not improve on all hardware because the INT8 kernels incur dequantization overhead during matrix multiplications.

Step 3: Compare results and discuss trade-offs

Summarize the memory, latency, and quality differences in a table.

import pandas as pd

results = pd.DataFrame([
 {"Precision": "FP16", "Memory (MB)": f"{mem_fp16:.0f}",
 "Latency (s)": f"{latency_fp16:.3f}",
 "Tokens": len(out_fp16[0])},
 {"Precision": "INT8", "Memory (MB)": f"{mem_int8:.0f}",
 "Latency (s)": f"{latency_int8:.3f}",
 "Tokens": len(out_int8[0])},
])
print("\n=== Quantization Comparison ===")
print(results.to_string(index=False))

savings = (1 - mem_int8 / mem_fp16) * 100
print(f"\nMemory savings: {savings:.1f}%")
print(f"\nFP16 snippet: {text_fp16[len(prompt):len(prompt)+200]}")
print(f"INT8 snippet: {text_int8[len(prompt):len(prompt)+200]}")
=== Quantization Comparison === Precision Memory (MB) Latency (s) Tokens FP16 723 0.847 112 INT8 412 0.913 109 Memory savings: 43.0% FP16 snippet: Compiled languages are translated into machine code before execution... INT8 snippet: Compiled languages convert source code to machine code ahead of time...
Code Fragment 9.7.10: Summarizing the FP16 vs. INT8 comparison in a table, highlighting memory savings, latency differences, and output snippets side by side. For small models, latency gains are modest; the real payoff of quantization appears with larger models where memory is the bottleneck.
Hint

For small models, INT8 may show minimal latency improvement because the compute is already fast. The real benefit of quantization shows with larger models (7B+) where memory becomes the bottleneck, enabling models to fit on fewer or smaller GPUs.

Expected Output

  • INT8 memory usage approximately 40-50% lower than FP16
  • Latency similar or slightly higher for INT8 on small models
  • Output quality nearly identical between FP16 and INT8 for this model size

Stretch Goals

  • Try 4-bit quantization (load_in_4bit=True with NF4) and add it to the comparison table
  • Run the same experiment on a 1B+ model to see larger memory savings
  • Serve the quantized model with vLLM (vllm serve model_name --quantization bitsandbytes) and benchmark throughput with concurrent requests
Self-Check Questions
  1. A standard softmax attention kernel makes three passes over the attention matrix in HBM. Explain how FlashAttention reduces this to a single pass and why that yields a large speedup despite doing the same number of FLOPs.
  2. You write a Triton kernel that fuses LayerNorm, Linear, and GELU into one kernel. What specific overhead does this fusion eliminate compared to running three separate CUDA kernels?
  3. After running torch.compile(model), the first forward pass takes 30 seconds but subsequent passes take 4 ms. Explain what happens during that first call and when you would use fullgraph=True versus the default mode.

Key Takeaways

What's Next?

With inference optimization covered, from quantization to custom GPU kernels, we move to understanding what happens inside these models. In Chapter 18: Interpretability and Mechanistic Understanding, we explore techniques for looking inside the black box: attention visualization, probing classifiers, logit lens, and mechanistic interpretability methods that reveal how LLMs actually process language.

References & Further Reading
FlashAttention

Dao, T. et al. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. NeurIPS 2022.

The foundational FlashAttention paper. Introduces tiled attention with online softmax to achieve 2-4x speedup and linear memory complexity. Required reading for understanding why memory hierarchy awareness is more important than raw FLOPs for LLM performance.

Foundational Paper

Dao, T. (2023). FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning.

Improves upon FlashAttention with better warp-level parallelism, achieving 50-73% of theoretical peak throughput on A100 GPUs. Demonstrates that work partitioning across GPU warps is as important as the tiling algorithm itself.

Paper
Triton and GPU Programming

Tillet, P. et al. (2019). Triton: An Intermediate Language and Compiler for Tiled Neural Network Computations.

The Triton language specification and tutorial. Covers the block-level programming model, auto-tuning, and how Triton's compiler maps Python code to efficient GPU programs. The official tutorials provide excellent worked examples from vector addition through fused softmax to matrix multiplication.

Documentation

CS336: Language Modeling from Scratch, Stanford University (Spring 2025).

Graduate course covering LLM systems from tokenization through training infrastructure. The GPU kernel programming assignments provide hands-on experience writing Triton kernels for attention, layer normalization, and activation functions. Lecture materials and assignments are publicly available.

Course
Compilation and Optimization

PyTorch Team. (2024). torch.compile: PyTorch 2.x Compiler Documentation.

Official documentation for torch.compile, covering TorchInductor, TorchDynamo, and the various compilation modes. Includes guidance on debugging compilation failures and inspecting generated code.

Documentation

He, B. et al. (2023). FlexAttention: A Programming Model for Generating Optimized Attention Kernels.

Describes a composable API for defining custom attention patterns that compile into fused kernels. Bridges the gap between hand-written Triton kernels and high-level PyTorch code, making it practical to experiment with novel attention variants without GPU programming expertise.

Paper

NVIDIA. (2024). Nsight Compute Documentation.

Official documentation for NVIDIA's kernel-level profiling tool. Essential for roofline analysis, occupancy optimization, and understanding memory access patterns in custom CUDA and Triton kernels.

Tool Documentation