Debugging Recipes

Section E.9
A cheerful cartoon detective with a magnifying glass exploring a colorful maze of pipes and ducts, spotlighting a tiny green glowing drip of liquid leaking from one junction, with numbers and symbols floating along other pipes, illustrating NaN and OOM hunting
Debugging deep learning is plumbing inspection: the bug almost always hides in a junction labeled NaN, OOM, or shape mismatch, and the magnifying glass beats the wrench every time.

Deep learning code fails in characteristic ways. Tensors have the wrong shape. Gradients turn into NaN. Training runs OOM at unexpected moments. Models give different results on every run despite a fixed seed. This section catalogs the five most common categories of PyTorch failures, the diagnostics that localize each, and the patterns that resolve them. The goal is to turn cryptic stack traces into a small number of well-rehearsed playbooks.

Recipe: Hunting NaN and Inf

A NaN in the loss is the most demoralizing failure mode because it propagates instantly: one NaN in a forward pass becomes NaN logits, NaN loss, NaN gradients, and NaN parameters. The next iteration's forward pass produces all NaNs from those parameters, and recovery is impossible. The fix is to localize the source.

Step 1: Confirm it is happening. Insert assert torch.isfinite(loss).all() after the loss computation and assert all(torch.isfinite(p.grad).all() for p in model.parameters() if p.grad is not None) after backward. The earliest assertion that fires identifies which side of the optimizer the NaN appeared on.

Step 2: Localize within the forward pass. Register forward hooks (Section E.2) on every layer that record output.isfinite().all(). The first layer to produce a non-finite output is the culprit.

Step 3: Use anomaly detection for the backward pass. Wrap the training step in with torch.autograd.detect_anomaly(): and the framework will record forward-pass stack traces and replay them whenever a non-finite gradient is produced, identifying the precise operation. The overhead is large; use only during debugging.

Warning: Anomaly Detection Is a 10x Slowdown

torch.autograd.detect_anomaly wraps every forward op in a stack-trace recorder and re-validates every backward result. The slowdown is typically 5x to 10x; on a large model it can turn a one-minute step into ten. Reserve it for the single iteration where the NaN appears, ideally after narrowing the suspect range with forward hooks. Never leave it enabled in a production training run, and never wrap an entire epoch when one step suffices.

import torch

def attach_nan_hooks(model):
    """Print the first layer whose forward output contains NaN/Inf."""
    def make_hook(name):
        def hook(module, inputs, output):
            t = output if torch.is_tensor(output) else output[0]
            if not torch.isfinite(t).all():
                bad = (~torch.isfinite(t)).sum().item()
                raise RuntimeError(
                    f"Non-finite values in `{name}`: {bad} entries, "
                    f"min={t.min()}, max={t.max()}"
                )
        return hook
    handles = []
    for name, module in model.named_modules():
        if len(list(module.children())) == 0:        # leaf modules only
            handles.append(module.register_forward_hook(make_hook(name)))
    return handles

handles = attach_nan_hooks(model)
try:
    loss = compute_loss(model(features), labels)   # raises if NaN appears
finally:
    for h in handles:
        h.remove()
Output (on failure): RuntimeError: Non-finite values in `transformer.h.5.attn`: 12 entries, ...
Code Fragment E.9.1: Forward-pass NaN/Inf detector. Raises with the exact layer name as soon as a non-finite output is produced, narrowing the search from "somewhere in the model" to "this specific layer."

The most common root causes of NaN in transformers are: a learning rate too high (especially without warmup; cosine into 0 is safer than linear into 0), missing gradient clipping (a single large batch can push gradients to infinity), softmax over very large logits (FP16 overflow; switch to bfloat16 or scale the inputs), a log over zero (numerical guard with x + epsilon), and a square root over a negative number (also a numerical guard).

Recipe: Diagnosing OOM

OOM means "tried to allocate more GPU memory than is free." The diagnostic question is: what is using the memory? The order of investigation is: (1) confirm batch size and sequence length; (2) check whether the OOM is on the first iteration (sizing problem) or later (leak); (3) capture a memory snapshot (Section E.8) and visualize it.

Once the culprit is identified, the standard mitigations are: lower batch size and use gradient accumulation to keep the effective batch size; enable activation checkpointing (significant memory reduction at the cost of an extra forward per backward); switch optimizer to a low-memory variant (bitsandbytes.optim.AdamW8bit stores Adam's momentum in 8-bit); shard the model with FSDP (Section E.7); offload to CPU memory with torch.cuda.amp.GradScaler or DeepSpeed's CPU-offload back end; reduce sequence length or pad to bucket boundaries.

Warning: torch.cuda.empty_cache Is Not What You Think

torch.cuda.empty_cache() returns cached but unallocated blocks to the system; it does not free memory that is currently in use by tensors. Sprinkling it through a training loop almost never helps and can slow training because the allocator has to re-request memory. The correct response to OOM is to find what is holding the memory, not to call empty_cache. The one useful place to call it is between training and evaluation if the eval workload has a fundamentally different memory profile and the cached blocks would not be reused.

Recipe: Shape Mismatches

Shape errors are the most common runtime failures. The error message reports the expected and actual shapes; the trick is to find where each shape came from. The fastest tool is to print the shape of every tensor at every layer boundary; print statements are crude but unbeatably fast for shape debugging.

A more durable approach is to annotate tensor shapes with named dimensions in comments (or with the einops library, which makes the convention machine-checked). The pattern # x: (B, L, D) alongside every tensor declaration trains the eye to spot mismatches at code-review time. For more advanced static-shape checking, the torchtyping library lets shape constraints be expressed in type annotations and verified at runtime.

import torch
from einops import rearrange

# Bad: a chain of reshape/transpose that is hard to read.
x = x.view(B, L, H, D).transpose(1, 2).reshape(B * H, L, D)

# Good: einops makes the rearrangement explicit and self-documenting.
x = rearrange(x, "b l (h d) -> (b h) l d", h=H)

# Even better: assert what you expect.
assert x.shape == (B * H, L, D), f"Got {x.shape}, expected {(B * H, L, D)}"
Output: (no stdout when shapes match; assertion error if they do not)
Code Fragment E.9.2: Shape-safe tensor reshaping with einops and explicit assertions. The pattern reads as the math, fails loudly when wrong, and survives refactoring.

Recipe: Common Autograd Errors

Two autograd-related errors dominate beginner experience and a third is less common but harder to fix.

RuntimeError: leaf variable that requires grad has been used in an in-place operation
Some code mutated a parameter (or a tensor with requires_grad=True) in place. The fix is to use the out-of-place variant or, if mutation is intentional (optimizer step), wrap the mutation in with torch.no_grad():. Optimizer implementations use this pattern; user code rarely needs it.
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
An intermediate tensor that backward needs was overwritten by an in-place op (typically a += or .relu_()). The fix is to switch to out-of-place; the slight memory increase is far better than a broken backward pass. Use anomaly detection to identify the exact operation if the trace is not obvious.
RuntimeError: Trying to backward through the graph a second time
A second backward call on the same loss without retain_graph=True. Either intentional (in which case pass retain_graph=True) or unintentional (a stale graph reference, often from a hook closure holding the output). Investigate which one applies before adding retain_graph reflexively; the latter usually indicates a memory leak.
Library Shortcut: torchinfo for Instant Model Summaries

When the shape mismatch is buried somewhere inside a deep submodule, do not chain print statements. torchinfo.summary walks the module tree with a dummy input and prints a Keras-style table of every layer's input shape, output shape, and parameter count. One call usually localizes the offending layer.

from torchinfo import summary

summary(model, input_size=(4, 3, 224, 224),
        col_names=["input_size", "output_size", "num_params"])
Fun Fact: A Mildly Paranoid PyTorch Module

After enough debugging sessions, a PyTorch module starts to develop habits: an assert at every shape boundary, a torch.isfinite check on every loss, a forward hook that whispers the activation norm into a log file. Slightly paranoid. Also: never blows up at 3 AM on the cluster. The discipline pays back many times over the first time a 12-hour training run does not have to be restarted.

Recipe: Reproducibility

"Run the same code twice, get the same numbers" is harder than it looks. PyTorch has at least four independent sources of randomness: Python's random, NumPy's random, PyTorch's CPU generator, and PyTorch's CUDA generator(s), one per device. Several CUDA kernels also use non-deterministic algorithms by default. The full reproducibility recipe is therefore a small ritual.

import os, random, numpy as np, torch

def set_reproducible(seed=42):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # cuDNN: pick deterministic algorithms.
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # PyTorch: error if a non-deterministic op is required.
    torch.use_deterministic_algorithms(True, warn_only=True)
    # Some ops need a CUBLAS workspace config for determinism.
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

set_reproducible(42)
Output: (no stdout; subsequent random operations are seeded)
Code Fragment E.9.3: Full reproducibility setup. Achieves bit-for-bit reproduction at a performance cost (deterministic cuDNN algorithms are typically slower than the autotuned non-deterministic ones).
Warning: Reproducibility Has a Cost

Deterministic mode forces PyTorch to pick the deterministic implementation of every operation, which is often slower (sometimes much slower) than the default. torch.backends.cudnn.benchmark = False disables the autotuner that finds the fastest kernel for the current shape. The combination can halve throughput. Reserve determinism for debugging and for the rare research project where exact reproduction matters; production training does not need it. Even with all of the above, results may still differ across GPU models (a 4090 and an A100 use different kernels) or across PyTorch versions; cross-hardware reproducibility is essentially impossible.

General Debugging Discipline

A handful of habits pay back many times over.

  1. Overfit a single batch. A correctly implemented model should be able to drive the loss to near zero on a single repeated batch within a hundred steps. If it cannot, the bug is in the model or the optimizer, not the data or the training loop.
  2. Train on a tiny subset before scaling. Reducing the dataset to 100 samples lets a full training run complete in seconds; many bugs (wrong loss, wrong target shape, label miscoding) reveal themselves immediately.
  3. Log the gradient norm. A norm that rises without bound is divergence in progress; a norm that collapses to near zero is a model that has stopped learning. Both deserve investigation before training continues for hours.
  4. Save a checkpoint at step zero. If a debugging session reveals the bug is in the model construction, the step-zero checkpoint is the ground truth to compare against.
  5. Read the error message twice. PyTorch's error messages are often unusually informative (especially shape mismatches and CUDA errors); a careful read is faster than reaching for the debugger.
Key Insight

Most PyTorch failures fall into five buckets: NaN/Inf, OOM, shape mismatch, autograd error, and non-reproducibility. Each has a corresponding playbook: forward hooks plus anomaly detection for NaNs; memory snapshot plus activation checkpointing for OOM; einops plus assertions for shapes; reading the error message and switching to out-of-place ops for autograd errors; the full deterministic ritual for reproducibility. The discipline that pays back fastest is to overfit a single batch on a tiny dataset before scaling up; many bugs reveal themselves in the first minute of that exercise.

Exercise E.9.1: Plant a NaN and Hunt It Down

Objective. Practice the NaN-hunting recipe end-to-end on a controlled bug.

Task. Take a 3-layer MLP. Inside the second layer's forward, insert x = torch.log(x - x.detach().min() + 1e-9) when a global counter exceeds 50 steps. Train on a small batch loop. When validation loss becomes NaN, do not just print the loss; (a) register a forward hook on every layer that asserts torch.isfinite on every output; (b) wrap the training step in torch.autograd.set_detect_anomaly(True) to catch the NaN producer; (c) identify the offending line from the stack trace and report the step number.

Expected outcome. The hook fires first on layer 2's output; the anomaly stack trace points to the log call. Report both, then fix by adding a positive lower bound on the input.

Exercise E.9.2: Overfit a Single Batch as a Smoke Test

Objective. Use the most powerful debugging discipline: confirm a model can overfit one batch before scaling to a full dataset.

Task. Take any model and dataset. Sample a single batch of 8 items. Disable shuffle and dataloader workers. Train for 200 steps on that batch alone, logging loss every 10 steps. Plot the curve. A correctly wired model should drop the loss to near zero. If it does not, the bug is in the model, loss, or optimizer, not in the data pipeline.

Hint. If the loss stalls above zero, three suspects in order: a frozen layer you did not intend to freeze, a learning rate too low for a tiny problem, a loss reduction (mean vs sum) inconsistent with the gradient magnitude. Walk all three before suspecting the model architecture.

Further Reading

Debugging References

PyTorch Notes: Reproducibility. The authoritative reference for the full reproducibility ritual, including a table of non-deterministic operations and the environment variables that force CUBLAS determinism.
PyTorch Notes: CUDA Semantics. Memory allocator behavior, asynchronous execution, and the subtle interactions that produce confusing CUDA errors. Required reading for anyone diagnosing CUDA-side bugs.
Rogozhnikov, A. (2020). "einops: clear and reliable tensor operations." The library and its accompanying tutorial. The single biggest improvement most PyTorch codebases can make to shape-related readability.
Karpathy, A. (2019). "A Recipe for Training Neural Networks." The canonical blog post on the disciplined process of training a neural network. The "overfit a single batch" rule comes from here; the rest of the post is equally worth reading.