The canonical PyTorch training loop fits in twelve lines. Every word in it carries weight, every order matters, and every default has been tuned over years of community experience. This section dissects the loop, then layers on the extras that distinguish a research prototype from a production training script: gradient clipping, learning-rate schedulers, gradient accumulation, and checkpointing patterns.
The Canonical Loop
import torch
import torch.nn as nn
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = NeuralNetwork(num_inputs=2, num_outputs=2).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.5)
num_epochs = 3
for epoch in range(num_epochs):
model.train()
for batch_idx, (features, labels) in enumerate(train_loader):
features = features.to(device, non_blocking=True)
labels = labels.to(device, non_blocking=True)
logits = model(features)
loss = F.cross_entropy(logits, labels)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
model.eval()
zero_grad, backward, step is non-negotiable.The three lines in the inner loop encode the heart of gradient-based learning. optimizer.zero_grad(set_to_none=True) erases gradients from the previous iteration; without it, .grad would accumulate across batches (a feature exploited by gradient accumulation, but a bug when unintended). loss.backward() populates .grad on every parameter that participated in the forward pass. optimizer.step() reads each .grad and updates the corresponding parameter. Swapping any two of these lines silently breaks training.
PyTorch's F.cross_entropy and nn.CrossEntropyLoss apply log-softmax internally for numerical stability. Pass raw logits, not the output of torch.softmax. Applying softmax before cross-entropy double-softmaxes the predictions, yielding a much weaker training signal and degraded accuracy. The same applies to F.binary_cross_entropy_with_logits versus F.binary_cross_entropy: the _with_logits variant fuses sigmoid and BCE for stability; the bare variant assumes probabilities. The fused variants are almost always the correct choice.
Gradient Clipping
Gradient clipping caps the size of the gradient before the optimizer step. It is essential for any sequence model or any deep model trained with high learning rates, where occasional exploding gradients (caused by rare but extreme batches) can destabilize training in a single step. PyTorch offers two variants:
torch.nn.utils.clip_grad_norm_(params, max_norm)scales every gradient by the same factor so the total $L_2$ norm of all gradients combined is at mostmax_norm. This is the standard choice for transformers and most modern training recipes; the typical setting ismax_norm=1.0.torch.nn.utils.clip_grad_value_(params, clip_value)clamps each individual gradient element to[-clip_value, clip_value]. Coarser than norm clipping; rarely used in modern recipes.
Both functions mutate gradients in place and return the (pre-clipping) total norm, which is worth logging because a steadily rising norm signals trouble.
import torch
optimizer.zero_grad(set_to_none=True)
loss.backward()
# Cap the global gradient norm at 1.0. The returned value is the
# pre-clipping norm; log it to monitor training stability.
total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
backward and step. The clipping must happen after backward (gradients exist) and before step (so the optimizer sees the clipped values).Learning Rate Schedulers
A learning rate scheduler adjusts optimizer.param_groups[i]["lr"] over time. PyTorch ships several pre-built schedules in torch.optim.lr_scheduler; the most useful for deep learning are:
LinearLR(start_factor, end_factor, total_iters): linearly interpolates the multiplier fromstart_factortoend_factorover the specified number of iterations. Used for warmup phases.CosineAnnealingLR(T_max, eta_min): smoothly decays from the initial LR toeta_minfollowing half a cosine wave overT_maxiterations. The standard decay schedule for transformer training.OneCycleLR(max_lr, total_steps): implements Smith's one-cycle policy: warmup tomax_lr, then anneal toward a very small final LR. Aggressive and often the fastest schedule to converge on a fixed budget.ReduceLROnPlateau(factor, patience): monitors a metric and drops the LR byfactorwhen it plateaus. Useful when the total iteration count is unknown.SequentialLR: chains other schedulers in sequence. The standard transformer recipe isSequentialLR([LinearLR(warmup), CosineAnnealingLR(decay)], milestones=[warmup_steps]).
Whichever scheduler is used, the pattern is the same: build it after the optimizer, call scheduler.step() after optimizer.step() (not before, because step() reads the current LR). Most schedulers step per iteration; ReduceLROnPlateau steps per epoch and takes the monitored metric as an argument.
import torch
from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR
warmup_steps = 500
total_steps = 10_000
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
warmup = LinearLR(optimizer, start_factor=0.01, end_factor=1.0,
total_iters=warmup_steps)
decay = CosineAnnealingLR(optimizer, T_max=total_steps - warmup_steps,
eta_min=3e-5)
scheduler = SequentialLR(optimizer, schedulers=[warmup, decay],
milestones=[warmup_steps])
for step in range(total_steps):
train_one_step() # forward, backward, optimizer.step
scheduler.step() # advance the LR by one step
SequentialLR. The same pattern appears in essentially every modern transformer training recipe.Gradient Accumulation
Gradient accumulation simulates a larger batch by running several forward and backward passes before each optimizer step. It is the standard trick when the target batch size does not fit in GPU memory but the recipe demands it (large-batch training for stability, contrastive losses that need many negatives, and so on). The implementation is a small twist on the canonical loop: skip zero_grad and step until accumulation_steps mini-batches have been processed.
accumulation_steps = 4
for step, (features, labels) in enumerate(train_loader):
logits = model(features.to(device))
# Scale the loss so its sum across micro-batches equals the
# loss that would have been computed for the full virtual batch.
loss = F.cross_entropy(logits, labels.to(device)) / accumulation_steps
loss.backward()
if (step + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad(set_to_none=True)
scheduler.step()
accumulation_steps keeps the gradient magnitude equivalent to the full virtual batch.BatchNorm computes statistics over the actual mini-batch, not the virtual one. Accumulating gradients over four micro-batches of size 8 is not equivalent to one batch of 32 from BatchNorm's perspective: each micro-batch sees its own statistics. The downstream consequence is that BatchNorm-equipped models trained with gradient accumulation often diverge or underfit. Workarounds: switch to LayerNorm or GroupNorm (which compute per-sample statistics), use SyncBatchNorm in distributed training (which aggregates across replicas), or avoid accumulation when BatchNorm is essential. Modern transformer training uses LayerNorm-like normalizations, which sidesteps the problem entirely.
The set_to_none Optimization
The default in modern PyTorch is optimizer.zero_grad(set_to_none=True), which sets each param.grad to None rather than writing zeros into the existing tensor. This saves a write per parameter per step, which is small individually but adds up across millions of parameters and millions of steps. It also lets PyTorch allocate the gradient afresh in the backward pass, sometimes enabling sparse-gradient optimizations.
The only caveat: code that inspects param.grad directly must handle the None case. Gradient norm monitoring, sparse logging, custom regularizers that read .grad, all need a if p.grad is not None guard. PyTorch's built-in clip_grad_norm_ already handles this correctly.
Hugging Face's accelerate library lets a single training script run on CPU, one GPU, multi-GPU DDP, FSDP, or DeepSpeed without code changes. Wrap the model, optimizer, scheduler, and loader in accelerator.prepare(), replace loss.backward() with accelerator.backward(loss), and launch with accelerate launch script.py. No torchrun ceremony, no manual device placement.
from accelerate import Accelerator
accelerator = Accelerator(mixed_precision="bf16")
model, optimizer, loader, scheduler = accelerator.prepare(
model, optimizer, loader, scheduler
)
for batch in loader:
optimizer.zero_grad(set_to_none=True)
loss = compute_loss(model, batch)
accelerator.backward(loss)
optimizer.step()
scheduler.step()
Checkpointing Patterns
A checkpoint records enough state to resume training exactly where it stopped. The minimum is the model's state_dict; for resuming training, the optimizer's state_dict, the scheduler's state_dict, the current step, and the RNG state are all required. Saving everything in a single dict keeps the bookkeeping simple.
import os, torch
def atomic_save(state, path):
"""Write to a temp file, then rename: prevents corrupted checkpoints."""
tmp = path + ".tmp"
torch.save(state, tmp)
os.replace(tmp, path) # atomic on POSIX and Windows NTFS
def save_checkpoint(path, model, optimizer, scheduler, step, best_val):
atomic_save({
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
"step": step,
"best_val": best_val,
"rng_cpu": torch.get_rng_state(),
"rng_cuda": torch.cuda.get_rng_state_all(),
}, path)
path always contains a complete file; a job killed mid-write leaves the previous checkpoint intact.Two checkpoints are usually maintained: a "last" checkpoint overwritten every N steps for crash recovery, and a "best" checkpoint updated only when validation improves. The best checkpoint is what gets evaluated on the test set and shipped to production; the last enables resumption after preemption. Some shops also keep a rolling window of the last few "last" checkpoints to defend against the rare case where the latest one captures a momentary instability.
The four lines that matter inside the training loop are zero_grad, forward, backward, step. Around them, three additions transform a toy script into a production trainer: gradient clipping (one line, prevents catastrophic single-step divergence), a warmup-plus-decay LR scheduler (a few lines, dramatically improves convergence), and atomic checkpointing (a dozen lines, prevents lost training time after crashes). Gradient accumulation simulates large batches when memory is tight, but is incompatible with BatchNorm and requires the per-micro-batch loss to be scaled by 1 / accumulation_steps.
Objective
Convert a 20-line toy training script into a production-quality loop with warmup, cosine decay, gradient clipping, gradient accumulation, AMP, atomic best-and-last checkpointing, and clean resumption after simulated crash. Use FashionMNIST so a full run completes on a single CPU or modest GPU in under 10 minutes.
Setup
- Install
torch,torchvision. Confirm CUDA or MPS availability if present; the lab must also pass on CPU. - Load FashionMNIST with the standard normalization and an 80 to 20 train-validation split.
- Define a small CNN: two conv-relu-pool blocks then a 128-unit linear head with 10 outputs.
Steps
- Baseline loop. Implement the four-line core (
zero_grad,forward,backward,step) withAdamW(lr=1e-3). Train for 5 epochs. Record final validation accuracy and per-epoch wall time. - Add a scheduler. Wrap with
OneCycleLR(max_lr=3e-3, total_steps=...). Confirm viascheduler.get_last_lr()that the LR rises then falls. Compare final accuracy against the baseline. - Add gradient clipping. Insert
nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)afterbackward(). Log the pre-clip global norm each step and plot a histogram. Confirm that less than 5 percent of steps trigger clipping in steady state. - Add gradient accumulation. Configure
physical_batch=32withaccumulation_steps=4. Verify the loss is scaled by1/4, andoptimizer.step()fires once per 4 micro-batches. - Add AMP. Wrap the forward pass in
torch.autocast(device_type)and use aGradScaler(only on CUDA). Measure throughput change. - Atomic checkpointing. Save a "last" checkpoint every 100 steps and a "best" checkpoint whenever validation accuracy improves. Write to a temp path and
os.replaceto the final path. The checkpoint dict must containmodel,optimizer,scheduler,scaler,step, andrng_state. - Crash and resume. Kill the training process partway through epoch 3 (use
os._exit(1)from a callback). Restart; confirm the loop resumes at the next step after the last checkpoint and that validation accuracy continues smoothly. Plot the full curve across both runs as a single line.
Stretch Goals
- Add per-parameter-group weight decay: zero decay for biases and LayerNorm parameters,
0.01for everything else. Re-measure accuracy. - Wrap the loop in
accelerateand confirm the diff is small. Note which lines you deleted. - Replace AMP with
bfloat16autocast on bfloat16-capable hardware; compare the two regimes.
Expected Output
Expected time: 3 to 4 hours. Difficulty: intermediate. Artifact: a single training script (~150 lines) plus the per-step log, a clipping-norm histogram, and the resume-across-crash accuracy curve.
Further Reading
Training Loop References
OneCycleLR. Practical justification for the warmup-then-anneal shape that dominates modern recipes.