Mixed precision said 'I will save you memory.' Gradient checkpointing said 'I will save you more memory.' Ring attention said 'Hold my coffee.' Together they let me train a model that would have crushed any single device.
Scale, Distributedly Sharded AI Agent
This section continues from Section 6.6, which covered communication primitives and the core parallelism strategies: DDP, FSDP/ZeRO, tensor parallelism, and pipeline parallelism. Here we add the techniques that make those strategies actually fit on real hardware: mixed-precision training (BF16, FP8), gradient checkpointing, the recipe for combining parallelism strategies into 3D parallelism, and ring attention for training on contexts longer than any single GPU can hold.
Prerequisites
This section continues from Section 6.6. Familiarity with all-reduce / reduce-scatter / all-gather, DDP, FSDP, tensor parallelism, and pipeline parallelism from that section is assumed.
When Google first deployed bfloat16 for pretraining, an engineer spent three weeks chasing a 0.02% accuracy drop on the validation set. The fix? A single accumulator that was still being kept in float16. The lesson made it into countless follow-up systems: always keep reductions (loss, norms, optimizer state) in float32, even when everything else lives in 16 bits. Mixed precision is not a switch you flip, it is a careful taxonomy of where bits can be safely shaved off.
Having mapped the four parallelism axes in Section 6.6, we now look at the cross-cutting techniques (precision, recomputation, ring attention) and the composition recipes that turn raw parallelism into production-scale training.
6.6.6 Mixed Precision Training
Mixed precision reduces memory usage and increases throughput by using lower-precision number formats for most computations while keeping critical accumulations in higher precision.
| Format | Bits | Range | Use Case |
|---|---|---|---|
| FP32 | 32 | Very large | Master weights, loss accumulation |
| FP16 | 16 | Limited (needs loss scaling) | Older GPUs (V100) |
| BF16 | 16 | Same as FP32 | Standard for modern training |
| FP8 (E4M3) | 8 | Limited | Forward pass activations (Hopper+) |
| FP8 (E5M2) | 8 | Wider range, less precision | Gradients (Hopper+) |
FP8 Training at Scale
DeepSeek V3 (covered in Section 7.3) demonstrated successful FP8 mixed-precision training at 671B parameters, the first large-scale demonstration of FP8 for LLM pretraining. The approach uses E4M3 format for forward pass activations (more precision, narrower range) and E5M2 for gradients (wider range, less precision). Per-tensor dynamic scaling factors are maintained to prevent overflow and underflow. FP8 training provides roughly 2x memory reduction and higher throughput compared to BF16 with minimal quality degradation.
6.6.7 Gradient Checkpointing
During the backward pass, computing gradients requires the activations from the forward pass. Normally all activations are stored in memory, consuming enormous amounts of GPU memory (proportional to batch size, sequence length, and hidden dimension). Gradient checkpointing (also called activation checkpointing) saves memory by storing only a subset of activations and recomputing the rest during the backward pass. The tradeoff is approximately 33% additional compute in exchange for a large reduction in activation memory.
from torch import nn
# Gradient checkpointing in PyTorch
from torch.utils.checkpoint import checkpoint
class CheckpointedTransformer(nn.Module):
def __init__(self, layers):
super().__init__()
self.layers = nn.ModuleList(layers)
def forward(self, x):
for layer in self.layers:
# Recompute activations during backward instead of storing
x = checkpoint(layer, x, use_reentrant=False)
return x
# Memory comparison
seq_len, hidden, num_layers, batch = 2048, 4096, 32, 8
bytes_per_elem = 2 # BF16
# Without checkpointing: store all layer activations
no_ckpt = batch * seq_len * hidden * num_layers * bytes_per_elem
# With checkpointing: store only input to each checkpointed segment
with_ckpt = batch * seq_len * hidden * 1 * bytes_per_elem # only 1 activation
print(f"Activation memory without checkpointing: {no_ckpt / 1e9:.1f} GB")
print(f"Activation memory with checkpointing: {with_ckpt / 1e9:.2f} GB")
print(f"Memory saved: {(1 - with_ckpt/no_ckpt)*100:.0f}%")
torch.utils.checkpoint.checkpoint(..., use_reentrant=False) trades recompute for memory: the analysis at the bottom of the snippet shows a 32-layer model dropping activation memory from 4.3 GB to 0.13 GB (97% reduction) at the cost of one extra forward pass per layer during backprop.Expert Parallelism for MoE Training
Mixture-of-Experts models introduce a parallelism dimension that dense models do not have: expert parallelism. In an MoE layer, each token is routed to a small subset of experts (typically 1 or 2 out of dozens or hundreds). Expert parallelism distributes the experts themselves across GPUs, so that each GPU hosts a different subset of the expert FFN blocks. When a token is routed to an expert that lives on another GPU, an all-to-all communication operation transfers the token's hidden state to the appropriate device, and the result is sent back after computation.
The communication pattern differs fundamentally from tensor parallelism. In tensor parallelism, every GPU participates in computing every token (splitting the matrix multiply). In expert parallelism, each GPU only computes the tokens routed to its local experts. This means the amount of computation per GPU varies depending on how the router distributes tokens, making load balancing critically important. If one expert receives disproportionately many tokens, the GPU hosting that expert becomes the bottleneck while others sit idle.
DeepSeek-V3 provides a detailed case study. With 256 routed experts and 1 shared expert, the experts are distributed across GPUs such that each GPU hosts a manageable subset. DeepSeek-V3 uses an auxiliary-loss-free load balancing mechanism with a per-expert bias term that the system adjusts dynamically during training to maintain even expert utilization. This avoids the traditional auxiliary loss approach, which adds a balancing penalty to the training objective and can distort the learning signal. The all-to-all communication is overlapped with computation by pipelining the expert dispatch: while one micro-batch's tokens are being routed and computed on remote experts, the next micro-batch's routing decisions are already being made.
In practice, expert parallelism is combined with the other strategies. A typical configuration for training a large MoE model on 256 GPUs might use tensor parallelism of 8 within each node, expert parallelism of 8 across nodes within a pod, and data parallelism of 4 across pods. The key constraint is that expert parallelism requires all-to-all communication, which has higher latency than the all-reduce used in data parallelism. Placing expert parallelism on a network tier with sufficient bandwidth (typically InfiniBand HDR or better) is essential. As models scale to hundreds or thousands of experts, expert parallelism has become the fourth essential dimension of the "4D parallelism" strategy used in modern large-scale training.
6.6.8 Combining Parallelism Strategies
Real-world large-scale training combines multiple parallelism strategies in a hierarchy. A common configuration for training a 70B model on 512 GPUs might use tensor parallelism with degree 8 within each node (leveraging fast NVLink), pipeline parallelism with degree 8 across nodes, and data parallelism with degree 8 across pipeline-parallel groups. This 3D parallelism approach matches each strategy to the communication bandwidth available at each level of the hardware hierarchy.
The choice of parallelism strategy depends on the hardware topology. Tensor parallelism demands the highest bandwidth and should use intra-node NVLink (600 GB/s on H100). Pipeline parallelism can tolerate lower bandwidth and can span nodes connected via InfiniBand (400 Gb/s). Data parallelism is the most bandwidth-efficient and can span the widest network distances.
As context windows grow to 128K tokens and beyond, a single sequence's activations may not fit in one GPU's memory. Sequence parallelism splits the sequence dimension across GPUs, with each GPU processing a contiguous chunk of the sequence. Context parallelism (also called ring attention) takes this further by distributing the key-value pairs of attention computation across GPUs in a ring topology, enabling each GPU to attend to the full context without materializing all KV pairs locally. These techniques are increasingly important for training long-context models and are used by Llama-3 and other recent systems. For more on these models, see Section 7.3.
torchtitan is PyTorch's official reference codebase for pretraining at scale. It composes FSDP2, tensor parallel, pipeline parallel, and context parallel into a single configurable trainer that runs on any cluster torchrun can reach. Use it when you want a clean, hackable starting point that tracks the latest PyTorch features (Float8 training, AsyncTP, expert parallelism) without depending on Megatron or DeepSpeed.
Show code
git clone https://github.com/pytorch/torchtitan && cd torchtitan
pip install -r requirements.txt
# Train Llama-3 8B with 2D parallelism on 8 GPUs
import subprocess
subprocess.run([
"torchrun", "--nproc_per_node=8",
"torchtitan/train.py",
"--job.config_file=train_configs/llama3_8b.toml",
"--training.tensor_parallel_degree=2",
"--training.data_parallel_replicate_degree=4",
], check=True)
Who: An ML infrastructure engineer at a mid-size AI company with access to a cluster of 32 A100 80GB GPUs across 4 nodes.
Situation: The team needed to train a 30B parameter model from scratch. The model weights alone required approximately 60 GB in FP16, and with optimizer states in FP32, total per-GPU memory would exceed 240 GB using standard data parallelism.
Problem: Pure data parallelism was impossible (model plus optimizer states exceeded single-GPU memory). Tensor parallelism required high-bandwidth interconnects, and their inter-node bandwidth was only 100 Gbps (InfiniBand HDR), much slower than intra-node NVLink.
Dilemma: Full 3D parallelism (data + tensor + pipeline) offered maximum flexibility but was complex to configure. FSDP (ZeRO-3) was simpler but added communication overhead for all-gather operations at every forward pass.
Decision: They used a hybrid approach: tensor parallelism of 4 within each node (leveraging fast NVLink), combined with FSDP across the 4 nodes (using the slower inter-node network only for gradient synchronization).
How: Tensor parallelism split each layer's weight matrices across 4 GPUs within a node, keeping all-reduce operations on NVLink. FSDP sharded optimizer states and gradients across the 4 node-groups, communicating via InfiniBand only during gradient reduction.
Result: The configuration achieved 42% Model FLOPs Utilization (MFU), compared to 28% MFU with FSDP alone and 51% MFU for a fully NVLink-connected setup. Training completed in 18 days instead of the projected 26 days with FSDP-only.
Lesson: Match your parallelism strategy to your network topology: use tensor parallelism within high-bandwidth domains (NVLink) and data/FSDP parallelism across slower interconnects to minimize communication bottlenecks.
The parallelism strategies above partition the model across devices, but they all assume the input sequence fits in a single GPU's memory. As context windows grow to 128K tokens and beyond, the attention matrix itself becomes the bottleneck. This requires parallelizing along a new dimension: the sequence itself.
6.6.9 Ring Attention and Sequence Parallelism for Long-Context Training
The push toward million-token context windows has made sequence-dimension parallelism a critical fourth axis of distributed training, alongside data, tensor, and pipeline parallelism. Standard self-attention requires materializing an attention matrix of size $L \times L$ (where $L$ is the sequence length), which grows quadratically. For a 128K-token sequence, this matrix alone would require over 60 GB of memory in FP32, exceeding the capacity of any single GPU. Sequence parallelism and ring attention solve this by distributing the sequence dimension across multiple devices.
Sequence Parallelism: Splitting the Sequence
Sequence parallelism partitions each input sequence into contiguous chunks, assigning each chunk to a different GPU. For non-attention operations (feed-forward layers, normalization, embedding lookups), this is straightforward because these operations process each token independently. The challenge arises with self-attention, where every token must attend to every other token. Naive partitioning would require each GPU to receive the full key-value (KV) pairs from all other GPUs at every attention layer, introducing prohibitive communication overhead. Megatron-LM's sequence parallelism (Korthikanti et al., 2022) addressed this for the non-attention components of transformer layers, splitting softmax and dropout operations along the sequence dimension. This alone saves significant activation memory, as dropout masks and normalization statistics scale linearly with sequence length.
Ring Attention: Communication-Efficient Long Context
Ring Attention (Liu et al., 2023) provides an elegant solution for distributing the attention computation itself. The core idea is to arrange GPUs in a logical ring and overlap attention computation with communication. Each GPU holds a chunk of the query (Q) tokens permanently and receives key-value (KV) chunks from its neighbors in a round-robin fashion. The algorithm proceeds in $P$ steps (where $P$ is the number of GPUs in the ring):
- Each GPU computes attention between its local Q chunk and its local KV chunk.
- While computing, each GPU asynchronously sends its KV chunk to the next GPU in the ring and receives a KV chunk from the previous GPU.
- In the next step, each GPU computes attention between its local Q and the newly received KV chunk.
- After $P$ steps, every Q chunk has attended to every KV chunk, and the attention output is complete.
The key insight is that the communication of KV chunks is overlapped with the computation of attention, hiding the communication latency behind the compute. With modern interconnects (NVLink at 900 GB/s on H100, InfiniBand at 400 Gb/s), the communication can be fully hidden as long as the per-chunk compute time exceeds the transfer time. In practice, this holds for chunks of 4K tokens or larger on H100 GPUs. The memory per GPU is reduced from $O(L^{2})$ to $O(L^{2}/P)$ for the attention matrix, and $O(L/P)$ for KV storage, enabling linear memory scaling with the number of ring participants.
Striped Attention and Hybrid Approaches
A limitation of basic ring attention is load imbalance when using causal (autoregressive) attention masks. Because causal attention is triangular (token $i$ attends only to tokens $1$ through $i$), GPUs holding later chunks do more work than those holding earlier chunks. Striped Attention (Brandon et al., 2023) addresses this by distributing tokens in an interleaved (striped) pattern rather than contiguous blocks: GPU 0 gets tokens 0, P, 2P, ...; GPU 1 gets tokens 1, P+1, 2P+1, ...; and so on. This ensures each GPU processes roughly equal amounts of causal attention work, improving utilization from approximately 50% (for contiguous chunking with causal masks) to near 100%.
Meta's Llama-3 training (2024) used a hybrid approach that combined ring attention (called "context parallelism" in their terminology) with tensor parallelism and data parallelism for training on sequences up to 128K tokens. Their implementation placed the ring attention ring within a single node's 8 GPUs (leveraging NVLink bandwidth), while using tensor parallelism across nodes and data parallelism across node groups. This hierarchical design matches communication intensity to available bandwidth at each level. DeepSeek-V2 and Qwen 2.5 employed similar strategies for their long-context training phases.
Hugging Face Accelerate wraps PyTorch distributed training so the same script runs on 1 GPU or many.
# pip install accelerate
from accelerate import Accelerator
accelerator = Accelerator(mixed_precision="bf16")
# Wrap model, optimizer, and dataloader in one call
model, optimizer, dataloader = accelerator.prepare(
model, optimizer, dataloader
)
for batch in dataloader:
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
# Launch with: accelerate launch --num_processes 4 train.py
Accelerator(mixed_precision="bf16") plus a single accelerator.prepare(...) call replaces the manual DistributedDataParallel, autocast, and GradScaler wiring. The exact same script launches with accelerate launch --num_processes 4 to scale from 1 to N GPUs.DeepSpeed enables ZeRO Stage 3 sharding with a simple JSON config file.
# DeepSpeed ZeRO-3: shard parameters, gradients, and optimizer states
# across all data-parallel ranks, plus offload params to CPU when idle.
# Write the config and launch as a Python subprocess for reproducibility.
import json, subprocess
from pathlib import Path
cfg = {
"bf16": {"enabled": True},
"zero_optimization": {
"stage": 3,
"offload_param": {"device": "cpu", "pin_memory": True},
"offload_optimizer": {"device": "cpu", "pin_memory": True},
"overlap_comm": True,
"contiguous_gradients": True,
},
"train_batch_size": 32,
"gradient_accumulation_steps": 4,
}
Path("ds_zero3.json").write_text(json.dumps(cfg, indent=2))
subprocess.run(
["deepspeed", "--num_gpus=8", "train.py", "--deepspeed", "ds_zero3.json"],
check=True,
)
deepspeed with eight ranks. offload_param + offload_optimizer push idle tensors to CPU memory, which is what lets ZeRO-3 train models larger than the aggregate GPU VRAM.Practical Considerations
When implementing sequence parallelism for long-context training, several practical considerations apply. First, ring attention typically requires that the sequence length is evenly divisible by the ring size; padding shorter sequences wastes compute proportionally. Second, position encodings (particularly RoPE) must be consistent across the distributed sequence: each GPU must use the correct positional indices for its assigned tokens, not indices starting from zero. Third, the gradients for attention computations across the ring require careful accumulation during the backward pass, as each partial attention result must be combined with the correct log-sum-exp normalization factors to produce numerically stable softmax gradients. Frameworks like Megatron-LM and DeepSpeed-Ulysses provide production-ready implementations of these techniques, while the Llama-3 training code release includes Meta's context parallelism implementation for reference.
At its core, distributed training is a communication problem, not a computation problem. Every parallelism strategy (data, tensor, pipeline) trades off computation for communication in different ways, and the optimal strategy depends on the ratio of compute speed to interconnect bandwidth. This mirrors a fundamental result from parallel computing theory: Amdahl's law says that the speedup from parallelization is limited by the sequential fraction of the workload, and in distributed training, communication is the sequential bottleneck. The progression from DDP to FSDP to 3D parallelism is essentially an engineering journey to minimize the communication-to-computation ratio. This is also why hardware co-design matters: NVIDIA's NVLink, InfiniBand, and most recently NVSwitch exist specifically to push the communication bottleneck further out. The same principle appears in Section 9.5, where inference serving frameworks must solve the same communication-computation tradeoff at deployment time.
For fine-tuning techniques that build on pretrained checkpoints, see Section 16.4. For inference-side cost-benefit consequences of scale choices, see Section 9.5. For frontier architectures and their scaling trade-offs, see Section 7.3.
Disaggregated training and heterogeneous clusters. Traditional distributed training assumes homogeneous GPU clusters connected by fast interconnects. Emerging approaches disaggregate computation, allowing training across heterogeneous hardware and even geographically distributed data centers. Ring attention (Liu et al., 2024) enables training on sequences longer than any single GPU's memory by distributing attention computation across a ring of devices. Meanwhile, FP4 training experiments (building on DeepSeek V3's FP8 success) promise further memory and communication savings, potentially enabling pretraining of 100B+ models on fewer GPUs than currently required.
- DDP is the simplest distributed training approach: replicate the model on each GPU and synchronize gradients via all-reduce.
- FSDP/ZeRO shards parameters, gradients, and optimizer states across GPUs to reduce per-GPU memory, enabling training of much larger models.
- Tensor parallelism splits individual layers across GPUs and requires fast intra-node interconnects (NVLink).
- Pipeline parallelism assigns different layers to different GPUs; the 1F1B schedule minimizes idle time.
- BF16 is the standard precision for LLM training; FP8 (demonstrated by DeepSeek V3) provides further memory and throughput improvements on Hopper GPUs.
- Gradient checkpointing trades ~33% extra compute for massive activation memory savings.
- Real-world training uses 3D parallelism, combining tensor, pipeline, and data parallelism to match hardware topology.
Show Answer
Show Answer
Show Answer
Show Answer
Exercises
Plain Distributed Data Parallel (DDP) replicates the full model on each GPU. (a) For a 70B fp16 model with Adam, why does DDP fail on an 8x A100-80GB node? (b) What is the smallest model where you would still default to DDP over FSDP for a single-node training run? (c) What does "gradient bucketing" do in DDP and why does it matter for throughput?
Answer Sketch
(a) Per-GPU memory needs: parameters (140GB fp16) + grads (140GB fp16) + Adam moments (560GB fp32) ~= 840GB on every GPU; an 80GB GPU cannot hold even the parameters. DDP simply cannot host this model and you need sharding (FSDP, ZeRO-3, or tensor parallelism). (b) Models under ~7B fit comfortably with DDP on 80GB GPUs and DDP gives lower communication overhead because it overlaps the all-reduce with the backward pass. (c) Bucketing groups small parameter gradients together so the all-reduce is amortized across many tensors, hiding latency and avoiding the per-tensor NCCL overhead that would otherwise dominate in models with thousands of small parameters (LayerNorm, biases).
You have a 64-GPU cluster and want to train a 175B model. Compute the typical (TP, PP, DP) allocation: (a) what tensor parallel degree is the sweet spot on a single node of 8 GPUs? (b) what pipeline parallel degree fits a 175B model in fp16 across 8 nodes? (c) what is the resulting data-parallel degree?
Answer Sketch
(a) TP=8 within a node: tensor parallelism is bandwidth-hungry and benefits enormously from NVLink (~600 GB/s) versus inter-node InfiniBand (~50 GB/s), so it should never cross node boundaries. (b) For 175B in fp16: 350GB params + 350GB grads + 1400GB optimizer ~= 2.1 TB. Each pipeline stage holds 1/PP of this; with 80GB GPUs and accounting for activations, PP=8 (4 stages of 2 nodes each) is a typical PaLM-style choice. (c) Total = TP * PP * DP, so DP = 64 / (8 * 8) = 1; with 64 GPUs you barely have one data-parallel replica. To get DP=4 (better convergence) you would need a 256-GPU cluster, which is why 175B-class models target 512+ GPUs in production.
You have a Transformer block TransformerBlock in PyTorch. Show how to wrap it with torch.utils.checkpoint.checkpoint so that activations are recomputed in the backward pass instead of stored. State the memory and compute tradeoff in concrete percentage terms for a typical 32-layer model.
Answer Sketch
from torch.utils.checkpoint import checkpoint
# inside model.forward:
x = checkpoint(self.transformer_block, x, use_reentrant=False)
For a 32-layer model, activation memory is roughly proportional to layer count, so checkpointing every block cuts activation memory by ~32x (you only keep the inputs to each block). Compute cost: each checkpointed block is recomputed once during backward, so total FLOPs increase by ~33% (one extra forward through the checkpointed regions). The standard recipe is "selective checkpointing": apply it only to the attention block, which holds the largest activations because of the seq^2 attention matrix; this gets ~70% of the memory savings at ~15% compute overhead.
You set up 8-stage pipeline parallelism with micro-batch size 1 and find your GPUs are utilized only 30% of the time, with most time spent waiting. Diagnose: (a) what is the "pipeline bubble" and how does it scale with stages and micro-batches? (b) What single hyperparameter change would most improve utilization? (c) Why doesn't infinitely raising that hyperparameter solve the problem?
Answer Sketch
(a) The pipeline bubble is the fraction of time GPUs sit idle while waiting for the first or last micro-batches to fill or drain the pipeline. Bubble fraction is approximately $(P-1)/(M+P-1)$ for P stages and M micro-batches, so 8 stages with M=1 gives 7/8 = 87.5% idle time, matching the 30% utilization symptom. (b) Increase the number of micro-batches per step. With M=32 and P=8, the bubble drops to ~18%, which is the GPipe-style sweet spot. (c) Each micro-batch must hold its own activations until backward, so memory grows linearly with M; you also lose effective batch normalization-style benefits and at some point activation memory crowds out parameter capacity. Modern recipes (PipeDream, 1F1B scheduling) interleave forward and backward to reduce activation pressure for the same M.
What's Next?
In the next section, Section 6.7: In-Context Learning Theory, we investigate in-context learning theory, understanding why LLMs can adapt to new tasks from just a few examples.