Mixed Precision

Section E.6
Two friendly anthropomorphic clocks of different sizes side by side in a sunny cartoon workshop, ticking at different speeds with motion lines, while a relaxed scientist watches contentedly representing how fp32 and bf16 coexist in the same training step
Mixed precision keeps two clocks in the same workshop: a fast, lightweight half-precision tick for the heavy math and a slow, careful float32 tick for the values that must not drift.

Modern accelerators run half-precision arithmetic two to eight times faster than single precision and use half the memory per tensor. Mixed-precision training exploits this without sacrificing convergence: most operations execute in float16 or bfloat16, while a small set of numerically sensitive ones (loss, softmax, layer norm, weight updates) stay in float32. The result is a training run that fits a larger model in the same memory and finishes substantially faster, with model quality nearly indistinguishable from full-precision training.

This section covers PyTorch's Automatic Mixed Precision (AMP) machinery, the differences between the two half-precision formats, when the gradient scaler is required, and the half-dozen pitfalls that bite practitioners moving from FP32 to AMP for the first time. Chapter 9 covers the related but distinct topic of low-precision inference via post-training quantization.

Why Mixed Precision Works

The motivation is hardware. NVIDIA's Tensor Cores (Volta and later), AMD's Matrix Cores (RDNA3 and later), Apple Silicon's Neural Engine, and Google's TPUs all expose dedicated matrix-multiply units that run at half precision dramatically faster than the general-purpose float32 path. On an A100, a float16 matmul achieves roughly 312 TFLOPS versus 19.5 TFLOPS for float32; the difference is large enough that mixed-precision training is the default for any serious workload.

The catch is numeric range. Float16 has 5 exponent bits, giving a representable range of roughly $\pm 6 \times 10^{4}$ and a smallest positive normal of $\approx 6 \times 10^{-5}$. Many gradient values, especially in deep networks late in training, fall below that smallest value and silently round to zero. Bfloat16 has 8 exponent bits (the same as float32), giving the same range as float32 but only 7 bits of mantissa precision. Mixed-precision training keeps the weights, gradients, and optimizer state in float32 (the master copy) while the forward and backward passes run in the half-precision format the hardware prefers.

torch.autocast: The Forward-Pass Wrapper

PyTorch's torch.autocast is a context manager that automatically chooses the right precision for each operation inside its scope. It is a thin layer around the operator dispatch system: matmul and conv run in the chosen half precision, while softmax, log, exp, and the loss reduction stay in float32. The user does not need to manually cast tensors; autocast inserts the conversions transparently.

import torch

device = "cuda"
model = NeuralNetwork(50, 10).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

for features, labels in train_loader:
    features = features.to(device, non_blocking=True)
    labels   = labels.to(device, non_blocking=True)

    optimizer.zero_grad(set_to_none=True)
    with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
        logits = model(features)
        loss = torch.nn.functional.cross_entropy(logits, labels)

    loss.backward()
    optimizer.step()
Output: (no stdout; training runs ~2x faster on Ampere or newer GPUs with no scaler needed)
Code Fragment E.6.1: bfloat16 mixed-precision training with torch.autocast. No gradient scaler is required because bfloat16 has the same exponent range as float32; underflow is essentially impossible.

The device_type argument selects which dispatch table to use; the most common values are "cuda", "cpu", and "xpu" (Intel). The dtype argument selects the half-precision format: torch.float16 or torch.bfloat16 on CUDA; torch.bfloat16 on CPU (with the BF16 ISA extension); MPS supports both. The autocast scope should wrap the forward pass and the loss computation; the backward pass and the optimizer step happen outside it so that gradients land in float32 storage.

bfloat16 vs float16

For new code, the choice is almost always bfloat16. Its float32-equivalent range means gradients do not underflow, the gradient scaler (covered below) is not needed, and training rarely requires any recipe changes from the float32 baseline. The only downsides are slightly less precision than float16 (7 vs 10 mantissa bits) and lack of support on pre-Ampere NVIDIA GPUs (V100 and older).

Float16, by contrast, has a wider hardware install base (every GPU from Pascal onward) and slightly better precision per bit, but its narrow exponent range demands gradient scaling. The exact recipe is to multiply the loss by a large constant before backward() (so gradients are pulled up into the representable range), then unscale them before clipping or the optimizer step. PyTorch's torch.amp.GradScaler automates this, dynamically adjusting the scale factor to be as large as possible without causing overflow.

import torch

scaler = torch.amp.GradScaler("cuda")

for features, labels in train_loader:
    features = features.to(device, non_blocking=True)
    labels   = labels.to(device, non_blocking=True)

    optimizer.zero_grad(set_to_none=True)
    with torch.autocast(device_type="cuda", dtype=torch.float16):
        logits = model(features)
        loss = torch.nn.functional.cross_entropy(logits, labels)

    # Scale the loss to lift gradients into the float16 range.
    scaler.scale(loss).backward()
    # Unscale before clipping (clipping reads .grad directly).
    scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    # step() and update() handle non-finite gradients gracefully.
    scaler.step(optimizer)
    scaler.update()
Output: (no stdout; gradients are scaled and unscaled transparently)
Code Fragment E.6.2: float16 mixed-precision training with the gradient scaler. The four extra lines (scaler.scale, scaler.unscale_, scaler.step, scaler.update) are what distinguishes a working FP16 recipe from a broken one.
Key Insight: bfloat16 Is the Default for Large Models

Every major open-weight LLM since 2022 (Llama, Falcon, Mistral, Qwen, DeepSeek) has been trained in bfloat16. The reason is operational simplicity: no scaler tuning, no overflow-induced training restarts, no per-recipe quirks. The slightly lower precision of bfloat16 has not been shown to harm convergence at scale, and the engineering reliability is worth more than the precision difference. Reach for float16 only on hardware that lacks bfloat16 support (Volta-generation V100s) or in cases where bit-for-bit reproduction of a published float16 recipe is required.

Common AMP Pitfalls

Five mistakes account for the majority of AMP-related training failures.

Using float16 without the scaler
Gradients silently underflow to zero; training appears to start but loss does not decrease. The fix is to add GradScaler; the alternative is to switch to bfloat16, which sidesteps the issue.
Calling clip_grad_norm_ on scaled gradients
The clipping threshold is interpreted against gradients that are still scaled up by the loss scaler. Either the clip threshold has to be multiplied by the current scale (fragile), or, much better, scaler.unscale_(optimizer) must be called before clipping. The pattern is in Code Fragment E.6.2 above.
Loss computation inside autocast on tiny values
Some loss functions accumulate many tiny intermediate values whose sum is dominated by precision loss in float16. Wrap the loss in with torch.autocast(device_type="cuda", enabled=False): after the forward pass (the activations are already half-precision; only the reduction needs to escape) to force float32 accumulation.
NaN in a single batch propagating to all parameters
One overflow in a forward pass produces NaN logits, NaN loss, NaN gradients, and NaN parameters after the next step. The scaler's scaler.step already skips parameter updates when any gradient is non-finite, so float16 training tolerates the occasional overflow. With bfloat16, ensure the training loop checks torch.isfinite(loss) and skips the step when not. Anomaly detection (Section E.9) helps localize the source.
Saving and loading checkpoints across precisions
Mixed-precision training keeps a float32 master copy of weights. The state dict therefore contains float32 tensors regardless of the autocast format. This is correct: do not save half-precision weights for resumption; doing so loses precision that float32 master copies are designed to preserve. For deployment, a separate cast-to-half-precision step is fine and is what most inference servers do.
Warning: BatchNorm and FP16 Still Need Care

BatchNorm's running statistics live in float32 by default, but the per-batch mean and variance are computed in whatever precision the input is. Under aggressive FP16 with a large batch, the variance can underflow when intermediate sums in the running average are tiny. Workarounds: wrap the BatchNorm forward in autocast(enabled=False), switch to nn.SyncBatchNorm for distributed training (which does its arithmetic in float32), or replace BatchNorm with LayerNorm or GroupNorm. As with gradient accumulation (Section E.5), modern transformer training avoids BatchNorm and the problem disappears.

Memory Savings

Mixed precision halves the memory used by activations (the largest contributor for moderate batch sizes), but does not halve total memory: parameters, gradients, and optimizer state still live in float32. For a transformer trained with AdamW, the memory breakdown is approximately 4 bytes per parameter for weights, 4 bytes for gradients, 8 bytes for Adam's two momentum buffers, and a variable amount for activations. AMP shrinks the activation cost and may shrink the gradient cost, but the optimizer state remains the limiting factor for the largest models. The fix for that is sharded optimizer state, covered in Section E.7 on distributed training.

Key Insight

Mixed precision is the single largest performance lever in modern PyTorch training. Wrap the forward pass in torch.autocast(device_type="cuda", dtype=torch.bfloat16), leave the backward pass and optimizer step outside, and the training loop runs two to four times faster with no other changes. On hardware that lacks bfloat16 support, use float16 with GradScaler and remember to scaler.unscale_(optimizer) before gradient clipping. For new code on Ampere or newer GPUs, prefer bfloat16: it has the operational simplicity of full precision and most of the speed of float16.

Exercise E.6.1: AMP Throughput and Accuracy Sanity

Objective. Confirm with measurements that AMP delivers the expected speed-up without hurting validation accuracy.

Task. Take the Lab E.5 training loop on FashionMNIST and run three configurations: float32 baseline, float16 with GradScaler, bfloat16 (if supported). For each, log: tokens or samples per second, peak GPU memory (torch.cuda.max_memory_allocated()), and final validation accuracy. Tabulate the three side by side.

Expected outcome. AMP runs 1.5x to 3x faster (hardware-dependent) with peak activation memory roughly halved. Final accuracy should be within 0.5 absolute percentage points of the float32 baseline. If accuracy diverges, inspect for NaN; see Section E.9 for the recipe.

Exercise E.6.2: When AMP Is the Wrong Answer

Objective. Recognize the two situations where AMP causes problems and apply the textbook fixes.

Task. (a) Add a custom loss that calls torch.log(x.sum()) with very small x values. Run with AMP and observe a NaN. Fix by wrapping the loss computation in torch.autocast(..., enabled=False). (b) Add gradient clipping by global norm to a float16 AMP loop without calling scaler.unscale_(optimizer) first. Observe that the clip threshold is being applied to scaled gradients (effectively no clipping at all). Fix by unscaling, clipping, then scaler.step.

Hint. The AMP rule of thumb: reductions, log, and exp are float32 territory; matmuls and convolutions are float16 or bfloat16 territory.

Further Reading

Mixed Precision References

Micikevicius, P. et al. (2018). "Mixed Precision Training." ICLR 2018. arXiv:1710.03740. The foundational paper that introduced loss scaling and the master-weights pattern. Required reading to understand why AMP works.
PyTorch Documentation: Automatic Mixed Precision. The canonical reference for torch.autocast and torch.amp.GradScaler, including a table of which ops are autocast to which precision.
Kalamkar, D. et al. (2019). "A Study of BFLOAT16 for Deep Learning Training." arXiv:1905.12322. arXiv:1905.12322. Empirical case for bfloat16: matches FP32 training quality on a wide range of tasks while keeping FP32's exponent range.
PyTorch Notes: Numerical Accuracy. Discussion of float16, bfloat16, and TF32 numerical behavior in PyTorch. Useful when reproducibility across hardware matters.