Training Loop Deep Dive

Section E.5

The canonical PyTorch training loop fits in twelve lines. Every word in it carries weight, every order matters, and every default has been tuned over years of community experience. This section dissects the loop, then layers on the extras that distinguish a research prototype from a production training script: gradient clipping, learning-rate schedulers, gradient accumulation, and checkpointing patterns.

The Canonical Loop

import torch
import torch.nn as nn
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = NeuralNetwork(num_inputs=2, num_outputs=2).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.5)
num_epochs = 3

for epoch in range(num_epochs):
    model.train()
    for batch_idx, (features, labels) in enumerate(train_loader):
        features = features.to(device, non_blocking=True)
        labels   = labels.to(device, non_blocking=True)

        logits = model(features)
        loss = F.cross_entropy(logits, labels)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

    model.eval()
Output (example): (no stdout; train loss falls from 0.75 at epoch 0 to near 0 at epoch 2)
Code Fragment E.5.1: The canonical PyTorch training loop. Move batches to the device, forward, loss, zero-grad, backward, step. The order of zero_grad, backward, step is non-negotiable.

The three lines in the inner loop encode the heart of gradient-based learning. optimizer.zero_grad(set_to_none=True) erases gradients from the previous iteration; without it, .grad would accumulate across batches (a feature exploited by gradient accumulation, but a bug when unintended). loss.backward() populates .grad on every parameter that participated in the forward pass. optimizer.step() reads each .grad and updates the corresponding parameter. Swapping any two of these lines silently breaks training.

Warning: cross_entropy Wants Logits, Not Probabilities

PyTorch's F.cross_entropy and nn.CrossEntropyLoss apply log-softmax internally for numerical stability. Pass raw logits, not the output of torch.softmax. Applying softmax before cross-entropy double-softmaxes the predictions, yielding a much weaker training signal and degraded accuracy. The same applies to F.binary_cross_entropy_with_logits versus F.binary_cross_entropy: the _with_logits variant fuses sigmoid and BCE for stability; the bare variant assumes probabilities. The fused variants are almost always the correct choice.

Gradient Clipping

Gradient clipping caps the size of the gradient before the optimizer step. It is essential for any sequence model or any deep model trained with high learning rates, where occasional exploding gradients (caused by rare but extreme batches) can destabilize training in a single step. PyTorch offers two variants:

Both functions mutate gradients in place and return the (pre-clipping) total norm, which is worth logging because a steadily rising norm signals trouble.

import torch

optimizer.zero_grad(set_to_none=True)
loss.backward()
# Cap the global gradient norm at 1.0. The returned value is the
# pre-clipping norm; log it to monitor training stability.
total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
Output: (no stdout; gradients are clipped in place)
Code Fragment E.5.2: Inserting global-norm gradient clipping between backward and step. The clipping must happen after backward (gradients exist) and before step (so the optimizer sees the clipped values).

Learning Rate Schedulers

A learning rate scheduler adjusts optimizer.param_groups[i]["lr"] over time. PyTorch ships several pre-built schedules in torch.optim.lr_scheduler; the most useful for deep learning are:

Whichever scheduler is used, the pattern is the same: build it after the optimizer, call scheduler.step() after optimizer.step() (not before, because step() reads the current LR). Most schedulers step per iteration; ReduceLROnPlateau steps per epoch and takes the monitored metric as an argument.

import torch
from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR

warmup_steps = 500
total_steps  = 10_000

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
warmup   = LinearLR(optimizer, start_factor=0.01, end_factor=1.0,
                    total_iters=warmup_steps)
decay    = CosineAnnealingLR(optimizer, T_max=total_steps - warmup_steps,
                             eta_min=3e-5)
scheduler = SequentialLR(optimizer, schedulers=[warmup, decay],
                         milestones=[warmup_steps])

for step in range(total_steps):
    train_one_step()                      # forward, backward, optimizer.step
    scheduler.step()                      # advance the LR by one step
Output: (no stdout; LR rises linearly to 3e-4 over 500 steps, then cosine-decays to 3e-5)
Code Fragment E.5.3: Warmup followed by cosine decay, composed with SequentialLR. The same pattern appears in essentially every modern transformer training recipe.

Gradient Accumulation

Gradient accumulation simulates a larger batch by running several forward and backward passes before each optimizer step. It is the standard trick when the target batch size does not fit in GPU memory but the recipe demands it (large-batch training for stability, contrastive losses that need many negatives, and so on). The implementation is a small twist on the canonical loop: skip zero_grad and step until accumulation_steps mini-batches have been processed.

accumulation_steps = 4

for step, (features, labels) in enumerate(train_loader):
    logits = model(features.to(device))
    # Scale the loss so its sum across micro-batches equals the
    # loss that would have been computed for the full virtual batch.
    loss = F.cross_entropy(logits, labels.to(device)) / accumulation_steps
    loss.backward()

    if (step + 1) % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)
        scheduler.step()
Output: (no stdout; effective batch size is `accumulation_steps * micro_batch`)
Code Fragment E.5.4: Gradient accumulation. Dividing the per-micro-batch loss by accumulation_steps keeps the gradient magnitude equivalent to the full virtual batch.
Warning: BatchNorm and Gradient Accumulation Do Not Mix

BatchNorm computes statistics over the actual mini-batch, not the virtual one. Accumulating gradients over four micro-batches of size 8 is not equivalent to one batch of 32 from BatchNorm's perspective: each micro-batch sees its own statistics. The downstream consequence is that BatchNorm-equipped models trained with gradient accumulation often diverge or underfit. Workarounds: switch to LayerNorm or GroupNorm (which compute per-sample statistics), use SyncBatchNorm in distributed training (which aggregates across replicas), or avoid accumulation when BatchNorm is essential. Modern transformer training uses LayerNorm-like normalizations, which sidesteps the problem entirely.

The set_to_none Optimization

The default in modern PyTorch is optimizer.zero_grad(set_to_none=True), which sets each param.grad to None rather than writing zeros into the existing tensor. This saves a write per parameter per step, which is small individually but adds up across millions of parameters and millions of steps. It also lets PyTorch allocate the gradient afresh in the backward pass, sometimes enabling sparse-gradient optimizations.

The only caveat: code that inspects param.grad directly must handle the None case. Gradient norm monitoring, sparse logging, custom regularizers that read .grad, all need a if p.grad is not None guard. PyTorch's built-in clip_grad_norm_ already handles this correctly.

Library Shortcut
Accelerate for One Training Loop, Any Hardware

Hugging Face's accelerate library lets a single training script run on CPU, one GPU, multi-GPU DDP, FSDP, or DeepSpeed without code changes. Wrap the model, optimizer, scheduler, and loader in accelerator.prepare(), replace loss.backward() with accelerator.backward(loss), and launch with accelerate launch script.py. No torchrun ceremony, no manual device placement.

from accelerate import Accelerator

accelerator = Accelerator(mixed_precision="bf16")
model, optimizer, loader, scheduler = accelerator.prepare(
    model, optimizer, loader, scheduler
)

for batch in loader:
    optimizer.zero_grad(set_to_none=True)
    loss = compute_loss(model, batch)
    accelerator.backward(loss)
    optimizer.step()
    scheduler.step()

Checkpointing Patterns

A checkpoint records enough state to resume training exactly where it stopped. The minimum is the model's state_dict; for resuming training, the optimizer's state_dict, the scheduler's state_dict, the current step, and the RNG state are all required. Saving everything in a single dict keeps the bookkeeping simple.

import os, torch

def atomic_save(state, path):
    """Write to a temp file, then rename: prevents corrupted checkpoints."""
    tmp = path + ".tmp"
    torch.save(state, tmp)
    os.replace(tmp, path)              # atomic on POSIX and Windows NTFS

def save_checkpoint(path, model, optimizer, scheduler, step, best_val):
    atomic_save({
        "model":     model.state_dict(),
        "optimizer": optimizer.state_dict(),
        "scheduler": scheduler.state_dict(),
        "step":      step,
        "best_val":  best_val,
        "rng_cpu":   torch.get_rng_state(),
        "rng_cuda":  torch.cuda.get_rng_state_all(),
    }, path)
Output: (no stdout; checkpoint written to disk)
Code Fragment E.5.5: An atomic checkpoint save. Writing to a temp file and then renaming guarantees that path always contains a complete file; a job killed mid-write leaves the previous checkpoint intact.

Two checkpoints are usually maintained: a "last" checkpoint overwritten every N steps for crash recovery, and a "best" checkpoint updated only when validation improves. The best checkpoint is what gets evaluated on the test set and shipped to production; the last enables resumption after preemption. Some shops also keep a rolling window of the last few "last" checkpoints to defend against the rare case where the latest one captures a momentary instability.

Key Insight

The four lines that matter inside the training loop are zero_grad, forward, backward, step. Around them, three additions transform a toy script into a production trainer: gradient clipping (one line, prevents catastrophic single-step divergence), a warmup-plus-decay LR scheduler (a few lines, dramatically improves convergence), and atomic checkpointing (a dozen lines, prevents lost training time after crashes). Gradient accumulation simulates large batches when memory is tight, but is incompatible with BatchNorm and requires the per-micro-batch loss to be scaled by 1 / accumulation_steps.

Lab E.5: A Production-Quality Training Loop for FashionMNIST

Objective

Convert a 20-line toy training script into a production-quality loop with warmup, cosine decay, gradient clipping, gradient accumulation, AMP, atomic best-and-last checkpointing, and clean resumption after simulated crash. Use FashionMNIST so a full run completes on a single CPU or modest GPU in under 10 minutes.

Setup

  1. Install torch, torchvision. Confirm CUDA or MPS availability if present; the lab must also pass on CPU.
  2. Load FashionMNIST with the standard normalization and an 80 to 20 train-validation split.
  3. Define a small CNN: two conv-relu-pool blocks then a 128-unit linear head with 10 outputs.

Steps

  1. Baseline loop. Implement the four-line core (zero_grad, forward, backward, step) with AdamW(lr=1e-3). Train for 5 epochs. Record final validation accuracy and per-epoch wall time.
  2. Add a scheduler. Wrap with OneCycleLR(max_lr=3e-3, total_steps=...). Confirm via scheduler.get_last_lr() that the LR rises then falls. Compare final accuracy against the baseline.
  3. Add gradient clipping. Insert nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) after backward(). Log the pre-clip global norm each step and plot a histogram. Confirm that less than 5 percent of steps trigger clipping in steady state.
  4. Add gradient accumulation. Configure physical_batch=32 with accumulation_steps=4. Verify the loss is scaled by 1/4, and optimizer.step() fires once per 4 micro-batches.
  5. Add AMP. Wrap the forward pass in torch.autocast(device_type) and use a GradScaler (only on CUDA). Measure throughput change.
  6. Atomic checkpointing. Save a "last" checkpoint every 100 steps and a "best" checkpoint whenever validation accuracy improves. Write to a temp path and os.replace to the final path. The checkpoint dict must contain model, optimizer, scheduler, scaler, step, and rng_state.
  7. Crash and resume. Kill the training process partway through epoch 3 (use os._exit(1) from a callback). Restart; confirm the loop resumes at the next step after the last checkpoint and that validation accuracy continues smoothly. Plot the full curve across both runs as a single line.

Stretch Goals

Expected Output

Expected time: 3 to 4 hours. Difficulty: intermediate. Artifact: a single training script (~150 lines) plus the per-step log, a clipping-norm histogram, and the resume-across-crash accuracy curve.

Further Reading

Training Loop References

PyTorch Documentation: torch.optim. The full catalog of optimizers and schedulers, with the parameter-group API that powers per-layer learning rates and weight decays.
Smith, L. N. (2017). "Cyclical Learning Rates for Training Neural Networks." WACV 2017. arXiv:1506.01186. The paper behind OneCycleLR. Practical justification for the warmup-then-anneal shape that dominates modern recipes.
Pascanu, R., Mikolov, T., & Bengio, Y. (2013). "On the difficulty of training Recurrent Neural Networks." ICML 2013. arXiv:1211.5063. Introduced global-norm gradient clipping. The reasoning generalizes to any deep model with occasional gradient explosions.
PyTorch Recipe: Saving and Loading a General Checkpoint. Official walkthrough of the dict-of-state-dicts pattern used in production training scripts.