Distributed Training at Scale

Section 6.6

Training a large model on one GPU is like reading the entire internet through a keyhole. Distributed training lets you knock down the wall, provided all the GPUs agree on which wall and when.

ScaleScale, Wall Demolishing AI Agent
Big Picture

No single GPU can train a modern LLM. A 70B parameter model requires over 140 GB just for its parameters in FP16 (16-bit floating point, 2 bytes per parameter), far exceeding the memory of any single accelerator. Training such models demands distributing computation across dozens to thousands of GPUs, coordinating their work through high-speed interconnects. This section covers the four fundamental parallelism strategies (data, tensor, pipeline, and expert parallelism), the communication primitives that enable them, mixed-precision training (including the even narrower 8-bit FP8 format), and the memory optimization techniques that make large-scale training feasible. The GPU compute model from Section 3.6 explains why memory bandwidth, not raw FLOPS (floating-point operations per second), is the binding constraint.

Key Insight: Remember

Four parallelisms, four cuts of the cake: data parallelism splits the batch, tensor parallelism splits the matmul, pipeline parallelism splits the layers, expert parallelism splits the FFN. Real training stacks all four at once; the art is keeping the GPUs talking faster than they compute.

Prerequisites

This section assumes familiarity with PyTorch tensor operations from Section 0.2 and the transformer architecture from Section 3.1. Understanding of matrix multiplication is essential for the tensor parallelism discussion. The optimizer memory analysis from Section 6.5 motivates why distributed training is necessary.

6.6.1 Communication Primitives

Distributed training relies on collective communication operations to synchronize data between GPUs. Understanding these primitives is essential for reasoning about the communication overhead of different parallelism strategies.

Table 6.6.1: Primitive Comparison (as of 2026).
Primitive Input Output Use Case
All-Reduce Each GPU has a tensor All GPUs have the sum Gradient synchronization (DDP, Distributed Data Parallel; see Section 6.6.2)
All-Gather Each GPU has a shard All GPUs have the full tensor Parameter reconstruction (FSDP)
Reduce-Scatter Each GPU has a tensor Each GPU has a shard of the sum Gradient sharding (FSDP)
Broadcast One GPU has a tensor All GPUs have it Weight initialization

These operations are implemented efficiently using ring or tree topologies. In a ring all-reduce with $P$ GPUs, each GPU sends and receives $2(P-1)/P$ times the tensor size, giving near-optimal bandwidth utilization regardless of the number of GPUs. The NCCL library (NVIDIA Collective Communications Library) provides highly optimized implementations for NVIDIA GPUs.

6.6.2 Data Parallelism (DDP)

Fun Fact

Training GPT-4 reportedly required tens of thousands of GPUs running in parallel for months. The electricity bill alone likely exceeded what it costs to run a small town for a year. "Scaling laws" sometimes feel less like scientific principles and more like a dare issued to the power grid.

Distributed Data Parallelism is the simplest and most widely used form of parallelism. Each GPU holds a complete copy of the model and processes a different subset of the training data. After each forward-backward pass, gradients are synchronized across all GPUs using all-reduce, ensuring that all copies perform identical parameter updates.

DDP: each GPU holds full model, processes different data, gradients synchronized via all-reduce
Figure 6.6.1a: In DDP, each GPU holds a full model copy and processes different data. Gradients are synchronized via all-reduce after each backward pass.
# DDP training with PyTorch
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def setup_ddp(rank, world_size):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)
def train_ddp(rank, world_size, model_class):
    setup_ddp(rank, world_size)
    # Each GPU gets a full model copy
    model = model_class().to(rank)
    model = DDP(model, device_ids=[rank])
    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
    for batch in dataloader:
        # Reset gradients from previous step
        optimizer.zero_grad()
        loss = model(batch)
        # Compute gradients via backpropagation
        loss.backward() # DDP auto-syncs gradients via all-reduce
        # Update weights using computed gradients
        optimizer.step()
        dist.destroy_process_group()
Code Fragment 6.6.1b: DDP training with PyTorch.
import torch
# FSDP training with PyTorch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
# Mixed precision policy for FSDP
mp_policy = MixedPrecision(
    param_dtype=torch.bfloat16,
    reduce_dtype=torch.bfloat16,
    buffer_dtype=torch.bfloat16,
    )
# Wrap model with FSDP (full sharding = ZeRO Stage 3)
model = FSDP(
    model,
    sharding_strategy=ShardingStrategy.FULL_SHARD,
    mixed_precision=mp_policy,
    device_id=torch.cuda.current_device(),
    )
# Training loop is identical to standard PyTorch
for batch in dataloader:
    optimizer.zero_grad()
    loss = model(batch)
    loss.backward()
    optimizer.step()
Code Fragment 6.6.2: DDP training with PyTorch.

6.6.3 Fully Sharded Data Parallelism (FSDP) and ZeRO

DDP's limitation is that every GPU must hold a complete copy of the model, gradients, and optimizer states. For a 7B model with AdamW, that is ~112 GB per GPU. FSDP (and the equivalent DeepSpeed ZeRO) resolves this by sharding these tensors across GPUs so each GPU stores only a fraction. These sharding techniques are equally important during fine-tuning of large models (see Section 16.4).

ZeRO Optimization Stages

Key Insight

FSDP Stage 3 trades communication for memory. Each layer's forward pass requires an all-gather to reconstruct the full parameters, and each backward pass requires a reduce-scatter of the gradients. This means each parameter is communicated 3 times per training step (gather for forward, gather for backward, reduce-scatter for gradient). The communication overhead is significant but acceptable when the alternative is not being able to train the model at all.

# FSDP (Fully Sharded Data Parallel) training with PyTorch.
# Each rank holds only 1/N of the parameters; missing shards are gathered
# just-in-time during forward/backward and freed immediately after.
import torch
import torch.distributed as dist
from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    MixedPrecision,
    ShardingStrategy,
)
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from functools import partial

# 1) Initialize the process group (one rank per GPU)
dist.init_process_group(backend="nccl")
local_rank = dist.get_rank()
torch.cuda.set_device(local_rank)

# 2) Build the model on each rank
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3-8B")

# 3) Wrap with FSDP -- one wrapper per transformer block keeps memory low
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={LlamaDecoderLayer})

model = FSDP(
    model,
    sharding_strategy=ShardingStrategy.FULL_SHARD,
    auto_wrap_policy=wrap_policy,
    mixed_precision=MixedPrecision(
        param_dtype=torch.bfloat16,
        reduce_dtype=torch.bfloat16,
        buffer_dtype=torch.bfloat16,
    ),
    device_id=torch.cuda.current_device(),
)

# 4) Normal-looking training step -- FSDP handles all-gather/reduce-scatter
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5)
for batch in train_loader:
    out = model(**batch)
    out.loss.backward()
    optimizer.step()
    optimizer.zero_grad(set_to_none=True)

dist.destroy_process_group()
Code Fragment 6.6.3: FSDP training with PyTorch.

6.6.4 Tensor Parallelism

Tensor parallelism splits individual layers across GPUs. For a linear layer $Y = XW$, the weight matrix $W$ can be split along its columns (column parallelism) or rows (row parallelism). Each GPU computes a portion of the output, and an all-reduce or all-gather combines the partial results.

Column Parallelism

Split $W$ into columns: $W = [W_{1} | W_{2}]$. GPU 0 computes $XW_{1}$, GPU 1 computes $XW_{2}$. The results are concatenated, requiring no communication in the forward pass (but an all-reduce in the backward pass). This is typically used for the first linear layer in the feed-forward network.

Row Parallelism

Split $W$ into rows. Each GPU processes a different slice of the input. The partial outputs are summed via all-reduce in the forward pass. This is typically used for the second linear layer in the feed-forward network.

In Megatron-LM style parallelism, column and row parallelism are combined so that the MLP block requires only one all-reduce in the forward pass and one in the backward pass. Tensor parallelism requires very fast interconnects (NVLink within a node) because communication happens at every layer. These same parallelism strategies are also essential for inference serving at scale, as discussed in Section 9.5.

Key Insight: Column-Parallel Matrix Multiply

Consider a feed-forward layer with input $X$ (batch=2, d=4) and weight $W$ (4x8), split across 2 GPUs:

GPU 0: $W_{0}$ = first 4 columns of W (4x4). Computes $Y_{0}$ = X · $W_{0}$, producing a (2x4) result.

GPU 1: $W_{1}$ = last 4 columns of W (4x4). Computes $Y_{1}$ = X · $W_{1}$, producing a (2x4) result.

Combine: Y = [$Y_{0}$ | $Y_{1}$] via concatenation (no communication needed). Each GPU did half the work, and the result is identical to a single GPU computing Y = X · W. The catch: the backward pass requires an all-reduce to sum gradients across GPUs.

Note: Assembly Line vs. Task Division

Think of distributed training strategies as ways to organize a factory. Data parallelism is like opening duplicate factories that each build complete products from different orders. Tensor parallelism is like splitting each workstation across two workers who each handle half the parts. Pipeline parallelism is like an assembly line where each station does one step. Expert parallelism is like a specialized factory floor where different workers handle different product types, and a router directs each order to the right specialist.

6.6.5 Pipeline Parallelism

Pipeline parallelism assigns different layers of the model to different GPUs. GPU 0 runs layers 0-15, GPU 1 runs layers 16-31, and so on. The input flows through the pipeline, with each GPU passing its output to the next.

The naive approach has a severe pipeline bubble problem: while GPU 0 is processing the forward pass, GPUs 1-3 are idle, and while GPU 3 is processing the backward pass, GPUs 0-2 are idle. The 1F1B (one forward, one backward) schedule mitigates this by splitting each batch into micro-batches and interleaving forward and backward passes across micro-batches. This keeps all GPUs active most of the time, though a small bubble remains at the beginning and end of each batch.

1F1B pipeline schedule interleaves forward and backward micro-batches
Figure 6.6.2a: The 1F1B pipeline schedule interleaves forward (F) and backward (B) micro-batches to minimize idle time (bubbles).
Key Takeaways
Self-Check
1. What is the fundamental difference between DDP and FSDP?
Show Answer
DDP replicates the full model, optimizer states, and gradients on every GPU. Each GPU processes different data and synchronizes gradients via all-reduce. FSDP shards (splits) the model parameters, gradients, and optimizer states across GPUs so each GPU stores only a fraction. FSDP reconstructs full parameters on-demand for each layer's computation via all-gather, and reduces gradients via reduce-scatter. DDP uses more memory but less communication; FSDP uses less memory but more communication.
2. What causes pipeline bubbles and how does the 1F1B schedule mitigate them?
Show Answer
Pipeline bubbles are idle time on GPUs when they have no work to do. In naive pipeline parallelism, GPU 0 must finish the forward pass for the entire batch before GPU 1 can start, creating a cascade of idle time. The 1F1B schedule splits the batch into micro-batches and interleaves forward and backward passes. Once GPU 0 finishes the forward pass for micro-batch 1, it can start micro-batch 2's forward pass while GPU 1 processes micro-batch 1. This keeps all GPUs busy for most of the training step, though small bubbles remain at the start and end.

What's Next?

The discussion continues in Section 6.6a: Mixed Precision, Checkpointing, 3D Parallelism & Ring Attention, which covers BF16 / FP8 training, gradient checkpointing, how to compose tensor + pipeline + data parallelism into a 3D recipe, and ring attention for very long contexts. After that, Section 6.7 turns to in-context learning theory.

Further Reading

Data & Model Parallelism

Rajbhandari, S. et al. (2020). "ZeRO: Memory Optimizations Toward Training Trillion Parameter Models." SC 2020. Introduces the ZeRO family of optimizations that partition optimizer states, gradients, and parameters across GPUs. The foundation of DeepSpeed and the technique that makes training models larger than single-GPU memory feasible.
Shoeybi, M. et al. (2020). "Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism." arXiv preprint arXiv:1909.08053. Presents efficient tensor parallelism strategies that split individual layers across GPUs within a node. The go-to reference for understanding how attention heads and MLP columns are distributed in practice.
Zhao, Y. et al. (2023). "PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel." VLDB 2023. Documents the design and production experience of PyTorch's native FSDP implementation. Covers practical trade-offs between sharding strategies, communication overhead, and memory savings for practitioners using the PyTorch ecosystem.