Section 59.2

ZeRO and FSDP: Memory-Efficient Data Parallelism

"With N replicas, (N-1)/N of your VRAM is a redundant copy of someone else's optimizer state. ZeRO is the trick of shipping the bytes without storing them twice."

KVKV, Memory-Allergic AI Agent
Big Picture

A few terms up front. Data parallelism (DDP, Distributed Data Parallel): every GPU holds an identical copy of the model and processes a different mini-batch slice, then averages gradients. BF16 is bfloat16, a 16-bit floating-point format with the same exponent range as FP32 (so it rarely overflows) but less mantissa precision. HBM is High-Bandwidth Memory, the on-package GPU memory (e.g., the 80 GB on an H100). All-reduce / reduce-scatter / all-gather are the three collective communication primitives used here; their meanings are spelled out in 59.2.2.1-3 as they appear. With those in hand: the single biggest waste in plain data parallelism is that every rank stores identical copies of every optimizer state, gradient, and parameter. With $N$ replicas, $(N-1)/N$ of that memory is redundant. ZeRO (DeepSpeed, 2019) and its PyTorch reimplementation FSDP shard this state across ranks and materialize each layer's full tensors only just-in-time during forward and backward. The same total bytes traverse the wire as a vanilla all-reduce; the savings come from never holding the full state on any one rank. ZeRO Stage 3 / FSDP Full-Shard is what makes 70B models trainable on commodity 80 GB GPUs. For LLM practitioners, this is the practical floor that decides whether an open-weight LLM (Llama, Qwen, Mistral) can be fine-tuned in-house on the GPUs you can actually rent, versus calling out to a paid training API; ZeRO and FSDP are the difference between LLM training as commodity and LLM training as a privilege.

Prerequisites

This section assumes familiarity with distributed training fundamentals from Section 59.1 and with mixed-precision fine-tuning from Section 13.1. Familiarity with PyTorch optimizer internals from Section 0.5 helps when reading the optimizer-state breakdown.

59.2.1 The Memory Budget of Mixed-Precision Training

Fun Fact

Microsoft's DeepSpeed team published the ZeRO paper in 2020 as a sequence of three stages, where each stage shards more of the training state across GPUs. The acronym, "Zero Redundancy Optimizer," is the rare instance of a research project where the marketing name is more honest than the mathematics: in ZeRO-3, every byte of the optimizer state, gradients, and parameters lives on exactly one rank at a time. PyTorch's FSDP is, more or less, a clean-room rewrite of ZeRO-3 that the PyTorch team can ship without a Microsoft logo, which is the kind of inter-vendor diplomacy that happens entirely inside a config file.

Before we can shard anything, we need to know what we are sharding. The per-rank memory footprint of a transformer training step decomposes into four bucket classes. Let $P$ be the parameter count, $N$ the data-parallel world size.

Table 59.2.1: Per-rank memory of mixed-precision training (bytes per parameter).
BucketBytes / paramFor 70B modelShardable?
BF16 parameters (forward / backward)2140 GBYes (Stage 3)
BF16 gradients2140 GBYes (Stage 2)
FP32 master weights (for optimizer)4280 GBYes (Stage 1)
FP32 AdamW first moment $m$4280 GBYes (Stage 1)
FP32 AdamW second moment $v$4280 GBYes (Stage 1)
Total state161.12 TB
+ activations (typical)varies~200-800 GBRecompute (Section 6.6)

The 14 bytes per parameter beyond the BF16 params are almost entirely redundant under data parallelism. Every rank performs identical AdamW updates on identical inputs, so storing 1.12 TB on each of 16 GPUs is a 17.92 TB of cluster memory for a 1.12 TB working set. ZeRO erases that redundancy stage by stage.

59.2.2 The ZeRO Progression

A vertical mountain-climbing scene with four cartoon climbers ascending, labelled Stage 0, Stage 1, Stage 2, and Stage 3, each carrying a noticeably smaller backpack than the climber below, with the summit flag labelled Trillion-parameter model.
Figure 59.2.1a: ZeRO climbers ascend the memory mountain. Stage 0 carries everything; Stage 3 carries just a daypack. The trade is communication for memory: higher stages talk to peers more often.

Rajbhandari et al. (SC'20) introduced ZeRO (Zero Redundancy Optimizer) as a series of three orthogonal sharding stages, each picking off one of the buckets above. Each later stage subsumes the earlier ones and adds more communication for more memory savings.

Warning: Common Misconception

"ZeRO-3 and FSDP are competing approaches" is the most-repeated source of confusion across blog posts, course slides, and pull-request reviews. They are not competing; FSDP is essentially PyTorch's clean-room reimplementation of ZeRO Stage 3 inside the framework, while DeepSpeed continues to ship the original C++ implementation. The communication patterns, math, and per-rank memory footprint are identical. The decision between them is operational (PyTorch-native integration, mixed-precision flavor, sharding granularity) rather than algorithmic. Picking one because it has "less communication overhead" is usually a sign that the benchmark was measuring something else, like activation checkpointing or kernel fusion.

ZeRO progressively shards the per-rank state. Dashed boxes are sharded (only $1/N$ stored per rank); solid boxes are fully replicated. Communication per step grows with each stage: Stage 0 needs 1 all-reduce; Stage 1 needs 1 all-reduce on the gradients; Stage 2 replaces that with 1 reduce-scatter; Stage 3 adds 1 all-gather (forward) and 1 all-gather (backward) per layer on top. The memory savings come at the cost of these additional collectives, but the wire-time per step grows only modestly.
Figure 59.2.2: ZeRO progressively shards the per-rank state. Dashed boxes are sharded (only $1/N$ stored per rank); solid boxes are fully replicated. Communication per step grows with each stage: Stage 0 needs 1 all-reduce; Stage 1 needs 1 all-reduce on the gradients; Stage 2 replaces that with 1 reduce-scatter; Stage 3 adds 1 all-gather (forward) and 1 all-gather (backward) per layer on top. The memory savings come at the cost of these additional collectives, but the wire-time per step grows only modestly.

59.2.2.1 ZeRO Stage 1: Shard Optimizer State

The AdamW master weights, first moment, and second moment are 12 bytes per parameter and used only inside optimizer.step(). Stage 1 partitions them across the $N$ data-parallel ranks: rank $i$ stores $1/N$ of every optimizer tensor.

Forward and backward run as in DDP, producing a full BF16 gradient on every rank. Before the optimizer step, an all-reduce averages gradients (same as DDP). Then each rank applies the optimizer to its shard of the master weights, producing a $1/N$ shard of updated BF16 parameters. A final all-gather reconstructs the full BF16 parameter tensor on every rank.

Memory per rank drops from $16$ B/param to $4 + 12/N$ B/param. At $N=64$, that is roughly $4.2$ B/param: a 70B model fits in $\approx 290$ GB of state per rank, or about $3.6$ H100s. Communication grows by one all-gather (size $S$, replacing the implicit broadcast inside DDP). For most workloads, Stage 1 is nearly free.

59.2.2.2 ZeRO Stage 2: Also Shard Gradients

The gradient is fully needed only on the rank that owns the corresponding optimizer-state shard. So replace the all-reduce with a reduce-scatter: each rank ends up holding only $1/N$ of the gradient (the part corresponding to its optimizer shard). Memory per rank drops to $2 + 14/N$ B/param.

The communication math is striking: a reduce-scatter moves the same total $S(N-1)/N$ bytes as the reduce-scatter half of an all-reduce. We have saved roughly half the gradient memory ($\approx 70$ GB for a 70B model at $N=64$) by replacing two collectives with one half-as-expensive collective.

59.2.2.3 ZeRO Stage 3: Also Shard Parameters

The big one. At any moment, only the parameters of the current layer need to materialize as full tensors; the rest can stay sharded. Stage 3 (also called FSDP Full-Shard) wraps each transformer block so that:

  1. Just before forward: an all-gather reconstructs the block's full parameters on every rank.
  2. During forward: the block computes normally; the gathered parameters are held in memory.
  3. After forward: the gathered parameters are freed; only the $1/N$ shard persists.
  4. Just before backward: another all-gather reconstructs the parameters (they were freed in step 3).
  5. After backward: a reduce-scatter on the gradients, producing only the $1/N$ gradient shard on each rank.

Memory per rank is $16/N$ B/param. At $N=64$, that is $0.25$ B/param: a 70B model fits in $17.5$ GB of state per rank. You can run a 70B model on 8 GPUs (one DGX node), the configuration that put open-weight 70B models in reach of every research lab. The cost is three collectives per layer per step (forward gather, backward gather, gradient scatter), against one per step for DDP. The communication budget grows by roughly $3\times$ but is still bounded by interconnect bandwidth, not by $N$.

Key Insight
All-reduce factors into reduce-scatter + all-gather

The mathematical identity that powers ZeRO is that all-reduce = reduce-scatter then all-gather. In DDP both halves run on every rank; in ZeRO Stage 3 only the all-gather runs (during the next forward), and only on layers about to execute. Compared to DDP the total bytes moved are the same; the peak per-rank memory is $N\times$ lower. This single observation is responsible for nearly every memory-efficient parallelism scheme of the past five years.

59.2.3 PyTorch FSDP: ZeRO Re-Implemented Natively

FSDP (Zhao et al., 2023) is the PyTorch-native reimplementation of ZeRO Stage 3, designed to feel like DDP from the user's perspective. FSDP differs from DeepSpeed ZeRO in two important ways: (a) it integrates with torch.compile, AOT autograd, and the rest of the PyTorch ecosystem natively; (b) it uses explicit wrapping policies to control sharding granularity, which gives finer control over the activation-vs-comm trade-off.

59.2.3.1 Wrapping Policies

FSDP shards parameters at the granularity of wrap units. A wrap unit is one all-gather: every parameter inside the unit is gathered together. There are three common policies:

59.2.3.2 A Minimal FSDP Training Loop

# fsdp_minimal.py: torchrun --nproc_per_node=8 fsdp_minimal.py
import os, functools, torch
import torch.distributed as dist
from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    ShardingStrategy, MixedPrecision, BackwardPrefetch, CPUOffload,
)
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from transformers import AutoModelForCausalLM
from transformers.models.llama.modeling_llama import LlamaDecoderLayer

def main():
    dist.init_process_group(backend="nccl")
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)

    # 1) Build the bare model on each rank (full materialization here; meta-device
    #    init is the production trick when even that is too big; see Section 59.5).
    model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3-70B")

    # 2) Wrap each transformer block as its own FSDP unit.
    wrap_policy = functools.partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls={LlamaDecoderLayer},
    )

    # 3) Mixed precision: parameters in bf16, gradient reduction in bf16,
    #    keep loss / norm buffers in fp32 to avoid underflow.
    mp = MixedPrecision(
        param_dtype=torch.bfloat16,
        reduce_dtype=torch.bfloat16,
        buffer_dtype=torch.float32,
    )

    # 4) FULL_SHARD = ZeRO Stage 3. SHARD_GRAD_OP is the Stage 2 equivalent.
    model = FSDP(
        model,
        sharding_strategy=ShardingStrategy.FULL_SHARD,
        auto_wrap_policy=wrap_policy,
        mixed_precision=mp,
        backward_prefetch=BackwardPrefetch.BACKWARD_PRE,   # overlap gather w/ compute
        cpu_offload=CPUOffload(offload_params=False),      # set True if HBM tight
        device_id=local_rank,
    )

    # 5) The optimizer sees only the local shard, internally an fp32 master copy.
    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5)

    for batch in train_loader:
        out = model(**batch)
        out.loss.backward()
        # FSDP's hooks have already reduce-scattered grads by this point.
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)

    dist.destroy_process_group()

if __name__ == "__main__":
    main()
Code Fragment 59.2.1b: A complete FSDP loop. The only difference from DDP is the wrap step; all the magic happens in the FSDP(...) constructor's hook installation. BackwardPrefetch.BACKWARD_PRE issues the next block's all-gather before the current block's backward finishes, hiding the gather behind compute.

59.2.3.3 Prefetching and Overlap

The naive FSDP forward pass is: gather block $k$, compute block $k$, free block $k$, gather block $k+1$, compute block $k+1$, …. Each gather is a synchronous wait. With forward_prefetch=True, FSDP issues the gather for block $k+1$ while compute on block $k$ is still running; the same trick for BackwardPrefetch.BACKWARD_PRE on the way back down. With good overlap, the communication is essentially free up to about a $30\%$ slowdown over DDP for the largest models, which is the lever that decides whether FSDP is the right choice.

59.2.3.4 CPU Offload

If HBM is still tight even after sharding, FSDP supports CPUOffload(offload_params=True): the $1/N$ parameter shards live on host memory and are paged in to HBM only during a forward or backward gather. CPU memory is plentiful (1-2 TB per node is common); the cost is PCIe traffic on every layer's gather. PCIe Gen5 x16 is $63$ GB/s, far slower than NVLink, so CPU offload roughly halves training throughput. It is a "I cannot run otherwise" feature, not a default.

59.2.4 Memory Arithmetic: Llama 70B and 405B

To make the trade-offs concrete, here is the per-rank memory math for two scales on H100s with 80 GB. We assume mixed precision (BF16 + FP32 master), AdamW, transformer-block wrap, and gradient checkpointing.

Table 59.2.2a: Per-rank memory and minimum GPU count, by strategy.
StrategyState / GPU (70B)Min GPUs (70B)State / GPU (405B)Min GPUs (405B)
DDP (Stage 0)1.12 TB146.48 TB81
ZeRO-1280 GB + 12/N·P41.62 TB21
ZeRO-2140 GB + 14/N·P3810 GB11
ZeRO-3 / FSDP Full-Shard16/N·P1 (N≥14)16/N·P1 (N≥81)
ZeRO-3 + CPU offload~10 GB1 (N≥4)~50 GB1 (N≥16)

The interpretation: with 16-way FSDP Full-Shard, the per-rank state of a 70B model is $1.12 \text{ TB} / 16 \approx 70$ GB, which fits an 80 GB H100 with breathing room for activations. With 64-way sharding, the per-rank state is $17.5$ GB and you have 60+ GB for activations, KV caches, or a larger batch.

Real-World Scenario: Llama-3 405B on 1024 GPUs

Meta's Llama-3 405B was trained on 16,000 H100s. The training plan was FSDP Hybrid-Shard inside each pod ($N=128$ per shard group, replicated across pods) combined with tensor parallelism inside each node ($T=8$). At pod scale, per-rank FSDP state was $405 \times 10^9 \cdot 16 / 128 \approx 50$ GB, leaving 30 GB on each H100 for activations, KV caches in eval, and the working set of NCCL buffers. With tensor parallelism cutting the per-replica work by another factor of 8, the effective model state per GPU was $\approx 6$ GB. This recipe is the modern frontier-LLM default.

59.2.5 ZeRO++ and Recent Improvements

The three classic ZeRO stages have been refined by half a decade of follow-up work. The most impactful additions:

59.2.5.1 Quantized Weights on the Wire

ZeRO++ (Wang et al., 2023) observed that the all-gather of BF16 parameters is bandwidth-bound. By block-quantizing the parameter shards to INT4 or INT8 before transmission and dequantizing on the receiver, you reduce the wire bytes by $4\times$ at minimal accuracy loss. Quantization is per-block (one scale per ~128 elements) to preserve dynamic range. The communication-overhead reduction translates directly to throughput.

59.2.5.2 Hierarchical Communication (Hybrid-Shard)

When the data-parallel group spans multiple nodes, a flat all-gather has to cross the slow inter-node fabric for every gather. Hybrid-Shard partitions the world into intra-node groups: each node fully shards within itself ($N=8$ for an H100 server), and across nodes runs plain DDP. Memory savings are $8\times$, all collectives stay on NVLink, and inter-node traffic is the one DDP all-reduce per step. This is what Meta uses for Llama-3.

59.2.5.3 Overlap and Pipelining Improvements

FSDP2 (PyTorch 2.3+, 2024) and DTensor-based sharding (2025) replaced the original wrapper-based implementation with a per-parameter sharding model that prefetches and overlaps more aggressively. Combined with torch.compile, FSDP2 closes most of the throughput gap with hand-tuned Megatron-DeepSpeed configurations.

59.2.6 FSDP vs Tensor Parallelism: When to Choose Which

FSDP and Megatron-style tensor parallelism (Section 59.3) both make a too-large model fit. They are not interchangeable; the right choice depends on cluster topology and model shape.

Table 59.2.3: FSDP versus tensor parallelism: trade-off matrix.
PropertyFSDP (ZeRO-3)Tensor Parallel (Megatron)
Where it worksAny cluster, any topologyOnly within fast (NVLink) groups
Communication per layer1 all-gather + 1 reduce-scatter2 all-reduces (in MLP + attn)
Activation memoryFull activations on each rankSharded $1/T$ with sequence parallel
Compute / FLOP overheadNoneNone
Code complexityOne wrapperCustom modules per layer type
Best forModels that fit at TP=1Models that don't fit at TP=1

The rule of thumb in 2026: use tensor parallelism within a node (TP $\le$ 8), use FSDP across nodes. Tensor parallelism exploits NVLink's bandwidth advantage; FSDP exploits the fact that data-parallel scaling is bandwidth-bounded, not latency-bounded, so it tolerates InfiniBand. The 3D-parallel composition in Section 59.4 makes this concrete.

Warning: FSDP and gradient accumulation interact subtly

With no_sync() active across gradient-accumulation micro-steps, FSDP avoids the reduce-scatter on intermediate backwards. But the gradient sharding state stays in BF16 across accumulation steps, which can underflow on long sequences. The fix is to use MixedPrecision(reduce_dtype=torch.float32) for gradient reduction, or to enable FullStateDictType.SHARDED_STATE_DICT only at checkpoint time. Most "training diverged at step 5000 with no obvious cause" stories in 2024 were variants of this; the fix is in the FSDP2 default config.

59.2.7 Activation Checkpointing and FSDP

FSDP shards parameters and gradients, but it does not shard activations. The forward pass produces $\mathcal{O}(B L d \cdot n_\text{layers})$ activation memory which is fully materialized on every rank. For a Llama-3 70B model at $B=4, L=8192$ this is roughly 160 GB per rank, well over an 80 GB H100's budget.

Activation (gradient) checkpointing recomputes activations during backward rather than storing them. Wrap the transformer block in torch.utils.checkpoint and only the block's input is stored; the intermediate activations are recomputed from that input during backward. Memory drops by roughly $4\times$ per checkpointed block (the savings depend on what the block computes internally); compute cost rises by about $33\%$ (the recomputed forward).

FSDP and activation checkpointing compose with two subtleties:

59.2.8 Meta-Device Initialization: Skipping the OOM at Startup

For 70B+ models, even constructing the model on a single rank to wrap it with FSDP overflows memory. AutoModelForCausalLM.from_pretrained("llama-70b") materializes 140 GB of BF16 parameters; one rank cannot hold that.

The standard workaround is the meta device: construct the model with device="meta", which allocates only tensor metadata (shape, dtype) but no actual storage. Then FSDP's wrap pass shards the meta tensors, materializes only the local shard on each rank's GPU, and finally loads parameters shard-by-shard from disk. Total per-rank materialization is $\approx P/N$ bytes, never the full $P$.

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from transformers import AutoModelForCausalLM

# Build the model graph on the meta device (no real allocation).
with torch.device("meta"):
    model = AutoModelForCausalLM.from_pretrained(
        "meta-llama/Llama-3-70B", torch_dtype=torch.bfloat16
    )

# FSDP's wrap pass materializes only this rank's shard on GPU.
model = FSDP(
    model,
    sharding_strategy=ShardingStrategy.FULL_SHARD,
    auto_wrap_policy=wrap_policy,
    sync_module_states=True,    # rank 0 broadcasts param state to others
    param_init_fn=lambda mod: mod.to_empty(device="cuda"),
)

# Finally, stream the actual weights from disk to the local shards.
load_sharded_state_dict(model, checkpoint_path)
Code Fragment 59.2.2b: Meta-device initialization. The with torch.device("meta") context defers all tensor materialization until FSDP's wrap pass, at which point only the local shard is allocated on each rank.

59.2.9 DeepSpeed-ZeRO vs PyTorch FSDP: 2026 Status

The two stacks have converged in capability but diverge in ecosystem fit:

Table 59.2.4: DeepSpeed-ZeRO vs PyTorch FSDP, mid-2026.
AspectDeepSpeed ZeROPyTorch FSDP / FSDP2
OriginMicrosoft, 2019PyTorch Core, 2022 (FSDP2: 2024)
IntegrationHugging Face Trainer (default)Hugging Face Accelerate, native
torch.compileLimited (eager-mode hooks)Native, including DTensor
Pipeline parallelYes, well-integratedVia PiPPy or Megatron-Core
MoE / Expert parallelTutel / DeepSpeed-MoEMegablocks / DTensor
ConfigurationJSON config filesPython kwargs
Best forHF Trainer recipes, drop-inCustom training loops, deep integration

In 2026, FSDP2 is the default in new training stacks (Megatron-Core uses it under the hood for the DP axis); DeepSpeed remains the path of least resistance when starting from Hugging Face Trainer recipes. Both reach 35-45% MFU on H100s at 70B scale; the gap between them is now within measurement noise.

DeepSpeed ZeRO and PyTorch FSDP are the same algorithm; differ in surface area
Figure 59.2.4a: DeepSpeed ZeRO and PyTorch FSDP/FSDP2 implement the same sharding algorithm (ZeRO Stage 3 == FSDP Full Shard). They differ in ecosystem fit: ZeRO is the path of least resistance from a HuggingFace Trainer recipe; FSDP is the path of least resistance from a custom torch.compile training loop. MFU at 70B scale is within measurement noise.
Real-World Scenario
Picking FSDP vs ZeRO for a 70B Llama Fine-Tune

Who: A 2026 enterprise team fine-tuning Llama-3-70B for a regulated customer-support agent.

Situation: 16x H100 SXM5 GPUs (2 nodes), HuggingFace Trainer-based code from a previous 7B run, dataset of 200k preference pairs for DPO.

Decision: The team started with DeepSpeed ZeRO-3 because the existing Trainer-based training loop already had a working deepspeed_config.json for the 7B model. They updated the JSON to enable cpu_offload and activation_checkpointing for the 70B model and reached 38% MFU within a day. Later, they experimented with FSDP2 + torch.compile in a fork; this required rewriting the training loop and yielded 42% MFU, a 10% wall-clock saving. They kept ZeRO in production for the next quarter and migrated to FSDP only when they decided to integrate sequence parallelism, which FSDP2 + DTensor supports natively.

Lesson: Pick FSDP when starting fresh with torch.compile or planning hybrid sharding. Pick ZeRO when migrating an existing HF Trainer recipe. The throughput delta is real but small; the engineering-friction delta is large.

Library Shortcut
accelerate launch --use_fsdp for FSDP without the FSDP API

The minimal FSDP training loop in 59.2.3.2 shows what FSDP does; in practice, you almost never call FullyShardedDataParallel(...) directly. Hugging Face accelerate wraps the wrap-policy, mixed-precision, and activation-checkpointing knobs into a YAML config plus a launcher, so a single-GPU training script becomes a multi-node FSDP run by changing how it is launched. Prefer accelerate launch as the FSDP entry point; drop to raw FSDP only when you need a custom wrap policy or a non-standard backward hook.

Show code
pip install accelerate
# Interactive wizard writes ~/.cache/huggingface/accelerate/default_config.yaml
accelerate config
#   -> choose: multi-GPU, FSDP, FULL_SHARD, BACKWARD_PRE,
#      TRANSFORMER_BASED_WRAP (auto-wraps each transformer block),
#      bf16 mixed precision

# Same train.py, now FSDP across 8 GPUs on one node:
accelerate launch --num_processes=8 --use_fsdp \
    --fsdp_sharding_strategy=FULL_SHARD \
    --fsdp_auto_wrap_policy=TRANSFORMER_BASED_WRAP \
    --mixed_precision=bf16 train.py
Code Fragment 59.2.9.1: Turning any HF training script into a sharded FSDP run with two commands.
Looking Back: Summary

ZeRO and FSDP take the redundant replication out of plain data parallelism. Stage 1 shards optimizer state; Stage 2 shards gradients; Stage 3 / FSDP Full-Shard shards parameters and materializes each layer just-in-time. The peak per-rank state is $16/N$ bytes per parameter at Stage 3, a $64\times$ reduction over DDP at $N=64$, against a roughly $50\%$ throughput cost that prefetching and overlap can mostly hide. Hybrid-Shard combines FSDP within a node with DDP across nodes, exploiting the bandwidth hierarchy. The next section turns to the orthogonal axis: tensor parallelism, which shards within a layer rather than across the data dimension.

What's Next

ZeRO and FSDP shard data-parallel state; the next leap shards the model itself across devices. Continue to Section 59.3: Megatron-LM and Tensor Parallelism.

Further Reading
Rajbhandari, S., Rasley, J., Ruwase, O., He, Y. "ZeRO: Memory Optimizations Toward Training Trillion Parameter Models." SC '20. The foundational ZeRO paper.
Rajbhandari, S. et al. "ZeRO-Infinity: Breaking the GPU Memory Wall for Extreme Scale Deep Learning." SC '21. ZeRO-Infinity adds NVMe offload.
Wang, Y. et al. "ZeRO++: Extremely Efficient Collective Communication for Giant Model Training." arXiv:2306.10209 (2023). Quantized comms and hierarchical sharding.
Zhao, Y. et al. "PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel." arXiv:2304.11277 (2023). The PyTorch FSDP design paper.
Rasley, J., Rajbhandari, S., Ruwase, O., He, Y. "DeepSpeed: System Optimizations Enable Training Deep Learning Models with Over 100 Billion Parameters." KDD '20. DeepSpeed launch.
Touvron, H. et al. "Llama 2: Open Foundation and Fine-Tuned Chat Models." arXiv:2307.09288 (2023). FSDP at 70B scale.
Dubey, A. et al. "The Llama 3 Herd of Models." arXiv:2407.21783 (2024). FSDP Hybrid-Shard for 405B.
PyTorch Team. "FSDP2: Per-parameter sharding for fully sharded data parallel." PyTorch blog, 2024.