Distributed Training

Section E.7
A team of four friendly cartoon chefs in a sunny kitchen, each preparing one piece of a giant rainbow layer cake at their own station, with conveyor lines synchronizing the pieces at a central turntable, illustrating gradient synchronization across distributed workers
Distributed training is a team of chefs: each works on its own slice of the batch, then everyone synchronizes their gradients at the all-reduce so the recipe stays consistent across the kitchen.

When one GPU is no longer enough, distributed training spreads the model, the data, or both across multiple devices. PyTorch ships three increasingly capable approaches: the legacy single-process DataParallel (briefly covered for historical context), the modern multi-process DistributedDataParallel (the default for data-parallel training), and Fully Sharded Data Parallel (FSDP, the default for models that exceed a single GPU's memory). For most teams that do not want to write distributed code at all, Hugging Face Accelerate provides a thin launcher on top of these primitives. This section explains when to reach for each.

The mathematical story behind why distributed pretraining works at all (the Chinchilla scaling laws, the compute-optimal token budget) lives in Chapter 6. Distributed inference patterns for production serving are covered in Chapter 9.

DataParallel: The Legacy Single-Process Path

The original nn.DataParallel wraps a model and, on each forward pass, scatters the input across multiple GPUs, replicates the model on each, runs the forward pass in parallel, and gathers the outputs. It runs entirely in one Python process. This sounds attractive but is dominated by the global interpreter lock (GIL): the gather step serializes all gradients through the master process, and the replication step copies the model on every forward pass. DataParallel typically achieves 50 to 70 percent of single-GPU efficiency per added device.

Use DataParallel for nothing. It is mentioned here only because legacy code uses it. The drop-in modern replacement is DistributedDataParallel, which uses one process per GPU and is strictly faster and more scalable.

DistributedDataParallel: The Modern Default

nn.parallel.DistributedDataParallel (DDP) replicates the model on every GPU, runs forward and backward independently on each replica, and then synchronizes gradients via an all-reduce collective so every replica sees the average gradient and performs an identical optimizer step. Each GPU runs in its own process; there is no master process bottleneck. The all-reduce overlaps with the backward pass (gradients for early layers are reduced while later layers are still computing their backwards), so the communication cost is largely hidden.

The launcher is torchrun, which sets the environment variables every DDP process needs (RANK, WORLD_SIZE, LOCAL_RANK, MASTER_ADDR, MASTER_PORT) and spawns one process per GPU. The training script then initializes the process group, builds the model on the local GPU, and wraps it in DDP.

import os
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

def main():
    dist.init_process_group(backend="nccl")
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)
    device = torch.device("cuda", local_rank)

    model = NeuralNetwork(num_inputs=512, num_outputs=10).to(device)
    model = DDP(model, device_ids=[local_rank])
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

    sampler = torch.utils.data.distributed.DistributedSampler(train_set)
    loader  = torch.utils.data.DataLoader(train_set, batch_size=32,
                                          sampler=sampler)

    for epoch in range(num_epochs):
        sampler.set_epoch(epoch)         # ensures different shuffle per epoch
        for features, labels in loader:
            features = features.to(device, non_blocking=True)
            labels   = labels.to(device, non_blocking=True)
            optimizer.zero_grad(set_to_none=True)
            loss = torch.nn.functional.cross_entropy(model(features), labels)
            loss.backward()
            optimizer.step()

    dist.destroy_process_group()

if __name__ == "__main__":
    main()

# Launch: torchrun --nproc-per-node=4 train_ddp.py
Output: (no stdout from a single process; aggregate logs visible via rank-0 logging)
Code Fragment E.7.1: A complete DDP training script. The launcher (torchrun --nproc-per-node=N) spawns N processes, each pinning itself to one local GPU. The DistributedSampler ensures each rank sees a disjoint slice of the dataset.
Warning: set_epoch Is Not Optional

DistributedSampler generates its shuffle from a seed that depends on self.epoch. Without sampler.set_epoch(epoch) at the top of each epoch, every epoch will use the same shuffle, which silently degrades training. The same applies to the gradient-accumulating cousin: if accumulation is used, do not call set_epoch inside the accumulation inner loop. Forgetting this is the single most common DDP bug.

Practical Example: Rank-0-Only Logging and Saving

Every DDP rank runs the same script, but logs, checkpoints, and external API calls should usually happen on rank 0 only. The idiom is if dist.get_rank() == 0: around the relevant block. For metric aggregation across ranks, use dist.all_reduce(tensor, op=dist.ReduceOp.SUM) first to sum values across all ranks, then divide by world size on rank 0 before logging. Hugging Face Trainer and Accelerate handle this bookkeeping automatically; raw DDP requires explicit care.

Fully Sharded Data Parallel

DDP replicates the model on every GPU, so the maximum model size is bounded by single-GPU memory. For multi-billion-parameter models, this is the binding constraint long before compute becomes the bottleneck. Fully Sharded Data Parallel (FSDP) breaks the constraint by sharding the model's parameters, gradients, and optimizer state across GPUs: each rank holds only its slice and reconstructs full layers on demand using all-gather collectives, freeing them again after use.

The intuition is that, for a transformer layer with parameter count $P$ per GPU sharded across $N$ GPUs, each GPU only stores $P/N$ in steady state but briefly materializes $P$ during the forward pass for that layer. With activation checkpointing layered on top, even very large models can be trained on commodity GPU clusters. The cost is communication: each layer's forward and backward incurs an all-gather and (in backward) a reduce-scatter; the wall clock penalty versus DDP is typically 10 to 30 percent, more than paid back by the larger model that now fits.

FSDP1 vs FSDP2

PyTorch has shipped two FSDP implementations. FSDP1 (the original, available since PyTorch 1.11) wraps modules at construction time and exposes a per-module API. It works but has rough edges: limited support for parameter freezing, complex interaction with torch.compile, and a state-dict format that does not match a single-GPU model. FSDP2 (introduced in PyTorch 2.4, generally available in 2.5) uses a per-parameter sharding strategy built on the DTensor abstraction. It is simpler to use, composes cleanly with torch.compile and mixed precision, supports partial freezing, and saves state dicts that are interchangeable with single-GPU checkpoints. New code should use FSDP2.

import torch
import torch.distributed as dist
from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy

dist.init_process_group(backend="nccl")
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)

model = build_transformer().cuda()

mp_policy = MixedPrecisionPolicy(
    param_dtype=torch.bfloat16,        # parameters held in bfloat16
    reduce_dtype=torch.float32,        # gradient all-reduce in float32
)

# Shard each transformer block independently for finer-grained
# overlap of compute and communication.
for block in model.transformer_blocks:
    fully_shard(block, mp_policy=mp_policy)
fully_shard(model, mp_policy=mp_policy)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

# Training loop is identical to DDP from this point on.
Output: (no stdout; FSDP2 shards parameters across all ranks)
Code Fragment E.7.2: FSDP2 with mixed-precision policy. Each transformer block is sharded as its own unit; the outer fully_shard(model) handles the remaining top-level parameters.

Sharding Strategies

A comparison diagram of three FSDP sharding strategies across four GPUs: full shard splits parameters, gradients, and optimizer state into quarters; shard grad and opt keeps full parameters per GPU but shards gradients and optimizer state; hybrid shard splits within each node and replicates across nodes

Figure E.7.1: Three FSDP sharding strategies trade memory savings for communication volume; full shard maximizes savings, hybrid shard balances intra-node speed against cross-node bandwidth.

FSDP exposes several strategies that trade memory against communication:

Combining with Activation Checkpointing

Activation checkpointing (also called gradient checkpointing) trades compute for memory: instead of saving every intermediate activation for the backward pass, only a subset is saved and the rest are recomputed. For transformers, the typical recipe is to checkpoint every attention block, which roughly halves activation memory at the cost of one extra forward pass per backward. Combined with FSDP, activation checkpointing is what makes 7B-to-70B parameter models trainable on clusters of 80GB GPUs. The PyTorch API is torch.utils.checkpoint.checkpoint(function, *args) or, for whole modules, the more ergonomic torch.distributed.algorithms._checkpoint.checkpoint_wrapper.

Warning: DDP Gradient-Sync Deadlocks

DDP synchronizes gradients at the end of every backward pass. If one rank takes a different code path that skips a backward (early-exit on a condition, an exception swallowed in a try block, a rank that finishes its data shard early), the other ranks block forever waiting for the all-reduce that will never arrive. The symptom is a job that hangs at constant GPU utilization with no progress logs. Mitigations: keep the forward and backward path identical on every rank (gate divergences with dist.all_reduce on a flag tensor), set find_unused_parameters=True only when actually needed (it imposes its own overhead), and make sure all ranks see the same number of batches per epoch (use drop_last=True and a divisor-friendly batch size).

Hugging Face Accelerate

Hugging Face's accelerate library is a thin abstraction over PyTorch's distributed primitives that lets a single training script run unchanged on CPU, single GPU, multi-GPU DDP, FSDP, or DeepSpeed. The pattern is to construct an Accelerator, call accelerator.prepare(model, optimizer, loader), and then use accelerator.backward(loss) instead of loss.backward(). The accelerate config CLI walks the user through choosing a launch strategy and writes a config file; accelerate launch script.py then handles the equivalent of torchrun.

Accelerate is the recommended starting point for teams that want distributed training without authoring the boilerplate. The escape hatch (dropping to raw DDP or FSDP) is always available when fine-grained control is needed.

Library Shortcut
accelerator.prepare() Hides the Distributed Plumbing

A single call wraps the model in DDP or FSDP, attaches a DistributedSampler to the DataLoader, moves everything to the right device, and configures mixed precision. The training loop reads exactly like the single-GPU version; the launcher decides at runtime whether to spawn one or many processes.

from accelerate import Accelerator

accelerator = Accelerator(mixed_precision="bf16")
model, optimizer, loader = accelerator.prepare(model, optimizer, loader)

for epoch in range(num_epochs):
    for batch in loader:
        loss = compute_loss(model, batch)
        accelerator.backward(loss)
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)
Key Insight

Pick the distribution strategy based on what does not fit. If the model fits but training is slow, use DDP (torchrun plus DistributedDataParallel plus DistributedSampler). If the model does not fit on a single GPU, use FSDP2 with a mixed-precision policy and activation checkpointing. If standing up the distributed boilerplate is the bottleneck, use Hugging Face Accelerate to write hardware-agnostic training code. Skip DataParallel entirely; it exists only for backward compatibility.

Lab E.7: From Single-GPU to DDP to FSDP2

Objective

Take a working single-GPU training loop and migrate it through three stages: torchrun-launched DistributedDataParallel, FSDP2 with full sharding, and finally a hardware-agnostic Hugging Face Accelerate version. At each stage, measure throughput, peak memory per rank, and final accuracy to develop intuition for when each strategy pays off. Single-machine multi-GPU is fine; if only one GPU is available, simulate two ranks with gloo on CPU to exercise the API surface.

Setup

  1. Hardware: two or more GPUs preferred; otherwise CPU with gloo backend for API practice.
  2. Install torch>=2.4 (for FSDP2 stable API) and accelerate.
  3. Workload: train a 60M-parameter GPT-style transformer on the Tiny Shakespeare character-level dataset. Use a 6-layer, 384-dim, 6-head config so a single-rank run completes in roughly 10 minutes per epoch on a modest GPU.

Steps

  1. Step 1: Single-GPU baseline. Train for one epoch. Log peak memory, samples per second, and final validation loss. This is the reference.
  2. Step 2: DDP migration. Wrap the model in DistributedDataParallel. Switch the dataloader to DistributedSampler(dataset, num_replicas=world_size, rank=rank) and call sampler.set_epoch(epoch) each epoch. Initialize with dist.init_process_group(backend="nccl"). Launch with torchrun --nproc_per_node=2 train.py. Confirm throughput scales near-linearly with world size on small batches and that final loss matches the baseline within 2 percent.
  3. Step 3: Add gradient accumulation under DDP. Use model.no_sync() on the inner micro-batches to skip the gradient all-reduce until the last micro-batch. Measure the gain (one all-reduce per K micro-batches instead of K).
  4. Step 4: FSDP2 migration. Replace the DDP wrap with fully_shard(model, mp_policy=MixedPrecisionPolicy(param_dtype=torch.bfloat16)). Compare peak memory per rank against DDP (expect roughly a 3x reduction at this scale because both parameters and optimizer state are sharded). Throughput will drop slightly; quantify the trade-off.
  5. Step 5: Activation checkpointing. Wrap each transformer block with torch.utils.checkpoint.checkpoint. Re-measure peak memory and step time. Expect a 30 to 40 percent memory drop in exchange for roughly 30 percent slower steps.
  6. Step 6: Accelerate rewrite. Rewrite the loop using Accelerator. Confirm that the same single source file launches correctly via accelerate launch --multi_gpu and via accelerate launch --num_processes=1 without changes. Count lines deleted versus added.

Stretch Goals

Expected Output

Expected time: 6 to 8 hours. Difficulty: advanced. Artifact: three working training scripts (single-GPU, DDP, FSDP2) plus the Accelerate version, with a benchmark table comparing all four on the same workload.

Further Reading

Distributed Training References

Li, S. et al. (2020). "PyTorch Distributed: Experiences on Accelerating Data Parallel Training." VLDB 2020. arXiv:2006.15704. The DDP paper. Explains gradient bucketing, overlap with backward, and the rationale for one-process-per-GPU.
Zhao, Y. et al. (2023). "PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel." VLDB 2023. arXiv:2304.11277. The FSDP paper. Covers sharding strategies, mixed precision, activation checkpointing, and the empirical scaling behavior at thousand-GPU scale.
Rajbhandari, S. et al. (2020). "ZeRO: Memory Optimizations Toward Training Trillion Parameter Models." SC20. arXiv:1910.02054. The DeepSpeed ZeRO paper, which articulated the sharding strategy FSDP inherits. The clearest exposition of why sharding optimizer state matters at scale.
Hugging Face Accelerate Documentation. The library's user guide, including launcher, config wizard, and worked examples for DDP, FSDP, and DeepSpeed back ends.